Skip to content

JAX Utilities

Shape and Type Checking

fdtdx.core.jax.utils.check_shape_dtype(partial_value_dict, shape_dtype_dict)

Validates that arrays match their expected shapes and dtypes.

Checks each array in partial_value_dict against its corresponding shape and dtype specification in shape_dtype_dict. This is useful for validating that arrays match their expected specifications before using them in computations.

Parameters:

Name Type Description Default
partial_value_dict dict[str, Array]

Dictionary mapping names to JAX arrays to validate

required
shape_dtype_dict dict[str, ShapeDtypeStruct]

Dictionary mapping names to ShapeDtypeStruct objects specifying the expected shape and dtype for each array

required

Raises:

Type Description
Exception

If any array's shape or dtype doesn't match its specification in shape_dtype_dict. The error message indicates which array failed and how its shape/dtype differed from expected.

Example

shapes = {"x": jax.ShapeDtypeStruct((2,3), jnp.float32)} arrays = {"x": jnp.zeros((2,3), dtype=jnp.float32)} check_shape_dtype(arrays, shapes) # Passes bad = {"x": jnp.zeros((3,2))} # Wrong shape check_shape_dtype(bad, shapes) # Raises Exception

Source code in src/fdtdx/core/jax/utils.py
def check_shape_dtype(
    partial_value_dict: dict[str, jax.Array],
    shape_dtype_dict: dict[str, jax.ShapeDtypeStruct],
):
    """Validates that arrays match their expected shapes and dtypes.

    Checks each array in partial_value_dict against its corresponding shape and dtype
    specification in shape_dtype_dict. This is useful for validating that arrays
    match their expected specifications before using them in computations.

    Args:
        partial_value_dict: Dictionary mapping names to JAX arrays to validate
        shape_dtype_dict: Dictionary mapping names to ShapeDtypeStruct objects
            specifying the expected shape and dtype for each array

    Raises:
        Exception: If any array's shape or dtype doesn't match its specification
            in shape_dtype_dict. The error message indicates which array failed
            and how its shape/dtype differed from expected.

    Example:
        >>> shapes = {"x": jax.ShapeDtypeStruct((2,3), jnp.float32)}
        >>> arrays = {"x": jnp.zeros((2,3), dtype=jnp.float32)}
        >>> check_shape_dtype(arrays, shapes)  # Passes
        >>> bad = {"x": jnp.zeros((3,2))}  # Wrong shape
        >>> check_shape_dtype(bad, shapes)  # Raises Exception
    """
    for k, arr in partial_value_dict.items():
        shape_dtype = shape_dtype_dict[k]
        if arr.dtype != shape_dtype.dtype:
            raise Exception(f"Wrong dtype: {shape_dtype.dtype} != {arr.dtype}")
        if arr.shape != shape_dtype.shape:
            raise Exception(f"Wrong shape: {shape_dtype.shape} != {arr.shape}")

Utility Classes

fdtdx.core.misc.PaddingConfig

Bases: ExtendedTreeClass

Padding configuration. The order is: minx, maxx, miny, maxy, minz, maxz, ... or just single value that can be used for all

Source code in src/fdtdx/core/misc.py
@extended_autoinit
class PaddingConfig(ExtendedTreeClass):
    """
    Padding configuration. The order is:
    minx, maxx, miny, maxy, minz, maxz, ...
    or just single value that can be used for all
    """

    widths: Sequence[int] = frozen_field()
    modes: Sequence[str] = frozen_field()
    values: Sequence[float] = frozen_field(
        default=None,  # type: ignore
    )

Gradient Estimators

fdtdx.core.jax.ste.straight_through_estimator(x, y)

Straight Through Estimator for gradient estimation with discrete variables.

This function applies the straight through estimator (STE) by taking the gradient with respect to the continuous input x, while using the discrete values y in the forward pass. STE is commonly used in training neural networks with discrete/quantized values where standard backpropagation is not possible.

The implementation uses JAX's stop_gradient to control gradient flow: output = x - stop_gradient(x) + stop_gradient(y)

This ensures during the forward pass we use y, but during backprop the gradient flows through x.

Parameters:

Name Type Description Default
x Array

jax.Array, the original continuous values before quantization/discretization. Gradients will be computed with respect to these values.

required
y Array

jax.Array, the discrete/quantized values used in the forward pass. Must have the same shape as x.

required

Returns:

Type Description
Array

jax.Array: The result of applying the straight through estimator, which

Array

is the same shape as x and y. In the forward pass this equals y,

Array

but gradients flow through x.

Raises:

Type Description
ValueError

If x and y have different shapes.

Source code in src/fdtdx/core/jax/ste.py
def straight_through_estimator(x: jax.Array, y: jax.Array) -> jax.Array:
    """Straight Through Estimator for gradient estimation with discrete variables.

    This function applies the straight through estimator (STE) by taking the gradient
    with respect to the continuous input `x`, while using the discrete values `y`
    in the forward pass. STE is commonly used in training neural networks with
    discrete/quantized values where standard backpropagation is not possible.

    The implementation uses JAX's stop_gradient to control gradient flow:
        output = x - stop_gradient(x) + stop_gradient(y)

    This ensures during the forward pass we use y, but during backprop the
    gradient flows through x.

    Args:
        x: jax.Array, the original continuous values before quantization/discretization.
            Gradients will be computed with respect to these values.
        y: jax.Array, the discrete/quantized values used in the forward pass.
            Must have the same shape as x.

    Returns:
        jax.Array: The result of applying the straight through estimator, which
        is the same shape as `x` and `y`. In the forward pass this equals y,
        but gradients flow through x.

    Raises:
        ValueError: If x and y have different shapes.
    """

    return x - jax.lax.stop_gradient(x) + jax.lax.stop_gradient(y)