Skip to content

Interface Compression

This API can be used for automatic differentiation (autodiff) with time-reversibility, which is more memory efficient than other approaches. Additionally, some basic compression modules are implemented to reduce the memory footprint even further. However, they should be used with care since too much compression can reduce the gradient accuracy.

fdtdx.interfaces.Recorder

Bases: ExtendedTreeClass

Records and compresses simulation data over time using a sequence of processing modules.

The Recorder manages a pipeline of modules that process simulation data at each timestep. It supports both compression modules that reduce data size and time filters that control when data is recorded. The recorder handles initialization, compression and decompression of simulation data through its module pipeline.

Attributes:

Name Type Description
modules Sequence[CompressionModule | TimeStepFilter]

Sequence of processing modules to apply to the simulation data. Can be either CompressionModule for data reduction or TimeStepFilter for controlling recording frequency.

Source code in src/fdtdx/interfaces/recorder.py
@extended_autoinit
class Recorder(ExtendedTreeClass):
    """Records and compresses simulation data over time using a sequence of processing modules.

    The Recorder manages a pipeline of modules that process simulation data at each timestep.
    It supports both compression modules that reduce data size and time filters that control
    when data is recorded. The recorder handles initialization, compression and decompression
    of simulation data through its module pipeline.

    Attributes:
        modules: Sequence of processing modules to apply to the simulation data.
            Can be either CompressionModule for data reduction or TimeStepFilter
            for controlling recording frequency.
    """

    modules: Sequence[CompressionModule | TimeStepFilter]
    _input_shape_dtypes: dict[str, jax.ShapeDtypeStruct] = frozen_private_field(default=None)  # type:ignore
    _output_shape_dtypes: dict[str, jax.ShapeDtypeStruct] = frozen_private_field(default=None)  # type:ignore
    _max_time_steps: int = frozen_private_field(default=-1)
    _latent_array_size: int = frozen_private_field(default=-1, init=False)

    def init_state(
        self: Self,
        input_shape_dtypes: dict[str, jax.ShapeDtypeStruct],
        max_time_steps: int,
        backend: BackendOption,
    ) -> tuple[Self, RecordingState]:
        self = self.aset("_max_time_steps", max_time_steps)
        self = self.aset("_input_shape_dtypes", input_shape_dtypes)

        latent_arr_size, out_shapes = max_time_steps, input_shape_dtypes
        state_sizes: dict[str, jax.ShapeDtypeStruct] = {}
        new_modules = []
        for m in self.modules:
            if isinstance(m, CompressionModule):
                m, out_shapes, state_shapes = m.init_shapes(out_shapes)
            else:
                m, latent_arr_size, out_shapes, state_shapes = m.init_shapes(out_shapes, latent_arr_size)
            state_sizes.update(state_shapes)
            new_modules.append(m)

        self = self.aset("modules", new_modules)
        self = self.aset("_output_shape_dtypes", out_shapes)
        self = self.aset("_latent_array_size", latent_arr_size)

        expanded_out_shapes = {
            k: jax.ShapeDtypeStruct(
                shape=(self._latent_array_size, *v.shape),
                dtype=v.dtype,
            )
            for k, v in self._output_shape_dtypes.items()
        }
        state = init_recording_state(
            data_shape_dtypes=expanded_out_shapes,
            state_shape_dtypes=state_sizes,
            backend=backend,
        )
        return self, state

    def compress(
        self,
        values: dict[str, jax.Array],
        state: RecordingState,
        time_step: jax.Array,
        key: jax.Array,
    ) -> RecordingState:
        check_shape_dtype(values, self._input_shape_dtypes)
        latent_idx = time_step

        def helper_fn(m, values, state, latent_idx, key):
            if isinstance(m, CompressionModule):
                values, state = m.compress(values, state, key=key)
            elif isinstance(m, TimeStepFilter):
                values, state = m.compress(values, state, latent_idx, key=key)
                latent_idx = m.time_to_array_index(latent_idx)
            else:
                raise Exception(f"Invalid module: {m}")
            check_shape_dtype(values, m._output_shape_dtypes)
            return values, state, latent_idx

        def dummy_fn(m, values, state, latent_idx, key):
            del key
            # Only create zero arrays for keys that exist in the input values
            # This ensures structure matching with helper_fn for periodic boundaries
            values = {k: jnp.zeros(v.shape, v.dtype) for k, v in m._output_shape_dtypes.items() if k in values}
            check_shape_dtype(values, m._output_shape_dtypes)
            return values, state, latent_idx

        for m in self.modules:
            key, subkey = jax.random.split(key)
            values, state, latent_idx = jax.lax.cond(
                latent_idx == -1, dummy_fn, helper_fn, m, values, state, latent_idx, subkey
            )
            check_shape_dtype(values, m._output_shape_dtypes)

        def update_state_fn(state, values, latent_idx):
            # Only update state data for keys that exist in values
            for k in state.data.keys():
                if k in values:
                    state.data[k] = state.data[k].at[latent_idx].set(values[k])
            return state

        def update_dummy_fn(state, values, latent_idx):
            del values, latent_idx
            return state

        state = jax.lax.cond(latent_idx == -1, update_dummy_fn, update_state_fn, state, values, latent_idx)

        return state

    def decompress(
        self,
        state: RecordingState,
        time_step: jax.Array,
        key: jax.Array,
    ) -> tuple[
        dict[str, jax.Array],
        RecordingState,
    ]:
        # gather indices necessary to reconstruct
        time_filters = [m for m in self.modules if isinstance(m, TimeStepFilter)]
        indices: list[jax.Array] = [jnp.asarray([time_step])]
        time_indices: list[jax.Array] = []

        for tf in time_filters:
            cur_time_indices = indices[-1].flatten()
            cur_indices = jnp.asarray([tf.indices_to_decompress(idx) for idx in cur_time_indices])
            time_indices.append(cur_time_indices)
            indices.append(cur_indices)

        def reconstruction_iteration(
            m: CompressionModule | TimeStepFilter,
            state: RecordingState,
            key: jax.Array,
            latent: list[dict[str, jax.Array]],
            cur_tf_idx: int,
        ) -> tuple[
            int,
            list[dict[str, jax.Array]],
            RecordingState,
        ]:
            if isinstance(m, CompressionModule):
                latent = [m.decompress(v, state, key=key) for v in latent]
            else:
                num_time_idx = indices[cur_tf_idx].shape[0]
                num_arr_idx = indices[cur_tf_idx].shape[1]
                next_latent = []
                for cur_idx in range(0, num_time_idx):
                    # for idx in range(start_idx, start_idx + num_idx):
                    start_idx = cur_idx * num_arr_idx
                    cur_v = [latent[i] for i in range(start_idx, start_idx + num_arr_idx)]
                    arr_indices = indices[cur_tf_idx][cur_idx]
                    time_idx = time_indices[cur_tf_idx - 1][cur_idx]
                    next_v = m.decompress(
                        values=cur_v,
                        state=state,
                        arr_indices=arr_indices,
                        time_idx=time_idx,
                        key=key,
                    )
                    next_latent.append(next_v)
                latent = next_latent
                cur_tf_idx = cur_tf_idx - 1
            for v in latent:
                check_shape_dtype(v, m._input_shape_dtypes)
            return cur_tf_idx, latent, state

        def bottom_up_reconstruction(state: RecordingState, key):
            cur_tf_idx = len(time_filters)
            latent: list[dict[str, jax.Array]] = [
                {k: jnp.take(v, indices=idx.reshape(1), axis=0).squeeze(axis=0) for k, v in state.data.items()}
                for idx in indices[cur_tf_idx].flatten()
            ]
            for m in self.modules[::-1]:
                key, subkey = jax.random.split(key)
                cur_tf_idx, latent, state = reconstruction_iteration(
                    m=m,
                    state=state,
                    key=subkey,
                    latent=latent,
                    cur_tf_idx=cur_tf_idx,
                )
            return latent, state

        values, state = bottom_up_reconstruction(state, key)

        if len(values) != 1:
            raise Exception("This should never happen")
        return values[0], state

A recorder object for recording the interfaces between simulation volume and PML boundary during the forward simulation

Compression Modules

fdtdx.interfaces.LinearReconstructEveryK

Bases: TimeStepFilter

Time step filter that performs linear reconstruction between sampled steps.

This filter saves field values every k time steps and uses linear interpolation to reconstruct values at intermediate time steps.

Attributes:

Name Type Description
k int

Number of time steps between saved values.

start_recording_after int

Time step to start recording from.

_save_time_steps Array

Array of time steps that are saved.

_time_to_arr_idx Array

Mapping from time steps to array indices.

Source code in src/fdtdx/interfaces/time_filter.py
@extended_autoinit
class LinearReconstructEveryK(TimeStepFilter):
    """Time step filter that performs linear reconstruction between sampled steps.

    This filter saves field values every k time steps and uses linear interpolation
    to reconstruct values at intermediate time steps.

    Attributes:
        k: Number of time steps between saved values.
        start_recording_after: Time step to start recording from.
        _save_time_steps: Array of time steps that are saved.
        _time_to_arr_idx: Mapping from time steps to array indices.
    """

    k: int = frozen_field()
    start_recording_after: int = 0
    _save_time_steps: jax.Array = frozen_private_field(default=None)  # type: ignore
    _time_to_arr_idx: jax.Array = frozen_private_field(default=None)  # type: ignore

    def init_shapes(
        self,
        input_shape_dtypes: dict[str, jax.ShapeDtypeStruct],
        time_steps_max: int,  # maximum number of time steps
    ) -> tuple[
        Self,
        int,
        dict[str, jax.ShapeDtypeStruct],  # data
        dict[str, jax.ShapeDtypeStruct],  # state shapes
    ]:
        self = self.aset("_time_steps_max", time_steps_max)
        self = self.aset("_input_shape_dtypes", input_shape_dtypes)
        self = self.aset("_output_shape_dtypes", input_shape_dtypes)

        # init list of all time steps to save
        all_time_steps = jnp.arange(self.start_recording_after, self._time_steps_max, self.k).tolist()
        if all_time_steps[-1] != self._time_steps_max - 1:
            all_time_steps.append(self._time_steps_max - 1)

        self = self.aset("_save_time_steps", jnp.asarray(all_time_steps, dtype=jnp.int32))
        self = self.aset("_array_size", len(all_time_steps))

        # mapping between time steps and array indices
        index_tmp = jnp.arange(0, self._array_size, dtype=jnp.int32)
        time_indices = jnp.zeros(shape=(self._time_steps_max,), dtype=jnp.int32)
        time_indices = time_indices.at[self._save_time_steps].set(index_tmp)
        for _ in range(self.k - 1):
            rolled = jnp.roll(time_indices, 1)
            time_indices = jnp.where(
                time_indices == 0,
                rolled,
                time_indices,
            )
            time_indices = time_indices.at[: self.k].set(0)
        self = self.aset("_time_to_arr_idx", time_indices)
        return self, self._array_size, input_shape_dtypes, {}

    def time_to_array_index(
        self,
        time_idx: int,  # scalar
    ) -> int:  # scalar, array index if not filtered, else -1
        result = jax.lax.cond(
            jnp.any(time_idx == self._save_time_steps),
            lambda: self._time_to_arr_idx[time_idx],
            lambda: jnp.asarray(-1, dtype=jnp.int32),
        )
        return result

    def indices_to_decompress(
        self,
        time_idx: jax.Array,  # scalar
    ) -> jax.Array:  # 1d-list of array indices necessary to reconstruct
        arr_idx = self._time_to_arr_idx[time_idx]
        result = jnp.asarray([arr_idx, arr_idx + 1], dtype=jnp.int32)
        return result

    def compress(
        self,
        values: dict[str, jax.Array],
        state: RecordingState,
        time_idx: jax.Array,  # scalar
        key: jax.Array,
    ) -> tuple[
        dict[str, jax.Array],
        RecordingState,  # updated recording state
    ]:
        del time_idx, key
        return values, state

    def decompress(
        self,
        values: list[dict[str, jax.Array]],  # array values requested above
        state: RecordingState,
        arr_indices: jax.Array,
        time_idx: jax.Array,  # scalar
        key: jax.Array,
    ) -> dict[str, jax.Array]:  # reconstructed value
        del key, state

        def value_was_saved():
            return values[0]

        def linear_reconstruct():
            arr_idx = arr_indices[0]

            prev_save_time = index_1d_array(self._time_to_arr_idx, arr_idx)
            next_save_time = index_1d_array(self._time_to_arr_idx, arr_idx + 1)
            interp_factor = (time_idx - prev_save_time) / (next_save_time - prev_save_time)

            prev_vals, next_vals = values[0], values[1]
            res = {}
            for k, prev in prev_vals.items():
                next = next_vals[k]
                interp = prev + interp_factor.astype(next.dtype) * (next - prev)
                res[k] = interp
            return res

        result = jax.lax.cond(
            jnp.any(time_idx == self._save_time_steps),
            value_was_saved,
            linear_reconstruct,
        )
        return result

Compression module which only records every k time steps during the forward simulation. For reconstruction a linear interpolation between the recorded time steps is performed.

fdtdx.interfaces.DtypeConversion

Bases: CompressionModule

Compression module that converts data types of field values.

This module changes the data type of field values while preserving their shape, useful for reducing memory usage or meeting precision requirements.

Attributes:

Name Type Description
dtype dtype

Target data type for conversion.

exclude_filter Sequence[str]

List of field names to exclude from conversion.

Source code in src/fdtdx/interfaces/modules.py
@extended_autoinit
class DtypeConversion(CompressionModule):
    """Compression module that converts data types of field values.

    This module changes the data type of field values while preserving their shape,
    useful for reducing memory usage or meeting precision requirements.

    Attributes:
        dtype: Target data type for conversion.
        exclude_filter: List of field names to exclude from conversion.
    """

    dtype: jnp.dtype = frozen_field(kind="KW_ONLY")  # type: ignore
    exclude_filter: Sequence[str] = frozen_field(default=tuple([]), kind="KW_ONLY")

    def init_shapes(
        self,
        input_shape_dtypes: dict[str, jax.ShapeDtypeStruct],
    ) -> tuple[
        Self,
        dict[str, jax.ShapeDtypeStruct],  # data
        dict[str, jax.ShapeDtypeStruct],  # state shapes/dtypes
    ]:
        self = self.aset("_input_shape_dtypes", input_shape_dtypes)
        exclude = [] if self.exclude_filter is None else self.exclude_filter
        out_shape_dtypes = {
            k: (jax.ShapeDtypeStruct(v.shape, self.dtype) if not any(e in k for e in exclude) else v)
            for k, v in input_shape_dtypes.items()
        }
        self = self.aset("_output_shape_dtypes", out_shape_dtypes)
        return self, self._output_shape_dtypes, {}

    def compress(
        self,
        values: dict[str, jax.Array],
        state: RecordingState,
        key: jax.Array,
    ) -> tuple[
        dict[str, jax.Array],
        RecordingState,
    ]:
        del key
        out_vals = {
            k: (v.astype(self.dtype) if not any(e in k for e in self.exclude_filter) else v) for k, v in values.items()
        }
        return out_vals, state

    def decompress(
        self,
        values: dict[str, jax.Array],
        state: RecordingState,
        key: jax.Array,
    ) -> dict[str, jax.Array]:
        del key, state
        out_vals = {k: v.astype(self._input_shape_dtypes[k].dtype) for k, v in values.items()}
        return out_vals

Compression module to save the interfaces at a lower datatype resolution. From experience, in most applications saving the interfaces in jnp.float16 or jnp.float8_e4m3fnuz is sufficient.