JAX Utilities¶
Shape and Type Checking¶
fdtdx.core.jax.utils.check_shape_dtype(partial_value_dict, shape_dtype_dict)
¶
Validates that arrays match their expected shapes and dtypes.
Checks each array in partial_value_dict against its corresponding shape and dtype specification in shape_dtype_dict. This is useful for validating that arrays match their expected specifications before using them in computations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
partial_value_dict
|
dict[str, Array]
|
Dictionary mapping names to JAX arrays to validate |
required |
shape_dtype_dict
|
dict[str, ShapeDtypeStruct]
|
Dictionary mapping names to ShapeDtypeStruct objects specifying the expected shape and dtype for each array |
required |
Raises:
Type | Description |
---|---|
Exception
|
If any array's shape or dtype doesn't match its specification in shape_dtype_dict. The error message indicates which array failed and how its shape/dtype differed from expected. |
Example
shapes = {"x": jax.ShapeDtypeStruct((2,3), jnp.float32)} arrays = {"x": jnp.zeros((2,3), dtype=jnp.float32)} check_shape_dtype(arrays, shapes) # Passes bad = {"x": jnp.zeros((3,2))} # Wrong shape check_shape_dtype(bad, shapes) # Raises Exception
Source code in src/fdtdx/core/jax/utils.py
Utility Classes¶
fdtdx.core.misc.PaddingConfig
¶
Bases: ExtendedTreeClass
Padding configuration. The order is: minx, maxx, miny, maxy, minz, maxz, ... or just single value that can be used for all
Source code in src/fdtdx/core/misc.py
Gradient Estimators¶
fdtdx.core.jax.ste.straight_through_estimator(x, y)
¶
Straight Through Estimator for gradient estimation with discrete variables.
This function applies the straight through estimator (STE) by taking the gradient
with respect to the continuous input x
, while using the discrete values y
in the forward pass. STE is commonly used in training neural networks with
discrete/quantized values where standard backpropagation is not possible.
The implementation uses JAX's stop_gradient to control gradient flow: output = x - stop_gradient(x) + stop_gradient(y)
This ensures during the forward pass we use y, but during backprop the gradient flows through x.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Array
|
jax.Array, the original continuous values before quantization/discretization. Gradients will be computed with respect to these values. |
required |
y
|
Array
|
jax.Array, the discrete/quantized values used in the forward pass. Must have the same shape as x. |
required |
Returns:
Type | Description |
---|---|
Array
|
jax.Array: The result of applying the straight through estimator, which |
Array
|
is the same shape as |
Array
|
but gradients flow through x. |
Raises:
Type | Description |
---|---|
ValueError
|
If x and y have different shapes. |