Compression Base Classes

fdtdx.interfaces.modules.CompressionModule

Bases: TreeClass, ABC

Abstract base class for compression modules that process simulation data.

This class provides an interface for modules that compress and decompress field data during FDTD simulations. Implementations can perform operations like quantization, dimensionality reduction, or other compression techniques.

compress abstractmethod

compress(
    values: dict[str, Array],
    state: RecordingState,
    key: Array,
) -> tuple[dict[str, jax.Array], RecordingState]

Compress field values at the current time step.

Parameters:
  • values (dict[str, Array]) –

    Dictionary mapping field names to their values.

  • state (RecordingState) –

    Current recording state.

  • key (Array) –

    Random key for stochastic operations.

Returns:
  • tuple[dict[str, Array], RecordingState]

    tuple[dict[str, jax.Array], RecordingState]: Tuple containing: - Dictionary of compressed field values - Updated recording state

Source code in src/fdtdx/interfaces/modules.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
@abstractmethod
def compress(
    self,
    values: dict[str, jax.Array],
    state: RecordingState,
    key: jax.Array,
) -> tuple[
    dict[str, jax.Array],  # compressed data
    RecordingState,  # updated recording state
]:
    """Compress field values at the current time step.

    Args:
        values (dict[str, jax.Array]): Dictionary mapping field names to their values.
        state (RecordingState): Current recording state.
        key (jax.Array): Random key for stochastic operations.

    Returns:
        tuple[dict[str, jax.Array], RecordingState]: Tuple containing:
            - Dictionary of compressed field values
            - Updated recording state
    """
    del values, state, key
    raise NotImplementedError()

decompress abstractmethod

decompress(
    values: dict[str, Array],
    state: RecordingState,
    key: Array,
) -> dict[str, jax.Array]

Decompress field values back to their original form.

Parameters:
  • values (dict[str, Array]) –

    Dictionary mapping field names to their compressed values.

  • state (RecordingState) –

    Current recording state.

  • key (Array) –

    Random key for stochastic operations.

Returns:
  • dict[str, Array]

    dict[str, jax.Array]: Dictionary mapping field names to their decompressed values.

Source code in src/fdtdx/interfaces/modules.py
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
@abstractmethod
def decompress(
    self,
    values: dict[str, jax.Array],
    state: RecordingState,
    key: jax.Array,
) -> dict[str, jax.Array]:
    """Decompress field values back to their original form.

    Args:
        values (dict[str, jax.Array]): Dictionary mapping field names to their compressed values.
        state (RecordingState): Current recording state.
        key (jax.Array): Random key for stochastic operations.

    Returns:
        dict[str, jax.Array]: Dictionary mapping field names to their decompressed values.
    """
    del (
        values,
        state,
        key,
    )
    raise NotImplementedError()

init_shapes abstractmethod

init_shapes(
    input_shape_dtypes: dict[str, ShapeDtypeStruct],
) -> tuple[
    Self,
    dict[str, jax.ShapeDtypeStruct],
    dict[str, jax.ShapeDtypeStruct],
]

Initialize shapes and sizes for the compression module.

Parameters:
  • input_shape_dtypes (dict[str, ShapeDtypeStruct]) –

    Dictionary mapping field names to their input shapes/dtypes.

Returns:
  • tuple[Self, dict[str, ShapeDtypeStruct], dict[str, ShapeDtypeStruct]]

    tuple[Self, dict[str, jax.ShapeDtypeStruct], dict[str, jax.ShapeDtypeStruct]]: Tuple containing: - Self: Updated instance of the compression module - Dictionary mapping field names to their output shapes/dtypes - Dictionary mapping field names to their state shapes/dtypes

Source code in src/fdtdx/interfaces/modules.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
@abstractmethod
def init_shapes(
    self,
    input_shape_dtypes: dict[str, jax.ShapeDtypeStruct],
) -> tuple[
    Self,
    dict[str, jax.ShapeDtypeStruct],  # data
    dict[str, jax.ShapeDtypeStruct],  # state shapes/dtypes
]:
    """Initialize shapes and sizes for the compression module.

    Args:
        input_shape_dtypes (dict[str, jax.ShapeDtypeStruct]): Dictionary mapping field names to their input
            shapes/dtypes.

    Returns:
        tuple[Self, dict[str, jax.ShapeDtypeStruct], dict[str, jax.ShapeDtypeStruct]]: Tuple containing:
            - Self: Updated instance of the compression module
            - Dictionary mapping field names to their output shapes/dtypes
            - Dictionary mapping field names to their state shapes/dtypes
    """
    del input_shape_dtypes
    raise NotImplementedError()

fdtdx.interfaces.time_filter.TimeStepFilter

Bases: TreeClass, ABC

Abstract base class for filtering and processing time steps in FDTD simulations.

This class provides an interface for filters that process simulation data at specific time steps. Implementations can perform operations like downsampling, collation, or other temporal processing of field data.

compress abstractmethod

compress(
    values: dict[str, Array],
    state: RecordingState,
    time_idx: Array,
    key: Array,
) -> tuple[dict[str, jax.Array], RecordingState]

Compress field values at a given time step.

Parameters:
  • values (dict[str, Array]) –

    Dictionary mapping field names to their values.

  • state (RecordingState) –

    Current recording state.

  • time_idx (Array) –

    Current time step index.

  • key (Array) –

    Random key for stochastic operations.

Returns:
  • tuple[dict[str, Array], RecordingState]

    tuple[dict[str, jax.Array], RecordingState]: Tuple containing: - Dictionary of compressed field values - Updated recording state

Source code in src/fdtdx/interfaces/time_filter.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
@abstractmethod
def compress(
    self,
    values: dict[str, jax.Array],
    state: RecordingState,
    time_idx: jax.Array,  # scalar
    key: jax.Array,
) -> tuple[
    dict[str, jax.Array],
    RecordingState,  # updated recording state
]:
    """Compress field values at a given time step.

    Args:
        values (dict[str, jax.Array]): Dictionary mapping field names to their values.
        state (RecordingState): Current recording state.
        time_idx (jax.Array): Current time step index.
        key (jax.Array): Random key for stochastic operations.

    Returns:
        tuple[dict[str, jax.Array], RecordingState]: Tuple containing:
            - Dictionary of compressed field values
            - Updated recording state
    """
    del values, state, time_idx, key
    raise NotImplementedError()

decompress abstractmethod

decompress(
    values: list[dict[str, Array]],
    state: RecordingState,
    arr_indices: Array,
    time_idx: Array,
    key: Array,
) -> dict[str, jax.Array]

Decompress field values to reconstruct data for a time step.

Parameters:
  • values (list[dict[str, Array]]) –

    List of dictionaries containing array values needed for reconstruction.

  • state (RecordingState) –

    Current recording state.

  • arr_indices (Array) –

    Array indices needed for reconstruction.

  • time_idx (Array) –

    Time step index to reconstruct. scalar value.

  • key (Array) –

    Random key for stochastic operations.

Returns:
  • dict[str, Array]

    dict[str, jax.Array]: Dictionary of reconstructed field values.

Source code in src/fdtdx/interfaces/time_filter.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
@abstractmethod
def decompress(
    self,
    values: list[dict[str, jax.Array]],  # array values requested above
    state: RecordingState,
    arr_indices: jax.Array,
    time_idx: jax.Array,  # scalar
    key: jax.Array,
) -> dict[str, jax.Array]:
    """Decompress field values to reconstruct data for a time step.

    Args:
        values (list[dict[str, jax.Array]]): List of dictionaries containing array values needed for reconstruction.
        state (RecordingState): Current recording state.
        arr_indices (jax.Array): Array indices needed for reconstruction.
        time_idx (jax.Array): Time step index to reconstruct. scalar value.
        key (jax.Array): Random key for stochastic operations.

    Returns:
        dict[str, jax.Array]: Dictionary of reconstructed field values.
    """
    del values, state, arr_indices, time_idx, key
    raise NotImplementedError()

indices_to_decompress abstractmethod

indices_to_decompress(time_idx: Array) -> jax.Array

Get array indices needed to reconstruct data for a given time step.

Parameters:
  • time_idx (Array) –

    Time step index to reconstruct.

Returns:
  • Array

    jax.Array: Array of indices needed to reconstruct the data for this time step.

Source code in src/fdtdx/interfaces/time_filter.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
@abstractmethod
def indices_to_decompress(
    self,
    time_idx: jax.Array,  # scalar
) -> jax.Array:  # 1d-list of array indices necessary to reconstruct
    """Get array indices needed to reconstruct data for a given time step.

    Args:
        time_idx (jax.Array): Time step index to reconstruct.

    Returns:
        jax.Array: Array of indices needed to reconstruct the data for this time step.
    """
    del time_idx
    raise NotImplementedError()

init_shapes abstractmethod

init_shapes(
    input_shape_dtypes: dict[str, ShapeDtypeStruct],
    time_steps_max: int,
) -> tuple[
    Self,
    int,
    dict[str, jax.ShapeDtypeStruct],
    dict[str, jax.ShapeDtypeStruct],
]

Initialize shapes and sizes for the time step filter.

Parameters:
  • input_shape_dtypes (dict[str, ShapeDtypeStruct]) –

    Dictionary mapping field names to their shape/dtype information.

  • time_steps_max (int) –

    Maximum number of time steps in the simulation.

Returns:
  • tuple[Self, int, dict[str, ShapeDtypeStruct], dict[str, ShapeDtypeStruct]]

    tuple[Self, int, dict[str, jax.ShapeDtypeStruct], dict[str, jax.ShapeDtypeStruct]]: A tuple containing: - Updated filter instance - Size of array for storing filtered data - Dictionary of data shapes/dtypes - Dictionary of state shapes/dtypes

Source code in src/fdtdx/interfaces/time_filter.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
@abstractmethod
def init_shapes(
    self,
    input_shape_dtypes: dict[str, jax.ShapeDtypeStruct],
    time_steps_max: int,  # maximum number of time steps
) -> tuple[
    Self,
    int,  # array size (number of latent time steps)
    dict[str, jax.ShapeDtypeStruct],  # data
    dict[str, jax.ShapeDtypeStruct],  # state shapes
]:
    """Initialize shapes and sizes for the time step filter.

    Args:
        input_shape_dtypes (dict[str, jax.ShapeDtypeStruct]): Dictionary mapping field names to their
            shape/dtype information.
        time_steps_max (int): Maximum number of time steps in the simulation.

    Returns:
        tuple[Self, int, dict[str, jax.ShapeDtypeStruct], dict[str, jax.ShapeDtypeStruct]]: A tuple containing:
            - Updated filter instance
            - Size of array for storing filtered data
            - Dictionary of data shapes/dtypes
            - Dictionary of state shapes/dtypes
    """
    del input_shape_dtypes, time_steps_max
    raise NotImplementedError()

time_to_array_index abstractmethod

time_to_array_index(time_idx: int) -> int

Convert a time step index to its corresponding array index.

Parameters:
  • time_idx (int) –

    Time step index to convert.

Returns:
  • int( int ) –

    The corresponding array index if the time step is not filtered, or -1 if the time step is filtered out.

Source code in src/fdtdx/interfaces/time_filter.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
@abstractmethod
def time_to_array_index(
    self,
    time_idx: int,  # scalar
) -> int:  # array index if not filtered, else -1
    """Convert a time step index to its corresponding array index.

    Args:
        time_idx (int): Time step index to convert.

    Returns:
        int: The corresponding array index if the time step is not filtered,
            or -1 if the time step is filtered out.
    """
    del time_idx
    raise NotImplementedError()