Skip to content

Core FDTD Algorithms

Memory-Efficient Implementations

Reversible FDTD

fdtdx.fdtd.reversible_fdtd(arrays, objects, config, key)

Run a memory-efficient differentiable FDTD simulation leveraging time-reversal symmetry.

This implementation exploits the time-reversal symmetry of Maxwell's equations to perform backpropagation without storing the electromagnetic fields at each time step. During the backward pass, the fields are reconstructed by running the simulation in reverse, only requiring O(1) memory storage instead of O(T) where T is the number of time steps.

The only exception is boundary conditions which break time-reversal symmetry - these are recorded during the forward pass and replayed during backpropagation.

Parameters:

Name Type Description Default
arrays ArrayContainer

Initial state of the simulation containing: - E, H: Electric and magnetic field arrays - inv_permittivities, inv_permeabilities: Material properties - boundary_states: Dictionary of boundary conditions - detector_states: Dictionary of field detectors - recording_state: Optional state for recording field evolution

required
objects ObjectContainer

Collection of physical objects in the simulation (sources, detectors, boundaries, etc.)

required
config SimulationConfig

Simulation parameters including: - time_steps_total: Total number of steps to simulate - invertible_optimization: Whether to record boundaries for backprop

required
key Array

JAX PRNGKey for any stochastic operations

required

Returns:

Name Type Description
SimulationState SimulationState

Tuple containing: - Final time step (int) - ArrayContainer with the final state of all fields and components

Notes

The implementation uses custom vector-Jacobian products (VJPs) to enable efficient backpropagation through the entire simulation while maintaining numerical stability. This makes it suitable for gradient-based optimization of electromagnetic designs.

Source code in src/fdtdx/fdtd/fdtd.py
def reversible_fdtd(
    arrays: ArrayContainer,
    objects: ObjectContainer,
    config: SimulationConfig,
    key: jax.Array,
) -> SimulationState:
    """Run a memory-efficient differentiable FDTD simulation leveraging time-reversal symmetry.

    This implementation exploits the time-reversal symmetry of Maxwell's equations to perform
    backpropagation without storing the electromagnetic fields at each time step. During the
    backward pass, the fields are reconstructed by running the simulation in reverse, only
    requiring O(1) memory storage instead of O(T) where T is the number of time steps.

    The only exception is boundary conditions which break time-reversal symmetry - these are
    recorded during the forward pass and replayed during backpropagation.

    Args:
        arrays (ArrayContainer): Initial state of the simulation containing:
            - E, H: Electric and magnetic field arrays
            - inv_permittivities, inv_permeabilities: Material properties
            - boundary_states: Dictionary of boundary conditions
            - detector_states: Dictionary of field detectors
            - recording_state: Optional state for recording field evolution
        objects (ObjectContainer): Collection of physical objects in the simulation
            (sources, detectors, boundaries, etc.)
        config (SimulationConfig): Simulation parameters including:
            - time_steps_total: Total number of steps to simulate
            - invertible_optimization: Whether to record boundaries for backprop
        key (jax.Array): JAX PRNGKey for any stochastic operations

    Returns:
        SimulationState: Tuple containing:
            - Final time step (int)
            - ArrayContainer with the final state of all fields and components

    Notes:
        The implementation uses custom vector-Jacobian products (VJPs) to enable
        efficient backpropagation through the entire simulation while maintaining
        numerical stability. This makes it suitable for gradient-based optimization
        of electromagnetic designs.
    """
    arrays = reset_array_container(
        arrays,
        objects,
    )

    def reversible_fdtd_base(
        arr: ArrayContainer,
    ) -> SimulationState:
        """Core implementation of reversible FDTD simulation.

        Performs the main FDTD time-stepping loop using a while loop that respects
        JAX's functional programming model.

        Args:
            arr: ArrayContainer with initial field state and material properties

        Returns:
            SimulationState tuple containing:
                - Final time step
                - ArrayContainer with final simulation state
        """
        state = (jnp.asarray(0, dtype=jnp.int32), arr)
        state = eqxi.while_loop(
            max_steps=config.time_steps_total,
            cond_fun=lambda s: config.time_steps_total > s[0],
            body_fun=partial(
                forward,
                config=config,
                objects=objects,
                key=key,
                record_detectors=True,
                record_boundaries=config.invertible_optimization,
                simulate_boundaries=True,
            ),
            init_val=state,
            kind="lax",
        )
        return (state[0], state[1])

    @jax.custom_vjp
    def reversible_fdtd_primal(
        E: jax.Array,
        H: jax.Array,
        inv_permittivities: jax.Array,
        inv_permeabilities: jax.Array,
        boundary_states: dict[str, BoundaryState],
        detector_states: dict[str, DetectorState],
        recording_state: RecordingState | None,
    ):
        arr = ArrayContainer(
            E=E,
            H=H,
            inv_permittivities=inv_permittivities,
            inv_permeabilities=inv_permeabilities,
            boundary_states=boundary_states,
            detector_states=detector_states,
            recording_state=recording_state,
        )
        state = reversible_fdtd_base(arr)
        return (
            state[0],
            state[1].E,
            state[1].H,
            state[1].inv_permittivities,
            state[1].inv_permeabilities,
            state[1].boundary_states,
            state[1].detector_states,
            state[1].recording_state,
        )

    def body_fn(
        sr_tuple,
    ):
        state, cot = sr_tuple
        state = backward(
            state=state,
            config=config,
            objects=objects,
            key=key,
            record_detectors=False,
            reset_fields=False,
        )
        _, update_vjp = jax.vjp(
            partial(
                forward_single_args_wrapper,
                config=config,
                objects=objects,
                key=key,
                record_detectors=True,
                record_boundaries=False,
                simulate_boundaries=True,
            ),
            state[0],
            state[1].E,
            state[1].H,
            state[1].inv_permittivities,
            state[1].inv_permeabilities,
            state[1].boundary_states,
            state[1].detector_states,
            state[1].recording_state,
        )

        cot = update_vjp(cot)
        return state, cot

    def cond_fun(
        sr_tuple,
        start_time_step: int,
    ):
        s_k, r_k = sr_tuple
        del r_k
        time_step = s_k[0]
        return time_step >= start_time_step

    def fdtd_bwd(
        residual,
        cot,
    ):
        """Backward pass for reversible FDTD simulation.

        Implements the custom vector-Jacobian product for backpropagation through
        the FDTD simulation by leveraging time-reversibility.

        Args:
            residual: Tuple containing the final simulation state including:
                - Time step
                - E, H field arrays
                - Material properties
                - Boundary and detector states
                - Recording state
            cot: Cotangent values for gradient computation

        Returns:
            Tuple of cotangent values for each input parameter
        """
        (
            res_time_step,
            res_E,
            res_H,
            res_inv_permittivities,
            res_inv_permeabilities,
            res_boundary_states,
            res_detector_states,
            res_recording_state,
        ) = residual

        s_k = ArrayContainer(
            E=res_E,
            H=res_H,
            inv_permittivities=res_inv_permittivities,
            inv_permeabilities=res_inv_permeabilities,
            boundary_states=res_boundary_states,
            detector_states=res_detector_states,
            recording_state=res_recording_state,
        )

        _, cot = eqxi.while_loop(
            cond_fun=partial(cond_fun, start_time_step=0),
            body_fun=body_fn,
            init_val=((res_time_step, s_k), cot),
            kind="lax",
        )
        return (
            None,  # cot[1],
            None,  # cot[2],
            cot[3],
            cot[4],
            None,  # cot[5]
            None,  # cot[6],
            None,  # cot[7],
        )

    def fdtd_fwd(
        E: jax.Array,
        H: jax.Array,
        inv_permittivities: jax.Array,
        inv_permeabilities: jax.Array,
        boundary_states: dict[str, BoundaryState],
        detector_states: dict[str, DetectorState],
        recording_state: RecordingState | None,
    ):
        """Forward pass for reversible FDTD simulation.

        Performs the forward FDTD simulation and prepares residuals for the backward pass.

        Args:
            E: Electric field array
            H: Magnetic field array
            inv_permittivities: Inverse permittivity values
            inv_permeabilities: Inverse permeability values
            boundary_states: Dictionary mapping boundary names to their states
            detector_states: Dictionary mapping detector names to their states
            recording_state: Optional state for recording field evolution

        Returns:
            Tuple containing:
                - Primal outputs (final simulation state)
                - Residuals for backward pass
        """
        arr = ArrayContainer(
            E=E,
            H=H,
            inv_permittivities=inv_permittivities,
            inv_permeabilities=inv_permeabilities,
            boundary_states=boundary_states,
            detector_states=detector_states,
            recording_state=recording_state,
        )
        s_k = reversible_fdtd_base(arr)

        primal_out = (
            s_k[0],
            s_k[1].E,
            s_k[1].H,
            s_k[1].inv_permittivities,
            s_k[1].inv_permeabilities,
            s_k[1].boundary_states,
            s_k[1].detector_states,
            s_k[1].recording_state,  # None
        )
        residual = (
            s_k[0],
            s_k[1].E,
            s_k[1].H,
            s_k[1].inv_permittivities,
            s_k[1].inv_permeabilities,
            s_k[1].boundary_states,
            s_k[1].detector_states,
            s_k[1].recording_state,
        )
        return primal_out, residual

    reversible_fdtd_primal.defvjp(fdtd_fwd, fdtd_bwd)

    (
        time_step,
        E,
        H,
        inv_permittivities,
        inv_permeabilities,
        boundary_states,
        detector_states,
        recording_state,
    ) = reversible_fdtd_primal(
        E=arrays.E,
        H=arrays.H,
        inv_permittivities=arrays.inv_permittivities,
        inv_permeabilities=arrays.inv_permeabilities,
        boundary_states=arrays.boundary_states,
        detector_states=arrays.detector_states,
        recording_state=arrays.recording_state,
    )
    out_arrs = ArrayContainer(
        E=E,
        H=H,
        inv_permittivities=inv_permittivities,
        inv_permeabilities=inv_permeabilities,
        boundary_states=boundary_states,
        detector_states=detector_states,
        recording_state=recording_state,
    )
    return time_step, out_arrs

Time-reversal symmetric FDTD implementation with O(1) memory usage.

Checkpointed FDTD

fdtdx.fdtd.checkpointed_fdtd(arrays, objects, config, key)

Run an FDTD simulation with gradient checkpointing for memory efficiency.

This implementation uses checkpointing to reduce memory usage during backpropagation by only storing the field state at certain intervals and recomputing intermediate states as needed.

Parameters:

Name Type Description Default
arrays ArrayContainer

Initial state of the simulation containing fields and materials

required
objects ObjectContainer

Collection of physical objects in the simulation

required
config SimulationConfig

Simulation parameters including checkpointing settings

required
key Array

JAX PRNGKey for any stochastic operations

required

Returns:

Name Type Description
SimulationState SimulationState

Tuple containing final time step and ArrayContainer with final state

Notes

The number of checkpoints can be configured through config.gradient_config.num_checkpoints. More checkpoints reduce recomputation but increase memory usage.

Source code in src/fdtdx/fdtd/fdtd.py
def checkpointed_fdtd(
    arrays: ArrayContainer,
    objects: ObjectContainer,
    config: SimulationConfig,
    key: jax.Array,
) -> SimulationState:
    """Run an FDTD simulation with gradient checkpointing for memory efficiency.

    This implementation uses checkpointing to reduce memory usage during backpropagation
    by only storing the field state at certain intervals and recomputing intermediate
    states as needed.

    Args:
        arrays (ArrayContainer): Initial state of the simulation containing fields and materials
        objects (ObjectContainer): Collection of physical objects in the simulation
        config (SimulationConfig): Simulation parameters including checkpointing settings
        key (jax.Array): JAX PRNGKey for any stochastic operations

    Returns:
        SimulationState: Tuple containing final time step and ArrayContainer with final state

    Notes:
        The number of checkpoints can be configured through config.gradient_config.num_checkpoints.
        More checkpoints reduce recomputation but increase memory usage.
    """
    arrays = reset_array_container(arrays, objects)
    state = (jnp.asarray(0, dtype=jnp.int32), arrays)
    state = eqxi.while_loop(
        max_steps=config.time_steps_total,
        cond_fun=lambda s: config.time_steps_total > s[0],
        body_fun=partial(
            forward,
            config=config,
            objects=objects,
            key=key,
            record_detectors=True,
            record_boundaries=config.invertible_optimization,
            simulate_boundaries=True,
        ),
        init_val=state,
        kind="lax" if config.only_forward is None else "checkpointed",
        checkpoints=(None if config.gradient_config is None else config.gradient_config.num_checkpoints),
    )

    return state

Gradient checkpointing FDTD implementation for memory-performance tradeoff.

Custom Time Evolution

fdtdx.fdtd.custom_fdtd_forward(arrays, objects, config, key, reset_container, record_detectors, start_time, end_time)

Run a customizable forward FDTD simulation between specified time steps.

This function provides fine-grained control over the simulation execution, allowing partial time evolution and customization of recording behavior.

Parameters:

Name Type Description Default
arrays ArrayContainer

Initial state of the simulation

required
objects ObjectContainer

Collection of physical objects

required
config SimulationConfig

Simulation parameters

required
key Array

JAX PRNGKey for stochastic operations

required
reset_container bool

Whether to reset the array container before starting

required
record_detectors bool

Whether to record detector readings

required
start_time int | Array

Time step to start from

required
end_time int | Array

Time step to end at

required

Returns:

Name Type Description
SimulationState SimulationState

Tuple containing final time step and ArrayContainer with final state

Notes

This function is useful for implementing custom simulation strategies or running partial simulations for analysis purposes.

Source code in src/fdtdx/fdtd/fdtd.py
def custom_fdtd_forward(
    arrays: ArrayContainer,
    objects: ObjectContainer,
    config: SimulationConfig,
    key: jax.Array,
    reset_container: bool,
    record_detectors: bool,
    start_time: int | jax.Array,
    end_time: int | jax.Array,
) -> SimulationState:
    """Run a customizable forward FDTD simulation between specified time steps.

    This function provides fine-grained control over the simulation execution,
    allowing partial time evolution and customization of recording behavior.

    Args:
        arrays (ArrayContainer): Initial state of the simulation
        objects (ObjectContainer): Collection of physical objects
        config (SimulationConfig): Simulation parameters
        key (jax.Array): JAX PRNGKey for stochastic operations
        reset_container (bool): Whether to reset the array container before starting
        record_detectors (bool): Whether to record detector readings
        start_time (int | jax.Array): Time step to start from
        end_time (int | jax.Array): Time step to end at

    Returns:
        SimulationState: Tuple containing final time step and ArrayContainer with final state

    Notes:
        This function is useful for implementing custom simulation strategies or
        running partial simulations for analysis purposes.
    """
    if reset_container:
        arrays = reset_array_container(arrays, objects)
    state = (jnp.asarray(start_time, dtype=jnp.int32), arrays)
    state = eqxi.while_loop(
        max_steps=config.time_steps_total,
        cond_fun=lambda s: end_time > s[0],
        body_fun=partial(
            forward,
            config=config,
            objects=objects,
            key=key,
            record_detectors=record_detectors,
            record_boundaries=False,
            simulate_boundaries=True,
        ),
        init_val=state,
        kind="lax",
        checkpoints=None,
    )

    return state

Customizable FDTD implementation for partial time evolution and analysis.

Forward Propagation

fdtdx.fdtd.forward.forward(state, config, objects, key, record_detectors, record_boundaries, simulate_boundaries)

Performs one forward time step of the FDTD simulation.

Implements the core FDTD update scheme based on Maxwell's equations discretized on the Yee grid. Updates include: 1. Electric field update using curl of H field 2. Magnetic field update using curl of E field 3. Optional PML boundary conditions 4. Optional detector state updates 5. Optional recording of boundary values for gradient computation

The implementation leverages JAX for automatic compilation and GPU acceleration. Field updates follow the standard staggered time stepping of the Yee scheme.

Parameters:

Name Type Description Default
state SimulationState

Current simulation state (time step and field values)

required
config SimulationConfig

Simulation configuration parameters

required
objects ObjectContainer

Container with sources, PML and other simulation objects

required
key Array

Random key for compression

required
record_detectors bool

Whether to record detector values

required
record_boundaries bool

Whether to record boundary values for gradients

required
simulate_boundaries bool

Whether to apply PML boundary conditions

required

Returns:

Type Description
SimulationState

Updated simulation state for the next time step

Source code in src/fdtdx/fdtd/forward.py
def forward(
    state: SimulationState,
    config: SimulationConfig,
    objects: ObjectContainer,
    key: jax.Array,
    record_detectors: bool,
    record_boundaries: bool,
    simulate_boundaries: bool,
) -> SimulationState:
    """Performs one forward time step of the FDTD simulation.

    Implements the core FDTD update scheme based on Maxwell's equations discretized on the Yee grid.
    Updates include:
    1. Electric field update using curl of H field
    2. Magnetic field update using curl of E field
    3. Optional PML boundary conditions
    4. Optional detector state updates
    5. Optional recording of boundary values for gradient computation

    The implementation leverages JAX for automatic compilation and GPU acceleration.
    Field updates follow the standard staggered time stepping of the Yee scheme.

    Args:
        state: Current simulation state (time step and field values)
        config: Simulation configuration parameters
        objects: Container with sources, PML and other simulation objects
        key: Random key for compression
        record_detectors: Whether to record detector values
        record_boundaries: Whether to record boundary values for gradients
        simulate_boundaries: Whether to apply PML boundary conditions

    Returns:
        Updated simulation state for the next time step
    """
    time_step, arrays = state
    H_prev = arrays.H
    arrays = update_E(
        time_step=time_step,
        arrays=arrays,
        objects=objects,
        config=config,
        simulate_boundaries=simulate_boundaries,
    )
    arrays = update_H(
        time_step=time_step,
        arrays=arrays,
        objects=objects,
        config=config,
        simulate_boundaries=simulate_boundaries,
    )

    if record_boundaries:
        arrays = jax.lax.stop_gradient(
            collect_interfaces(
                time_step=time_step,
                arrays=arrays,
                objects=objects,
                config=config,
                key=key,
            )
        )

    if record_detectors:
        arrays = update_detector_states(
            time_step=time_step,
            arrays=arrays,
            objects=objects,
            H_prev=H_prev,
            inverse=False,
        )

    next_state = (time_step + 1, arrays)
    return next_state

Standard forward FDTD time stepping implementation.

fdtdx.fdtd.forward.forward_single_args_wrapper(time_step, E, H, inv_permittivities, inv_permeabilities, boundary_states, detector_states, recording_state, config, objects, key, record_detectors, record_boundaries, simulate_boundaries)

Wrapper function that unpacks ArrayContainer into individual arrays for JAX transformations.

This function provides a JAX-compatible interface by handling individual arrays instead of container objects. It converts between the array-based interface required by JAX and the object-oriented ArrayContainer interface used by the rest of the FDTD implementation.

Parameters:

Name Type Description Default
time_step Array

Current simulation time step

required
E Array

Electric field array

required
H Array

Magnetic field array

required
inv_permittivities Array

Inverse permittivity values

required
inv_permeabilities Array

Inverse permeability values

required
boundary_states dict[str, BoundaryState]

PML boundary conditions state

required
detector_states dict[str, DetectorState]

States of field detectors

required
recording_state RecordingState | None

Optional state for recording field values

required
config SimulationConfig

Simulation configuration parameters

required
objects ObjectContainer

Container with sources and other simulation objects

required
key Array

Random key for compression

required
record_detectors bool

Whether to record detector values

required
record_boundaries bool

Whether to record boundary values

required
simulate_boundaries bool

Whether to apply PML boundary conditions

required

Returns:

Type Description
tuple[Array, Array, Array, Array, Array, dict[str, BoundaryState], dict[str, DetectorState], RecordingState | None]

Tuple containing: - Updated time step - Updated E field array - Updated H field array - Updated inverse permittivities - Updated inverse permeabilities - Updated boundary states - Updated detector states - Updated recording state

Source code in src/fdtdx/fdtd/forward.py
def forward_single_args_wrapper(
    time_step: jax.Array,
    E: jax.Array,
    H: jax.Array,
    inv_permittivities: jax.Array,
    inv_permeabilities: jax.Array,
    boundary_states: dict[str, BoundaryState],
    detector_states: dict[str, DetectorState],
    recording_state: RecordingState | None,
    config: SimulationConfig,
    objects: ObjectContainer,
    key: jax.Array,
    record_detectors: bool,
    record_boundaries: bool,
    simulate_boundaries: bool,
) -> tuple[
    jax.Array,
    jax.Array,
    jax.Array,
    jax.Array,
    jax.Array,
    dict[str, BoundaryState],
    dict[str, DetectorState],
    RecordingState | None,
]:
    """Wrapper function that unpacks ArrayContainer into individual arrays for JAX transformations.

    This function provides a JAX-compatible interface by handling individual arrays instead of
    container objects. It converts between the array-based interface required by JAX and the
    object-oriented ArrayContainer interface used by the rest of the FDTD implementation.

    Args:
        time_step: Current simulation time step
        E: Electric field array
        H: Magnetic field array
        inv_permittivities: Inverse permittivity values
        inv_permeabilities: Inverse permeability values
        boundary_states: PML boundary conditions state
        detector_states: States of field detectors
        recording_state: Optional state for recording field values
        config: Simulation configuration parameters
        objects: Container with sources and other simulation objects
        key: Random key for compression
        record_detectors: Whether to record detector values
        record_boundaries: Whether to record boundary values
        simulate_boundaries: Whether to apply PML boundary conditions

    Returns:
        Tuple containing:
            - Updated time step
            - Updated E field array
            - Updated H field array
            - Updated inverse permittivities
            - Updated inverse permeabilities
            - Updated boundary states
            - Updated detector states
            - Updated recording state
    """
    """Wrapper function that unpacks ArrayContainer into individual arrays for JAX transformations.

    This function provides a JAX-compatible interface by handling individual arrays instead of
    container objects. It converts between the array-based interface required by JAX and the
    object-oriented ArrayContainer interface used by the rest of the FDTD implementation.

    Args:
        time_step: Current simulation time step
        E: Electric field array
        H: Magnetic field array 
        inv_permittivities: Inverse permittivity values
        inv_permeabilities: Inverse permeability values
        boundary_states: PML boundary conditions state
        detector_states: States of field detectors
        recording_state: Optional state for recording field values
        config: Simulation configuration parameters
        objects: Container with sources and other simulation objects
        key: Random key for compression
        record_detectors: Whether to record detector values
        record_boundaries: Whether to record boundary values
        simulate_boundaries: Whether to apply PML boundary conditions

    Returns:
        Tuple containing the updated time step and field arrays
    """
    arr = ArrayContainer(
        E=E,
        H=H,
        inv_permittivities=inv_permittivities,
        inv_permeabilities=inv_permeabilities,
        boundary_states=boundary_states,
        detector_states=detector_states,
        recording_state=recording_state,
    )
    state = forward(
        state=(time_step, arr),
        config=config,
        objects=objects,
        key=key,
        record_detectors=record_detectors,
        record_boundaries=record_boundaries,
        simulate_boundaries=simulate_boundaries,
    )
    return (
        state[0],
        state[1].E,
        state[1].H,
        state[1].inv_permittivities,
        state[1].inv_permeabilities,
        state[1].boundary_states,
        state[1].detector_states,
        state[1].recording_state,
    )

JAX-compatible wrapper for forward propagation.

Backward Propagation

fdtdx.fdtd.backward.full_backward(state, objects, config, key, record_detectors, reset_fields, start_time_step=0)

Perform full backward FDTD propagation from current state to start time.

Uses a while loop to repeatedly call backward() until reaching start_time_step. Leverages time-reversibility of Maxwell's equations.

Parameters:

Name Type Description Default
state SimulationState

Current simulation state tuple (time_step, arrays)

required
objects ObjectContainer

Container with simulation objects (sources, detectors, etc)

required
config SimulationConfig

Simulation configuration parameters

required
key Array

JAX PRNG key for random operations

required
record_detectors bool

Whether to record detector states

required
reset_fields bool

Whether to reset fields after each step

required
start_time_step int

Time step to propagate back to (default: 0)

0

Returns:

Name Type Description
SimulationState SimulationState

Final state after backward propagation

Source code in src/fdtdx/fdtd/backward.py
def full_backward(
    state: SimulationState,
    objects: ObjectContainer,
    config: SimulationConfig,
    key: jax.Array,
    record_detectors: bool,
    reset_fields: bool,
    start_time_step: int = 0,
) -> SimulationState:
    """Perform full backward FDTD propagation from current state to start time.

    Uses a while loop to repeatedly call backward() until reaching start_time_step.
    Leverages time-reversibility of Maxwell's equations.

    Args:
        state: Current simulation state tuple (time_step, arrays)
        objects: Container with simulation objects (sources, detectors, etc)
        config: Simulation configuration parameters
        key: JAX PRNG key for random operations
        record_detectors: Whether to record detector states
        reset_fields: Whether to reset fields after each step
        start_time_step: Time step to propagate back to (default: 0)

    Returns:
        SimulationState: Final state after backward propagation
    """
    s0 = eqxi.while_loop(
        cond_fun=partial(cond_fn, start_time_step=start_time_step),
        body_fun=partial(
            backward,
            config=config,
            objects=objects,
            key=key,
            record_detectors=record_detectors,
            reset_fields=reset_fields,
        ),
        init_val=state,
        kind="lax",
    )
    return s0

Complete backward FDTD propagation from current state to start time.

fdtdx.fdtd.backward.backward(state, config, objects, key, record_detectors, reset_fields, fields_to_reset=('E', 'H'))

Perform one step of backward FDTD propagation.

Updates fields from time step t to t-1 using time-reversed Maxwell's equations. Handles interfaces, field updates, optional field resetting, and detector recording.

Parameters:

Name Type Description Default
state SimulationState

Current simulation state tuple (time_step, arrays)

required
config SimulationConfig

Simulation configuration parameters

required
objects ObjectContainer

Container with simulation objects (sources, detectors, etc)

required
key Array

JAX PRNG key for random operations

required
record_detectors bool

Whether to record detector states

required
reset_fields bool

Whether to reset fields after updates

required
fields_to_reset Sequence[str]

Which fields to reset if reset_fields is True

('E', 'H')

Returns:

Name Type Description
SimulationState SimulationState

Updated state after one backward step

Source code in src/fdtdx/fdtd/backward.py
def backward(
    state: SimulationState,
    config: SimulationConfig,
    objects: ObjectContainer,
    key: jax.Array,
    record_detectors: bool,
    reset_fields: bool,
    fields_to_reset: Sequence[str] = ("E", "H"),
) -> SimulationState:
    """Perform one step of backward FDTD propagation.

    Updates fields from time step t to t-1 using time-reversed Maxwell's equations.
    Handles interfaces, field updates, optional field resetting, and detector recording.

    Args:
        state: Current simulation state tuple (time_step, arrays)
        config: Simulation configuration parameters
        objects: Container with simulation objects (sources, detectors, etc)
        key: JAX PRNG key for random operations
        record_detectors: Whether to record detector states
        reset_fields: Whether to reset fields after updates
        fields_to_reset: Which fields to reset if reset_fields is True

    Returns:
        SimulationState: Updated state after one backward step
    """
    time_step, arrays = state
    time_step = time_step - 1

    arrays = add_interfaces(
        time_step=time_step,
        arrays=arrays,
        objects=objects,
        config=config,
        key=key,
    )

    H = arrays.H

    arrays = update_H_reverse(
        time_step=time_step,
        arrays=arrays,
        config=config,
        objects=objects,
    )

    arrays = update_E_reverse(
        time_step=time_step,
        arrays=arrays,
        config=config,
        objects=objects,
    )

    if reset_fields:
        new_fields = {f: getattr(arrays, f) for f in fields_to_reset}
        for pml in objects.pml_objects:
            for name in fields_to_reset:
                new_fields[name] = new_fields[name].at[:, *pml.grid_slice].set(0)
        for name, f in new_fields.items():
            arrays = arrays.aset(name, f)

    if record_detectors:
        arrays = update_detector_states(
            time_step=time_step,
            arrays=arrays,
            objects=objects,
            H_prev=H,
            inverse=True,
        )

    next_state = (time_step, arrays)
    return next_state

Single step backward FDTD propagation.