Skip to content

Sources

A collection of source objects to induce light into a simulation. The spatial profile of the sources can be either a plane, a gaussian or a mode. The temporal profile is set through the correponding attribute. All sources are implemented using the Total-Field/Scattered-Field formulation (also known as soft sources), except for the sources explicitly marked as "hard". These directly set the electric/magnetic field to a fixed value.

fdtdx.objects.sources.GaussianPlaneSource

Bases: LinearlyPolarizedPlaneSource

Source code in src/fdtdx/objects/sources/plane_source.py
@extended_autoinit
class GaussianPlaneSource(LinearlyPolarizedPlaneSource):
    radius: float = frozen_field(kind="KW_ONLY")  # type: ignore
    std: float = frozen_field(kind="KW_ONLY", default=1 / 3)  # relative to radius

    @staticmethod
    def _gauss_profile(
        width: int,
        height: int,
        axis: int,
        center: tuple[float, float] | jax.Array,
        radii: tuple[float, float],
        std: float,
    ) -> jax.Array:  # shape (*grid_shape)
        grid = (
            jnp.stack(jnp.meshgrid(*map(jnp.arange, (height, width)), indexing="xy"), axis=-1) - jnp.asarray(center)
        ) / jnp.asarray(radii)
        euc_dist = (grid**2).sum(axis=-1)

        mask = euc_dist < 1
        mask = jnp.expand_dims(mask, axis=axis)

        exp_part = jnp.exp(-0.5 * euc_dist / std**2)
        exp_part = jnp.expand_dims(exp_part, axis=axis)

        profile = jnp.where(mask, exp_part, 0)
        profile = profile / profile.sum()

        return profile

    def _get_amplitude_raw(
        self,
        center: jax.Array,
    ) -> jax.Array:
        grid_radius = self.radius / self._config.resolution
        profile = self._gauss_profile(
            width=self.grid_shape[self.horizontal_axis],
            height=self.grid_shape[self.vertical_axis],
            axis=self.propagation_axis,
            center=center,
            radii=(grid_radius, grid_radius),
            std=self.std,
        )
        return profile

A source with a spatial profile of a gaussian.

fdtdx.objects.sources.ConstantAmplitudePlaneSource

Bases: LinearlyPolarizedPlaneSource

Source code in src/fdtdx/objects/sources/plane_source.py
@extended_autoinit
class ConstantAmplitudePlaneSource(LinearlyPolarizedPlaneSource):
    amplitude: float = 1.0

    def _get_amplitude_raw(
        self,
        center: jax.Array,
    ) -> jax.Array:
        del center
        result = jnp.ones(shape=self.grid_shape, dtype=jnp.float32)
        return result

A source with a spatial profile of a plane.

fdtdx.objects.sources.ModePlaneSource

Bases: PlaneSource

Source code in src/fdtdx/objects/sources/plane_source.py
@extended_autoinit
class ModePlaneSource(PlaneSource):
    mode_index: int = frozen_field(default=0)
    filter_pol: Literal["te", "tm"] | None = frozen_field(default=None)

    _inv_permittivity: jax.Array = frozen_field(default=None, init=False)  # type: ignore
    _inv_permeability: jax.Array | float = frozen_field(default=None, init=False)  # type: ignore

    def apply(
        self: Self,
        key: jax.Array,
        inv_permittivities: jax.Array,
        inv_permeabilities: jax.Array | float,
    ) -> Self:
        if (
            self.azimuth_angle != 0
            or self.elevation_angle != 0
            or self.max_angle_random_offset != 0
            or self.max_vertical_offset != 0
            or self.max_horizontal_offset != 0
        ):
            raise NotImplementedError()

        self = super().apply(
            key=key,
            inv_permittivities=inv_permittivities,
            inv_permeabilities=inv_permeabilities,
        )
        inv_permittivity_slice = inv_permittivities[*self.grid_slice]
        if isinstance(inv_permeabilities, jax.Array) and inv_permeabilities.ndim > 0:
            inv_permeability_slice = inv_permeabilities[*self.grid_slice]
        else:
            inv_permeability_slice = inv_permeabilities

        self = self.aset("_inv_permittivity", inv_permittivity_slice)
        self = self.aset("_inv_permeability", inv_permeability_slice)

        return self

    def _compute_modes(
        self,
        inv_permittivity_slice: jax.Array,
    ) -> tuple[jax.Array, jax.Array]:
        input_dtype = inv_permittivity_slice.dtype

        permittivity_cross_section = 1 / inv_permittivity_slice
        other_axes = [a for a in range(3) if a != self.propagation_axis]
        coords = [np.arange(permittivity_cross_section.shape[dim] + 1) / 1e-6 for dim in other_axes]
        permittivity_squeezed = jnp.take(
            permittivity_cross_section,
            indices=0,
            axis=self.propagation_axis,
        )

        def mode_helper(permittivity):
            modes = compute_modes(
                frequency=self.wave_character.frequency,
                permittivity_cross_section=permittivity,  # type: ignore
                coords=coords,
                num_modes=self.mode_index + 1,
                filter_pol=self.filter_pol,
                direction=self.direction,
            )

            mode_E_list, mode_H_list = [], []
            for mode in modes:
                if self.propagation_axis == 0:
                    mode_E, mode_H = (
                        np.stack([mode.Ez, mode.Ex, mode.Ey], axis=0).astype(np.complex64),
                        np.stack([mode.Hz, mode.Hx, mode.Hy], axis=0).astype(np.complex64),
                    )
                elif self.propagation_axis == 1:
                    # mode_E, mode_H = (  # untested
                    #     np.stack([mode.Ex, mode.Ez, mode.Ey], axis=0).astype(np.complex64),
                    #     np.stack([mode.Hx, mode.Hz, mode.Hy], axis=0).astype(np.complex64),
                    # )
                    raise NotImplementedError()
                elif self.propagation_axis == 2:
                    # mode_E, mode_H = (  # untested
                    #     np.stack([mode.Ez, mode.Ex, mode.Ey], axis=0).astype(np.complex64),
                    #     np.stack([mode.Hz, mode.Hx, mode.Hy], axis=0).astype(np.complex64),
                    # )
                    raise NotImplementedError()
                else:
                    raise Exception(f"Invalid popagation axis: {self.propagation_axis}")
                mode_E_list.append(mode_E)
                mode_H_list.append(mode_H)

                return mode_E_list[self.mode_index], mode_H_list[self.mode_index]

        result_shape_dtype = (
            jnp.zeros((3, *permittivity_squeezed.shape), dtype=jnp.complex64),
            jnp.zeros((3, *permittivity_squeezed.shape), dtype=jnp.complex64),
        )
        mode_E_raw, mode_H_raw = jax.pure_callback(
            mode_helper,
            result_shape_dtype,
            jax.lax.stop_gradient(permittivity_squeezed),
        )

        mode_E = jnp.real(jnp.expand_dims(mode_E_raw, axis=self.propagation_axis + 1)).astype(input_dtype)

        mode_H = jnp.real(jnp.expand_dims(mode_H_raw, axis=self.propagation_axis + 1)).astype(input_dtype)
        return mode_E, mode_H

    def get_EH_variation(
        self,
        key: jax.Array,
        inv_permittivities: jax.Array,
        inv_permeabilities: jax.Array | float,
    ) -> tuple[
        jax.Array,  # E: (3, *grid_shape)
        jax.Array,  # H: (3, *grid_shape)
        jax.Array,  # time_offset_E: (3, *grid_shape)
        jax.Array,  # time_offset_H: (3, *grid_shape)
    ]:
        del key

        center = jnp.asarray(
            [round(self.grid_shape[self.horizontal_axis]), round(self.grid_shape[self.vertical_axis])], dtype=jnp.int32
        )

        inv_permittivity_slice = inv_permittivities[*self.grid_slice]
        if isinstance(inv_permeabilities, jax.Array) and inv_permeabilities.ndim > 0:
            inv_permeability_slice = inv_permeabilities[*self.grid_slice]
        else:
            inv_permeability_slice = inv_permeabilities

        time_offset_E, time_offset_H = self._calculate_time_offset_yee(
            center=center,
            wave_vector=self._get_wave_vector_raw(),
            inv_permittivities=jnp.ones_like(inv_permittivity_slice),
            inv_permeabilities=jnp.ones_like(inv_permeability_slice),
        )

        mode_E, mode_H = self._compute_modes(inv_permittivity_slice=inv_permittivity_slice)

        mode_E2_norm, mode_H2_norm = normalize_by_energy(
            E=mode_E,
            H=mode_H,
            inv_permittivity=inv_permittivity_slice,
            inv_permeability=inv_permeability_slice,
        )

        return mode_E2_norm, mode_H2_norm, time_offset_E, time_offset_H

    def plot(self, save_path: str | Path):
        if self._H is None or self._E is None:
            raise Exception("Cannot plot mode without init to grid and apply params first")

        energy = compute_energy(
            E=self._E,
            H=self._H,
            inv_permittivity=self._inv_permittivity,
            inv_permeability=self._inv_permeability,
        )
        energy_2d = energy.squeeze().T

        plt.clf()
        fig = plt.figure(figsize=(10, 10))
        levels = jnp.linspace(energy_2d.min(), energy_2d.max(), 11)[1:]
        mode_cmap = "inferno"

        # Add contour lines on top of the imshow plot
        plt.contour(energy_2d, cmap=mode_cmap, alpha=0.5, levels=levels)
        plt.gca().set_aspect("equal")

        plt.colorbar()

        # Ensure the plot takes up the entire figure
        plt.tight_layout(pad=0)

        plt.savefig(save_path, bbox_inches="tight", pad_inches=0)
        plt.close(fig)

A source with the spatial profile of a mode. The mode is computed automatically and by default a first order mode is used. In the future, we will develop a better interface to support other modes as well.

fdtdx.objects.sources.HardConstantAmplitudePlanceSource

Bases: DirectionalPlaneSourceBase

Source code in src/fdtdx/objects/sources/plane_source.py
@extended_autoinit
class HardConstantAmplitudePlanceSource(DirectionalPlaneSourceBase):
    amplitude: float = 1.0
    fixed_E_polarization_vector: tuple[float, float, float] | None = None
    fixed_H_polarization_vector: tuple[float, float, float] | None = None

    def _get_raw_EH_polarization(
        self,
    ) -> tuple[jax.Array, jax.Array]:
        # determine E/H polarization
        e_pol = self.fixed_E_polarization_vector
        h_pol = self.fixed_H_polarization_vector
        if h_pol is not None:
            h_pol = jnp.asarray(self.fixed_H_polarization_vector, dtype=jnp.float32)
            h_pol = h_pol / jnp.linalg.norm(h_pol)
        if e_pol is not None:
            e_pol = jnp.asarray(self.fixed_E_polarization_vector, dtype=jnp.float32)
            e_pol = e_pol / jnp.linalg.norm(e_pol)
        if e_pol is None:
            if h_pol is None:
                raise Exception("Need to specify either E or H polarization")
            e_pol = self._orthogonal_vector(v_H=h_pol)
        if h_pol is None:
            if e_pol is None:
                raise Exception("Need to specify either E or H polarization")
            h_pol = self._orthogonal_vector(v_E=e_pol)
        return e_pol, h_pol

    def update_E(
        self,
        E: jax.Array,
        inv_permittivities: jax.Array,
        inv_permeabilities: jax.Array | float,
        time_step: jax.Array,
        inverse: bool,
    ) -> jax.Array:
        del inv_permittivities, inv_permeabilities
        if inverse:
            return E
        delta_t = self._config.time_step_duration
        time_phase = 2 * jnp.pi * time_step * delta_t / self.wave_character.period + self.wave_character.phase_shift
        magnitude = jnp.real(self.amplitude * jnp.exp(-1j * time_phase))
        e_pol, _ = self._get_raw_EH_polarization()
        E_update = e_pol[:, None, None, None] * magnitude

        E = E.at[:, *self.grid_slice].set(E_update.astype(E.dtype))
        return E

    def update_H(
        self,
        H: jax.Array,
        inv_permittivities: jax.Array,
        inv_permeabilities: jax.Array | float,
        time_step: jax.Array,
        inverse: bool,
    ):
        del inv_permeabilities, inv_permittivities
        if inverse:
            return H
        delta_t = self._config.time_step_duration
        time_phase = 2 * jnp.pi * time_step * delta_t / self.wave_character.period + self.wave_character.phase_shift
        magnitude = jnp.real(self.amplitude * jnp.exp(-1j * time_phase))
        _, h_pol = self._get_raw_EH_polarization()
        H_update = h_pol[:, None, None, None] * magnitude

        H = H.at[:, *self.grid_slice].set(H_update.astype(H.dtype))
        return H

    def apply(
        self,
        key: jax.Array,
        inv_permittivities: jax.Array,
        inv_permeabilities: jax.Array | float,
    ) -> Self:
        del key, inv_permittivities, inv_permeabilities
        return self

A hard source with the spatial profile of a plane.

Temporal Profiles

fdtdx.objects.sources.SingleFrequencyProfile

Bases: TemporalProfile

Simple sinusoidal temporal profile at a single frequency.

Source code in src/fdtdx/objects/sources/profile.py
@extended_autoinit
class SingleFrequencyProfile(TemporalProfile):
    """Simple sinusoidal temporal profile at a single frequency."""

    def get_amplitude(
        self,
        time: jax.Array,
        period: float,
        phase_shift: float = 0.0,
    ) -> jax.Array:
        time_phase = 2 * jnp.pi * time / period + phase_shift
        return jnp.cos(time_phase)

A temporal profile which exhibits just a single wave throughout the whole simulation time.

fdtdx.objects.sources.GaussianPulseProfile

Bases: TemporalProfile

Gaussian pulse temporal profile with carrier wave.

Source code in src/fdtdx/objects/sources/profile.py
@extended_autoinit
class GaussianPulseProfile(TemporalProfile):
    """Gaussian pulse temporal profile with carrier wave."""

    spectral_width: float  # Width of the Gaussian envelope in frequency domain
    center_frequency: float  # Center frequency of the pulse

    def get_amplitude(
        self,
        time: jax.Array,
        period: float,
        phase_shift: float = 0.0,
    ) -> jax.Array:
        del period
        # Calculate envelope parameters
        sigma_t = 1.0 / (2 * jnp.pi * self.spectral_width)
        t0 = 6 * sigma_t  # Offset peak to avoid discontinuity at t=0

        # Gaussian envelope
        envelope = jnp.exp(-((time - t0) ** 2) / (2 * sigma_t**2))

        # Carrier wave
        carrier_phase = 2 * jnp.pi * self.center_frequency * time + phase_shift
        carrier = jnp.cos(carrier_phase)

        return envelope * carrier

A temporal pulse of a gaussian envelope.