Skip to content

Devices

In FDTDX, devices are objects whose shape can be optimized. A device has a corresponding set of latent parameters, which are mapped to produce the current shape of the device.

fdtdx.objects.device.DiscreteDevice

Bases: BaseDevice

A device with discrete material states.

This class represents a simulation object whose permittivity distribution can be optimized through gradient-based methods, with discrete transitions between materials. The permittivity values are controlled by parameters that are mapped through constraints to produce the final device structure.

Attributes:

Name Type Description
name

Optional name identifier for the device

constraint_mapping

Maps optimization parameters to permittivity values

dtype

Data type for device parameters, defaults to float32

color

RGB color tuple for visualization, defaults to pink

Source code in src/fdtdx/objects/device/device.py
@extended_autoinit
class DiscreteDevice(BaseDevice):
    """A device with discrete material states.

    This class represents a simulation object whose permittivity distribution can be
    optimized through gradient-based methods, with discrete transitions between materials.
    The permittivity values are controlled by parameters that are mapped through constraints
    to produce the final device structure.

    Attributes:
        name: Optional name identifier for the device
        constraint_mapping: Maps optimization parameters to permittivity values
        dtype: Data type for device parameters, defaults to float32
        color: RGB color tuple for visualization, defaults to pink
    """

    material: dict[str, Material] = frozen_field(kind="KW_ONLY")  # type: ignore
    parameter_mapping: DiscreteParameterMapping = frozen_field(kind="KW_ONLY")  # type:ignore

    def place_on_grid(
        self: Self,
        grid_slice_tuple: SliceTuple3D,
        config: SimulationConfig,
        key: jax.Array,
    ) -> Self:
        self = super().place_on_grid(
            grid_slice_tuple=grid_slice_tuple,
            config=config,
            key=key,
        )
        mapping = self.parameter_mapping.init_modules(
            config=config,
            material=self.material,
            output_shape_dtype=jax.ShapeDtypeStruct(
                shape=self.matrix_voxel_grid_shape,
                dtype=jnp.int32,
            ),
        )
        self = self.aset("parameter_mapping", mapping)
        return self

Parameter Mapping

fdtdx.objects.device.DiscreteParameterMapping

Bases: LatentParameterMapping

Source code in src/fdtdx/objects/device/parameters/mapping.py
@extended_autoinit
class DiscreteParameterMapping(LatentParameterMapping):
    discretization: Discretization = frozen_field(kind="KW_ONLY", default=ClosestIndex())
    post_transforms: Sequence[DiscreteTransformation] = frozen_field(default=tuple([]), kind="KW_ONLY")

    def init_modules(
        self: Self,
        config: SimulationConfig,
        material: dict[str, Material],
        output_shape_dtype: jax.ShapeDtypeStruct,
    ) -> Self:
        new_post_transforms = []
        for t in self.post_transforms:
            cur_transform = t.init_module(
                config=config,
                material=material,
            )
            new_post_transforms.append(cur_transform)
        new_discretization = self.discretization.init_module(
            config=config,
            material=material,
            output_shape_dtype=output_shape_dtype,
        )

        new_latent_transforms = []
        cur_output_shape_dtype = new_discretization._input_shape_dtypes
        for m in self.latent_transforms:
            m_new = m.init_module(
                config=config,
                material=material,
                output_shape_dtypes=cur_output_shape_dtype,
            )
            new_latent_transforms.append(m_new)
            cur_output_shape_dtype = m_new._input_shape_dtypes
        self = super().init_modules(
            config=config,
            material=material,
            output_shape_dtypes=cur_output_shape_dtype,
        )
        self = self.aset("latent_transforms", new_latent_transforms)
        self = self.aset("discretization", new_discretization)
        self = self.aset("post_transforms", new_post_transforms)
        return self

    def __call__(
        self,
        input_params: dict[str, jax.Array] | jax.Array,
    ) -> jax.Array:
        latent = super().__call__(
            input_params=input_params,
        )
        discretized = self.discretization(latent)
        cur_arr = discretized
        for transform in self.post_transforms:
            cur_arr = transform(cur_arr)
        return cur_arr

Tranformation of latent parameters

fdtdx.objects.device.StandardToInversePermittivityRange

Bases: SameShapeDtypeLatentTransform

Maps standard [0,1] range to inverse permittivity range.

Linearly maps values from [0,1] to the range between minimum and maximum inverse permittivity values allowed by the material configuration.

Source code in src/fdtdx/objects/device/parameters/latent.py
@extended_autoinit
class StandardToInversePermittivityRange(SameShapeDtypeLatentTransform):
    """Maps standard [0,1] range to inverse permittivity range.

    Linearly maps values from [0,1] to the range between minimum and maximum
    inverse permittivity values allowed by the material configuration.
    """

    def transform(
        self,
        input_params: dict[str, jax.Array] | jax.Array,
    ) -> dict[str, jax.Array] | jax.Array:
        # determine minimum and maximum allowed permittivity
        max_inv_perm, min_inv_perm = -math.inf, math.inf
        if isinstance(self._material, dict):
            for k, v in self._material.items():
                p = 1 / v.permittivity
                if p > max_inv_perm:
                    max_inv_perm = p
                if p < min_inv_perm:
                    min_inv_perm = p
        elif isinstance(self._material, ContinuousMaterialRange):
            start_perm = self._material.start_material.permittivity
            end_perm = self._material.end_material.permittivity
            max_inv_perm = max(start_perm, end_perm)
            min_inv_perm = min(start_perm, end_perm)

        # transform
        if isinstance(input_params, dict):
            result = {}
            for k, v in input_params.items():
                mapped = v * (max_inv_perm - min_inv_perm) + min_inv_perm
                result[k] = mapped
        else:
            result = input_params * (max_inv_perm - min_inv_perm) + min_inv_perm
        return result

fdtdx.objects.device.StandardToPlusOneMinusOneRange

Bases: StandardToCustomRange

Maps standard [0,1] range to [-1,1] range.

Special case of StandardToCustomRange that maps to [-1,1] range. Used for symmetric value ranges around zero.

Attributes:

Name Type Description
min_value float

Fixed to -1

max_value float

Fixed to 1

Source code in src/fdtdx/objects/device/parameters/latent.py
@extended_autoinit
class StandardToPlusOneMinusOneRange(StandardToCustomRange):
    """Maps standard [0,1] range to [-1,1] range.

    Special case of StandardToCustomRange that maps to [-1,1] range.
    Used for symmetric value ranges around zero.

    Attributes:
        min_value: Fixed to -1
        max_value: Fixed to 1
    """

    min_value: float = frozen_field(default=-1, init=False)
    max_value: float = frozen_field(default=1, init=False)

fdtdx.objects.device.StandardToCustomRange

Bases: SameShapeDtypeLatentTransform

Maps standard [0,1] range to custom range [min_value, max_value].

Linearly maps values from [0,1] to a custom range specified by min_value and max_value parameters.

Attributes:

Name Type Description
min_value float

Minimum value of target range

max_value float

Maximum value of target range

Source code in src/fdtdx/objects/device/parameters/latent.py
@extended_autoinit
class StandardToCustomRange(SameShapeDtypeLatentTransform):
    """Maps standard [0,1] range to custom range [min_value, max_value].

    Linearly maps values from [0,1] to a custom range specified by min_value
    and max_value parameters.

    Attributes:
        min_value: Minimum value of target range
        max_value: Maximum value of target range
    """

    min_value: float = frozen_field(default=0)
    max_value: float = frozen_field(default=1)

    def transform(
        self,
        input_params: dict[str, jax.Array] | jax.Array,
    ) -> dict[str, jax.Array] | jax.Array:
        if isinstance(input_params, dict):
            result = {}
            for k, v in input_params.items():
                mapped = v * (self.max_value - self.min_value) + self.min_value
                result[k] = mapped
        else:
            result = input_params * (self.max_value - self.min_value) + self.min_value
        return result

Discretizations

fdtdx.objects.device.ClosestIndex

Bases: Discretization

Maps continuous latent values to nearest allowed material indices.

For each input value, finds the index of the closest allowed inverse permittivity value. Uses straight-through gradient estimation to maintain differentiability.

Source code in src/fdtdx/objects/device/parameters/discretization.py
@extended_autoinit
class ClosestIndex(Discretization):
    """Maps continuous latent values to nearest allowed material indices.

    For each input value, finds the index of the closest allowed inverse
    permittivity value. Uses straight-through gradient estimation to maintain
    differentiability.
    """

    def __call__(
        self,
        input_params: dict[str, jax.Array] | jax.Array,
    ) -> jax.Array:
        if not isinstance(input_params, jax.Array):
            raise Exception("Closest Index cannot be used with latent parameters that contain multiple entries")
        arr = input_params
        allowed_inv_perms = 1 / jnp.asarray(compute_allowed_permittivities(self._material))
        dist = jnp.abs(arr[..., None] - allowed_inv_perms)
        discrete = jnp.argmin(dist, axis=-1)
        result = straight_through_estimator(arr, discrete)
        return result

fdtdx.objects.device.PillarDiscretization

Bases: Discretization

Constraint module for mapping pillar structures to allowed configurations.

Maps arbitrary pillar structures to the nearest allowed configurations based on material constraints and geometry requirements. Ensures structures meet fabrication rules like single polymer columns and no trapped air holes.

Attributes:

Name Type Description
axis int

Axis along which to enforce pillar constraints (0=x, 1=y, 2=z).

single_polymer_columns bool

If True, restrict to single polymer columns.

distance_metric Literal['euclidean', 'permittivity_differences_plus_average_permittivity']

Method to compute distances between material distributions: - "euclidean": Standard Euclidean distance between permittivity values - "permittivity_differences_plus_average_permittivity": Weighted combination of permittivity differences and average permittivity values, optimized for material distribution comparisons

Source code in src/fdtdx/objects/device/parameters/discretization.py
@extended_autoinit
class PillarDiscretization(Discretization):
    """Constraint module for mapping pillar structures to allowed configurations.

    Maps arbitrary pillar structures to the nearest allowed configurations based on
    material constraints and geometry requirements. Ensures structures meet fabrication
    rules like single polymer columns and no trapped air holes.

    Attributes:
        axis: Axis along which to enforce pillar constraints (0=x, 1=y, 2=z).
        single_polymer_columns: If True, restrict to single polymer columns.
        distance_metric: Method to compute distances between material distributions:
            - "euclidean": Standard Euclidean distance between permittivity values
            - "permittivity_differences_plus_average_permittivity": Weighted combination
              of permittivity differences and average permittivity values, optimized
              for material distribution comparisons
    """

    axis: int = frozen_field(kind="KW_ONLY")
    single_polymer_columns: bool = frozen_field(kind="KW_ONLY")

    distance_metric: Literal["euclidean", "permittivity_differences_plus_average_permittivity"] = frozen_field(
        default="permittivity_differences_plus_average_permittivity",
    )
    _allowed_indices: jax.Array = frozen_private_field()

    def init_module(
        self: Self,
        config: SimulationConfig,
        material: dict[str, Material],
        output_shape_dtype: jax.ShapeDtypeStruct,
    ) -> Self:
        self = super().init_module(
            config=config,
            material=material,
            output_shape_dtype=output_shape_dtype,
        )
        air_name = get_air_name(self._material)
        ordered_name_list = compute_ordered_names(self._material)
        air_idx = ordered_name_list.index(air_name)

        allowed_columns = compute_allowed_indices(
            num_layers=output_shape_dtype.shape[self.axis],
            indices=list(range(len(material))),
            fill_holes_with_index=[air_idx],
            single_polymer_columns=self.single_polymer_columns,
        )
        self = self.aset("_allowed_indices", allowed_columns)
        logger.info(f"{allowed_columns=}")
        logger.info(f"{allowed_columns.shape=}")
        return self

    def __call__(
        self,
        input_params: dict[str, jax.Array] | jax.Array,
    ) -> jax.Array:
        if not isinstance(input_params, jax.Array):
            raise Exception("BrushConstraint2D cannot be used with latent parameters that contain multiple entries")

        allowed_inv_perms = 1 / jnp.asarray(compute_allowed_permittivities(self._material))

        nearest_allowed_index = nearest_index(
            values=input_params,
            allowed_values=allowed_inv_perms,
            axis=self.axis,
            distance_metric=self.distance_metric,
            allowed_indices=self._allowed_indices,
            return_distances=False,
        )
        result_index = self._allowed_indices[nearest_allowed_index]
        if self.axis == 2:
            pass  # no transposition needed
        elif self.axis == 1:
            result_index = jnp.transpose(result_index, axes=(0, 2, 1))
        elif self.axis == 0:
            result_index = jnp.transpose(result_index, axes=(2, 0, 1))
        else:
            raise Exception(f"invalid axis: {self.axis}")
        result_index = straight_through_estimator(input_params, result_index)
        return result_index

fdtdx.objects.device.BrushConstraint2D

Bases: Discretization

Applies 2D brush-based constraints to ensure minimum feature sizes.

Implements the brush-based constraint method described in: https://pubs.acs.org/doi/10.1021/acsphotonics.2c00313

This ensures minimum feature sizes and connectivity in 2D designs by using morphological operations with a brush kernel.

Attributes:

Name Type Description
brush Array

JAX array defining the brush kernel for morphological operations.

axis int

Axis along which to apply the 2D constraint (perpendicular plane).

Source code in src/fdtdx/objects/device/parameters/discretization.py
@extended_autoinit
class BrushConstraint2D(Discretization):
    """Applies 2D brush-based constraints to ensure minimum feature sizes.

    Implements the brush-based constraint method described in:
    https://pubs.acs.org/doi/10.1021/acsphotonics.2c00313

    This ensures minimum feature sizes and connectivity in 2D designs by using
    morphological operations with a brush kernel.

    Attributes:
        brush: JAX array defining the brush kernel for morphological operations.
        axis: Axis along which to apply the 2D constraint (perpendicular plane).
    """

    brush: jax.Array = frozen_field()
    axis: int = frozen_field()

    def __call__(
        self,
        input_params: dict[str, jax.Array] | jax.Array,
    ) -> jax.Array:
        if not isinstance(input_params, jax.Array):
            raise Exception("BrushConstraint2D cannot be used with latent parameters that contain multiple entries")
        if len(self._material) > 2:
            raise Exception("BrushConstraint2D currently only implemented for single material and air")
        s = input_params.shape
        if len(s) != 3:
            raise Exception(f"BrushConstraint2D Generator can only work with 2D-Arrays, got {s=}")
        if s[self.axis] != 1:
            raise Exception(f"BrushConstraint2D Generator needs array size 1 in axis, but got {s=}")
        arr_2d = jnp.take(
            input_params,
            jnp.asarray(0),
            axis=self.axis,
        )

        cur_result = 1 - self._generator(arr_2d)

        air_name = get_air_name(self._material)
        ordered_name_list = compute_ordered_names(self._material)
        air_idx = ordered_name_list.index(air_name)
        if air_idx != 0:
            cur_result = 1 - cur_result
        cur_result = jnp.expand_dims(cur_result, axis=self.axis)
        result = straight_through_estimator(input_params, cur_result)
        return result

    def _generator(
        self,
        arr: jax.Array,
    ) -> jax.Array:
        touches_void = jnp.zeros_like(arr, dtype=jnp.bool)
        touches_solid = jnp.zeros_like(touches_void)

        def cond_fn(arrs):
            touch_v, touch_s = arrs[0], arrs[1]
            pixel_existing_solid = dilate_jax(touch_s, self.brush)
            pixel_existing_void = dilate_jax(touch_v, self.brush)
            return ~jnp.all(pixel_existing_solid | pixel_existing_void)

        def body_fn(sv_arrs: tuple[jax.Array, jax.Array]):
            # see Algorithm 1 in paper
            touch_v, touch_s = sv_arrs[0], sv_arrs[1]
            # compute touches and pixel arrays
            pixel_existing_solid = dilate_jax(touch_s, self.brush)
            pixel_existing_void = dilate_jax(touch_v, self.brush)
            touch_impossible_solid = dilate_jax(pixel_existing_void, self.brush)
            touch_impossible_void = dilate_jax(pixel_existing_solid, self.brush)
            touch_valid_solid = ~touch_impossible_solid & ~touch_s
            touch_valid_void = ~touch_impossible_void & ~touch_v
            pixel_possible_solid = dilate_jax(touch_s | touch_valid_solid, self.brush)
            pixel_possible_void = dilate_jax(touch_v | touch_valid_void, self.brush)
            pixel_required_solid = ~pixel_existing_solid & ~pixel_possible_void
            pixel_required_void = ~pixel_existing_void & ~pixel_possible_solid
            touch_resolving_solid = dilate_jax(pixel_required_solid, self.brush) & touch_valid_solid
            touch_resolving_void = dilate_jax(pixel_required_void, self.brush) & touch_valid_void
            touch_free_solid = ~dilate_jax(pixel_possible_void | pixel_existing_void, self.brush) & touch_valid_solid
            touch_free_void = ~dilate_jax(pixel_possible_solid | pixel_existing_solid, self.brush) & touch_valid_void

            # case 1
            def select_all_free_touches():
                new_v = touch_v | touch_free_void
                new_s = touch_s | touch_free_solid
                return new_v, new_s

            # case 2
            def select_best_resolving_touch():
                values_solid = jnp.where(touch_resolving_solid, arr, -jnp.inf)
                values_void = jnp.where(touch_resolving_void, -arr, -jnp.inf)

                def select_void():
                    max_idx = jnp.argmax(values_void)
                    new_v = touch_v.flatten().at[max_idx].set(True).reshape(touch_s.shape)
                    return new_v, touch_s

                def select_solid():
                    max_idx = jnp.argmax(values_solid)
                    new_s = touch_s.flatten().at[max_idx].set(True).reshape(touch_v.shape)
                    return touch_v, new_s

                return jax.lax.cond(
                    jnp.max(values_solid) > jnp.max(values_void),
                    select_solid,
                    select_void,
                )

            # case 3
            def select_best_valid_touch():
                values_solid = jnp.where(touch_valid_solid, arr, -jnp.inf)
                values_void = jnp.where(touch_valid_void, -arr, -jnp.inf)

                def select_void():
                    max_idx = jnp.argmax(values_void)
                    new_v = touch_v.flatten().at[max_idx].set(True).reshape(touch_s.shape)
                    return new_v, touch_s

                def select_solid():
                    max_idx = jnp.argmax(values_solid)
                    new_s = touch_s.flatten().at[max_idx].set(True).reshape(touch_v.shape)
                    return touch_v, new_s

                return jax.lax.cond(
                    jnp.max(values_solid) > jnp.max(values_void),
                    select_solid,
                    select_void,
                )

            # case 2 and 3
            def case_2_and_3_function():
                resolving_exists = jnp.any(touch_resolving_solid | touch_resolving_void)

                return jax.lax.cond(
                    resolving_exists,
                    select_best_resolving_touch,
                    select_best_valid_touch,
                )

            free_touches_exist = jnp.any(touch_free_solid | touch_free_void)
            new_v, new_s = jax.lax.cond(
                free_touches_exist,
                select_all_free_touches,
                case_2_and_3_function,
            )
            return new_v, new_s

        arrs = (touches_void, touches_solid)

        res_arrs = eqxi.while_loop(
            cond_fun=cond_fn,
            body_fun=body_fn,
            init_val=arrs,
            kind="lax",
        )
        pixel_existing_solid = dilate_jax(res_arrs[1], self.brush)
        return pixel_existing_solid

fdtdx.objects.device.circular_brush(diameter, size=None)

Creates a circular binary mask/brush for morphological operations.

Parameters:

Name Type Description Default
diameter float

Diameter of the circle in grid units.

required
size int | None

Optional size of the output array. If None, uses ceil(diameter) rounded up to next odd number.

None

Returns:

Type Description
Array

Binary JAX array containing a circular mask where True indicates points

Array

within the circle diameter.

Source code in src/fdtdx/objects/device/parameters/discretization.py
def circular_brush(
    diameter: float,
    size: int | None = None,
) -> jax.Array:
    """Creates a circular binary mask/brush for morphological operations.

    Args:
        diameter: Diameter of the circle in grid units.
        size: Optional size of the output array. If None, uses ceil(diameter) rounded
            up to next odd number.

    Returns:
        Binary JAX array containing a circular mask where True indicates points
        within the circle diameter.
    """
    if size is None:
        s = math.ceil(diameter)
        if s % 2 == 0:
            s += 1
        size = s
    xy = jnp.stack(jnp.meshgrid(*map(jnp.arange, (size, size)), indexing="xy"), axis=-1) - jnp.asarray((size / 2) - 0.5)
    euc_dist = jnp.sqrt((xy**2).sum(axis=-1))
    # the less EQUAL here is important, because otherwise design may be infeasible due to discretization errors
    mask = euc_dist <= (diameter / 2)
    return mask

Discrete PostProcessing

fdtdx.objects.device.BinaryMedianFilterModule

Bases: DiscreteTransformation

Performs 3D binary median filtering on the design.

Applies a 3D median filter to smooth and clean up binary material distributions. This helps remove small features and noise while preserving larger structures.

Attributes:

Name Type Description
padding_cfg PaddingConfig

Configuration for padding behavior at boundaries.

kernel_sizes tuple[int, int, int]

3-tuple of kernel sizes for each dimension.

num_repeats int

Number of times to apply the filter consecutively.

Source code in src/fdtdx/objects/device/parameters/discrete.py
@extended_autoinit
class BinaryMedianFilterModule(DiscreteTransformation):
    """Performs 3D binary median filtering on the design.

    Applies a 3D median filter to smooth and clean up binary material distributions.
    This helps remove small features and noise while preserving larger structures.

    Attributes:
        padding_cfg: Configuration for padding behavior at boundaries.
        kernel_sizes: 3-tuple of kernel sizes for each dimension.
        num_repeats: Number of times to apply the filter consecutively.
    """

    padding_cfg: PaddingConfig = frozen_field()
    kernel_sizes: tuple[int, int, int] = frozen_field()
    num_repeats: int = frozen_field(default=1)

    def __call__(
        self,
        material_indices: jax.Array,
    ) -> jax.Array:
        if len(self._material) != 2:
            raise Exception("BinaryMedianFilterModule only works for two materials!")
        cur_arr = material_indices
        for _ in range(self.num_repeats):
            cur_arr = binary_median_filter(
                arr_3d=cur_arr,
                kernel_sizes=self.kernel_sizes,
                padding_cfg=self.padding_cfg,
            )
        result = straight_through_estimator(material_indices, cur_arr)
        return result

fdtdx.objects.device.ConnectHolesAndStructures

Bases: DiscreteTransformation

Connects floating polymer regions and ensures air holes connect to outside.

This constraint module ensures physical realizability of designs by: 1. Either connecting floating polymer regions to the substrate or removing them 2. Ensuring all air holes are connected to the outside (no trapped air)

The bottom (lower z) is treated as the substrate reference.

Attributes:

Name Type Description
fill_material str | None

Name of material to use for filling gaps when connecting regions. Required when working with more than 2 materials.

Source code in src/fdtdx/objects/device/parameters/discrete.py
@extended_autoinit
class ConnectHolesAndStructures(DiscreteTransformation):
    """Connects floating polymer regions and ensures air holes connect to outside.

    This constraint module ensures physical realizability of designs by:
    1. Either connecting floating polymer regions to the substrate or removing them
    2. Ensuring all air holes are connected to the outside (no trapped air)

    The bottom (lower z) is treated as the substrate reference.

    Attributes:
        fill_material: Name of material to use for filling gaps when connecting regions.
            Required when working with more than 2 materials.
    """

    fill_material: str | None = frozen_field(default=None)

    def __call__(
        self,
        material_indices: jax.Array,
    ) -> jax.Array:
        if len(self._material) > 2 and self.fill_material is None:
            raise Exception(
                "ConnectHolesAndStructures: Need to specify fill material when working with more than a single material"
            )
        air_name = get_air_name(self._material)
        ordered_name_list = compute_ordered_names(self._material)
        air_idx = ordered_name_list.index(air_name)
        is_material_matrix = material_indices != air_idx
        feasible_material_matrix = connect_holes_and_structures(is_material_matrix)

        result = jnp.empty_like(material_indices)
        # set air
        result = jnp.where(
            feasible_material_matrix,
            -1,  # this is set below
            air_idx,
        )
        # material where previously was material
        result = jnp.where(feasible_material_matrix & is_material_matrix, material_indices, result)

        # material, where previously was air
        fill_name = self.fill_material
        if fill_name is None:
            fill_name = ordered_name_list[1 - air_idx]
        fill_idx = ordered_name_list.index(fill_name)
        result = jnp.where(
            feasible_material_matrix & ~is_material_matrix,
            fill_idx,
            result,
        )
        return result

fdtdx.objects.device.RemoveFloatingMaterial

Bases: DiscreteTransformation

Finds all material that floats in the air and sets their permittivity to air.

This constraint module identifies regions of material that are not connected to any substrate or boundary and converts them to air. This helps ensure physically realizable designs by eliminating floating/disconnected material regions.

The module only works with binary material systems (2 permittivities) where one material represents air.

Source code in src/fdtdx/objects/device/parameters/discrete.py
@extended_autoinit
class RemoveFloatingMaterial(DiscreteTransformation):
    """Finds all material that floats in the air and sets their permittivity to air.

    This constraint module identifies regions of material that are not connected to any
    substrate or boundary and converts them to air. This helps ensure physically
    realizable designs by eliminating floating/disconnected material regions.

    The module only works with binary material systems (2 permittivities) where one
    material represents air.
    """

    def __call__(
        self,
        material_indices: jax.Array,
    ) -> jax.Array:
        if len(self._material) != 2:
            raise NotImplementedError("Remove floating material currently only implemented for single material")
        air_name = get_air_name(self._material)
        ordered_name_list = compute_ordered_names(self._material)
        air_idx = ordered_name_list.index(air_name)

        is_material_matrix = material_indices != air_idx
        is_material_after_removal = remove_floating_polymer(is_material_matrix)
        result = (1 - air_idx) * is_material_after_removal + air_idx * ~is_material_after_removal
        result = straight_through_estimator(material_indices, result)
        return result

fdtdx.objects.device.BinaryMedianFilterModule

Bases: DiscreteTransformation

Performs 3D binary median filtering on the design.

Applies a 3D median filter to smooth and clean up binary material distributions. This helps remove small features and noise while preserving larger structures.

Attributes:

Name Type Description
padding_cfg PaddingConfig

Configuration for padding behavior at boundaries.

kernel_sizes tuple[int, int, int]

3-tuple of kernel sizes for each dimension.

num_repeats int

Number of times to apply the filter consecutively.

Source code in src/fdtdx/objects/device/parameters/discrete.py
@extended_autoinit
class BinaryMedianFilterModule(DiscreteTransformation):
    """Performs 3D binary median filtering on the design.

    Applies a 3D median filter to smooth and clean up binary material distributions.
    This helps remove small features and noise while preserving larger structures.

    Attributes:
        padding_cfg: Configuration for padding behavior at boundaries.
        kernel_sizes: 3-tuple of kernel sizes for each dimension.
        num_repeats: Number of times to apply the filter consecutively.
    """

    padding_cfg: PaddingConfig = frozen_field()
    kernel_sizes: tuple[int, int, int] = frozen_field()
    num_repeats: int = frozen_field(default=1)

    def __call__(
        self,
        material_indices: jax.Array,
    ) -> jax.Array:
        if len(self._material) != 2:
            raise Exception("BinaryMedianFilterModule only works for two materials!")
        cur_arr = material_indices
        for _ in range(self.num_repeats):
            cur_arr = binary_median_filter(
                arr_3d=cur_arr,
                kernel_sizes=self.kernel_sizes,
                padding_cfg=self.padding_cfg,
            )
        result = straight_through_estimator(material_indices, cur_arr)
        return result