Skip to content

Devices

In FDTDX, devices are objects whose shape can be optimized. A device has a corresponding set of latent parameters, which are mapped to produce the current shape of the device.

fdtdx.Device

Bases: OrderableObject, ABC

Abstract base class for devices with optimizable permittivity distributions.

This class defines the common interface and functionality for both discrete and continuous devices that can be optimized through gradient-based methods.

Attributes:

Name Type Description
name str

Optional name identifier for the device

dtype str

Data type for device parameters, defaults to float32

color tuple[float, float, float] | None

RGB color tuple for visualization, defaults to pink

Source code in src/fdtdx/objects/device/device.py
@extended_autoinit
class Device(OrderableObject, ABC):
    """Abstract base class for devices with optimizable permittivity distributions.

    This class defines the common interface and functionality for both discrete and
    continuous devices that can be optimized through gradient-based methods.

    Attributes:
        name: Optional name identifier for the device
        dtype: Data type for device parameters, defaults to float32
        color: RGB color tuple for visualization, defaults to pink
    """

    materials: dict[str, Material] = field()
    param_transforms: Sequence[ParameterTransformation] = field()
    color: tuple[float, float, float] | None = frozen_field(default=PINK)
    partial_voxel_grid_shape: PartialGridShape3D = frozen_field(default=UNDEFINED_SHAPE_3D)
    partial_voxel_real_shape: PartialRealShape3D = frozen_field(default=UNDEFINED_SHAPE_3D)

    _single_voxel_grid_shape: GridShape3D = frozen_private_field(default=INVALID_SHAPE_3D)

    @property
    def matrix_voxel_grid_shape(self) -> GridShape3D:
        """Calculate the shape of the voxel matrix in grid coordinates.

        Returns:
            Tuple of (x,y,z) dimensions representing how many voxels fit in each direction
            of the grid shape when divided by the single voxel shape.
        """
        return (
            round(self.grid_shape[0] / self.single_voxel_grid_shape[0]),
            round(self.grid_shape[1] / self.single_voxel_grid_shape[1]),
            round(self.grid_shape[2] / self.single_voxel_grid_shape[2]),
        )

    @property
    def single_voxel_grid_shape(self) -> GridShape3D:
        """Get the shape of a single voxel in grid coordinates.

        Returns:
            Tuple of (x,y,z) dimensions for one voxel.

        Raises:
            Exception: If the object has not been initialized yet.
        """
        if self._single_voxel_grid_shape == INVALID_SHAPE_3D:
            raise Exception(f"{self} is not initialized yet")
        return self._single_voxel_grid_shape

    @property
    def single_voxel_real_shape(self) -> RealShape3D:
        """Calculate the shape of a single voxel in real (physical) coordinates.

        Returns:
            Tuple of (x,y,z) dimensions in real units, computed by multiplying
            the grid shape by the simulation resolution.
        """
        grid_shape = self.single_voxel_grid_shape
        return (
            grid_shape[0] * self._config.resolution,
            grid_shape[1] * self._config.resolution,
            grid_shape[2] * self._config.resolution,
        )

    @property
    def output_type(self) -> ParameterType:
        if not self.param_transforms:
            return ParameterType.CONTINUOUS
        out_type = self.param_transforms[-1]._output_type
        if isinstance(out_type, dict) and len(out_type) == 1:
            out_type = list(out_type.values())[0]
        if not isinstance(out_type, ParameterType):
            raise Exception(
                "Output of Parameter transformation sequence (last module) needs to be a single array, but got: "
                f"{out_type}"
            )
        return out_type

    def place_on_grid(
        self: Self,
        grid_slice_tuple: SliceTuple3D,
        config: SimulationConfig,
        key: jax.Array,
    ) -> Self:
        self = super().place_on_grid(grid_slice_tuple=grid_slice_tuple, config=config, key=key)
        # determine voxel shape
        voxel_grid_shape = []
        for axis in range(3):
            partial_grid = self.partial_voxel_grid_shape[axis]
            partial_real = self.partial_voxel_real_shape[axis]
            if partial_grid is not None and partial_real is not None:
                raise Exception(f"Multi-Material voxels overspecified in axis: {axis=}")
            if partial_grid is not None:
                voxel_grid_shape.append(partial_grid)
            elif partial_real is not None:
                voxel_grid_shape.append(round(partial_real / config.resolution))
            else:
                raise Exception(f"Multi-Material voxels not specified in axis: {axis=}")

        self = self.aset("_single_voxel_grid_shape", tuple(voxel_grid_shape))

        # sanity checks on the voxel shape
        for axis in range(3):
            float_div = is_float_divisible(
                self.single_voxel_real_shape[axis],
                self._config.resolution,
            )
            if not float_div:
                raise Exception(f"Not divisible: {self.single_voxel_real_shape[axis]=}, {self._config.resolution=}")
            if self.grid_shape[axis] % self.matrix_voxel_grid_shape[axis] != 0:
                raise Exception(
                    f"Due to discretization, matrix got skewered for {axis=}. "
                    f"{self.grid_shape=}, {self.matrix_voxel_grid_shape=}"
                )

        # init parameter transformations
        # We need to go once backwards through the transformations to determine the shape of the latent parameters
        # then we need to go forward through the transformations again to determine the parameter type of the
        # output
        new_t_list: list[ParameterTransformation] = []
        cur_shape = {"params": self.matrix_voxel_grid_shape}
        for transform in self.param_transforms[::-1]:
            t_new = transform.init_module(
                config=config,
                materials=self.materials,
                matrix_voxel_grid_shape=self.matrix_voxel_grid_shape,
                single_voxel_size=self.single_voxel_real_shape,
                output_shape=cur_shape,
            )
            new_t_list.append(t_new)
            cur_shape = t_new._input_shape

        # init shape of transformations by going backwards through new list
        module_list: list[ParameterTransformation] = []
        cur_input_type = {"params": ParameterType.CONTINUOUS}
        for transform in new_t_list[::-1]:
            t_new = transform.init_type(
                input_type=cur_input_type,
            )
            module_list.append(t_new)
            cur_input_type = t_new._output_type

        # set own input shape dtype
        self = self.aset("param_transforms", module_list)
        if self.output_type == ParameterType.CONTINUOUS and len(self.materials) != 2:
            raise Exception(
                f"Need exactly two materials in device when parameter mapping outputs continuous permittivity indices, "
                f"but got {self.materials}"
            )
        return self

    def init_params(
        self,
        key: jax.Array,
    ) -> dict[str, jax.Array] | jax.Array:
        if len(self.param_transforms) > 0:
            shapes = self.param_transforms[0]._input_shape
        else:
            shapes = self.matrix_voxel_grid_shape
        if not isinstance(shapes, dict):
            shapes = {"params": shapes}
        params = {}
        for k, cur_shape in shapes.items():
            key, subkey = jax.random.split(key)
            p = jax.random.uniform(
                key=subkey,
                shape=cur_shape,
                minval=0,  # parameter always live between 0 and 1
                maxval=1,
                dtype=jnp.float32,
            )
            params[k] = p
        if len(params) == 1:
            params = list(params.values())[0]
        return params

    def __call__(
        self,
        params: dict[str, jax.Array] | jax.Array,
        expand_to_sim_grid: bool = False,
        **transform_kwargs,
    ) -> jax.Array:
        if not isinstance(params, dict):
            params = {"params": params}
        # walk through modules
        for transform in self.param_transforms:
            check_specs(params, transform._input_shape)
            params = transform(params, **transform_kwargs)
            check_specs(params, transform._output_shape)
        if len(params) == 1:
            params = list(params.values())[0]
        else:
            raise Exception(
                "The parameter mapping should return a single array of indices. If using a continuous device, please"
                " make sure that the latent transformations abide to this rule."
            )
        if expand_to_sim_grid:
            params = expand_matrix(
                matrix=params,
                grid_points_per_voxel=self.single_voxel_grid_shape,
            )
        return params

matrix_voxel_grid_shape property

Calculate the shape of the voxel matrix in grid coordinates.

Returns:

Type Description
GridShape3D

Tuple of (x,y,z) dimensions representing how many voxels fit in each direction

GridShape3D

of the grid shape when divided by the single voxel shape.

single_voxel_grid_shape property

Get the shape of a single voxel in grid coordinates.

Returns:

Type Description
GridShape3D

Tuple of (x,y,z) dimensions for one voxel.

Raises:

Type Description
Exception

If the object has not been initialized yet.

single_voxel_real_shape property

Calculate the shape of a single voxel in real (physical) coordinates.

Returns:

Type Description
RealShape3D

Tuple of (x,y,z) dimensions in real units, computed by multiplying

RealShape3D

the grid shape by the simulation resolution.

Parameter Mapping

fdtdx.ParameterTransformation

Bases: ExtendedTreeClass, ABC

Source code in src/fdtdx/objects/device/parameters/transform.py
@extended_autoinit
class ParameterTransformation(ExtendedTreeClass, ABC):
    _input_type: dict[str, ParameterType] = frozen_private_field()
    _input_shape: dict[str, tuple[int, ...]] = frozen_private_field()
    _output_type: dict[str, ParameterType] = frozen_private_field()
    _output_shape: dict[str, tuple[int, ...]] = frozen_private_field()
    _materials: dict[str, Material] = frozen_private_field()
    _config: SimulationConfig = frozen_private_field()
    _matrix_voxel_grid_shape: tuple[int, int, int] = frozen_private_field()
    _single_voxel_size: tuple[float, float, float] = frozen_private_field()

    # settings
    _check_single_array: bool = frozen_private_field(default=False)
    _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(default=None)
    _all_arrays_2d: bool = frozen_private_field(default=False)

    def init_module(
        self: Self,
        config: SimulationConfig,
        materials: dict[str, Material],
        matrix_voxel_grid_shape: tuple[int, int, int],
        single_voxel_size: tuple[float, float, float],
        output_shape: dict[str, tuple[int, ...]],
    ) -> Self:
        self = self.aset("_config", config)
        self = self.aset("_materials", materials)
        self = self.aset("_matrix_voxel_grid_shape", matrix_voxel_grid_shape)
        self = self.aset("_single_voxel_size", single_voxel_size)

        self = self.aset("_output_shape", output_shape)
        input_shape = self.get_input_shape(output_shape)
        self = self.aset("_input_shape", input_shape)
        return self

    def init_type(
        self,
        input_type: dict[str, ParameterType],
    ) -> Self:
        # given input type
        self = self.aset("_input_type", input_type)
        # compute output type
        output_type = self.get_output_type(input_type)
        self = self.aset("_output_type", output_type)
        return self

    def get_output_type(
        self,
        input_type: dict[str, ParameterType],
    ) -> dict[str, ParameterType]:
        # checks
        if self._check_single_array and len(input_type) != 1:
            raise Exception(
                f"ParameterTransform {self.__class__} expects input to be a single array, but got: {input_type}"
            )
        if self._fixed_input_type is not None:
            for v in input_type.values():
                err_msg = (
                    f"ParameterTransform {self.__class__} expects all input types to be {self._fixed_input_type}"
                    f", but got {input_type}"
                )
                if isinstance(self._fixed_input_type, Sequence):
                    if v not in self._fixed_input_type:
                        raise Exception(err_msg)
                elif v != self._fixed_input_type:
                    raise Exception(err_msg)
        # implementation
        output_type = self._get_output_type_impl(input_type)
        return output_type

    def get_input_shape(
        self,
        output_shape: dict[str, tuple[int, ...]],
    ) -> dict[str, tuple[int, ...]]:
        # checks
        if self._all_arrays_2d:
            for v in output_shape.values():
                err_msg = (
                    f"ParameterTransform {self.__class__} expects to work with 2d arrays, so exactly one axis of the "
                    f"3d array needs to have size of 1, but got: {output_shape}"
                )
                if len(v) != 3 or 1 not in v:
                    raise Exception(err_msg)
                if sum([n != 1 for n in v]) != 2:
                    raise Exception(err_msg)

        # implementation
        input_shape = self._get_input_shape_impl(output_shape)
        return input_shape

    @abstractmethod
    def _get_input_shape_impl(
        self,
        output_shape: dict[str, tuple[int, ...]],
    ) -> dict[str, tuple[int, ...]]:
        raise NotImplementedError()

    @abstractmethod
    def _get_output_type_impl(
        self,
        input_type: dict[str, ParameterType],
    ) -> dict[str, ParameterType]:
        raise NotImplementedError()

    @abstractmethod
    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        raise NotImplementedError()

Projections

fdtdx.TanhProjection

Bases: SameShapeTypeParameterTransform

Tanh projection filter.

This needs the steepness parameter \(eta\) as a keyword-argument in apply_params

Ref: F. Wang, B. S. Lazarov, & O. Sigmund, On projection methods, convergence and robust formulations in topology optimization. Structural and Multidisciplinary Optimization, 43(6), pp. 767-784 (2011).

Source code in src/fdtdx/objects/device/parameters/projection.py
@extended_autoinit
class TanhProjection(SameShapeTypeParameterTransform):
    """
    Tanh projection filter.

    This needs the steepness parameter $\beta$ as a keyword-argument in
    apply_params

    Ref: F. Wang, B. S. Lazarov, & O. Sigmund, On projection methods,
    convergence and robust formulations in topology optimization.
    Structural and Multidisciplinary Optimization, 43(6), pp. 767-784 (2011).
    """

    projection_midpoint: float = frozen_field(default=0.5)

    _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
        default=ParameterType.CONTINUOUS
    )

    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        if "beta" not in kwargs:
            raise Exception("TanhProjection needs the beta parameter as additional keyword argument!")
        beta = kwargs["beta"]

        result = {}
        for k, v in params.items():
            result[k] = tanh_projection(v, beta, self.projection_midpoint)

        return result

fdtdx.SubpixelSmoothedProjection

Bases: SameShapeTypeParameterTransform

This function is adapted from the Meep repository: https://github.com/NanoComp/meep/blob/master/python/adjoint/filters.py

The details of this projection are described in the paper by Alec Hammond: https://arxiv.org/pdf/2503.20189

Project using subpixel smoothing, which allows for β→∞. This technique integrates out the discontinuity within the projection function, allowing the user to smoothly increase β from 0 to ∞ without losing the gradient. Effectively, a level set is created, and from this level set, first-order subpixel smoothing is applied to the interfaces (if any are present).

In order for this to work, the input array must already be smooth (e.g. by filtering).

While the original approach involves numerical quadrature, this approach performs a "trick" by assuming that the user is always infinitely projecting (β=∞). In this case, the expensive quadrature simplifies to an analytic fill-factor expression. When to use this fill factor requires some careful logic.

For one, we want to make sure that the user can indeed project at any level (not just infinity). So in these cases, we simply check if in interface is within the pixel. If not, we revert to the standard filter plus project technique.

If there is an interface, we want to make sure the derivative remains continuous both as the interface leaves the cell, and as it crosses the center. To ensure this, we need to account for the different possibilities.

Source code in src/fdtdx/objects/device/parameters/projection.py
@extended_autoinit
class SubpixelSmoothedProjection(SameShapeTypeParameterTransform):
    """
    This function is adapted from the Meep repository:
    https://github.com/NanoComp/meep/blob/master/python/adjoint/filters.py

    The details of this projection are described in the paper by Alec Hammond:
    https://arxiv.org/pdf/2503.20189

    Project using subpixel smoothing, which allows for β→∞.
    This technique integrates out the discontinuity within the projection
    function, allowing the user to smoothly increase β from 0 to ∞ without
    losing the gradient. Effectively, a level set is created, and from this
    level set, first-order subpixel smoothing is applied to the interfaces (if
    any are present).

    In order for this to work, the input array must already be smooth (e.g. by
    filtering).

    While the original approach involves numerical quadrature, this approach
    performs a "trick" by assuming that the user is always infinitely projecting
    (β=∞). In this case, the expensive quadrature simplifies to an analytic
    fill-factor expression. When to use this fill factor requires some careful
    logic.

    For one, we want to make sure that the user can indeed project at any level
    (not just infinity). So in these cases, we simply check if in interface is
    within the pixel. If not, we revert to the standard filter plus project
    technique.

    If there is an interface, we want to make sure the derivative remains
    continuous both as the interface leaves the cell, *and* as it crosses the
    center. To ensure this, we need to account for the different possibilities.
    """

    projection_midpoint: float = frozen_field(default=0.5)

    _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
        default=ParameterType.CONTINUOUS
    )
    _check_single_array: bool = frozen_private_field(default=True)
    _all_arrays_2d: bool = frozen_private_field(default=True)

    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        if "beta" not in kwargs:
            raise Exception("SubpixelSmoothedProjection needs the beta parameter as additional keyword argument!")
        beta = kwargs["beta"]

        result = {}
        for k, v in params.items():
            # shape sanity checks
            vertical_axis = v.shape.index(1)
            first_axis = 0 if vertical_axis != 0 else 1
            second_axis = 2 if vertical_axis != 2 else 1
            if self._single_voxel_size[first_axis] != self._single_voxel_size[second_axis]:
                raise Exception(
                    "SubpixelSmoothedProjection expects voxel size to be equal in "
                    f"two axes, but got {self._single_voxel_size}"
                )
            voxel_size = self._single_voxel_size[first_axis]
            v_2d = v.squeeze(vertical_axis)

            result_2d = smoothed_projection(
                v_2d,
                beta=beta,
                eta=self.projection_midpoint,
                # expects resolution as pixels / µm
                resolution=1 / (voxel_size / 1e-6),
            )
            result[k] = jnp.expand_dims(result_2d, vertical_axis)

        return result

Tranformation of latent parameters

fdtdx.StandardToPlusOneMinusOneRange

Bases: StandardToCustomRange

Maps standard [0,1] range to [-1,1] range.

Special case of StandardToCustomRange that maps to [-1,1] range. Used for symmetric value ranges around zero.

Attributes:

Name Type Description
min_value float

Fixed to -1

max_value float

Fixed to 1

Source code in src/fdtdx/objects/device/parameters/continuous.py
@extended_autoinit
class StandardToPlusOneMinusOneRange(StandardToCustomRange):
    """Maps standard [0,1] range to [-1,1] range.

    Special case of StandardToCustomRange that maps to [-1,1] range.
    Used for symmetric value ranges around zero.

    Attributes:
        min_value: Fixed to -1
        max_value: Fixed to 1
    """

    min_value: float = frozen_private_field(default=-1)
    max_value: float = frozen_private_field(default=1)

fdtdx.StandardToCustomRange

Bases: SameShapeTypeParameterTransform

Maps standard [0,1] range to custom range [min_value, max_value].

Linearly maps values from [0,1] to a custom range specified by min_value and max_value parameters.

Attributes:

Name Type Description
min_value float

Minimum value of target range

max_value float

Maximum value of target range

Source code in src/fdtdx/objects/device/parameters/continuous.py
@extended_autoinit
class StandardToCustomRange(SameShapeTypeParameterTransform):
    """Maps standard [0,1] range to custom range [min_value, max_value].

    Linearly maps values from [0,1] to a custom range specified by min_value
    and max_value parameters.

    Attributes:
        min_value: Minimum value of target range
        max_value: Maximum value of target range
    """

    min_value: float = frozen_field(default=0)
    max_value: float = frozen_field(default=1)
    _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
        default=ParameterType.CONTINUOUS
    )

    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        del kwargs
        result = {}
        for k, v in params.items():
            mapped = v * (self.max_value - self.min_value) + self.min_value
            result[k] = mapped
        return result

fdtdx.GaussianSmoothing2D

Bases: SameShapeTypeParameterTransform

Applies Gaussian smoothing to 2D parameter arrays.

This transform convolves the input with a 2D Gaussian kernel, which helps reduce noise and smooth the data.

Attributes:

Name Type Description
std_discrete int

Integer specifying the standard deviation of the Gaussian kernel in discrete units.

Source code in src/fdtdx/objects/device/parameters/continuous.py
@extended_autoinit
class GaussianSmoothing2D(SameShapeTypeParameterTransform):
    """
    Applies Gaussian smoothing to 2D parameter arrays.

    This transform convolves the input with a 2D Gaussian kernel,
    which helps reduce noise and smooth the data.

    Attributes:
        std_discrete: Integer specifying the standard deviation of the
                     Gaussian kernel in discrete units.
    """

    _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
        default=ParameterType.CONTINUOUS
    )
    _all_arrays_2d: bool = frozen_private_field(default=True)

    std_discrete: int

    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        del kwargs
        return {k: self._apply_smoothing(v) for k, v in params.items()}

    def _apply_smoothing(self, x: jax.Array) -> jax.Array:
        vertical_axis = x.shape.index(1)
        x_squeezed = x.squeeze(vertical_axis)
        # Check if the array is 2D
        if x_squeezed.ndim != 2:
            raise ValueError(f"Expected 2D array, got shape {x_squeezed.shape}")

        # Create Gaussian kernel
        kernel_size = 6 * self.std_discrete + 1  # Ensure kernel covers 3 std on each side
        kernel = self._create_gaussian_kernel(kernel_size, self.std_discrete)

        # pad array with edge values
        padding_cfg = PaddingConfig(widths=(kernel_size // 2,), modes=("edge",))
        padded_arr, orig_slice = advanced_padding(x_squeezed, padding_cfg)

        result = jax.scipy.signal.convolve(
            padded_arr,
            kernel,
            mode="same",
        )
        result = result[*orig_slice]

        # Reshape back to original dimensions
        return result.reshape(x.shape)

    def _create_gaussian_kernel(self, size: int, sigma: float) -> jax.Array:
        # Create a coordinate grid
        coords = jnp.arange(-(size // 2), size // 2 + 1)
        x, y = jnp.meshgrid(coords, coords)

        # Create the Gaussian kernel
        kernel = jnp.exp(-(x**2 + y**2) / (2 * sigma**2))

        # Normalize the kernel to sum to 1
        kernel = kernel / jnp.sum(kernel)

        return kernel

Discretizations

fdtdx.ClosestIndex

Bases: ParameterTransformation

Maps continuous latent values to nearest allowed material indices.

For each input value, finds the index of the closest allowed inverse permittivity value. Uses straight-through gradient estimation to maintain differentiability. If mapping_from_inverse_permittivities is set to False (default), then the transform only quantizes the latent parameters to the closest integer value.

Source code in src/fdtdx/objects/device/parameters/discretization.py
@extended_autoinit
class ClosestIndex(ParameterTransformation):
    """
    Maps continuous latent values to nearest allowed material indices.

    For each input value, finds the index of the closest allowed inverse
    permittivity value. Uses straight-through gradient estimation to maintain
    differentiability. If mapping_from_inverse_permittivities is set to False (default),
    then the transform only quantizes the latent parameters to the closest integer value.
    """

    mapping_from_inverse_permittivities: bool = frozen_field(default=False)
    _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
        default=ParameterType.CONTINUOUS
    )

    def _get_input_shape_impl(
        self,
        output_shape: dict[str, tuple[int, ...]],
    ) -> dict[str, tuple[int, ...]]:
        return output_shape

    def _get_output_type_impl(
        self,
        input_type: dict[str, ParameterType],
    ) -> dict[str, ParameterType]:
        if len(self._materials) <= 1:
            raise Exception(f"Invalid materials (need two or more): {self._materials}")
        elif len(self._materials) == 2:
            output_type = ParameterType.BINARY
        else:
            output_type = ParameterType.DISCRETE
        result = {k: output_type for k in input_type.keys()}
        return result

    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        del kwargs

        def transform_arr(arr: jax.Array) -> jax.Array:
            if self.mapping_from_inverse_permittivities:
                allowed_inv_perms = 1 / jnp.asarray(compute_allowed_permittivities(self._materials))
                dist = jnp.abs(arr[..., None] - allowed_inv_perms)
                discrete = jnp.argmin(dist, axis=-1)
            else:
                discrete = jnp.clip(jnp.round(arr), 0, len(self._materials) - 1)
            return straight_through_estimator(arr, discrete)

        result = {}
        for k, v in params.items():
            result[k] = transform_arr(v)
        return result

fdtdx.PillarDiscretization

Bases: ParameterTransformation

Constraint module for mapping pillar structures to allowed configurations.

Maps arbitrary pillar structures to the nearest allowed configurations based on material constraints and geometry requirements. Ensures structures meet fabrication rules like single polymer columns and no trapped air holes.

Attributes:

Name Type Description
axis int

Axis along which to enforce pillar constraints (0=x, 1=y, 2=z).

single_polymer_columns bool

If True, restrict to single polymer columns.

distance_metric Literal['euclidean', 'permittivity_differences_plus_average_permittivity']

Method to compute distances between material distributions: - "euclidean": Standard Euclidean distance between permittivity values - "permittivity_differences_plus_average_permittivity": Weighted combination of permittivity differences and average permittivity values, optimized for material distribution comparisons

Source code in src/fdtdx/objects/device/parameters/discretization.py
@extended_autoinit
class PillarDiscretization(ParameterTransformation):
    """Constraint module for mapping pillar structures to allowed configurations.

    Maps arbitrary pillar structures to the nearest allowed configurations based on
    material constraints and geometry requirements. Ensures structures meet fabrication
    rules like single polymer columns and no trapped air holes.

    Attributes:
        axis: Axis along which to enforce pillar constraints (0=x, 1=y, 2=z).
        single_polymer_columns: If True, restrict to single polymer columns.
        distance_metric: Method to compute distances between material distributions:
            - "euclidean": Standard Euclidean distance between permittivity values
            - "permittivity_differences_plus_average_permittivity": Weighted combination
              of permittivity differences and average permittivity values, optimized
              for material distribution comparisons
    """

    axis: int = frozen_field()
    single_polymer_columns: bool = frozen_field()
    distance_metric: Literal["euclidean", "permittivity_differences_plus_average_permittivity"] = frozen_field(
        default="permittivity_differences_plus_average_permittivity",
    )
    background_material: str | None = frozen_field(default=None)
    _allowed_indices: jax.Array = frozen_private_field()

    _check_single_array: bool = frozen_private_field(default=True)
    _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
        default=ParameterType.CONTINUOUS
    )

    def _get_input_shape_impl(
        self,
        output_shape: dict[str, tuple[int, ...]],
    ) -> dict[str, tuple[int, ...]]:
        return output_shape

    def _get_output_type_impl(
        self,
        input_type: dict[str, ParameterType],
    ) -> dict[str, ParameterType]:
        if len(self._materials) <= 1:
            raise Exception(f"Invalid materials (need two or more): {self._materials}")
        elif len(self._materials) == 2:
            output_type = ParameterType.BINARY
        else:
            output_type = ParameterType.DISCRETE
        return {k: output_type for k in input_type.keys()}

    def init_module(
        self: Self,
        config: SimulationConfig,
        materials: dict[str, Material],
        matrix_voxel_grid_shape: tuple[int, int, int],
        single_voxel_size: tuple[float, float, float],
        output_shape: dict[str, tuple[int, ...]],
    ) -> Self:
        self = super().init_module(
            config=config,
            materials=materials,
            matrix_voxel_grid_shape=matrix_voxel_grid_shape,
            single_voxel_size=single_voxel_size,
            output_shape=output_shape,
        )

        if self.background_material is None:
            background_name = get_background_material_name(self._materials)
        else:
            background_name = self.background_material
        ordered_name_list = compute_ordered_names(self._materials)
        background_idx = ordered_name_list.index(background_name)

        allowed_columns = compute_allowed_indices(
            num_layers=matrix_voxel_grid_shape[self.axis],
            indices=list(range(len(materials))),
            fill_holes_with_index=[background_idx],
            single_polymer_columns=self.single_polymer_columns,
        )
        self = self.aset("_allowed_indices", allowed_columns)
        logger.info(f"{allowed_columns=}")
        logger.info(f"{allowed_columns.shape=}")
        return self

    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        del kwargs

        single_key = list(params.keys())[0]
        params_arr = params[single_key]

        allowed_inv_perms = 1 / jnp.asarray(compute_allowed_permittivities(self._materials))
        nearest_allowed_index = nearest_index(
            values=params_arr,
            allowed_values=allowed_inv_perms,
            axis=self.axis,
            distance_metric=self.distance_metric,
            allowed_indices=self._allowed_indices,
            return_distances=False,
        )
        result_index = self._allowed_indices[nearest_allowed_index]
        if self.axis == 2:
            pass  # no transposition needed
        elif self.axis == 1:
            result_index = jnp.transpose(result_index, axes=(0, 2, 1))
        elif self.axis == 0:
            result_index = jnp.transpose(result_index, axes=(2, 0, 1))
        else:
            raise Exception(f"invalid axis: {self.axis}")
        result_index = straight_through_estimator(params_arr, result_index)
        return {single_key: result_index}

fdtdx.BrushConstraint2D

Bases: ParameterTransformation

Applies 2D brush-based constraints to ensure minimum feature sizes.

Implements the brush-based constraint method described in: https://pubs.acs.org/doi/10.1021/acsphotonics.2c00313

This ensures minimum feature sizes and connectivity in 2D designs by using morphological operations with a brush kernel.

Attributes:

Name Type Description
brush Array

JAX array defining the brush kernel for morphological operations.

axis int

Axis along which to apply the 2D constraint (perpendicular plane).

Source code in src/fdtdx/objects/device/parameters/discretization.py
@extended_autoinit
class BrushConstraint2D(ParameterTransformation):
    """Applies 2D brush-based constraints to ensure minimum feature sizes.

    Implements the brush-based constraint method described in:
    https://pubs.acs.org/doi/10.1021/acsphotonics.2c00313

    This ensures minimum feature sizes and connectivity in 2D designs by using
    morphological operations with a brush kernel.

    Attributes:
        brush: JAX array defining the brush kernel for morphological operations.
        axis: Axis along which to apply the 2D constraint (perpendicular plane).
    """

    brush: jax.Array = frozen_field()
    axis: int = frozen_field()
    background_material: str | None = frozen_field(default=None)

    _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
        default=ParameterType.CONTINUOUS
    )
    _check_single_array: bool = frozen_private_field(default=True)
    _all_arrays_2d: bool = frozen_private_field(default=True)

    def _get_input_shape_impl(
        self,
        output_shape: dict[str, tuple[int, ...]],
    ) -> dict[str, tuple[int, ...]]:
        return output_shape

    def _get_output_type_impl(
        self,
        input_type: dict[str, ParameterType],
    ) -> dict[str, ParameterType]:
        if len(self._materials) != 2:
            raise Exception(
                f"BrushConstraint2D currently only implemented for exactly two materials, but got {self._materials}"
            )
        return {k: ParameterType.BINARY for k in input_type.keys()}

    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        del kwargs

        single_key = list(params.keys())[0]
        param_arr = params[single_key]
        s = param_arr.shape
        if s[self.axis] != 1:
            raise Exception(f"BrushConstraint2D Generator needs array size 1 in axis, but got {s=}")
        arr_2d = jnp.take(
            param_arr,
            jnp.asarray(0),
            axis=self.axis,
        )

        if self.background_material is None:
            background_name = get_background_material_name(self._materials)
        else:
            background_name = self.background_material

        ordered_name_list = compute_ordered_names(self._materials)
        background_idx = ordered_name_list.index(background_name)
        if background_idx != 0:
            arr_2d = -arr_2d
        cur_result = self._generator(arr_2d)
        if background_idx != 0:
            cur_result = 1 - cur_result

        cur_result = jnp.expand_dims(cur_result, axis=self.axis)
        result = straight_through_estimator(param_arr, cur_result)
        return {single_key: result}

    def _generator(
        self,
        arr: jax.Array,
    ) -> jax.Array:
        touches_void = jnp.zeros_like(arr, dtype=jnp.bool)
        touches_solid = jnp.zeros_like(touches_void)

        def cond_fn(arrs):
            touch_v, touch_s = arrs[0], arrs[1]
            pixel_existing_solid = dilate_jax(touch_s, self.brush)
            pixel_existing_void = dilate_jax(touch_v, self.brush)
            return ~jnp.all(pixel_existing_solid | pixel_existing_void)

        def body_fn(sv_arrs: tuple[jax.Array, jax.Array]):
            # see Algorithm 1 in paper
            touch_v, touch_s = sv_arrs[0], sv_arrs[1]
            # compute touches and pixel arrays
            pixel_existing_solid = dilate_jax(touch_s, self.brush)
            pixel_existing_void = dilate_jax(touch_v, self.brush)
            touch_impossible_solid = dilate_jax(pixel_existing_void, self.brush)
            touch_impossible_void = dilate_jax(pixel_existing_solid, self.brush)
            touch_valid_solid = ~touch_impossible_solid & ~touch_s
            touch_valid_void = ~touch_impossible_void & ~touch_v
            pixel_possible_solid = dilate_jax(touch_s | touch_valid_solid, self.brush)
            pixel_possible_void = dilate_jax(touch_v | touch_valid_void, self.brush)
            pixel_required_solid = ~pixel_existing_solid & ~pixel_possible_void
            pixel_required_void = ~pixel_existing_void & ~pixel_possible_solid
            touch_resolving_solid = dilate_jax(pixel_required_solid, self.brush) & touch_valid_solid
            touch_resolving_void = dilate_jax(pixel_required_void, self.brush) & touch_valid_void
            touch_free_solid = ~dilate_jax(pixel_possible_void | pixel_existing_void, self.brush) & touch_valid_solid
            touch_free_void = ~dilate_jax(pixel_possible_solid | pixel_existing_solid, self.brush) & touch_valid_void

            # case 1
            def select_all_free_touches():
                new_v = touch_v | touch_free_void
                new_s = touch_s | touch_free_solid
                return new_v, new_s

            # case 2
            def select_best_resolving_touch():
                values_solid = jnp.where(touch_resolving_solid, arr, -jnp.inf)
                values_void = jnp.where(touch_resolving_void, -arr, -jnp.inf)

                def select_void():
                    max_idx = jnp.argmax(values_void)
                    new_v = touch_v.flatten().at[max_idx].set(True).reshape(touch_s.shape)
                    return new_v, touch_s

                def select_solid():
                    max_idx = jnp.argmax(values_solid)
                    new_s = touch_s.flatten().at[max_idx].set(True).reshape(touch_v.shape)
                    return touch_v, new_s

                return jax.lax.cond(
                    jnp.max(values_solid) > jnp.max(values_void),
                    select_solid,
                    select_void,
                )

            # case 3
            def select_best_valid_touch():
                values_solid = jnp.where(touch_valid_solid, arr, -jnp.inf)
                values_void = jnp.where(touch_valid_void, -arr, -jnp.inf)

                def select_void():
                    max_idx = jnp.argmax(values_void)
                    new_v = touch_v.flatten().at[max_idx].set(True).reshape(touch_s.shape)
                    return new_v, touch_s

                def select_solid():
                    max_idx = jnp.argmax(values_solid)
                    new_s = touch_s.flatten().at[max_idx].set(True).reshape(touch_v.shape)
                    return touch_v, new_s

                return jax.lax.cond(
                    jnp.max(values_solid) > jnp.max(values_void),
                    select_solid,
                    select_void,
                )

            # case 2 and 3
            def case_2_and_3_function():
                resolving_exists = jnp.any(touch_resolving_solid | touch_resolving_void)

                return jax.lax.cond(
                    resolving_exists,
                    select_best_resolving_touch,
                    select_best_valid_touch,
                )

            free_touches_exist = jnp.any(touch_free_solid | touch_free_void)
            new_v, new_s = jax.lax.cond(
                free_touches_exist,
                select_all_free_touches,
                case_2_and_3_function,
            )
            return new_v, new_s

        arrs = (touches_void, touches_solid)

        res_arrs = eqxi.while_loop(
            cond_fun=cond_fn,
            body_fun=body_fn,
            init_val=arrs,
            kind="lax",
        )
        pixel_existing_solid = dilate_jax(res_arrs[1], self.brush)
        return pixel_existing_solid

fdtdx.circular_brush(diameter, size=None)

Creates a circular binary mask/brush for morphological operations.

Parameters:

Name Type Description Default
diameter float

Diameter of the circle in grid units.

required
size int | None

Optional size of the output array. If None, uses ceil(diameter) rounded up to next odd number.

None

Returns:

Type Description
Array

Binary JAX array containing a circular mask where True indicates points

Array

within the circle diameter.

Source code in src/fdtdx/objects/device/parameters/discretization.py
def circular_brush(
    diameter: float,
    size: int | None = None,
) -> jax.Array:
    """Creates a circular binary mask/brush for morphological operations.

    Args:
        diameter: Diameter of the circle in grid units.
        size: Optional size of the output array. If None, uses ceil(diameter) rounded
            up to next odd number.

    Returns:
        Binary JAX array containing a circular mask where True indicates points
        within the circle diameter.
    """
    if size is None:
        s = math.ceil(diameter)
        if s % 2 == 0:
            s += 1
        size = s
    xy = jnp.stack(jnp.meshgrid(*map(jnp.arange, (size, size)), indexing="xy"), axis=-1) - jnp.asarray((size / 2) - 0.5)
    euc_dist = jnp.sqrt((xy**2).sum(axis=-1))
    # the less EQUAL here is important, because otherwise design may be infeasible due to discretization errors
    mask = euc_dist <= (diameter / 2)
    return mask

Discrete PostProcessing

fdtdx.BinaryMedianFilterModule

Bases: SameShapeTypeParameterTransform

Performs 3D binary median filtering on the design.

Applies a 3D median filter to smooth and clean up binary material distributions. This helps remove small features and noise while preserving larger structures.

Attributes:

Name Type Description
padding_cfg PaddingConfig

Configuration for padding behavior at boundaries.

kernel_sizes tuple[int, int, int]

3-tuple of kernel sizes for each dimension.

num_repeats int

Number of times to apply the filter consecutively.

Source code in src/fdtdx/objects/device/parameters/discrete.py
@extended_autoinit
class BinaryMedianFilterModule(SameShapeTypeParameterTransform):
    """Performs 3D binary median filtering on the design.

    Applies a 3D median filter to smooth and clean up binary material distributions.
    This helps remove small features and noise while preserving larger structures.

    Attributes:
        padding_cfg: Configuration for padding behavior at boundaries.
        kernel_sizes: 3-tuple of kernel sizes for each dimension.
        num_repeats: Number of times to apply the filter consecutively.
    """

    padding_cfg: PaddingConfig = frozen_field()
    kernel_sizes: tuple[int, int, int] = frozen_field()
    num_repeats: int = frozen_field(default=1)

    _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
        default=ParameterType.BINARY,
    )
    _check_single_array: bool = frozen_private_field(default=True)

    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        del kwargs
        single_key = list(params.keys())[0]
        param_arr = params[single_key]
        cur_arr = param_arr
        for _ in range(self.num_repeats):
            cur_arr = binary_median_filter(
                arr_3d=cur_arr,
                kernel_sizes=self.kernel_sizes,
                padding_cfg=self.padding_cfg,
            )
        result = straight_through_estimator(param_arr, cur_arr)
        return {single_key: result}

fdtdx.ConnectHolesAndStructures

Bases: SameShapeTypeParameterTransform

Connects floating polymer regions and ensures air holes connect to outside.

This constraint module ensures physical realizability of designs by: 1. Either connecting floating polymer regions to the substrate or removing them 2. Ensuring all air holes are connected to the outside (no trapped air)

The bottom (lower z) is treated as the substrate reference.

Attributes:

Name Type Description
fill_material str | None

Name of material to use for filling gaps when connecting regions. Required when working with more than 2 materials.

Source code in src/fdtdx/objects/device/parameters/discrete.py
@extended_autoinit
class ConnectHolesAndStructures(SameShapeTypeParameterTransform):
    """Connects floating polymer regions and ensures air holes connect to outside.

    This constraint module ensures physical realizability of designs by:
    1. Either connecting floating polymer regions to the substrate or removing them
    2. Ensuring all air holes are connected to the outside (no trapped air)

    The bottom (lower z) is treated as the substrate reference.

    Attributes:
        fill_material: Name of material to use for filling gaps when connecting regions.
            Required when working with more than 2 materials.
    """

    fill_material: str | None = frozen_field(default=None)
    background_material: str | None = frozen_field(default=None)
    _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
        default=(ParameterType.DISCRETE, ParameterType.BINARY),
    )
    _check_single_array: bool = frozen_private_field(default=True)

    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        del kwargs
        if len(self._materials) > 2 and self.fill_material is None:
            raise Exception(
                "ConnectHolesAndStructures: Need to specify fill_material when working with more than two materials"
            )
        if self.background_material is None:
            background_name = get_background_material_name(self._materials)
        else:
            background_name = self.background_material
        ordered_name_list = compute_ordered_names(self._materials)
        background_idx = ordered_name_list.index(background_name)

        single_key = list(params.keys())[0]
        param_arr = params[single_key]
        is_material_matrix = param_arr != background_idx
        feasible_material_matrix = connect_holes_and_structures(is_material_matrix)

        result = jnp.empty_like(param_arr)
        # set air
        result = jnp.where(
            feasible_material_matrix,
            -1,  # this is set below
            background_idx,
        )
        # material where previously was material
        result = jnp.where(feasible_material_matrix & is_material_matrix, param_arr, result)

        # material, where previously was background material (air)
        fill_name = self.fill_material
        if fill_name is None:
            fill_name = ordered_name_list[1 - background_idx]
        fill_idx = ordered_name_list.index(fill_name)
        result = jnp.where(
            feasible_material_matrix & ~is_material_matrix,
            fill_idx,
            result,
        )
        result = straight_through_estimator(param_arr, result)
        return {single_key: result}

fdtdx.RemoveFloatingMaterial

Bases: SameShapeTypeParameterTransform

Finds all material that floats in the air and sets their permittivity to air.

This constraint module identifies regions of material that are not connected to any substrate or boundary and converts them to air. This helps ensure physically realizable designs by eliminating floating/disconnected material regions.

The module only works with binary material systems (2 permittivities) where one material represents air.

Source code in src/fdtdx/objects/device/parameters/discrete.py
@extended_autoinit
class RemoveFloatingMaterial(SameShapeTypeParameterTransform):
    """Finds all material that floats in the air and sets their permittivity to air.

    This constraint module identifies regions of material that are not connected to any
    substrate or boundary and converts them to air. This helps ensure physically
    realizable designs by eliminating floating/disconnected material regions.

    The module only works with binary material systems (2 permittivities) where one
    material represents air.
    """

    _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
        default=(ParameterType.DISCRETE, ParameterType.BINARY),
    )
    _check_single_array: bool = frozen_private_field(default=True)

    background_material: str | None = frozen_field(default=None)

    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        del kwargs
        if self.background_material is None:
            background_name = get_background_material_name(self._materials)
        else:
            background_name = self.background_material
        ordered_name_list = compute_ordered_names(self._materials)
        background_idx = ordered_name_list.index(background_name)

        single_key = list(params.keys())[0]
        param_arr = params[single_key]
        is_material_matrix = param_arr != background_idx
        is_material_after_removal = remove_floating_polymer(is_material_matrix)
        result = (1 - background_idx) * is_material_after_removal + background_idx * ~is_material_after_removal
        result = straight_through_estimator(param_arr, result)
        return {single_key: result}

fdtdx.BinaryMedianFilterModule

Bases: SameShapeTypeParameterTransform

Performs 3D binary median filtering on the design.

Applies a 3D median filter to smooth and clean up binary material distributions. This helps remove small features and noise while preserving larger structures.

Attributes:

Name Type Description
padding_cfg PaddingConfig

Configuration for padding behavior at boundaries.

kernel_sizes tuple[int, int, int]

3-tuple of kernel sizes for each dimension.

num_repeats int

Number of times to apply the filter consecutively.

Source code in src/fdtdx/objects/device/parameters/discrete.py
@extended_autoinit
class BinaryMedianFilterModule(SameShapeTypeParameterTransform):
    """Performs 3D binary median filtering on the design.

    Applies a 3D median filter to smooth and clean up binary material distributions.
    This helps remove small features and noise while preserving larger structures.

    Attributes:
        padding_cfg: Configuration for padding behavior at boundaries.
        kernel_sizes: 3-tuple of kernel sizes for each dimension.
        num_repeats: Number of times to apply the filter consecutively.
    """

    padding_cfg: PaddingConfig = frozen_field()
    kernel_sizes: tuple[int, int, int] = frozen_field()
    num_repeats: int = frozen_field(default=1)

    _fixed_input_type: ParameterType | Sequence[ParameterType] | None = frozen_private_field(
        default=ParameterType.BINARY,
    )
    _check_single_array: bool = frozen_private_field(default=True)

    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        del kwargs
        single_key = list(params.keys())[0]
        param_arr = params[single_key]
        cur_arr = param_arr
        for _ in range(self.num_repeats):
            cur_arr = binary_median_filter(
                arr_3d=cur_arr,
                kernel_sizes=self.kernel_sizes,
                padding_cfg=self.padding_cfg,
            )
        result = straight_through_estimator(param_arr, cur_arr)
        return {single_key: result}

Symmetries

fdtdx.DiagonalSymmetry2D

Bases: SameShapeTypeParameterTransform

Enforce symmetries by effectively havling the parameter space. The symmetry is transposing by rotating the image and taking the mean of original and transpose. Attributes: min_min_to_max_max: if true, the symmetry axes is from (x_min, y_min) to (x_max, y_max). If false, the other diagonal is used.

Source code in src/fdtdx/objects/device/parameters/symmetries.py
@extended_autoinit
class DiagonalSymmetry2D(SameShapeTypeParameterTransform):
    """
    Enforce symmetries by effectively havling the parameter space. The symmetry is transposing by rotating the image
    and taking the mean of original and transpose.
    Attributes:
        min_min_to_max_max: if true, the symmetry axes is from (x_min, y_min) to (x_max, y_max). If false, the other
            diagonal is used.
    """

    min_min_to_max_max: bool = frozen_field()

    _all_arrays_2d: bool = frozen_private_field(default=True)

    def __call__(
        self,
        params: dict[str, jax.Array],
        **kwargs,
    ) -> dict[str, jax.Array]:
        del kwargs
        result = {}
        for k, v in params.items():
            # convert to 2d
            vertical_axis = v.shape.index(1)
            v_2d = v.squeeze(vertical_axis)

            # enforce symmetry
            if self.min_min_to_max_max:
                other = v_2d.T
            else:
                other = v_2d[::-1, ::-1].T
            cur_mean = (v_2d + other) / 2

            # expand dims again
            result[k] = jnp.expand_dims(cur_mean, vertical_axis)
        return result