Skip to content

Utility Functions

fdtdx.utils.Logger

Logger for managing experiment outputs and visualization.

Handles experiment logging, metrics tracking, and visualization of simulation results. Creates a working directory structure, initializes logging, and provides methods for saving figures, metrics, and device parameters.

Parameters:

Name Type Description Default
experiment_name str

Name of the experiment

required
name str | None

Optional specific name for the working directory. If None, uses timestamp.

None
Source code in src/fdtdx/utils/logger.py
class Logger:
    """Logger for managing experiment outputs and visualization.

    Handles experiment logging, metrics tracking, and visualization of simulation results.
    Creates a working directory structure, initializes logging, and provides methods for
    saving figures, metrics, and device parameters.

    Args:
        experiment_name: Name of the experiment
        name: Optional specific name for the working directory. If None, uses timestamp.
    """

    def __init__(self, experiment_name: str, name: str | None = None):
        sns.set_theme(context="paper", style="white", palette="colorblind")
        self.cwd = init_working_directory(experiment_name, wd_name=name)
        self.console = Console()
        self.progress = Progress(
            SpinnerColumn(),
            *Progress.get_default_columns(),
            TimeElapsedColumn(),
            console=self.console,
        ).__enter__()
        atexit.register(self.progress.stop)
        logger.remove()
        logger.add(
            self.console.print,
            level="TRACE",
            format=_log_formatter,
            colorize=True,
        )
        logger.add(
            self.cwd / "logs.log",
            level="TRACE",
            format="{time:DD.MM.YYYY HH:mm:ss:ssss} | {level} - {message}",
        )
        logger.info(f"Starting experiment {experiment_name} in {self.cwd}")
        snapshot_python_files(self.cwd / "code")
        self.fieldnames = None
        self.writer = None
        self.csvfile = open(self.cwd / "metrics.csv", "w", newline="")
        self.last_indices: dict[str, jax.Array | None] = defaultdict(lambda: None)
        atexit.register(self.csvfile.close)

    @property
    def stl_dir(self) -> Path:
        """Directory for storing STL files.

        Returns:
            Path: Directory for STL file outputs
        """
        directory = self.cwd / "device" / "stl"
        directory.mkdir(parents=True, exist_ok=True)
        return directory

    @property
    def params_dir(self) -> Path:
        """Directory for storing parameter files.

        Returns:
            Path: Directory for parameter file outputs
        """
        directory = self.cwd / "device" / "params"
        directory.mkdir(parents=True, exist_ok=True)
        return directory

    def savefig(self, directory: Path, filename: str, fig: Figure, dpi: int = 300):
        """Save a matplotlib figure to file.

        Creates a figures subdirectory if needed and saves the figure with specified settings.

        Args:
            directory: Base directory to save in
            filename: Name for the figure file
            fig: Matplotlib figure to save
            dpi: Resolution in dots per inch
        """
        figure_directory = directory / "figures"
        figure_directory.mkdir(parents=True, exist_ok=True)
        fig.savefig(directory / "figures" / filename, dpi=dpi, bbox_inches="tight")
        plt.close(fig)

    def write(self, stats: dict, do_print: bool = True):
        """Write statistics to CSV file and optionally print them.

        Records metrics in a CSV file and optionally displays them in a formatted table.
        Automatically initializes CSV headers on first write.

        Args:
            stats: Dictionary of statistics to record
            do_print: Whether to print stats to console
        """
        stats = {
            k: v.item() if isinstance(v, jax.Array) else v
            for k, v in stats.items()
            if isinstance(v, (int, float)) or (isinstance(v, jax.Array) and v.size == 1)
        }
        if self.fieldnames is None:
            self.fieldnames = list(stats.keys())
            self.writer = csv.DictWriter(self.csvfile, fieldnames=self.fieldnames)
            self.writer.writeheader()
        assert self.writer is not None
        self.writer.writerow(stats)
        self.csvfile.flush()
        if do_print:
            table = Table(box=None)
            for k, v in stats.items():
                table.add_column(k)
                table.add_column(str(v))
            self.console.print(table)

    def log_detectors(
        self,
        iter_idx: int,
        objects: ObjectContainer,
        detector_states: dict[str, DetectorState],
        exclude: list[str] = [],
    ):
        """Log detector states and generate visualization plots.

        Creates plots for each detector's state and saves them to the detector's output directory.
        Handles both figure outputs and other detector-specific file formats.

        Args:
            iter_idx: Current iteration index
            objects: Container with simulation objects
            detector_states: Dictionary mapping detector names to their states
            exclude: List of detector names to exclude from logging
        """
        for detector in [d for d in objects.detectors if d.name not in exclude]:
            cur_state = jax.device_get(detector_states[detector.name])
            cur_state = cast_floating_to_numpy(cur_state, float)

            if not detector.plot:
                continue
            figure_dict = detector.draw_plot(
                state=cur_state,
                progress=self.progress,
            )

            detector_dir = self.cwd / "detectors" / detector.name
            detector_dir.mkdir(parents=True, exist_ok=True)

            for k, v in figure_dict.items():
                if isinstance(v, Figure):
                    self.savefig(
                        detector_dir,
                        f"{detector.name}_{k}_{iter_idx}.png",
                        v,
                        dpi=detector.plot_dpi,  # type: ignore
                    )
                elif isinstance(v, str):
                    shutil.copy(
                        v,
                        detector_dir / f"{detector.name}_{k}_{iter_idx}{Path(v).suffix}",
                    )
                else:
                    raise Exception(f"invalid detector output for plotting: {k}, {v}")

    def log_params(
        self,
        iter_idx: int,
        params: ParameterContainer,
        objects: ObjectContainer,
        export_figure: bool = False,
        export_stl: bool = False,
        export_air_stl: bool = False,
    ) -> int:
        """Log parameter states and export device visualizations.

        Saves device parameters and optionally exports visualizations as figures or STL files.
        Tracks changes in device voxels between iterations.

        Args:
            iter_idx: Current iteration index
            params: Container with device parameters
            objects: Container with simulation objects
            export_figure: Whether to export index matrix figures
            export_stl: Whether to export device geometry as STL
            export_air_stl: Whether to export air regions as STL

        Returns:
            int: Number of voxels that changed since last iteration
        """
        changed_voxels = 0
        for device in objects.devices:
            device_params = params[device.name]
            indices = device.get_material_mapping(device_params)

            # raw parameters and indices
            if isinstance(device_params, dict):
                for k, v in device_params.items():
                    jnp.save(self.params_dir / f"params_{iter_idx}_{device.name}_{k}.npy", v)
            else:
                jnp.save(self.params_dir / f"params_{iter_idx}_{device.name}.npy", device_params)
            jnp.save(self.params_dir / f"matrix_{iter_idx}_{device.name}.npy", indices)

            if not isinstance(device, DiscreteDevice):
                continue
            has_previous = self.last_indices[device.name] is not None
            cur_changed_voxels = 0
            if has_previous:
                last_device_indices = self.last_indices[device.name]
                cur_changed_voxels = int(jnp.sum(indices != last_device_indices))
            changed_voxels += cur_changed_voxels
            self.last_indices[device.name] = indices
            if cur_changed_voxels == 0 and has_previous:
                continue
            if export_stl:
                air_name = get_air_name(device.material)
                ordered_name_list = compute_ordered_names(device.material)
                air_idx = ordered_name_list.index(air_name)
                for idx in range(len(device.material)):
                    if idx == air_idx and not export_air_stl:
                        continue
                    name = ordered_name_list[idx]
                    export_stl_fn(
                        matrix=np.asarray(indices) == idx,
                        stl_filename=self.stl_dir / f"matrix_{iter_idx}_{device.name}_{name}.stl",
                        voxel_grid_size=device.single_voxel_grid_shape,
                    )
                if len(device.material) > 2:
                    export_stl_fn(
                        matrix=np.asarray(indices) != air_idx,
                        stl_filename=self.stl_dir / f"matrix_{iter_idx}_{device.name}_non_air.stl",
                        voxel_grid_size=device.single_voxel_grid_shape,
                    )

            # image of indices
            if export_figure:
                fig = device_matrix_index_figure(
                    device_matrix_indices=indices,
                    material=device.material,
                )
                self.savefig(
                    self.cwd / "device",
                    f"matrix_indices_{iter_idx}_{device.name}.png",
                    fig,
                    dpi=72,
                )

        return changed_voxels

params_dir: Path property

Directory for storing parameter files.

Returns:

Name Type Description
Path Path

Directory for parameter file outputs

stl_dir: Path property

Directory for storing STL files.

Returns:

Name Type Description
Path Path

Directory for STL file outputs

log_detectors(iter_idx, objects, detector_states, exclude=[])

Log detector states and generate visualization plots.

Creates plots for each detector's state and saves them to the detector's output directory. Handles both figure outputs and other detector-specific file formats.

Parameters:

Name Type Description Default
iter_idx int

Current iteration index

required
objects ObjectContainer

Container with simulation objects

required
detector_states dict[str, DetectorState]

Dictionary mapping detector names to their states

required
exclude list[str]

List of detector names to exclude from logging

[]
Source code in src/fdtdx/utils/logger.py
def log_detectors(
    self,
    iter_idx: int,
    objects: ObjectContainer,
    detector_states: dict[str, DetectorState],
    exclude: list[str] = [],
):
    """Log detector states and generate visualization plots.

    Creates plots for each detector's state and saves them to the detector's output directory.
    Handles both figure outputs and other detector-specific file formats.

    Args:
        iter_idx: Current iteration index
        objects: Container with simulation objects
        detector_states: Dictionary mapping detector names to their states
        exclude: List of detector names to exclude from logging
    """
    for detector in [d for d in objects.detectors if d.name not in exclude]:
        cur_state = jax.device_get(detector_states[detector.name])
        cur_state = cast_floating_to_numpy(cur_state, float)

        if not detector.plot:
            continue
        figure_dict = detector.draw_plot(
            state=cur_state,
            progress=self.progress,
        )

        detector_dir = self.cwd / "detectors" / detector.name
        detector_dir.mkdir(parents=True, exist_ok=True)

        for k, v in figure_dict.items():
            if isinstance(v, Figure):
                self.savefig(
                    detector_dir,
                    f"{detector.name}_{k}_{iter_idx}.png",
                    v,
                    dpi=detector.plot_dpi,  # type: ignore
                )
            elif isinstance(v, str):
                shutil.copy(
                    v,
                    detector_dir / f"{detector.name}_{k}_{iter_idx}{Path(v).suffix}",
                )
            else:
                raise Exception(f"invalid detector output for plotting: {k}, {v}")

log_params(iter_idx, params, objects, export_figure=False, export_stl=False, export_air_stl=False)

Log parameter states and export device visualizations.

Saves device parameters and optionally exports visualizations as figures or STL files. Tracks changes in device voxels between iterations.

Parameters:

Name Type Description Default
iter_idx int

Current iteration index

required
params ParameterContainer

Container with device parameters

required
objects ObjectContainer

Container with simulation objects

required
export_figure bool

Whether to export index matrix figures

False
export_stl bool

Whether to export device geometry as STL

False
export_air_stl bool

Whether to export air regions as STL

False

Returns:

Name Type Description
int int

Number of voxels that changed since last iteration

Source code in src/fdtdx/utils/logger.py
def log_params(
    self,
    iter_idx: int,
    params: ParameterContainer,
    objects: ObjectContainer,
    export_figure: bool = False,
    export_stl: bool = False,
    export_air_stl: bool = False,
) -> int:
    """Log parameter states and export device visualizations.

    Saves device parameters and optionally exports visualizations as figures or STL files.
    Tracks changes in device voxels between iterations.

    Args:
        iter_idx: Current iteration index
        params: Container with device parameters
        objects: Container with simulation objects
        export_figure: Whether to export index matrix figures
        export_stl: Whether to export device geometry as STL
        export_air_stl: Whether to export air regions as STL

    Returns:
        int: Number of voxels that changed since last iteration
    """
    changed_voxels = 0
    for device in objects.devices:
        device_params = params[device.name]
        indices = device.get_material_mapping(device_params)

        # raw parameters and indices
        if isinstance(device_params, dict):
            for k, v in device_params.items():
                jnp.save(self.params_dir / f"params_{iter_idx}_{device.name}_{k}.npy", v)
        else:
            jnp.save(self.params_dir / f"params_{iter_idx}_{device.name}.npy", device_params)
        jnp.save(self.params_dir / f"matrix_{iter_idx}_{device.name}.npy", indices)

        if not isinstance(device, DiscreteDevice):
            continue
        has_previous = self.last_indices[device.name] is not None
        cur_changed_voxels = 0
        if has_previous:
            last_device_indices = self.last_indices[device.name]
            cur_changed_voxels = int(jnp.sum(indices != last_device_indices))
        changed_voxels += cur_changed_voxels
        self.last_indices[device.name] = indices
        if cur_changed_voxels == 0 and has_previous:
            continue
        if export_stl:
            air_name = get_air_name(device.material)
            ordered_name_list = compute_ordered_names(device.material)
            air_idx = ordered_name_list.index(air_name)
            for idx in range(len(device.material)):
                if idx == air_idx and not export_air_stl:
                    continue
                name = ordered_name_list[idx]
                export_stl_fn(
                    matrix=np.asarray(indices) == idx,
                    stl_filename=self.stl_dir / f"matrix_{iter_idx}_{device.name}_{name}.stl",
                    voxel_grid_size=device.single_voxel_grid_shape,
                )
            if len(device.material) > 2:
                export_stl_fn(
                    matrix=np.asarray(indices) != air_idx,
                    stl_filename=self.stl_dir / f"matrix_{iter_idx}_{device.name}_non_air.stl",
                    voxel_grid_size=device.single_voxel_grid_shape,
                )

        # image of indices
        if export_figure:
            fig = device_matrix_index_figure(
                device_matrix_indices=indices,
                material=device.material,
            )
            self.savefig(
                self.cwd / "device",
                f"matrix_indices_{iter_idx}_{device.name}.png",
                fig,
                dpi=72,
            )

    return changed_voxels

savefig(directory, filename, fig, dpi=300)

Save a matplotlib figure to file.

Creates a figures subdirectory if needed and saves the figure with specified settings.

Parameters:

Name Type Description Default
directory Path

Base directory to save in

required
filename str

Name for the figure file

required
fig Figure

Matplotlib figure to save

required
dpi int

Resolution in dots per inch

300
Source code in src/fdtdx/utils/logger.py
def savefig(self, directory: Path, filename: str, fig: Figure, dpi: int = 300):
    """Save a matplotlib figure to file.

    Creates a figures subdirectory if needed and saves the figure with specified settings.

    Args:
        directory: Base directory to save in
        filename: Name for the figure file
        fig: Matplotlib figure to save
        dpi: Resolution in dots per inch
    """
    figure_directory = directory / "figures"
    figure_directory.mkdir(parents=True, exist_ok=True)
    fig.savefig(directory / "figures" / filename, dpi=dpi, bbox_inches="tight")
    plt.close(fig)

write(stats, do_print=True)

Write statistics to CSV file and optionally print them.

Records metrics in a CSV file and optionally displays them in a formatted table. Automatically initializes CSV headers on first write.

Parameters:

Name Type Description Default
stats dict

Dictionary of statistics to record

required
do_print bool

Whether to print stats to console

True
Source code in src/fdtdx/utils/logger.py
def write(self, stats: dict, do_print: bool = True):
    """Write statistics to CSV file and optionally print them.

    Records metrics in a CSV file and optionally displays them in a formatted table.
    Automatically initializes CSV headers on first write.

    Args:
        stats: Dictionary of statistics to record
        do_print: Whether to print stats to console
    """
    stats = {
        k: v.item() if isinstance(v, jax.Array) else v
        for k, v in stats.items()
        if isinstance(v, (int, float)) or (isinstance(v, jax.Array) and v.size == 1)
    }
    if self.fieldnames is None:
        self.fieldnames = list(stats.keys())
        self.writer = csv.DictWriter(self.csvfile, fieldnames=self.fieldnames)
        self.writer.writeheader()
    assert self.writer is not None
    self.writer.writerow(stats)
    self.csvfile.flush()
    if do_print:
        table = Table(box=None)
        for k, v in stats.items():
            table.add_column(k)
            table.add_column(str(v))
        self.console.print(table)

A logger which can automatically record important metrics during optimization. For a detailed guide on the usage, see the example scripts in the github repository.

fdtdx.utils.plot_setup.plot_setup(config, objects, exclude_object_list=[], filename=None, axs=None, plot_legend=True, exclude_xy_plane_object_list=[], exclude_yz_plane_object_list=[], exclude_xz_plane_object_list=[])

Creates a visualization of the simulation setup showing objects in XY, XZ and YZ planes.

Generates three subplots showing cross-sections of the simulation volume and the objects within it. Objects are drawn as colored rectangles with optional legends. The visualization helps verify the correct positioning and sizing of objects in the simulation setup.

Parameters:

Name Type Description Default
config SimulationConfig

Configuration object containing simulation parameters like resolution

required
objects ObjectContainer

Container holding all simulation objects to be plotted

required
exclude_object_list list[SimulationObject]

List of objects to exclude from all plots

[]
filename str | Path | None

If provided, saves the plot to this file instead of displaying

None
axs Sequence[Any] | None

Optional matplotlib axes to plot on. If None, creates new figure

None
plot_legend bool

Whether to add a legend showing object names/types

True
exclude_xy_plane_object_list list[SimulationObject]

Objects to exclude from XY plane plot

[]
exclude_yz_plane_object_list list[SimulationObject]

Objects to exclude from YZ plane plot

[]
exclude_xz_plane_object_list list[SimulationObject]

Objects to exclude from XZ plane plot

[]

Returns:

Type Description
Figure

matplotlib.figure.Figure: The generated figure object

Note

The plots show object positions in micrometers, converting from simulation units. PML objects are automatically excluded from their respective boundary planes.

Source code in src/fdtdx/utils/plot_setup.py
def plot_setup(
    config: SimulationConfig,
    objects: ObjectContainer,
    exclude_object_list: list[SimulationObject] = [],
    filename: str | Path | None = None,
    axs: Sequence[Any] | None = None,
    plot_legend: bool = True,
    exclude_xy_plane_object_list: list[SimulationObject] = [],
    exclude_yz_plane_object_list: list[SimulationObject] = [],
    exclude_xz_plane_object_list: list[SimulationObject] = [],
) -> Figure:
    """Creates a visualization of the simulation setup showing objects in XY, XZ and YZ planes.

    Generates three subplots showing cross-sections of the simulation volume and the objects
    within it. Objects are drawn as colored rectangles with optional legends. The visualization
    helps verify the correct positioning and sizing of objects in the simulation setup.

    Args:
        config: Configuration object containing simulation parameters like resolution
        objects: Container holding all simulation objects to be plotted
        exclude_object_list: List of objects to exclude from all plots
        filename: If provided, saves the plot to this file instead of displaying
        axs: Optional matplotlib axes to plot on. If None, creates new figure
        plot_legend: Whether to add a legend showing object names/types
        exclude_xy_plane_object_list: Objects to exclude from XY plane plot
        exclude_yz_plane_object_list: Objects to exclude from YZ plane plot
        exclude_xz_plane_object_list: Objects to exclude from XZ plane plot

    Returns:
        matplotlib.figure.Figure: The generated figure object

    Note:
        The plots show object positions in micrometers, converting from simulation units.
        PML objects are automatically excluded from their respective boundary planes.
    """
    # add boundaries to exclude lists
    for o in objects.objects:
        if not isinstance(o, (PerfectlyMatchedLayer, PeriodicBoundary)):
            continue
        if o.axis == 0:
            exclude_yz_plane_object_list.append(o)
        elif o.axis == 1:
            exclude_xz_plane_object_list.append(o)
        elif o.axis == 2:
            exclude_xy_plane_object_list.append(o)
    # add volume to exclude list
    volume = objects.volume
    exclude_object_list.append(volume)

    object_list = [o for o in objects.objects if o not in exclude_object_list]
    if axs is None:
        fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    else:
        fig = None
    assert axs is not None
    resolution = config.resolution / 1.0e-6  # Convert to µm

    # get a color map
    colored_objects: list[SimulationObject] = [o for o in object_list if o.color is not None]

    if plot_legend:
        handles = []
        used_lists = []
        for o in colored_objects:
            print_single = False
            for o2 in colored_objects:
                if o.__class__ == o2.__class__:
                    if o.color != o2.color:
                        print_single = True
                    if not o.name.startswith("Object"):
                        print_single = True
            label = o.__class__.__name__ if o.name.startswith("Object") else o.name
            patch = Patch(color=o.color, label=label)
            if print_single:
                handles.append(patch)
            else:
                if o.__class__.__name__ not in used_lists:
                    used_lists.append(o.__class__.__name__)
                    handles.append(patch)

        plt.legend(
            handles=handles,
            loc="upper right",
            bbox_to_anchor=(1.75, 0.75),
            frameon=False,
        )

    # Plot each object on the corresponding subplot
    for obj in colored_objects:
        slices = obj.grid_slice_tuple
        color = obj.color

        # XY plane at Z center
        if exclude_xy_plane_object_list is None or obj not in exclude_xy_plane_object_list:
            axs[0].add_patch(
                Rectangle(
                    (slices[0][0] * resolution, slices[1][0] * resolution),
                    (slices[0][1] - slices[0][0]) * resolution,
                    (slices[1][1] - slices[1][0]) * resolution,
                    color=color,
                    alpha=0.5,
                    linestyle="--" if isinstance(obj, PeriodicBoundary) else "-",
                )
            )

        # XZ plane at Y center
        if exclude_xz_plane_object_list is None or obj not in exclude_xz_plane_object_list:
            axs[1].add_patch(
                Rectangle(
                    (slices[0][0] * resolution, slices[2][0] * resolution),
                    (slices[0][1] - slices[0][0]) * resolution,
                    (slices[2][1] - slices[2][0]) * resolution,
                    color=color,
                    alpha=0.5,
                    linestyle="--" if isinstance(obj, PeriodicBoundary) else "-",
                )
            )

        # YZ plane at X center
        if exclude_yz_plane_object_list is None or obj not in exclude_yz_plane_object_list:
            axs[2].add_patch(
                Rectangle(
                    (slices[1][0] * resolution, slices[2][0] * resolution),
                    (slices[1][1] - slices[1][0]) * resolution,
                    (slices[2][1] - slices[2][0]) * resolution,
                    color=color,
                    alpha=0.5,
                    linestyle="--" if isinstance(obj, PeriodicBoundary) else "-",
                )
            )

    # Set labels and titles
    axs[0].set_xlabel("x (µm)")
    axs[0].set_ylabel("y (µm)")
    axs[0].set_title("XY plane")
    axs[0].set_xlim([0, volume.grid_shape[0] * resolution])
    axs[0].set_ylim([0, volume.grid_shape[1] * resolution])

    axs[1].set_xlabel("x (µm)")
    axs[1].set_ylabel("z (µm)")
    axs[1].set_title("XZ plane")
    axs[1].set_xlim([0, volume.grid_shape[0] * resolution])
    axs[1].set_ylim([0, volume.grid_shape[2] * resolution])

    axs[2].set_xlabel("y (µm)")
    axs[2].set_ylabel("z (µm)")
    axs[2].set_title("YZ plane")
    axs[2].set_xlim([0, volume.grid_shape[1] * resolution])
    axs[2].set_ylim([0, volume.grid_shape[2] * resolution])

    # Adjust the plots for better visualization
    for ax in axs:
        ax.set_aspect("equal")
        ax.grid(True)

    if filename is not None:
        plt.savefig(filename, bbox_inches="tight", dpi=300)
        plt.close()
    return plt.gcf() if fig is None else fig

Plots an image of the simulation scene using matplotlib. This is very helpful for veryfying the correct positions of all objects in the simulation scene.