Detector Base Class

fdtdx.objects.detectors.detector.Detector

Bases: SimulationObject, ABC

Base class for electromagnetic field detectors in FDTD simulations.

This class provides core functionality for recording and analyzing electromagnetic field data during FDTD simulations. It supports flexible timing control, data collection intervals, and visualization of results.

Attributes:
  • dtype (dtype) –

    Data type for detector arrays, defaults to float32.

  • exact_interpolation (bool) –

    Whether to use exact field interpolation. Defaults to True.

  • inverse (bool) –

    Whether to record fields in inverse time order. Defaults to false.

  • switch (OnOffSwitch) –

    This switch controls the time steps that the detector is on, i.e. records data. Defaults to all time steps.

  • plot (bool) –

    Whether to generate plots of recorded data. Defaults to true.

  • if_inverse_plot_backwards (bool) –

    Plot inverse data in reverse time order.

  • num_video_workers (int | None) –

    Number of workers for video generation. If None (default), then no multiprocessing is used. Note that the combination of multiprocessing and matplotlib is known to produce problems and can cause the entire system to freeze. It does make the video generation much faster though.

  • color (tuple[float, float, float] | None) –

    RGB color for plotting. Defaults to light green.

  • plot_interpolation (str) –

    Interpolation method for plots. Defualts to "gaussian".

  • plot_dpi (int | None) –

    DPI resolution for plots. Defaults to None.

num_time_steps_recorded property

num_time_steps_recorded: int

Gets the total number of time steps that will be recorded.

Returns:
  • int( int ) –

    Number of time steps where detector will record data.

Raises:
  • Exception

    If detector is not yet initialized.

draw_plot

draw_plot(
    state: dict[str, ndarray],
    progress: Progress | None = None,
) -> dict[str, Figure | str]

Generates plots or videos from recorded detector data.

Creates visualizations based on dimensionality of recorded data and detector configuration. Supports 1D line plots, 2D heatmaps, and video generation for time-varying data.

Parameters:
  • state (dict[str, ndarray]) –

    Dictionary containing recorded field data arrays.

  • progress (Progress | None, default: None ) –

    Optional progress bar for video generation.

Returns:
  • dict[str, Figure | str]

    dict[str, Figure | str]: Dictionary mapping plot names to either matplotlib Figure objects or paths to generated video files.

Source code in src/fdtdx/objects/detectors/detector.py
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
def draw_plot(
    self,
    state: dict[str, np.ndarray],
    progress: Progress | None = None,
) -> dict[str, Figure | str]:
    """Generates plots or videos from recorded detector data.

    Creates visualizations based on dimensionality of recorded data and detector
    configuration. Supports 1D line plots, 2D heatmaps, and video generation
    for time-varying data.

    Args:
        state (dict[str, np.ndarray]): Dictionary containing recorded field data arrays.
        progress (Progress | None, optional): Optional progress bar for video generation.

    Returns:
        dict[str, Figure | str]: Dictionary mapping plot names to either
            matplotlib Figure objects or paths to generated video files.
    """
    squeezed_arrs = {}
    squeezed_ndim = None
    for k, v in state.items():
        v_squeezed = v.squeeze()
        if self.inverse and self.if_inverse_plot_backwards and self.num_time_steps_recorded > 1:
            squeezed_arrs[k] = v_squeezed[::-1, ...]
        else:
            squeezed_arrs[k] = v_squeezed
        if squeezed_ndim is None:
            squeezed_ndim = len(v_squeezed.shape)
        else:
            if len(v_squeezed.shape) != squeezed_ndim:
                raise Exception("Cannot plot multiple arrays with different ndim")
    if squeezed_ndim is None:
        raise Exception(f"empty state: {state}")

    figs = {}
    if squeezed_ndim == 1 and self.num_time_steps_recorded > 1:
        # do line plot
        time_steps = np.where(np.asarray(self._is_on_at_time_step_arr))[0]
        time_steps = time_steps * self._config.time_step_duration
        for k, v in squeezed_arrs.items():
            fig = plot_line_over_time(arr=v, time_steps=time_steps.tolist(), metric_name=f"{self.name}: {k}")
            figs[k] = fig
    elif squeezed_ndim == 1 and self.num_time_steps_recorded == 1:
        SCALE = 10
        xlabel = None
        if self.grid_shape[0] > 1 and self.grid_shape[1] <= 1 and self.grid_shape[2] <= 1:
            xlabel = "X axis (μm)"
        elif self.grid_shape[0] <= 1 and self.grid_shape[1] > 1 and self.grid_shape[2] <= 1:
            xlabel = "Y axis (μm)"
        elif self.grid_shape[0] <= 1 and self.grid_shape[1] <= 1 and self.grid_shape[2] > 1:
            xlabel = "Z axis (μm)"
        assert xlabel is not None, "This should never happen"
        for k, v in squeezed_arrs.items():
            spatial_axis = np.arange(len(v)) / SCALE
            fig = plot_line_over_time(
                arr=v, time_steps=spatial_axis, metric_name=f"{self.name}: {k}", xlabel=xlabel
            )
            figs[k] = fig
    elif squeezed_ndim == 2 and self.num_time_steps_recorded > 1:
        # multiple time steps, 1d spatial data - visualize as 2D waterfall plot
        time_steps = np.where(np.asarray(self._is_on_at_time_step_arr))[0]
        time_steps = time_steps * self._config.time_step_duration

        # Determine spatial axis based on which dimension has size > 1
        SCALE = 10  # μm per grid point

        for k, v in squeezed_arrs.items():
            # Determine which dimension is spatial (not time)
            spatial_dim = 1 if v.shape[1] > 1 else 0
            if spatial_dim == 0:
                # Transpose if needed so time is always first dimension
                v = v.T

            # Create spatial axis in μm
            spatial_points = np.arange(v.shape[1]) / SCALE

            fig = plot_waterfall_over_time(
                arr=v,
                time_steps=time_steps,
                spatial_steps=spatial_points,
                metric_name=f"{self.name}: {k}",
                spatial_unit="μm",
            )
            figs[k] = fig
    elif squeezed_ndim == 2 and self.num_time_steps_recorded == 1:
        # single time step, 2d-plot  # TODO:
        if all([x in squeezed_arrs.keys() for x in ["XY Plane", "XZ Plane", "YZ Plane"]]):
            fig = plot_2d_from_slices(
                xy_slice=squeezed_arrs["XY Plane"],
                xz_slice=squeezed_arrs["XZ Plane"],
                yz_slice=squeezed_arrs["YZ Plane"],
                resolutions=(
                    self._config.resolution,
                    self._config.resolution,
                    self._config.resolution,
                ),
                plot_dpi=self.plot_dpi,
                plot_interpolation=self.plot_interpolation,
            )
            figs["sliced_plot"] = fig
        else:
            raise Exception(f"Cannot plot {squeezed_arrs.keys()}")
    elif squeezed_ndim == 3 and self.num_time_steps_recorded > 1:
        # multiple time steps, 2d-plots
        if all([x in squeezed_arrs.keys() for x in ["XY Plane", "XZ Plane", "YZ Plane"]]):
            path = generate_video_from_slices(
                plt_fn=plot_from_slices,
                xy_slice=squeezed_arrs["XY Plane"],
                xz_slice=squeezed_arrs["XZ Plane"],
                yz_slice=squeezed_arrs["YZ Plane"],
                progress=progress,
                num_worker=self.num_video_workers,
                resolutions=(
                    self._config.resolution,
                    self._config.resolution,
                    self._config.resolution,
                ),
                plot_dpi=self.plot_dpi,
                plot_interpolation=self.plot_interpolation,
            )
            figs["sliced_video"] = path
        else:
            raise Exception(
                f"Cannot plot {squeezed_arrs.keys()}. "
                f"Consider setting plot=False for Object {self.name} ({self.__class__=})"
            )
    elif squeezed_ndim == 3 and self.num_time_steps_recorded == 1:
        # single step, 3d-plot. # TODO: do as mean over planes
        for k, v in squeezed_arrs.items():
            xy_slice = squeezed_arrs[k].mean(axis=0)
            xz_slice = squeezed_arrs[k].mean(axis=1)
            yz_slice = squeezed_arrs[k].mean(axis=2)
            fig = plot_2d_from_slices(
                xy_slice=xy_slice,
                xz_slice=xz_slice,
                yz_slice=yz_slice,
                resolutions=(
                    self._config.resolution,
                    self._config.resolution,
                    self._config.resolution,
                ),
                plot_dpi=self.plot_dpi,
                plot_interpolation=self.plot_interpolation,
            )
            figs[k] = fig
    elif squeezed_ndim == 4 and self.num_time_steps_recorded > 1:
        # video with 3d-volume in each time step. plot as slices
        for k, v in squeezed_arrs.items():
            xy_slice = squeezed_arrs[k].mean(axis=1)
            xz_slice = squeezed_arrs[k].mean(axis=2)
            yz_slice = squeezed_arrs[k].mean(axis=3)
            path = generate_video_from_slices(
                plt_fn=plot_from_slices,
                xy_slice=xy_slice,
                xz_slice=xz_slice,
                yz_slice=yz_slice,
                progress=progress,
                num_worker=self.num_video_workers,
                resolutions=(
                    self._config.resolution,
                    self._config.resolution,
                    self._config.resolution,
                ),
                plot_dpi=self.plot_dpi,
                plot_interpolation=self.plot_interpolation,
            )
            figs[k] = path
    else:
        raise Exception("Cannot plot detector with more than three dimensions")
    return figs

update abstractmethod

update(
    time_step: Array,
    E: Array,
    H: Array,
    state: DetectorState,
    inv_permittivity: Array,
    inv_permeability: Array | float,
) -> DetectorState

Updates detector state with current field values.

Parameters:
  • time_step (Array) –

    Current simulation time step.

  • E (Array) –

    Electric field array.

  • H (Array) –

    Magnetic field array.

  • state (DetectorState) –

    Current detector state.

  • inv_permittivity (Array) –

    Inverse permittivity array.

  • inv_permeability (Array | float) –

    Inverse permeability array.

Returns:
  • DetectorState( DetectorState ) –

    Updated detector state.

Source code in src/fdtdx/objects/detectors/detector.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
@abstractmethod
def update(
    self,
    time_step: jax.Array,
    E: jax.Array,
    H: jax.Array,
    state: DetectorState,
    inv_permittivity: jax.Array,
    inv_permeability: jax.Array | float,
) -> DetectorState:
    """Updates detector state with current field values.

    Args:
        time_step (jax.Array): Current simulation time step.
        E (jax.Array): Electric field array.
        H (jax.Array): Magnetic field array.
        state (DetectorState): Current detector state.
        inv_permittivity (jax.Array): Inverse permittivity array.
        inv_permeability (jax.Array | float): Inverse permeability array.

    Returns:
        DetectorState: Updated detector state.
    """
    del (
        time_step,
        E,
        H,
        state,
        inv_permittivity,
        inv_permeability,
    )
    raise NotImplementedError()