Skip to content

Object Placement and Parameters

fdtdx.fdtd.place_objects(volume, config, constraints, key)

Places simulation objects according to specified constraints and initializes containers.

Parameters:

Name Type Description Default
volume SimulationObject

The volume object defining the simulation boundaries

required
config SimulationConfig

The simulation configuration

required
constraints Sequence[PositionConstraint | SizeConstraint | SizeExtensionConstraint | GridCoordinateConstraint | RealCoordinateConstraint]

Sequence of positioning and sizing constraints for objects

required
key Array

JAX random key for initialization

required

Returns:

Type Description
tuple[ObjectContainer, ArrayContainer, ParameterContainer, SimulationConfig, dict[str, Any]]

A tuple containing: - ObjectContainer with placed simulation objects - ArrayContainer with initialized field arrays - ParameterContainer with device parameters - Updated SimulationConfig - Dictionary with additional initialization info

Source code in src/fdtdx/fdtd/initialization.py
def place_objects(
    volume: SimulationObject,
    config: SimulationConfig,
    constraints: Sequence[
        (
            PositionConstraint
            | SizeConstraint
            | SizeExtensionConstraint
            | GridCoordinateConstraint
            | RealCoordinateConstraint
        )
    ],
    key: jax.Array,
) -> tuple[
    ObjectContainer,
    ArrayContainer,
    ParameterContainer,
    SimulationConfig,
    dict[str, Any],
]:
    """Places simulation objects according to specified constraints and initializes containers.

    Args:
        volume: The volume object defining the simulation boundaries
        config: The simulation configuration
        constraints: Sequence of positioning and sizing constraints for objects
        key: JAX random key for initialization

    Returns:
        A tuple containing:
            - ObjectContainer with placed simulation objects
            - ArrayContainer with initialized field arrays
            - ParameterContainer with device parameters
            - Updated SimulationConfig
            - Dictionary with additional initialization info
    """
    slice_tuple_dict = _resolve_object_constraints(
        volume=volume,
        constraints=constraints,
        config=config,
    )
    obj_list = list(slice_tuple_dict.keys())

    # place objects on computed grid positions
    placed_objects = []
    for o in obj_list:
        if o == volume:
            continue
        key, subkey = jax.random.split(key)
        placed_objects.append(
            o.place_on_grid(
                grid_slice_tuple=slice_tuple_dict[o],
                config=config,
                key=subkey,
            )
        )
    key, subkey = jax.random.split(key)
    placed_objects.insert(
        0,
        volume.place_on_grid(
            grid_slice_tuple=slice_tuple_dict[volume],
            config=config,
            key=subkey,
        ),
    )

    # create container
    objects = ObjectContainer(
        object_list=placed_objects,
        volume_idx=0,
    )
    params = _init_params(
        objects=objects,
        key=key,
    )
    arrays, config, info = _init_arrays(
        objects=objects,
        config=config,
    )

    # replace config in objects with compiled config
    new_object_list = []
    for o in objects.objects:
        o = o.aset("_config", config)
        new_object_list.append(o)
    objects = ObjectContainer(
        object_list=new_object_list,
        volume_idx=0,
    )

    return objects, arrays, params, config, info

Main entry point for placing and initializing simulation objects.

fdtdx.fdtd.apply_params(arrays, objects, params, key)

Applies parameters to devices and updates source states.

Parameters:

Name Type Description Default
arrays ArrayContainer

Container with field arrays

required
objects ObjectContainer

Container with simulation objects

required
params ParameterContainer

Container with device parameters

required
key Array

JAX random key for source updates

required

Returns:

Type Description
tuple[ArrayContainer, ObjectContainer, dict[str, Any]]

A tuple containing: - Updated ArrayContainer with applied device parameters - Updated ObjectContainer with new source states - Dictionary with parameter application info

Source code in src/fdtdx/fdtd/initialization.py
def apply_params(
    arrays: ArrayContainer,
    objects: ObjectContainer,
    params: ParameterContainer,
    key: jax.Array,
) -> tuple[ArrayContainer, ObjectContainer, dict[str, Any]]:
    """Applies parameters to devices and updates source states.

    Args:
        arrays: Container with field arrays
        objects: Container with simulation objects
        params: Container with device parameters
        key: JAX random key for source updates

    Returns:
        A tuple containing:
            - Updated ArrayContainer with applied device parameters
            - Updated ObjectContainer with new source states
            - Dictionary with parameter application info
    """
    info = {}
    # apply parameter to devices
    for device in objects.discrete_devices:
        cur_material_indices = device.get_expanded_material_mapping(params[device.name])
        allowed_perm_list = compute_allowed_permittivities(device.material)
        new_perm_slice = (1.0 / jnp.asarray(allowed_perm_list))[cur_material_indices.astype(jnp.int32)]
        new_perm_slice = straight_through_estimator(cur_material_indices, new_perm_slice)
        new_perm = arrays.inv_permittivities.at[*device.grid_slice].set(new_perm_slice)
        arrays = arrays.at["inv_permittivities"].set(new_perm)

    # apply parameters to continous devices
    if not objects.all_objects_non_magnetic:
        for device in objects.continous_devices:
            cur_material_indices = device.get_expanded_material_mapping(params[device.name])
            new_perm_slice = (1 - cur_material_indices) * (
                1 / device.material.start_material.permittivity
            ) + cur_material_indices * (1 / device.material.end_material.permittivity)
            new_perm = arrays.inv_permittivities.at[*device.grid_slice].set(new_perm_slice)
            arrays = arrays.at["inv_permittivities"].set(new_perm)

    # apply random key to sources
    new_sources = []
    for source in objects.sources:
        key, subkey = jax.random.split(key)
        new_source = source.apply(
            key=subkey,
            inv_permittivities=jax.lax.stop_gradient(arrays.inv_permittivities),
            inv_permeabilities=jax.lax.stop_gradient(arrays.inv_permeabilities),
        )
        new_sources.append(new_source)
    objects = objects.replace_sources(new_sources)

    return arrays, objects, info

Applies parameters to devices and updates source states to be ready for simulation.

Core FDTD Algorithms

fdtdx.fdtd.reversible_fdtd(arrays, objects, config, key)

Run a memory-efficient differentiable FDTD simulation leveraging time-reversal symmetry.

This implementation exploits the time-reversal symmetry of Maxwell's equations to perform backpropagation without storing the electromagnetic fields at each time step. During the backward pass, the fields are reconstructed by running the simulation in reverse, only requiring O(1) memory storage instead of O(T) where T is the number of time steps.

The only exception is boundary conditions which break time-reversal symmetry - these are recorded during the forward pass and replayed during backpropagation.

Parameters:

Name Type Description Default
arrays ArrayContainer

Initial state of the simulation containing: - E, H: Electric and magnetic field arrays - inv_permittivities, inv_permeabilities: Material properties - boundary_states: Dictionary of boundary conditions - detector_states: Dictionary of field detectors - recording_state: Optional state for recording field evolution

required
objects ObjectContainer

Collection of physical objects in the simulation (sources, detectors, boundaries, etc.)

required
config SimulationConfig

Simulation parameters including: - time_steps_total: Total number of steps to simulate - invertible_optimization: Whether to record boundaries for backprop

required
key Array

JAX PRNGKey for any stochastic operations

required

Returns:

Name Type Description
SimulationState SimulationState

Tuple containing: - Final time step (int) - ArrayContainer with the final state of all fields and components

Notes

The implementation uses custom vector-Jacobian products (VJPs) to enable efficient backpropagation through the entire simulation while maintaining numerical stability. This makes it suitable for gradient-based optimization of electromagnetic designs.

Source code in src/fdtdx/fdtd/fdtd.py
def reversible_fdtd(
    arrays: ArrayContainer,
    objects: ObjectContainer,
    config: SimulationConfig,
    key: jax.Array,
) -> SimulationState:
    """Run a memory-efficient differentiable FDTD simulation leveraging time-reversal symmetry.

    This implementation exploits the time-reversal symmetry of Maxwell's equations to perform
    backpropagation without storing the electromagnetic fields at each time step. During the
    backward pass, the fields are reconstructed by running the simulation in reverse, only
    requiring O(1) memory storage instead of O(T) where T is the number of time steps.

    The only exception is boundary conditions which break time-reversal symmetry - these are
    recorded during the forward pass and replayed during backpropagation.

    Args:
        arrays (ArrayContainer): Initial state of the simulation containing:
            - E, H: Electric and magnetic field arrays
            - inv_permittivities, inv_permeabilities: Material properties
            - boundary_states: Dictionary of boundary conditions
            - detector_states: Dictionary of field detectors
            - recording_state: Optional state for recording field evolution
        objects (ObjectContainer): Collection of physical objects in the simulation
            (sources, detectors, boundaries, etc.)
        config (SimulationConfig): Simulation parameters including:
            - time_steps_total: Total number of steps to simulate
            - invertible_optimization: Whether to record boundaries for backprop
        key (jax.Array): JAX PRNGKey for any stochastic operations

    Returns:
        SimulationState: Tuple containing:
            - Final time step (int)
            - ArrayContainer with the final state of all fields and components

    Notes:
        The implementation uses custom vector-Jacobian products (VJPs) to enable
        efficient backpropagation through the entire simulation while maintaining
        numerical stability. This makes it suitable for gradient-based optimization
        of electromagnetic designs.
    """
    arrays = reset_array_container(
        arrays,
        objects,
    )

    def reversible_fdtd_base(
        arr: ArrayContainer,
    ) -> SimulationState:
        """Core implementation of reversible FDTD simulation.

        Performs the main FDTD time-stepping loop using a while loop that respects
        JAX's functional programming model.

        Args:
            arr: ArrayContainer with initial field state and material properties

        Returns:
            SimulationState tuple containing:
                - Final time step
                - ArrayContainer with final simulation state
        """
        state = (jnp.asarray(0, dtype=jnp.int32), arr)
        state = eqxi.while_loop(
            max_steps=config.time_steps_total,
            cond_fun=lambda s: config.time_steps_total > s[0],
            body_fun=partial(
                forward,
                config=config,
                objects=objects,
                key=key,
                record_detectors=True,
                record_boundaries=config.invertible_optimization,
                simulate_boundaries=True,
            ),
            init_val=state,
            kind="lax",
        )
        return (state[0], state[1])

    @jax.custom_vjp
    def reversible_fdtd_primal(
        E: jax.Array,
        H: jax.Array,
        inv_permittivities: jax.Array,
        inv_permeabilities: jax.Array,
        boundary_states: dict[str, BaseBoundaryState],
        detector_states: dict[str, DetectorState],
        recording_state: RecordingState | None,
    ):
        arr = ArrayContainer(
            E=E,
            H=H,
            inv_permittivities=inv_permittivities,
            inv_permeabilities=inv_permeabilities,
            boundary_states=boundary_states,
            detector_states=detector_states,
            recording_state=recording_state,
        )
        state = reversible_fdtd_base(arr)
        return (
            state[0],
            state[1].E,
            state[1].H,
            state[1].inv_permittivities,
            state[1].inv_permeabilities,
            state[1].boundary_states,
            state[1].detector_states,
            state[1].recording_state,
        )

    def body_fn(
        sr_tuple,
    ):
        state, cot = sr_tuple
        state = backward(
            state=state,
            config=config,
            objects=objects,
            key=key,
            record_detectors=False,
            reset_fields=False,
        )
        _, update_vjp = jax.vjp(
            partial(
                forward_single_args_wrapper,
                config=config,
                objects=objects,
                key=key,
                record_detectors=True,
                record_boundaries=False,
                simulate_boundaries=True,
            ),
            state[0],
            state[1].E,
            state[1].H,
            state[1].inv_permittivities,
            state[1].inv_permeabilities,
            state[1].boundary_states,
            state[1].detector_states,
            state[1].recording_state,
        )

        cot = update_vjp(cot)
        return state, cot

    def cond_fun(
        sr_tuple,
        start_time_step: int,
    ):
        s_k, r_k = sr_tuple
        del r_k
        time_step = s_k[0]
        return time_step >= start_time_step

    def fdtd_bwd(
        residual,
        cot,
    ):
        """Backward pass for reversible FDTD simulation.

        Implements the custom vector-Jacobian product for backpropagation through
        the FDTD simulation by leveraging time-reversibility.

        Args:
            residual: Tuple containing the final simulation state including:
                - Time step
                - E, H field arrays
                - Material properties
                - Boundary and detector states
                - Recording state
            cot: Cotangent values for gradient computation

        Returns:
            Tuple of cotangent values for each input parameter
        """
        (
            res_time_step,
            res_E,
            res_H,
            res_inv_permittivities,
            res_inv_permeabilities,
            res_boundary_states,
            res_detector_states,
            res_recording_state,
        ) = residual

        s_k = ArrayContainer(
            E=res_E,
            H=res_H,
            inv_permittivities=res_inv_permittivities,
            inv_permeabilities=res_inv_permeabilities,
            boundary_states=res_boundary_states,
            detector_states=res_detector_states,
            recording_state=res_recording_state,
        )

        _, cot = eqxi.while_loop(
            cond_fun=partial(cond_fun, start_time_step=0),
            body_fun=body_fn,
            init_val=((res_time_step, s_k), cot),
            kind="lax",
        )
        return (
            None,  # cot[1],
            None,  # cot[2],
            cot[3],
            cot[4],
            None,  # cot[5]
            None,  # cot[6],
            None,  # cot[7],
        )

    def fdtd_fwd(
        E: jax.Array,
        H: jax.Array,
        inv_permittivities: jax.Array,
        inv_permeabilities: jax.Array,
        boundary_states: dict[str, BaseBoundaryState],
        detector_states: dict[str, DetectorState],
        recording_state: RecordingState | None,
    ):
        """Forward pass for reversible FDTD simulation.

        Performs the forward FDTD simulation and prepares residuals for the backward pass.

        Args:
            E: Electric field array
            H: Magnetic field array
            inv_permittivities: Inverse permittivity values
            inv_permeabilities: Inverse permeability values
            boundary_states: Dictionary mapping boundary names to their states
            detector_states: Dictionary mapping detector names to their states
            recording_state: Optional state for recording field evolution

        Returns:
            Tuple containing:
                - Primal outputs (final simulation state)
                - Residuals for backward pass
        """
        arr = ArrayContainer(
            E=E,
            H=H,
            inv_permittivities=inv_permittivities,
            inv_permeabilities=inv_permeabilities,
            boundary_states=boundary_states,
            detector_states=detector_states,
            recording_state=recording_state,
        )
        s_k = reversible_fdtd_base(arr)

        primal_out = (
            s_k[0],
            s_k[1].E,
            s_k[1].H,
            s_k[1].inv_permittivities,
            s_k[1].inv_permeabilities,
            s_k[1].boundary_states,
            s_k[1].detector_states,
            s_k[1].recording_state,  # None
        )
        residual = (
            s_k[0],
            s_k[1].E,
            s_k[1].H,
            s_k[1].inv_permittivities,
            s_k[1].inv_permeabilities,
            s_k[1].boundary_states,
            s_k[1].detector_states,
            s_k[1].recording_state,
        )
        return primal_out, residual

    reversible_fdtd_primal.defvjp(fdtd_fwd, fdtd_bwd)

    (
        time_step,
        E,
        H,
        inv_permittivities,
        inv_permeabilities,
        boundary_states,
        detector_states,
        recording_state,
    ) = reversible_fdtd_primal(
        E=arrays.E,
        H=arrays.H,
        inv_permittivities=arrays.inv_permittivities,
        inv_permeabilities=arrays.inv_permeabilities,
        boundary_states=arrays.boundary_states,
        detector_states=arrays.detector_states,
        recording_state=arrays.recording_state,
    )
    out_arrs = ArrayContainer(
        E=E,
        H=H,
        inv_permittivities=inv_permittivities,
        inv_permeabilities=inv_permeabilities,
        boundary_states=boundary_states,
        detector_states=detector_states,
        recording_state=recording_state,
    )
    return time_step, out_arrs

Time-reversal symmetric FDTD implementation with memory-efficient autodiff.

fdtdx.fdtd.checkpointed_fdtd(arrays, objects, config, key)

Run an FDTD simulation with gradient checkpointing for memory efficiency.

This implementation uses checkpointing to reduce memory usage during backpropagation by only storing the field state at certain intervals and recomputing intermediate states as needed.

Parameters:

Name Type Description Default
arrays ArrayContainer

Initial state of the simulation containing fields and materials

required
objects ObjectContainer

Collection of physical objects in the simulation

required
config SimulationConfig

Simulation parameters including checkpointing settings

required
key Array

JAX PRNGKey for any stochastic operations

required

Returns:

Name Type Description
SimulationState SimulationState

Tuple containing final time step and ArrayContainer with final state

Notes

The number of checkpoints can be configured through config.gradient_config.num_checkpoints. More checkpoints reduce recomputation but increase memory usage.

Source code in src/fdtdx/fdtd/fdtd.py
def checkpointed_fdtd(
    arrays: ArrayContainer,
    objects: ObjectContainer,
    config: SimulationConfig,
    key: jax.Array,
) -> SimulationState:
    """Run an FDTD simulation with gradient checkpointing for memory efficiency.

    This implementation uses checkpointing to reduce memory usage during backpropagation
    by only storing the field state at certain intervals and recomputing intermediate
    states as needed.

    Args:
        arrays (ArrayContainer): Initial state of the simulation containing fields and materials
        objects (ObjectContainer): Collection of physical objects in the simulation
        config (SimulationConfig): Simulation parameters including checkpointing settings
        key (jax.Array): JAX PRNGKey for any stochastic operations

    Returns:
        SimulationState: Tuple containing final time step and ArrayContainer with final state

    Notes:
        The number of checkpoints can be configured through config.gradient_config.num_checkpoints.
        More checkpoints reduce recomputation but increase memory usage.
    """
    arrays = reset_array_container(arrays, objects)
    state = (jnp.asarray(0, dtype=jnp.int32), arrays)
    state = eqxi.while_loop(
        max_steps=config.time_steps_total,
        cond_fun=lambda s: config.time_steps_total > s[0],
        body_fun=partial(
            forward,
            config=config,
            objects=objects,
            key=key,
            record_detectors=True,
            record_boundaries=config.invertible_optimization,
            simulate_boundaries=True,
        ),
        init_val=state,
        kind="lax" if config.only_forward is None else "checkpointed",
        checkpoints=(None if config.gradient_config is None else config.gradient_config.num_checkpoints),
    )

    return state

Gradient checkpointing FDTD implementation for memory-performance tradeoff when using autodiff. In most use-cases this performs worse than the reversible FDTD.

fdtdx.fdtd.full_backward(state, objects, config, key, record_detectors, reset_fields, start_time_step=0)

Perform full backward FDTD propagation from current state to start time.

Uses a while loop to repeatedly call backward() until reaching start_time_step. Leverages time-reversibility of Maxwell's equations.

Parameters:

Name Type Description Default
state SimulationState

Current simulation state tuple (time_step, arrays)

required
objects ObjectContainer

Container with simulation objects (sources, detectors, etc)

required
config SimulationConfig

Simulation configuration parameters

required
key Array

JAX PRNG key for random operations

required
record_detectors bool

Whether to record detector states

required
reset_fields bool

Whether to reset fields after each step

required
start_time_step int

Time step to propagate back to (default: 0)

0

Returns:

Name Type Description
SimulationState SimulationState

Final state after backward propagation

Source code in src/fdtdx/fdtd/backward.py
def full_backward(
    state: SimulationState,
    objects: ObjectContainer,
    config: SimulationConfig,
    key: jax.Array,
    record_detectors: bool,
    reset_fields: bool,
    start_time_step: int = 0,
) -> SimulationState:
    """Perform full backward FDTD propagation from current state to start time.

    Uses a while loop to repeatedly call backward() until reaching start_time_step.
    Leverages time-reversibility of Maxwell's equations.

    Args:
        state: Current simulation state tuple (time_step, arrays)
        objects: Container with simulation objects (sources, detectors, etc)
        config: Simulation configuration parameters
        key: JAX PRNG key for random operations
        record_detectors: Whether to record detector states
        reset_fields: Whether to reset fields after each step
        start_time_step: Time step to propagate back to (default: 0)

    Returns:
        SimulationState: Final state after backward propagation
    """
    s0 = eqxi.while_loop(
        cond_fun=partial(cond_fn, start_time_step=start_time_step),
        body_fun=partial(
            backward,
            config=config,
            objects=objects,
            key=key,
            record_detectors=record_detectors,
            reset_fields=reset_fields,
        ),
        init_val=state,
        kind="lax",
    )
    return s0

Complete backward FDTD propagation from current state to start time. This can be used to check if the compression of boundary interfaces still lead to a physically accurate backward pass.

Custom Time Evolution

fdtdx.fdtd.custom_fdtd_forward(arrays, objects, config, key, reset_container, record_detectors, start_time, end_time)

Run a customizable forward FDTD simulation between specified time steps.

This function provides fine-grained control over the simulation execution, allowing partial time evolution and customization of recording behavior.

Parameters:

Name Type Description Default
arrays ArrayContainer

Initial state of the simulation

required
objects ObjectContainer

Collection of physical objects

required
config SimulationConfig

Simulation parameters

required
key Array

JAX PRNGKey for stochastic operations

required
reset_container bool

Whether to reset the array container before starting

required
record_detectors bool

Whether to record detector readings

required
start_time int | Array

Time step to start from

required
end_time int | Array

Time step to end at

required

Returns:

Name Type Description
SimulationState SimulationState

Tuple containing final time step and ArrayContainer with final state

Notes

This function is useful for implementing custom simulation strategies or running partial simulations for analysis purposes.

Source code in src/fdtdx/fdtd/fdtd.py
def custom_fdtd_forward(
    arrays: ArrayContainer,
    objects: ObjectContainer,
    config: SimulationConfig,
    key: jax.Array,
    reset_container: bool,
    record_detectors: bool,
    start_time: int | jax.Array,
    end_time: int | jax.Array,
) -> SimulationState:
    """Run a customizable forward FDTD simulation between specified time steps.

    This function provides fine-grained control over the simulation execution,
    allowing partial time evolution and customization of recording behavior.

    Args:
        arrays (ArrayContainer): Initial state of the simulation
        objects (ObjectContainer): Collection of physical objects
        config (SimulationConfig): Simulation parameters
        key (jax.Array): JAX PRNGKey for stochastic operations
        reset_container (bool): Whether to reset the array container before starting
        record_detectors (bool): Whether to record detector readings
        start_time (int | jax.Array): Time step to start from
        end_time (int | jax.Array): Time step to end at

    Returns:
        SimulationState: Tuple containing final time step and ArrayContainer with final state

    Notes:
        This function is useful for implementing custom simulation strategies or
        running partial simulations for analysis purposes.
    """
    if reset_container:
        arrays = reset_array_container(arrays, objects)
    state = (jnp.asarray(start_time, dtype=jnp.int32), arrays)
    state = eqxi.while_loop(
        max_steps=config.time_steps_total,
        cond_fun=lambda s: end_time > s[0],
        body_fun=partial(
            forward,
            config=config,
            objects=objects,
            key=key,
            record_detectors=record_detectors,
            record_boundaries=False,
            simulate_boundaries=True,
        ),
        init_val=state,
        kind="lax",
        checkpoints=None,
    )

    return state

Customizable FDTD implementation for partial time evolution and analysis. If used smartly, this can make simulation a bit faster, but in most use-cases this is not necessary.

Python Objects used for FDTD simulation

fdtdx.fdtd.ArrayContainer

Bases: ExtendedTreeClass

Container for simulation field arrays and states.

This class holds the electromagnetic field arrays and various state information needed during FDTD simulation. It includes the E and H fields, material properties, and states for boundaries, detectors and recordings.

Attributes:

Name Type Description
E Array

Electric field array.

H Array

Magnetic field array.

inv_permittivities Array

Inverse permittivity values array.

inv_permeabilities Array | float

Inverse permeability values array.

boundary_states dict[str, BaseBoundaryState]

Dictionary mapping boundary names to their states.

detector_states dict[str, DetectorState]

Dictionary mapping detector names to their states.

recording_state RecordingState | None

Optional state for recording simulation data.

Source code in src/fdtdx/fdtd/container.py
@extended_autoinit
class ArrayContainer(ExtendedTreeClass):
    """Container for simulation field arrays and states.

    This class holds the electromagnetic field arrays and various state information
    needed during FDTD simulation. It includes the E and H fields, material properties,
    and states for boundaries, detectors and recordings.

    Attributes:
        E: Electric field array.
        H: Magnetic field array.
        inv_permittivities: Inverse permittivity values array.
        inv_permeabilities: Inverse permeability values array.
        boundary_states: Dictionary mapping boundary names to their states.
        detector_states: Dictionary mapping detector names to their states.
        recording_state: Optional state for recording simulation data.
    """

    E: jax.Array
    H: jax.Array
    inv_permittivities: jax.Array
    inv_permeabilities: jax.Array | float
    boundary_states: dict[str, BaseBoundaryState]
    detector_states: dict[str, DetectorState]
    recording_state: RecordingState | None

Container holding the electric/magnetic fields as well as permittivity/permeability arrays for simulation

fdtdx.fdtd.ObjectContainer

Bases: ExtendedTreeClass

Container for managing simulation objects and their relationships.

This class provides a structured way to organize and access different types of simulation objects like sources, detectors, PML/periodic boundaries and devices. It maintains object lists and provides filtered access to specific object types.

Attributes:

Name Type Description
object_list list[SimulationObject]

List of all simulation objects in the container.

volume_idx int

Index of the volume object in the object list.

Source code in src/fdtdx/fdtd/container.py
@extended_autoinit
class ObjectContainer(ExtendedTreeClass):
    """Container for managing simulation objects and their relationships.

    This class provides a structured way to organize and access different types of simulation
    objects like sources, detectors, PML/periodic boundaries and devices. It maintains object lists
    and provides filtered access to specific object types.

    Attributes:
        object_list: List of all simulation objects in the container.
        volume_idx: Index of the volume object in the object list.
    """

    object_list: list[SimulationObject]
    volume_idx: int

    @property
    def volume(self) -> SimulationObject:
        return self.object_list[self.volume_idx]

    @property
    def objects(self) -> list[SimulationObject]:
        return self.object_list

    @property
    def static_material_objects(self) -> list[StaticMaterialObject]:
        return [o for o in self.objects if isinstance(o, StaticMaterialObject)]

    @property
    def sources(self) -> list[Source]:
        return [o for o in self.objects if isinstance(o, Source)]

    @property
    def devices(self) -> list[BaseDevice]:
        return [o for o in self.objects if isinstance(o, BaseDevice)]

    @property
    def discrete_devices(self) -> list[DiscreteDevice]:
        return [o for o in self.objects if isinstance(o, DiscreteDevice)]

    @property
    def continous_devices(self) -> list[ContinuousDevice]:
        return [o for o in self.objects if isinstance(o, ContinuousDevice)]

    @property
    def detectors(self) -> list[Detector]:
        return [o for o in self.objects if isinstance(o, Detector)]

    @property
    def forward_detectors(self) -> list[Detector]:
        return [o for o in self.detectors if not o.inverse]

    @property
    def backward_detectors(self) -> list[Detector]:
        return [o for o in self.detectors if o.inverse]

    @property
    def pml_objects(self) -> list[PerfectlyMatchedLayer]:
        return [o for o in self.objects if isinstance(o, PerfectlyMatchedLayer)]

    @property
    def periodic_objects(self) -> list[PeriodicBoundary]:
        return [o for o in self.objects if isinstance(o, PeriodicBoundary)]

    @property
    def boundary_objects(self) -> list[BaseBoundary]:
        return [o for o in self.objects if isinstance(o, (PerfectlyMatchedLayer, PeriodicBoundary))]

    @property
    def all_objects_non_magnetic(self) -> bool:
        def _fn(m: Material):
            return not m.is_magnetic

        return self._is_material_fn_true_for_all(_fn)

    @property
    def all_objects_non_conductive(self) -> bool:
        def _fn(m: Material):
            return not m.is_conductive

        return self._is_material_fn_true_for_all(_fn)

    def _is_material_fn_true_for_all(
        self,
        fn: Callable[[Material], bool],
    ) -> bool:
        for o in self.objects:
            if not isinstance(o, StaticMaterialObject) and not isinstance(o, BaseDevice):
                continue
            if isinstance(o.material, Material):
                if not fn(o.material):
                    return False
            elif isinstance(o.material, dict):
                for v in o.material.values():
                    if not fn(v):
                        return False
            else:
                if fn(o.material.start_material):
                    return False
                if fn(o.material.end_material):
                    return False
        return True

    def __iter__(self):
        return iter(self.object_list)

    def __getitem__(
        self,
        key: str,
    ) -> SimulationObject:
        for o in self.objects:
            if o.name == key:
                return o
        raise ValueError(f"Key {key} does not exist in object list: {[o.name for o in self.objects]}")

    def replace_sources(
        self,
        sources: list[Source],
    ) -> Self:
        new_objects = [o for o in self.objects if o not in self.sources] + sources
        self = self.aset("object_list", new_objects)
        return self

Container holding all the objects in a simulation scene

fdtdx.fdtd.ParameterContainer = dict[str, dict[str, jax.Array] | jax.Array] module-attribute

Dictionary holding the parameters for every device in the simulation

fdtdx.fdtd.SimulationState = tuple[jax.Array, ArrayContainer] module-attribute

Simulation state returned by the FDTD simulations. This is a tuple of the simulation time step and an array container.