Skip to content

Object Placement and Parameters

fdtdx.place_objects(volume, config, constraints, key)

Places simulation objects according to specified constraints and initializes containers.

Parameters:

Name Type Description Default
volume SimulationObject

The volume object defining the simulation boundaries

required
config SimulationConfig

The simulation configuration

required
constraints Sequence[PositionConstraint | SizeConstraint | SizeExtensionConstraint | GridCoordinateConstraint | RealCoordinateConstraint]

Sequence of positioning and sizing constraints for objects

required
key Array

JAX random key for initialization

required

Returns:

Type Description
tuple[ObjectContainer, ArrayContainer, ParameterContainer, SimulationConfig, dict[str, Any]]

A tuple containing: - ObjectContainer with placed simulation objects - ArrayContainer with initialized field arrays - ParameterContainer with device parameters - Updated SimulationConfig - Dictionary with additional initialization info

Source code in src/fdtdx/fdtd/initialization.py
def place_objects(
    volume: SimulationObject,
    config: SimulationConfig,
    constraints: Sequence[
        (
            PositionConstraint
            | SizeConstraint
            | SizeExtensionConstraint
            | GridCoordinateConstraint
            | RealCoordinateConstraint
        )
    ],
    key: jax.Array,
) -> tuple[
    ObjectContainer,
    ArrayContainer,
    ParameterContainer,
    SimulationConfig,
    dict[str, Any],
]:
    """Places simulation objects according to specified constraints and initializes containers.

    Args:
        volume: The volume object defining the simulation boundaries
        config: The simulation configuration
        constraints: Sequence of positioning and sizing constraints for objects
        key: JAX random key for initialization

    Returns:
        A tuple containing:
            - ObjectContainer with placed simulation objects
            - ArrayContainer with initialized field arrays
            - ParameterContainer with device parameters
            - Updated SimulationConfig
            - Dictionary with additional initialization info
    """
    slice_tuple_dict = _resolve_object_constraints(
        volume=volume,
        constraints=constraints,
        config=config,
    )
    obj_list = list(slice_tuple_dict.keys())

    # place objects on computed grid positions
    placed_objects = []
    for o in obj_list:
        if o == volume:
            continue
        key, subkey = jax.random.split(key)
        placed_objects.append(
            o.place_on_grid(
                grid_slice_tuple=slice_tuple_dict[o],
                config=config,
                key=subkey,
            )
        )
    key, subkey = jax.random.split(key)
    placed_objects.insert(
        0,
        volume.place_on_grid(
            grid_slice_tuple=slice_tuple_dict[volume],
            config=config,
            key=subkey,
        ),
    )

    # create container
    objects = ObjectContainer(
        object_list=placed_objects,
        volume_idx=0,
    )
    params = _init_params(
        objects=objects,
        key=key,
    )
    arrays, config, info = _init_arrays(
        objects=objects,
        config=config,
    )

    # replace config in objects with compiled config
    new_object_list = []
    for o in objects.objects:
        o = o.aset("_config", config)
        new_object_list.append(o)
    objects = ObjectContainer(
        object_list=new_object_list,
        volume_idx=0,
    )

    return objects, arrays, params, config, info

Main entry point for placing and initializing simulation objects.

fdtdx.apply_params(arrays, objects, params, key, **transform_kwargs)

Applies parameters to devices and updates source states.

Parameters:

Name Type Description Default
arrays ArrayContainer

Container with field arrays

required
objects ObjectContainer

Container with simulation objects

required
params ParameterContainer

Container with device parameters

required
key Array

JAX random key for source updates

required

Returns:

Type Description
tuple[ArrayContainer, ObjectContainer, dict[str, Any]]

A tuple containing: - Updated ArrayContainer with applied device parameters - Updated ObjectContainer with new source states - Dictionary with parameter application info

Source code in src/fdtdx/fdtd/initialization.py
def apply_params(
    arrays: ArrayContainer,
    objects: ObjectContainer,
    params: ParameterContainer,
    key: jax.Array,
    **transform_kwargs,
) -> tuple[ArrayContainer, ObjectContainer, dict[str, Any]]:
    """Applies parameters to devices and updates source states.

    Args:
        arrays: Container with field arrays
        objects: Container with simulation objects
        params: Container with device parameters
        key: JAX random key for source updates

    Returns:
        A tuple containing:
            - Updated ArrayContainer with applied device parameters
            - Updated ObjectContainer with new source states
            - Dictionary with parameter application info
    """
    info = {}
    # apply parameter to devices
    for device in objects.devices:
        cur_material_indices = device(params[device.name], expand_to_sim_grid=True, **transform_kwargs)
        allowed_perm_list = compute_allowed_permittivities(device.materials)
        if device.output_type == ParameterType.CONTINUOUS:
            first_term = (1 - cur_material_indices) * (1 / allowed_perm_list[0])
            second_term = cur_material_indices * (1 / allowed_perm_list[1])
            new_perm_slice = first_term + second_term
        else:
            new_perm_slice = jnp.asarray(allowed_perm_list)[cur_material_indices.astype(jnp.int32)]
            new_perm_slice = straight_through_estimator(cur_material_indices, new_perm_slice)
            new_perm_slice = 1 / new_perm_slice
        new_perm = arrays.inv_permittivities.at[*device.grid_slice].set(new_perm_slice)
        arrays = arrays.at["inv_permittivities"].set(new_perm)

    # apply random key to sources
    new_objects = []
    for obj in objects.object_list:
        key, subkey = jax.random.split(key)
        new_obj = obj.apply(
            key=subkey,
            inv_permittivities=jax.lax.stop_gradient(arrays.inv_permittivities),
            inv_permeabilities=jax.lax.stop_gradient(arrays.inv_permeabilities),
        )
        new_objects.append(new_obj)
    new_objects = ObjectContainer(
        object_list=new_objects,
        volume_idx=objects.volume_idx,
    )

    return arrays, new_objects, info

Applies parameters to devices and updates source states to be ready for simulation.

Core FDTD Algorithm

fdtdx.run_fdtd(arrays, objects, config, key)

Source code in src/fdtdx/fdtd/wrapper.py
def run_fdtd(
    arrays: ArrayContainer,
    objects: ObjectContainer,
    config: SimulationConfig,
    key: jax.Array,
) -> SimulationState:
    if config.gradient_config is None:
        # only forward simulation, use standard while loop of checkpointed fdtd
        return checkpointed_fdtd(
            arrays=arrays,
            objects=objects,
            config=config,
            key=key,
        )
    if config.gradient_config.method == "reversible":
        return reversible_fdtd(
            arrays=arrays,
            objects=objects,
            config=config,
            key=key,
        )
    elif config.gradient_config.method == "checkpointed":
        return checkpointed_fdtd(
            arrays=arrays,
            objects=objects,
            config=config,
            key=key,
        )
    else:
        raise Exception(f"Unknown gradient computation method: {config.gradient_config.method}")

Time-reversal symmetric FDTD implementation with memory-efficient autodiff.

Python Objects used for FDTD simulation

fdtdx.ArrayContainer

Bases: ExtendedTreeClass

Container for simulation field arrays and states.

This class holds the electromagnetic field arrays and various state information needed during FDTD simulation. It includes the E and H fields, material properties, and states for boundaries, detectors and recordings.

Attributes:

Name Type Description
E Array

Electric field array.

H Array

Magnetic field array.

inv_permittivities Array

Inverse permittivity values array.

inv_permeabilities Array | float

Inverse permeability values array.

boundary_states dict[str, BaseBoundaryState]

Dictionary mapping boundary names to their states.

detector_states dict[str, DetectorState]

Dictionary mapping detector names to their states.

recording_state RecordingState | None

Optional state for recording simulation data.

Source code in src/fdtdx/fdtd/container.py
@extended_autoinit
class ArrayContainer(ExtendedTreeClass):
    """Container for simulation field arrays and states.

    This class holds the electromagnetic field arrays and various state information
    needed during FDTD simulation. It includes the E and H fields, material properties,
    and states for boundaries, detectors and recordings.

    Attributes:
        E: Electric field array.
        H: Magnetic field array.
        inv_permittivities: Inverse permittivity values array.
        inv_permeabilities: Inverse permeability values array.
        boundary_states: Dictionary mapping boundary names to their states.
        detector_states: Dictionary mapping detector names to their states.
        recording_state: Optional state for recording simulation data.
    """

    E: jax.Array
    H: jax.Array
    inv_permittivities: jax.Array
    inv_permeabilities: jax.Array | float
    boundary_states: dict[str, BaseBoundaryState]
    detector_states: dict[str, DetectorState]
    recording_state: RecordingState | None
    electric_conductivity: jax.Array | None = None
    magnetic_conductivity: jax.Array | None = None

Container holding the electric/magnetic fields as well as permittivity/permeability arrays for simulation

fdtdx.ObjectContainer

Bases: ExtendedTreeClass

Container for managing simulation objects and their relationships.

This class provides a structured way to organize and access different types of simulation objects like sources, detectors, PML/periodic boundaries and devices. It maintains object lists and provides filtered access to specific object types.

Attributes:

Name Type Description
object_list list[SimulationObject]

List of all simulation objects in the container.

volume_idx int

Index of the volume object in the object list.

Source code in src/fdtdx/fdtd/container.py
@extended_autoinit
class ObjectContainer(ExtendedTreeClass):
    """Container for managing simulation objects and their relationships.

    This class provides a structured way to organize and access different types of simulation
    objects like sources, detectors, PML/periodic boundaries and devices. It maintains object lists
    and provides filtered access to specific object types.

    Attributes:
        object_list: List of all simulation objects in the container.
        volume_idx: Index of the volume object in the object list.
    """

    object_list: list[SimulationObject]
    volume_idx: int = frozen_field()

    @property
    def volume(self) -> SimulationObject:
        return self.object_list[self.volume_idx]

    @property
    def objects(self) -> list[SimulationObject]:
        return self.object_list

    @property
    def static_material_objects(self) -> list[UniformMaterialObject | StaticMultiMaterialObject]:
        return [o for o in self.objects if isinstance(o, (UniformMaterialObject, StaticMultiMaterialObject))]

    @property
    def sources(self) -> list[Source]:
        return [o for o in self.objects if isinstance(o, Source)]

    @property
    def devices(self) -> list[Device]:
        return [o for o in self.objects if isinstance(o, Device)]

    @property
    def detectors(self) -> list[Detector]:
        return [o for o in self.objects if isinstance(o, Detector)]

    @property
    def forward_detectors(self) -> list[Detector]:
        return [o for o in self.detectors if not o.inverse]

    @property
    def backward_detectors(self) -> list[Detector]:
        return [o for o in self.detectors if o.inverse]

    @property
    def pml_objects(self) -> list[PerfectlyMatchedLayer]:
        return [o for o in self.objects if isinstance(o, PerfectlyMatchedLayer)]

    @property
    def periodic_objects(self) -> list[PeriodicBoundary]:
        return [o for o in self.objects if isinstance(o, PeriodicBoundary)]

    @property
    def boundary_objects(self) -> list[BaseBoundary]:
        return [o for o in self.objects if isinstance(o, (PerfectlyMatchedLayer, PeriodicBoundary))]

    @property
    def all_objects_non_magnetic(self) -> bool:
        def _fn(m: Material):
            return not m.is_magnetic

        return self._is_material_fn_true_for_all(_fn)

    @property
    def all_objects_non_electrically_conductive(self) -> bool:
        def _fn(m: Material):
            return not m.is_electrically_conductive

        return self._is_material_fn_true_for_all(_fn)

    @property
    def all_objects_non_magnetically_conductive(self) -> bool:
        def _fn(m: Material):
            return not m.is_magnetically_conductive

        return self._is_material_fn_true_for_all(_fn)

    def _is_material_fn_true_for_all(
        self,
        fn: Callable[[Material], bool],
    ) -> bool:
        for o in self.objects:
            if isinstance(o, UniformMaterialObject):
                m = o.material
            elif isinstance(o, Device):
                m = o.materials
            elif isinstance(o, StaticMultiMaterialObject):
                m = o.materials
            else:
                continue
            if isinstance(m, Material):
                if not fn(m):
                    return False
            elif isinstance(m, dict):
                for v in m.values():
                    if not fn(v):
                        return False
        return True

    def __iter__(self):
        return iter(self.object_list)

    def __getitem__(
        self,
        key: str,
    ) -> SimulationObject:
        for o in self.objects:
            if o.name == key:
                return o
        raise ValueError(f"Key {key} does not exist in object list: {[o.name for o in self.objects]}")

    def replace_sources(
        self,
        sources: list[Source],
    ) -> Self:
        new_objects = [o for o in self.objects if o not in self.sources] + sources
        self = self.aset("object_list", new_objects)
        return self

Container holding all the objects in a simulation scene

fdtdx.ParameterContainer = dict[str, dict[str, jax.Array] | jax.Array] module-attribute

Dictionary holding the parameters for every device in the simulation

fdtdx.SimulationState = tuple[jax.Array, ArrayContainer] module-attribute

Simulation state returned by the FDTD simulations. This is a tuple of the simulation time step and an array container.