Skip to content

Core Functions

These core functions are used by many different parts of the FDTDX package

fdtdx.core.WaveCharacter

Bases: ExtendedTreeClass

Class describing a wavelength/period/frequency in free space. Importantly, the wave characteristic conversion is based on a free space wave when using the wavelength (For conversion, a refractive index of 1 is used).

Attributes:

Name Type Description
_period float | None

Optional period in seconds. Mutually exclusive with _wavelength and _frequency.

_wavelength float | None

Optional wavelength in meters. Mutually exclusive with _period and _frequency.

_frequency float | None

Optional frequency in Hz. Mutually exclusive with _period and _wavelength.

Source code in src/fdtdx/core/wavelength.py
@extended_autoinit
class WaveCharacter(ExtendedTreeClass):
    """Class describing a wavelength/period/frequency in free space. Importantly, the wave characteristic conversion is
    based on a free space wave when using the wavelength (For conversion, a refractive index of 1 is used).



    Attributes:
        _period: Optional period in seconds. Mutually exclusive with _wavelength and _frequency.
        _wavelength: Optional wavelength in meters. Mutually exclusive with _period and _frequency.
        _frequency: Optional frequency in Hz. Mutually exclusive with _period and _wavelength.
    """

    phase_shift: float = 0.0
    _period: float | None = field(default=None, alias="period")
    _wavelength: float | None = field(default=None, alias="wavelength")
    _frequency: float | None = field(default=None, alias="frequency")

    def __post_init__(
        self,
    ):
        self._check_input()

    def _check_input(self):
        if sum([self._period is not None, self._frequency is not None, self._wavelength is not None]) != 1:
            raise Exception("Need to set exactly one of Period, Frequency or Wavelength in WaveCharacter")

    @property
    def period(self) -> float:
        """Gets the period in seconds.

        Returns:
            float: The period in seconds.

        Raises:
            Exception: If neither period nor wavelength is set, or if both are set.
        """
        self._check_input()
        if self._period is not None:
            return self._period
        if self._wavelength is not None:
            return self._wavelength / constants.c
        if self._frequency is not None:
            return 1.0 / self._frequency
        raise Exception("This should never happen")

    @property
    def wavelength(self) -> float:
        """Gets the wavelength in meters.

        Returns:
            float: The wavelength in meters.

        Raises:
            Exception: If neither period nor wavelength is set, or if both are set.
        """
        self._check_input()
        if self._wavelength is not None:
            return self._wavelength
        if self._period is not None:
            return self._period * constants.c
        if self._frequency is not None:
            return constants.c / self._frequency
        raise Exception("This should never happen")

    @property
    def frequency(self) -> float:
        """Gets the frequency in Hz.

        Returns:
            float: The frequency in Hz.
        """
        self._check_input()
        if self._period is not None:
            return 1.0 / self._period
        if self._wavelength is not None:
            return constants.c / self._wavelength
        if self._frequency is not None:
            return self._frequency
        raise Exception("This should never happen")

frequency: float property

Gets the frequency in Hz.

Returns:

Name Type Description
float float

The frequency in Hz.

period: float property

Gets the period in seconds.

Returns:

Name Type Description
float float

The period in seconds.

Raises:

Type Description
Exception

If neither period nor wavelength is set, or if both are set.

wavelength: float property

Gets the wavelength in meters.

Returns:

Name Type Description
float float

The wavelength in meters.

Raises:

Type Description
Exception

If neither period nor wavelength is set, or if both are set.

A container for specifying the character of a wave by either frequency, period or wavelength. Additionally, a phase offset can be set.

Physical metrics

fdtdx.core.metric_efficiency(detector_states, in_names, out_names, metric_name)

Calculate efficiency metrics between input and output detectors.

Computes efficiency ratios between input and output detectors by comparing their metric values (e.g. energy, power). For each input-output detector pair, calculates the ratio of output/input metric values.

Parameters:

Name Type Description Default
detector_states dict[str, dict[str, Array]]

Dictionary mapping detector names to their state dictionaries, which contain metric values as JAX arrays

required
in_names Sequence[str]

Names of input detectors to use as reference

required
out_names Sequence[str]

Names of output detectors to compare against inputs

required
metric_name str

Name of the metric to compare between detectors (e.g. "energy")

required

Returns:

Type Description
tuple[Array, dict[str, Any]]

tuple containing: - jax.Array: Mean efficiency across all input-output pairs - dict: Additional info including individual metric values and efficiencies with keys like: "{detector}{metric}" for raw metric values "{out}{by}_{in}_efficiency" for individual efficiency ratios

Source code in src/fdtdx/core/physics/losses.py
def metric_efficiency(
    detector_states: dict[str, dict[str, jax.Array]],
    in_names: Sequence[str],
    out_names: Sequence[str],
    metric_name: str,
) -> tuple[jax.Array, dict[str, Any]]:
    """Calculate efficiency metrics between input and output detectors.

    Computes efficiency ratios between input and output detectors by comparing their
    metric values (e.g. energy, power). For each input-output detector pair, calculates
    the ratio of output/input metric values.

    Args:
        detector_states: Dictionary mapping detector names to their state dictionaries,
            which contain metric values as JAX arrays
        in_names: Names of input detectors to use as reference
        out_names: Names of output detectors to compare against inputs
        metric_name: Name of the metric to compare between detectors (e.g. "energy")

    Returns:
        tuple containing:
            - jax.Array: Mean efficiency across all input-output pairs
            - dict: Additional info including individual metric values and efficiencies
              with keys like:
                "{detector}_{metric}" for raw metric values
                "{out}_{by}_{in}_efficiency" for individual efficiency ratios
    """
    efficiencies, info = [], {}
    for in_name in in_names:
        in_value = jax.lax.stop_gradient(detector_states[in_name][metric_name].mean())
        info[f"{in_name}_{metric_name}"] = in_value
        for out_name in out_names:
            out_value = detector_states[out_name][metric_name].mean()
            eff = jnp.where(in_value == 0, 0, out_value / in_value)
            efficiencies.append(eff)
            info[f"{out_name}_{metric_name}"] = out_value
            info[f"{out_name}_by_{in_name}_efficiency"] = eff
    objective = jnp.mean(jnp.asarray(efficiencies))
    return objective, info

Convenience function that computes the ratio between a physical metrics measured at different places in the simulation.

fdtdx.core.compute_energy(E, H, inv_permittivity, inv_permeability, axis=0)

Computes the total electromagnetic energy density of the field.

Parameters:

Name Type Description Default
E Array

Electric field array with shape (3, nx, ny, nz)

required
H Array

Magnetic field array with shape (3, nx, ny, nz)

required
inv_permittivity Array

Inverse of the electric permittivity array

required
inv_permeability Array | float

Inverse of the magnetic permeability array

required
axis int

Axis index of the X,Y,Z component for the E and H field

0

Returns: Total energy density array with shape (nx, ny, nz)

Source code in src/fdtdx/core/physics/metrics.py
def compute_energy(
    E: jax.Array,
    H: jax.Array,
    inv_permittivity: jax.Array,
    inv_permeability: jax.Array | float,
    axis: int = 0,
) -> jax.Array:
    """Computes the total electromagnetic energy density of the field.

    Args:
        E: Electric field array with shape (3, nx, ny, nz)
        H: Magnetic field array with shape (3, nx, ny, nz)
        inv_permittivity: Inverse of the electric permittivity array
        inv_permeability: Inverse of the magnetic permeability array
        axis: Axis index of the X,Y,Z component for the E and H field
    Returns:
        Total energy density array with shape (nx, ny, nz)
    """
    abs_E = jnp.sum(jnp.square(jnp.abs(E)), axis=axis)
    energy_E = 0.5 * (1 / inv_permittivity) * abs_E

    abs_H = jnp.sum(jnp.square(jnp.abs(H)), axis=axis)
    energy_H = 0.5 * (1 / inv_permeability) * abs_H

    total_energy = energy_E + energy_H
    return total_energy

Computes the physical energy in the simualation grid.

fdtdx.core.poynting_flux(E, H, axis=0)

Calculates the Poynting vector (energy flux) from E and H fields.

Parameters:

Name Type Description Default
E Array

Electric field array with shape (3, nx, ny, nz)

required
H Array

Magnetic field array with shape (3, nx, ny, nz)

required

Returns:

Type Description
Array

Poynting vector array with shape (3, nx, ny, nz) representing

Array

energy flux in each direction

Source code in src/fdtdx/core/physics/metrics.py
def poynting_flux(E: jax.Array, H: jax.Array, axis: int = 0) -> jax.Array:
    """Calculates the Poynting vector (energy flux) from E and H fields.

    Args:
        E: Electric field array with shape (3, nx, ny, nz)
        H: Magnetic field array with shape (3, nx, ny, nz)

    Returns:
        Poynting vector array with shape (3, nx, ny, nz) representing
        energy flux in each direction
    """
    return jnp.cross(
        E,
        jnp.conj(H),
        axisa=axis,
        axisb=axis,
        axisc=axis,
    )

Computes the poynting flux.