Bases: TreeClass
, ABC
Abstract base class for filtering and processing time steps in FDTD simulations.
This class provides an interface for filters that process simulation data at specific
time steps. Implementations can perform operations like downsampling, collation, or
other temporal processing of field data.
compress
abstractmethod
compress(
values: dict[str, Array],
state: RecordingState,
time_idx: Array,
key: Array,
) -> tuple[dict[str, jax.Array], RecordingState]
Compress field values at a given time step.
Parameters: |
-
values
(dict[str, Array] )
–
Dictionary mapping field names to their values.
-
state
(RecordingState )
–
-
time_idx
(Array )
–
-
key
(Array )
–
Random key for stochastic operations.
|
Returns: |
-
tuple[dict[str, Array], RecordingState]
–
tuple[dict[str, jax.Array], RecordingState]: Tuple containing:
- Dictionary of compressed field values
- Updated recording state
|
Source code in src/fdtdx/interfaces/time_filter.py
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96 | @abstractmethod
def compress(
self,
values: dict[str, jax.Array],
state: RecordingState,
time_idx: jax.Array, # scalar
key: jax.Array,
) -> tuple[
dict[str, jax.Array],
RecordingState, # updated recording state
]:
"""Compress field values at a given time step.
Args:
values (dict[str, jax.Array]): Dictionary mapping field names to their values.
state (RecordingState): Current recording state.
time_idx (jax.Array): Current time step index.
key (jax.Array): Random key for stochastic operations.
Returns:
tuple[dict[str, jax.Array], RecordingState]: Tuple containing:
- Dictionary of compressed field values
- Updated recording state
"""
del values, state, time_idx, key
raise NotImplementedError()
|
decompress
abstractmethod
decompress(
values: list[dict[str, Array]],
state: RecordingState,
arr_indices: Array,
time_idx: Array,
key: Array,
) -> dict[str, jax.Array]
Decompress field values to reconstruct data for a time step.
Parameters: |
-
values
(list[dict[str, Array]] )
–
List of dictionaries containing array values needed for reconstruction.
-
state
(RecordingState )
–
-
arr_indices
(Array )
–
Array indices needed for reconstruction.
-
time_idx
(Array )
–
Time step index to reconstruct. scalar value.
-
key
(Array )
–
Random key for stochastic operations.
|
Returns: |
-
dict[str, Array]
–
dict[str, jax.Array]: Dictionary of reconstructed field values.
|
Source code in src/fdtdx/interfaces/time_filter.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136 | @abstractmethod
def decompress(
self,
values: list[dict[str, jax.Array]], # array values requested above
state: RecordingState,
arr_indices: jax.Array,
time_idx: jax.Array, # scalar
key: jax.Array,
) -> dict[str, jax.Array]:
"""Decompress field values to reconstruct data for a time step.
Args:
values (list[dict[str, jax.Array]]): List of dictionaries containing array values needed for reconstruction.
state (RecordingState): Current recording state.
arr_indices (jax.Array): Array indices needed for reconstruction.
time_idx (jax.Array): Time step index to reconstruct. scalar value.
key (jax.Array): Random key for stochastic operations.
Returns:
dict[str, jax.Array]: Dictionary of reconstructed field values.
"""
del values, state, arr_indices, time_idx, key
raise NotImplementedError()
|
indices_to_decompress
abstractmethod
indices_to_decompress(time_idx: Array) -> jax.Array
Get array indices needed to reconstruct data for a given time step.
Parameters: |
-
time_idx
(Array )
–
Time step index to reconstruct.
|
Returns: |
-
Array
–
jax.Array: Array of indices needed to reconstruct the data for this time step.
|
Source code in src/fdtdx/interfaces/time_filter.py
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112 | @abstractmethod
def indices_to_decompress(
self,
time_idx: jax.Array, # scalar
) -> jax.Array: # 1d-list of array indices necessary to reconstruct
"""Get array indices needed to reconstruct data for a given time step.
Args:
time_idx (jax.Array): Time step index to reconstruct.
Returns:
jax.Array: Array of indices needed to reconstruct the data for this time step.
"""
del time_idx
raise NotImplementedError()
|
init_shapes
abstractmethod
init_shapes(
input_shape_dtypes: dict[str, ShapeDtypeStruct],
time_steps_max: int,
) -> tuple[
Self,
int,
dict[str, jax.ShapeDtypeStruct],
dict[str, jax.ShapeDtypeStruct],
]
Initialize shapes and sizes for the time step filter.
Parameters: |
-
input_shape_dtypes
(dict[str, ShapeDtypeStruct] )
–
Dictionary mapping field names to their
shape/dtype information.
-
time_steps_max
(int )
–
Maximum number of time steps in the simulation.
|
Returns: |
-
tuple[Self, int, dict[str, ShapeDtypeStruct], dict[str, ShapeDtypeStruct]]
–
tuple[Self, int, dict[str, jax.ShapeDtypeStruct], dict[str, jax.ShapeDtypeStruct]]: A tuple containing:
- Updated filter instance
- Size of array for storing filtered data
- Dictionary of data shapes/dtypes
- Dictionary of state shapes/dtypes
|
Source code in src/fdtdx/interfaces/time_filter.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52 | @abstractmethod
def init_shapes(
self,
input_shape_dtypes: dict[str, jax.ShapeDtypeStruct],
time_steps_max: int, # maximum number of time steps
) -> tuple[
Self,
int, # array size (number of latent time steps)
dict[str, jax.ShapeDtypeStruct], # data
dict[str, jax.ShapeDtypeStruct], # state shapes
]:
"""Initialize shapes and sizes for the time step filter.
Args:
input_shape_dtypes (dict[str, jax.ShapeDtypeStruct]): Dictionary mapping field names to their
shape/dtype information.
time_steps_max (int): Maximum number of time steps in the simulation.
Returns:
tuple[Self, int, dict[str, jax.ShapeDtypeStruct], dict[str, jax.ShapeDtypeStruct]]: A tuple containing:
- Updated filter instance
- Size of array for storing filtered data
- Dictionary of data shapes/dtypes
- Dictionary of state shapes/dtypes
"""
del input_shape_dtypes, time_steps_max
raise NotImplementedError()
|
time_to_array_index
abstractmethod
time_to_array_index(time_idx: int) -> int
Convert a time step index to its corresponding array index.
Parameters: |
-
time_idx
(int )
–
Time step index to convert.
|
Returns: |
-
int ( int
) –
The corresponding array index if the time step is not filtered,
or -1 if the time step is filtered out.
|
Source code in src/fdtdx/interfaces/time_filter.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69 | @abstractmethod
def time_to_array_index(
self,
time_idx: int, # scalar
) -> int: # array index if not filtered, else -1
"""Convert a time step index to its corresponding array index.
Args:
time_idx (int): Time step index to convert.
Returns:
int: The corresponding array index if the time step is not filtered,
or -1 if the time step is filtered out.
"""
del time_idx
raise NotImplementedError()
|