Skip to content

Sources

A collection of source objects to induce light into a simulation. The spatial profile of the sources can be either a plane, a gaussian or a mode. The temporal profile is set through the correponding attribute. All sources are implemented using the Total-Field/Scattered-Field formulation (also known as soft sources), except for the sources explicitly marked as "hard". These directly set the electric/magnetic field to a fixed value.

fdtdx.GaussianPlaneSource

Bases: LinearlyPolarizedPlaneSource

Source code in src/fdtdx/objects/sources/linear_polarization.py
@extended_autoinit
class GaussianPlaneSource(LinearlyPolarizedPlaneSource):
    radius: float = frozen_field()
    std: float = frozen_field(default=1 / 3)  # relative to radius

    @staticmethod
    def _gauss_profile(
        width: int,
        height: int,
        axis: int,
        center: tuple[float, float] | jax.Array,
        radii: tuple[float, float],
        std: float,
    ) -> jax.Array:  # shape (*grid_shape)
        grid = (
            jnp.stack(jnp.meshgrid(*map(jnp.arange, (height, width)), indexing="xy"), axis=-1) - jnp.asarray(center)
        ) / jnp.asarray(radii)
        euc_dist = (grid**2).sum(axis=-1)

        mask = euc_dist < 1
        mask = jnp.expand_dims(mask, axis=axis)

        exp_part = jnp.exp(-0.5 * euc_dist / std**2)
        exp_part = jnp.expand_dims(exp_part, axis=axis)

        profile = jnp.where(mask, exp_part, 0)
        profile = profile / profile.sum()

        return profile

    def _get_amplitude_raw(
        self,
        center: jax.Array,
    ) -> jax.Array:
        grid_radius = self.radius / self._config.resolution
        profile = self._gauss_profile(
            width=self.grid_shape[self.horizontal_axis],
            height=self.grid_shape[self.vertical_axis],
            axis=self.propagation_axis,
            center=center,
            radii=(grid_radius, grid_radius),
            std=self.std,
        )
        return profile

A source with a spatial profile of a gaussian.

fdtdx.UniformPlaneSource

Bases: LinearlyPolarizedPlaneSource

Source code in src/fdtdx/objects/sources/linear_polarization.py
@extended_autoinit
class UniformPlaneSource(LinearlyPolarizedPlaneSource):
    amplitude: float = frozen_field(default=1.0)

    def _get_amplitude_raw(
        self,
        center: jax.Array,
    ) -> jax.Array:
        del center
        profile = jnp.ones(shape=self.grid_shape, dtype=jnp.float32)
        return self.amplitude * profile

A source with a spatial profile of a uniform plane.

fdtdx.ModePlaneSource

Bases: TFSFPlaneSource

Source code in src/fdtdx/objects/sources/mode.py
@extended_autoinit
class ModePlaneSource(TFSFPlaneSource):
    mode_index: int = frozen_field(default=0)
    filter_pol: Literal["te", "tm"] | None = frozen_field(default=None)

    _inv_permittivity: jax.Array = private_field()
    _inv_permeability: jax.Array | float = private_field()

    def apply(
        self: Self,
        key: jax.Array,
        inv_permittivities: jax.Array,
        inv_permeabilities: jax.Array | float,
    ) -> Self:
        self = super().apply(
            key=key,
            inv_permittivities=inv_permittivities,
            inv_permeabilities=inv_permeabilities,
        )
        if (
            self.azimuth_angle != 0
            or self.elevation_angle != 0
            or self.max_angle_random_offset != 0
            or self.max_vertical_offset != 0
            or self.max_horizontal_offset != 0
        ):
            raise NotImplementedError()

        inv_permittivity_slice = inv_permittivities[*self.grid_slice]
        if isinstance(inv_permeabilities, jax.Array) and inv_permeabilities.ndim > 0:
            inv_permeability_slice = inv_permeabilities[*self.grid_slice]
        else:
            inv_permeability_slice = inv_permeabilities

        self = self.aset("_inv_permittivity", inv_permittivity_slice)
        self = self.aset("_inv_permeability", inv_permeability_slice)

        return self

    def get_EH_variation(
        self,
        key: jax.Array,
        inv_permittivities: jax.Array,
        inv_permeabilities: jax.Array | float,
    ) -> tuple[
        jax.Array,  # E: (3, *grid_shape)
        jax.Array,  # H: (3, *grid_shape)
        jax.Array,  # time_offset_E: (3, *grid_shape)
        jax.Array,  # time_offset_H: (3, *grid_shape)
    ]:
        del key

        center = jnp.asarray(
            [round(self.grid_shape[self.horizontal_axis]), round(self.grid_shape[self.vertical_axis])], dtype=jnp.int32
        )

        inv_permittivity_slice = inv_permittivities[*self.grid_slice]
        if isinstance(inv_permeabilities, jax.Array) and inv_permeabilities.ndim > 0:
            inv_permeability_slice = inv_permeabilities[*self.grid_slice]
        else:
            inv_permeability_slice = inv_permeabilities

        raw_wave_vector = get_wave_vector_raw(
            direction=self.direction,
            propagation_axis=self.propagation_axis,
        )

        # compute mode
        mode_E, mode_H, eff_index = compute_mode(
            frequency=self.wave_character.frequency,
            inv_permittivities=inv_permittivity_slice,
            inv_permeabilities=inv_permeability_slice,  # type: ignore
            resolution=self._config.resolution,
            direction=self.direction,
            mode_index=self.mode_index,
            filter_pol=self.filter_pol,
        )

        time_offset_E, time_offset_H = calculate_time_offset_yee(
            center=center,
            wave_vector=raw_wave_vector,
            inv_permittivities=inv_permittivity_slice,
            inv_permeabilities=jnp.ones_like(inv_permeability_slice),
            resolution=self._config.resolution,
            time_step_duration=self._config.time_step_duration,
            effective_index=jnp.real(eff_index),
        )

        return mode_E, mode_H, time_offset_E, time_offset_H

    def plot(self, save_path: str | Path):
        if self._H is None or self._E is None:
            raise Exception("Cannot plot mode without init to grid and apply params first")

        energy = compute_energy(
            E=self._E,
            H=self._H,
            inv_permittivity=self._inv_permittivity,
            inv_permeability=self._inv_permeability,
        )

        energy_2d = energy.squeeze().T

        plt.clf()
        fig = plt.figure(figsize=(10, 10))
        mode_cmap = "inferno"

        im = plt.imshow(
            energy_2d,
            cmap=mode_cmap,
            origin="lower",
            aspect="equal",
        )
        plt.colorbar(im)

        # Ensure the plot takes up the entire figure
        plt.tight_layout(pad=0)

        plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
        plt.close(fig)

A source with the spatial profile of a mode. The mode is computed automatically and by default a first order mode is used. In the future, we will develop a better interface to support other modes as well.

Temporal Profiles

fdtdx.SingleFrequencyProfile

Bases: TemporalProfile

Simple sinusoidal temporal profile at a single frequency.

Source code in src/fdtdx/objects/sources/profile.py
@extended_autoinit
class SingleFrequencyProfile(TemporalProfile):
    """Simple sinusoidal temporal profile at a single frequency."""

    def get_amplitude(
        self,
        time: jax.Array,
        period: float,
        phase_shift: float = 0.0,
    ) -> jax.Array:
        time_phase = 2 * jnp.pi * time / period + phase_shift
        return jnp.real(jnp.exp(-1j * time_phase))

A temporal profile which exhibits just a single wave throughout the whole simulation time.

fdtdx.GaussianPulseProfile

Bases: TemporalProfile

Gaussian pulse temporal profile with carrier wave.

Source code in src/fdtdx/objects/sources/profile.py
@extended_autoinit
class GaussianPulseProfile(TemporalProfile):
    """Gaussian pulse temporal profile with carrier wave."""

    spectral_width: float = frozen_field()  # Width of the Gaussian envelope in frequency domain
    center_frequency: float = frozen_field()  # Center frequency of the pulse

    def get_amplitude(
        self,
        time: jax.Array,
        period: float,
        phase_shift: float = 0.0,
    ) -> jax.Array:
        del period
        # Calculate envelope parameters
        sigma_t = 1.0 / (2 * jnp.pi * self.spectral_width)
        t0 = 6 * sigma_t  # Offset peak to avoid discontinuity at t=0

        # Gaussian envelope
        envelope = jnp.exp(-((time - t0) ** 2) / (2 * sigma_t**2))

        # Carrier wave
        carrier_phase = 2 * jnp.pi * self.center_frequency * time + phase_shift
        carrier = jnp.real(jnp.exp(-1j * carrier_phase))

        return envelope * carrier

A temporal pulse of a gaussian envelope.