Skip to content

Detectors

In FDTDX, detectors can be used to perform measurements within a simulation. In conjunction with the Logger, these detectors can also automatically produce plots or even videos.

fdtdx.objects.detectors.EnergyDetector

Bases: Detector

Detector for measuring electromagnetic energy distribution.

This detector computes and records the electromagnetic energy density at specified points in the simulation volume. It can operate in different modes to either record full 3D data, 2D slices, or reduced volume measurements.

Attributes:

Name Type Description
as_slices bool

If True, returns energy measurements as 2D slices through the volume center. If False, returns full 3D volume or reduced measurements.

reduce_volume bool

If True, reduces the volume data to a single energy value by summing. If False, maintains spatial distribution of energy.

Source code in src/fdtdx/objects/detectors/energy.py
@extended_autoinit
class EnergyDetector(Detector):
    """Detector for measuring electromagnetic energy distribution.

    This detector computes and records the electromagnetic energy density at specified
    points in the simulation volume. It can operate in different modes to either record
    full 3D data, 2D slices, or reduced volume measurements.

    Attributes:
        as_slices: If True, returns energy measurements as 2D slices through the volume
            center. If False, returns full 3D volume or reduced measurements.
        reduce_volume: If True, reduces the volume data to a single energy value by
            summing. If False, maintains spatial distribution of energy.
    """

    as_slices: bool = False
    reduce_volume: bool = False

    def _shape_dtype_single_time_step(
        self,
    ) -> dict[str, jax.ShapeDtypeStruct]:
        if self.as_slices and self.reduce_volume:
            raise Exception("Cannot both reduce volume and save mean slices!")
        if self.as_slices:
            gs = self.grid_shape
            return {
                "XY Plane": jax.ShapeDtypeStruct((gs[0], gs[1]), self.dtype),
                "XZ Plane": jax.ShapeDtypeStruct((gs[0], gs[2]), self.dtype),
                "YZ Plane": jax.ShapeDtypeStruct((gs[1], gs[2]), self.dtype),
            }
        if self.reduce_volume:
            return {"energy": jax.ShapeDtypeStruct((1,), self.dtype)}
        return {"energy": jax.ShapeDtypeStruct(self.grid_shape, self.dtype)}

    def update(
        self,
        time_step: jax.Array,
        E: jax.Array,
        H: jax.Array,
        state: DetectorState,
        inv_permittivity: jax.Array,
        inv_permeability: jax.Array | float,
    ) -> DetectorState:
        cur_E = E[:, *self.grid_slice]
        cur_H = H[:, *self.grid_slice]
        cur_inv_permittivity = inv_permittivity[self.grid_slice]
        if isinstance(inv_permeability, jax.Array) and inv_permeability.ndim > 0:
            cur_inv_permeability = inv_permeability[self.grid_slice]
        else:
            cur_inv_permeability = inv_permeability

        energy = compute_energy(
            E=cur_E,
            H=cur_H,
            inv_permittivity=cur_inv_permittivity,
            inv_permeability=cur_inv_permeability,
        )
        arr_idx = self._time_step_to_arr_idx[time_step]
        if self.as_slices:
            energy_xy = energy.mean(axis=2)
            new_xy = state["XY Plane"].at[arr_idx].set(energy_xy)
            energy_xz = energy.mean(axis=1)
            new_xz = state["XZ Plane"].at[arr_idx].set(energy_xz)
            energy_yz = energy.mean(axis=0)
            new_yz = state["YZ Plane"].at[arr_idx].set(energy_yz)
            return {
                "XY Plane": new_xy,
                "XZ Plane": new_xz,
                "YZ Plane": new_yz,
            }
        if self.reduce_volume:
            energy = energy.sum()
        new_full_arr = state["energy"].at[arr_idx].set(energy)
        new_state = {"energy": new_full_arr}
        return new_state

fdtdx.objects.detectors.PoyntingFluxDetector

Bases: Detector

Detector for measuring Poynting flux in electromagnetic simulations.

This detector computes the Poynting flux (power flow) through a specified surface in the simulation volume. It can measure flux in either positive or negative direction along the propagation axis, and optionally reduce measurements to a single value by summing over the detection surface.

Attributes:

Name Type Description
direction Literal['+', '-']

Direction of flux measurement, either "+" for positive or "-" for negative along the propagation axis.

reduce_volume bool

If True, reduces measurements to a single value by summing over the detection surface. If False, maintains spatial distribution.

Source code in src/fdtdx/objects/detectors/poynting_flux.py
@extended_autoinit
class PoyntingFluxDetector(Detector):
    """Detector for measuring Poynting flux in electromagnetic simulations.

    This detector computes the Poynting flux (power flow) through a specified surface
    in the simulation volume. It can measure flux in either positive or negative
    direction along the propagation axis, and optionally reduce measurements to a
    single value by summing over the detection surface.

    Attributes:
        direction: Direction of flux measurement, either "+" for positive or "-" for
            negative along the propagation axis.
        reduce_volume: If True, reduces measurements to a single value by summing
            over the detection surface. If False, maintains spatial distribution.
    """

    direction: Literal["+", "-"] = frozen_field(kind="KW_ONLY")  # type: ignore
    reduce_volume: bool = True

    @property
    def propagation_axis(self) -> int:
        """Determines the axis along which Poynting flux is measured.

        The propagation axis is identified as the dimension with size 1 in the
        detector's grid shape, representing a plane perpendicular to the flux
        measurement direction.

        Returns:
            int: Index of the propagation axis (0 for x, 1 for y, 2 for z)

        Raises:
            Exception: If detector shape does not have exactly one dimension of size 1
        """
        if sum([a == 1 for a in self.grid_shape]) != 1:
            raise Exception(f"Invalid poynting flux detector shape: {self.grid_shape}")
        return self.grid_shape.index(1)

    def _shape_dtype_single_time_step(
        self,
    ) -> dict[str, jax.ShapeDtypeStruct]:
        if self.reduce_volume:
            return {"poynting_flux": jax.ShapeDtypeStruct((1,), self.dtype)}
        return {"poynting_flux": jax.ShapeDtypeStruct(self.grid_shape, self.dtype)}

    def update(
        self,
        time_step: jax.Array,
        E: jax.Array,
        H: jax.Array,
        state: DetectorState,
        inv_permittivity: jax.Array,
        inv_permeability: jax.Array | float,
    ) -> DetectorState:
        del inv_permeability, inv_permittivity
        cur_E = E[:, *self.grid_slice]
        cur_H = H[:, *self.grid_slice]

        pf = poynting_flux(cur_E, cur_H)[self.propagation_axis]
        if self.direction == "-":
            pf = -pf
        if self.reduce_volume:
            pf = pf.sum()
        arr_idx = self._time_step_to_arr_idx[time_step]
        new_full_arr = state["poynting_flux"].at[arr_idx].set(pf)
        new_state = {"poynting_flux": new_full_arr}
        return new_state

propagation_axis: int property

Determines the axis along which Poynting flux is measured.

The propagation axis is identified as the dimension with size 1 in the detector's grid shape, representing a plane perpendicular to the flux measurement direction.

Returns:

Name Type Description
int int

Index of the propagation axis (0 for x, 1 for y, 2 for z)

Raises:

Type Description
Exception

If detector shape does not have exactly one dimension of size 1

fdtdx.objects.detectors.PhasorDetector

Bases: Detector

Detector for measuring phasor components of electromagnetic fields.

This detector computes complex phasor representations of the field components at specified frequencies, enabling frequency-domain analysis of the electromagnetic fields.

Attributes:

Name Type Description
frequencies Sequence[float]

Sequence of frequencies to analyze (in Hz)

as_slices bool

If True, returns results as slices rather than full volume.

reduce_volume bool

If True, reduces the volume of recorded data.

components Sequence[Literal['Ex', 'Ey', 'Ez', 'Hx', 'Hy', 'Hz']]

Sequence of field components to measure. Can include any of: "Ex", "Ey", "Ez", "Hx", "Hy", "Hz".

Source code in src/fdtdx/objects/detectors/phasor.py
@extended_autoinit
class PhasorDetector(Detector):
    """Detector for measuring phasor components of electromagnetic fields.

    This detector computes complex phasor representations of the field components at specified
    frequencies, enabling frequency-domain analysis of the electromagnetic fields.

    Attributes:
        frequencies: Sequence of frequencies to analyze (in Hz)
        as_slices: If True, returns results as slices rather than full volume.
        reduce_volume: If True, reduces the volume of recorded data.
        components: Sequence of field components to measure. Can include any of:
            "Ex", "Ey", "Ez", "Hx", "Hy", "Hz".
    """

    frequencies: Sequence[float] = (None,)  # type: ignore
    as_slices: bool = False
    reduce_volume: bool = False
    components: Sequence[Literal["Ex", "Ey", "Ez", "Hx", "Hy", "Hz"]] = frozen_field(
        default=("Ex", "Ey", "Ez", "Hx", "Hy", "Hz"),
    )
    dtype: jnp.dtype = frozen_field(
        default=jnp.complex64,
        kind="KW_ONLY",
    )

    def __post_init__(
        self,
    ):
        if self.dtype not in [jnp.complex64, jnp.complex128]:
            raise Exception(f"Invalid dtype in PhasorDetector: {self.dtype}")

        # Precompute angular frequencies for vectorization
        self._angular_frequencies = 2 * jnp.pi * jnp.array(self.frequencies)
        self._scale = self._config.time_step_duration / jnp.sqrt(2 * jnp.pi)

    def _num_latent_time_steps(self) -> int:
        return 1

    def _shape_dtype_single_time_step(
        self,
    ) -> dict[str, jax.ShapeDtypeStruct]:
        field_dtype = jnp.complex128 if self.dtype == jnp.float64 else jnp.complex64
        num_components = len(self.components)
        num_frequencies = len(self.frequencies)
        phasor_shape = (num_frequencies, num_components, *self.grid_shape)
        return {"phasor": jax.ShapeDtypeStruct(shape=phasor_shape, dtype=field_dtype)}

    def update(
        self,
        time_step: jax.Array,
        E: jax.Array,
        H: jax.Array,
        state: DetectorState,
        inv_permittivity: jax.Array,
        inv_permeability: jax.Array | float,
    ) -> DetectorState:
        del inv_permeability, inv_permittivity
        time_passed = time_step * self._config.time_step_duration

        E, H = E[:, *self.grid_slice], H[:, *self.grid_slice]
        fields = []
        if "Ex" in self.components:
            fields.append(E[0])
        if "Ey" in self.components:
            fields.append(E[1])
        if "Ez" in self.components:
            fields.append(E[2])
        if "Hx" in self.components:
            fields.append(H[0])
        if "Hy" in self.components:
            fields.append(H[1])
        if "Hz" in self.components:
            fields.append(H[2])

        EH = jnp.stack(fields, axis=0)

        # Vectorized phasor calculation for all frequencies
        phase_angles = self._angular_frequencies[:, None] * time_passed  # Shape: (num_freqs, 1)
        phasors = jnp.exp(1j * phase_angles)  # Shape: (num_freqs, 1)
        new_phasors = EH[None, ...] * phasors[..., None] * self._scale  # Broadcasting handles the multiplication

        if self.reduce_volume:
            # Average over all spatial dimensions
            spatial_axes = tuple(range(2, new_phasors.ndim))  # Skip freq and component axes
            new_phasors = new_phasors.mean(axis=spatial_axes) if spatial_axes else new_phasors

        if self.inverse:
            result = state["phasor"] - new_phasors[None, ...]
        else:
            result = state["phasor"] + new_phasors[None, ...]
        return {"phasor": result.astype(self.dtype)}

fdtdx.objects.detectors.DiffractiveDetector

Bases: Detector

Detector for computing Fourier transforms of fields at specific frequencies and diffraction orders.

This detector computes field amplitudes for specific diffraction orders and frequencies through a specified plane in the simulation volume. It can measure diffraction in either positive or negative direction along the propagation axis.

Attributes:

Name Type Description
frequencies Sequence[float]

List of frequencies to analyze (in Hz)

orders Sequence[Tuple[int, int]]

Tuple of (nx, ny) pairs specifying diffraction orders to compute

direction Literal['+', '-']

Direction of diffraction analysis ("+" or "-") along propagation axis

Source code in src/fdtdx/objects/detectors/diffractive.py
@extended_autoinit
class DiffractiveDetector(Detector):
    """Detector for computing Fourier transforms of fields at specific frequencies and diffraction orders.

    This detector computes field amplitudes for specific diffraction orders and frequencies through
    a specified plane in the simulation volume. It can measure diffraction in either positive or negative
    direction along the propagation axis.

    Attributes:
        frequencies: List of frequencies to analyze (in Hz)
        orders: Tuple of (nx, ny) pairs specifying diffraction orders to compute
        direction: Direction of diffraction analysis ("+" or "-") along propagation axis
    """

    frequencies: Sequence[float] = field(kind="KW_ONLY")  # type: ignore
    orders: Sequence[Tuple[int, int]] = ((0, 0),)
    direction: Literal["+", "-"] = frozen_field(kind="KW_ONLY")  # type: ignore
    dtype: jnp.dtype = frozen_field(default=jnp.complex64, kind="KW_ONLY")

    def __post_init__(self):
        if self.dtype not in [jnp.complex64, jnp.complex128]:
            raise Exception(f"Invalid dtype in DiffractiveDetector: {self.dtype}")

    @property
    def propagation_axis(self) -> int:
        """Determines the axis along which diffraction is measured.

        The propagation axis is identified as the dimension with size 1 in the
        detector's grid shape, representing a plane perpendicular to the diffraction
        measurement direction.

        Returns:
            int: Index of the propagation axis (0 for x, 1 for y, 2 for z)

        Raises:
            Exception: If detector shape does not have exactly one dimension of size 1
        """
        if sum([a == 1 for a in self.grid_shape]) != 1:
            raise Exception(f"Invalid diffractive detector shape: {self.grid_shape}")
        return self.grid_shape.index(1)

    # def _validate_orders(self, wavelength: float) -> None:
    #     """Validate that requested diffraction orders are physically realizable.

    #     Args:
    #         wavelength: Wavelength of the light in meters

    #     Raises:
    #         Exception: If any requested order is not physically realizable
    #     """
    #     if self._Nx is None:
    #         raise Exception("Order info not yet computed. Run update first.")

    #     # Maximum possible orders based on grid
    #     max_nx = self._Nx // 2
    #     max_ny = self._Ny // 2

    #     # Check Nyquist limits for all orders at once
    #     nx_valid = jnp.all(jnp.abs(jnp.array([o[0] for o in self.orders])) <= max_nx)
    #     ny_valid = jnp.all(jnp.abs(jnp.array([o[1] for o in self.orders])) <= max_ny)

    #     if not (nx_valid and ny_valid):
    #         raise Exception(f"Some orders exceed Nyquist limit for grid size ({self._Nx}, {self._Ny})")

    #     # Check physical realizability for all orders at once
    #     k0 = 2 * jnp.pi / wavelength
    #     kt_squared = self._kx_normalized**2 + self._ky_normalized**2

    #     if jnp.any(kt_squared > k0**2):
    #         raise Exception(f"Some orders are evanescent at wavelength {wavelength*1e9:.1f}nm")

    def _shape_dtype_single_time_step(self) -> dict[str, jax.ShapeDtypeStruct]:
        num_freqs = len(self.frequencies)
        num_orders = len(self.orders)

        shape = (num_freqs, num_orders)

        # Ensure we're using a complex dtype
        field_dtype = jnp.complex128 if self.dtype == jnp.float64 else jnp.complex64
        return {"diffractive": jax.ShapeDtypeStruct(shape=shape, dtype=field_dtype)}

    def _num_latent_time_steps(self) -> int:
        return 1

    def update(
        self,
        time_step: jax.Array,
        E: jax.Array,
        H: jax.Array,
        state: DetectorState,
        inv_permittivity: jax.Array,
        inv_permeability: jax.Array | float,
    ) -> DetectorState:
        del inv_permittivity, inv_permeability

        # Get grid dimensions for the plane perpendicular to propagation axis
        prop_axis = self.propagation_axis
        plane_dims = [i for i in range(3) if i != prop_axis]
        Nx, Ny = [self.grid_shape[i] for i in plane_dims]

        # Get current field values at the specified plane
        cur_E = E[:, *self.grid_slice]  # Shape: (3, nx, ny, 1)
        cur_H = H[:, *self.grid_slice]  # Shape: (3, nx, ny, 1)

        # Remove the normal axis dimension since it should be 1
        cur_E = jnp.squeeze(cur_E, axis=prop_axis + 1)  # Shape: (3, nx, ny)
        cur_H = jnp.squeeze(cur_H, axis=prop_axis + 1)  # Shape: (3, nx, ny)

        # Compute FFT of each field component
        E_k = jnp.fft.fft2(cur_E, axes=tuple(d + 1 for d in plane_dims))  # FFT in spatial dimensions
        H_k = jnp.fft.fft2(cur_H, axes=tuple(d + 1 for d in plane_dims))

        # Convert orders to array for vectorization
        orders = jnp.array(self.orders)  # Shape: (num_orders, 2)

        # Compute FFT indices for all orders
        kx_indices = jnp.where(orders[:, 0] >= 0, orders[:, 0], Nx + orders[:, 0])
        ky_indices = jnp.where(orders[:, 1] >= 0, orders[:, 1], Ny + orders[:, 1])

        # Compute wavevectors
        dx = dy = self._config.resolution
        kx = 2 * jnp.pi * jnp.fft.fftfreq(Nx, dx)
        ky = 2 * jnp.pi * jnp.fft.fftfreq(Ny, dy)
        k0 = 2 * jnp.pi * self.frequencies[0] / constants.c  # Use first frequency for now

        # For each requested order, compute the diffracted power
        order_amplitudes = []
        for kx_idx, ky_idx in zip(kx_indices, ky_indices):
            # Get the field components for this k-point
            E_order = E_k[:, kx_idx, ky_idx]
            H_order = H_k[:, kx_idx, ky_idx]

            # Compute kz for propagating waves
            kz = jnp.sqrt(k0**2 - kx[kx_idx] ** 2 - ky[ky_idx] ** 2 + 0j)
            k_vec = jnp.array([kx[kx_idx], ky[ky_idx], kz])

            # Project fields to be transverse to k
            E_t = E_order - jnp.dot(E_order, k_vec) * k_vec / jnp.dot(k_vec, k_vec)
            H_t = H_order - jnp.dot(H_order, k_vec) * k_vec / jnp.dot(k_vec, k_vec)

            # Compute power in this order
            P_order = jnp.abs(jnp.cross(E_t, jnp.conj(H_t)).sum())
            if self.direction == "-":
                P_order = -P_order
            order_amplitudes.append(P_order)

        order_amplitudes = jnp.array(order_amplitudes)

        # Time domain analysis - vectorized for all frequencies
        t = time_step * self._config.time_step_duration
        angular_frequencies = 2 * jnp.pi * jnp.array(self.frequencies)
        phase_angles = angular_frequencies[:, None] * t  # Shape: (num_freqs, 1)
        phasors = jnp.exp(-1j * phase_angles)  # Shape: (num_freqs, 1)

        # Compute all frequency components for all orders at once
        order_amplitudes = order_amplitudes[None, :]  # Shape: (1, num_orders)
        new_values = order_amplitudes * phasors  # Shape: (num_freqs, num_orders)

        # Update state
        arr_idx = self._time_step_to_arr_idx[time_step]
        new_state = state.copy()
        new_state["diffractive"] = new_state["diffractive"].at[arr_idx].set(new_values)

        return new_state

propagation_axis: int property

Determines the axis along which diffraction is measured.

The propagation axis is identified as the dimension with size 1 in the detector's grid shape, representing a plane perpendicular to the diffraction measurement direction.

Returns:

Name Type Description
int int

Index of the propagation axis (0 for x, 1 for y, 2 for z)

Raises:

Type Description
Exception

If detector shape does not have exactly one dimension of size 1