Logger for managing experiment outputs and visualization.
Handles experiment logging, metrics tracking, and visualization of simulation results.
Creates a working directory structure, initializes logging, and provides methods for
saving figures, metrics, and device parameters.
Parameters: |
-
experiment_name
(str )
–
Name of the experiment. This is the naming of the parent directory where the experiment
will be saved.
-
name
(str | None , default:
None
)
–
Optional specific name for the working directory. If None, uses timestamp.
|
Source code in src/fdtdx/utils/logger.py
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131 | def __init__(self, experiment_name: str, name: str | None = None):
sns.set_theme(context="paper", style="white", palette="colorblind")
self.cwd = init_working_directory(experiment_name, wd_name=name)
self.console = Console()
self.progress = Progress(
SpinnerColumn(),
*Progress.get_default_columns(),
TimeElapsedColumn(),
console=self.console,
).__enter__()
atexit.register(self.progress.stop)
logger.remove()
logger.add(
self.console.print,
level="TRACE",
format=_log_formatter,
colorize=True,
)
logger.add(
self.cwd / "logs.log",
level="TRACE",
format="{time:DD.MM.YYYY HH:mm:ss:ssss} | {level} - {message}",
)
logger.info(f"Starting experiment {experiment_name} in {self.cwd}")
snapshot_python_files(self.cwd / "code")
self.fieldnames = None
self.writer = None
self.csvfile = open(self.cwd / "metrics.csv", "w", newline="")
self.last_indices: dict[str, jax.Array | None] = defaultdict(lambda: None)
atexit.register(self.csvfile.close)
|
params_dir
property
Directory for storing parameter files.
Returns: |
-
Path ( Path
) –
Directory for parameter file outputs
|
stl_dir
property
Directory for storing STL files.
Returns: |
-
Path ( Path
) –
Directory for STL file outputs
|
log_detectors
log_detectors(
iter_idx: int,
objects: ObjectContainer,
detector_states: dict[str, DetectorState],
exclude: Sequence[str] = [],
)
Log detector states and generate visualization plots.
Creates plots for each detector's state and saves them to the detector's output directory.
Handles both figure outputs and other detector-specific file formats.
Parameters: |
-
iter_idx
(int )
–
-
objects
(ObjectContainer )
–
Container with simulation objects
-
detector_states
(dict[str, DetectorState] )
–
Dictionary mapping detector names to their states
-
exclude
(Sequence[str] , default:
[]
)
–
List of detector names to exclude from logging
|
Source code in src/fdtdx/utils/logger.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246 | def log_detectors(
self,
iter_idx: int,
objects: ObjectContainer,
detector_states: dict[str, DetectorState],
exclude: Sequence[str] = [],
):
"""Log detector states and generate visualization plots.
Creates plots for each detector's state and saves them to the detector's output directory.
Handles both figure outputs and other detector-specific file formats.
Args:
iter_idx (int): Current iteration index
objects (ObjectContainer): Container with simulation objects
detector_states (dict[str, DetectorState]): Dictionary mapping detector names to their states
exclude (Sequence[str], optional): List of detector names to exclude from logging
"""
for detector in [d for d in objects.detectors if d.name not in exclude]:
cur_state = jax.device_get(detector_states[detector.name])
cur_state = cast_floating_to_numpy(cur_state, float)
if not detector.plot:
continue
figure_dict = detector.draw_plot(
state=cur_state,
progress=self.progress,
)
detector_dir = self.cwd / "detectors" / detector.name
detector_dir.mkdir(parents=True, exist_ok=True)
for k, v in figure_dict.items():
if isinstance(v, Figure):
self.savefig(
detector_dir,
f"{detector.name}_{k}_{iter_idx}.png",
v,
dpi=detector.plot_dpi, # type: ignore
)
elif isinstance(v, str):
shutil.copy(
v,
detector_dir / f"{detector.name}_{k}_{iter_idx}{Path(v).suffix}",
)
else:
raise Exception(f"invalid detector output for plotting: {k}, {v}")
|
log_params
log_params(
iter_idx: int,
params: ParameterContainer,
objects: ObjectContainer,
export_figure: bool = False,
export_stl: bool = False,
export_background_stl: bool = False,
**transformation_kwargs
) -> int
Log parameter states and export device visualizations.
Saves device parameters and optionally exports visualizations as figures or STL files.
Tracks changes in device voxels between iterations.
Parameters: |
-
iter_idx
(int )
–
-
params
(ParameterContainer )
–
Container with device parameters
-
objects
(ObjectContainer )
–
Container with simulation objects
-
export_figure
(bool , default:
False
)
–
Whether to export index matrix figures
-
export_stl
(bool , default:
False
)
–
Whether to export device geometry as STL
-
export_background_stl
(bool , default:
False
)
–
Whether to export air regions as STL
-
**transformation_kwargs
–
keyword arguments passed to the parameter transformation
|
Returns: |
-
int ( int
) –
Number of voxels that changed since last iteration
|
Source code in src/fdtdx/utils/logger.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330 | def log_params(
self,
iter_idx: int,
params: ParameterContainer,
objects: ObjectContainer,
export_figure: bool = False,
export_stl: bool = False,
export_background_stl: bool = False,
**transformation_kwargs,
) -> int:
"""Log parameter states and export device visualizations.
Saves device parameters and optionally exports visualizations as figures or STL files.
Tracks changes in device voxels between iterations.
Args:
iter_idx (int): Current iteration index
params (ParameterContainer): Container with device parameters
objects (ObjectContainer): Container with simulation objects
export_figure (bool, optional): Whether to export index matrix figures
export_stl (bool, optional): Whether to export device geometry as STL
export_background_stl (bool, optional): Whether to export air regions as STL
**transformation_kwargs: keyword arguments passed to the parameter transformation
Returns:
int: Number of voxels that changed since last iteration
"""
changed_voxels = 0
for device in objects.devices:
device_params = params[device.name]
indices = device(device_params, **transformation_kwargs)
# raw parameters and indices
if isinstance(device_params, dict):
for k, v in device_params.items():
jnp.save(self.params_dir / f"params_{iter_idx}_{device.name}_{k}.npy", v)
else:
jnp.save(self.params_dir / f"params_{iter_idx}_{device.name}.npy", device_params)
jnp.save(self.params_dir / f"matrix_{iter_idx}_{device.name}.npy", indices)
has_previous = self.last_indices[device.name] is not None
cur_changed_voxels = 0
if has_previous:
last_device_indices = self.last_indices[device.name]
cur_changed_voxels = int(jnp.sum(indices != last_device_indices))
changed_voxels += cur_changed_voxels
self.last_indices[device.name] = indices
if cur_changed_voxels == 0 and has_previous:
continue
if export_stl:
background_name = get_background_material_name(device.materials)
ordered_name_list = compute_ordered_names(device.materials)
background_idx = ordered_name_list.index(background_name)
for idx in range(len(device.materials)):
if idx == background_idx and not export_background_stl:
continue
name = ordered_name_list[idx]
export_stl_fn(
matrix=np.round(indices) == idx,
stl_filename=self.stl_dir / f"matrix_{iter_idx}_{device.name}_{name}.stl",
voxel_grid_size=device.single_voxel_grid_shape,
)
if len(device.materials) > 2:
export_stl_fn(
matrix=np.round(indices) != background_idx,
stl_filename=self.stl_dir / f"matrix_{iter_idx}_{device.name}_non_air.stl",
voxel_grid_size=device.single_voxel_grid_shape,
)
# image of indices
if export_figure:
fig = device_matrix_index_figure(
device_matrix_indices=indices,
material=device.materials,
parameter_type=device.output_type,
)
self.savefig(
self.cwd / "device",
f"matrix_indices_{iter_idx}_{device.name}.png",
fig,
)
return changed_voxels
|
savefig
savefig(
directory: Path,
filename: str,
fig: Figure,
dpi: int = 300,
)
Save a matplotlib figure to file.
Creates a figures subdirectory if needed and saves the figure with specified settings.
Parameters: |
-
directory
(Path )
–
Base directory to save in
-
filename
(str )
–
-
fig
(Figure )
–
Matplotlib figure to save
-
dpi
(int , default:
300
)
–
Resolution in dots per inch. Defaults to 300.
|
Source code in src/fdtdx/utils/logger.py
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169 | def savefig(self, directory: Path, filename: str, fig: Figure, dpi: int = 300):
"""Save a matplotlib figure to file.
Creates a figures subdirectory if needed and saves the figure with specified settings.
Args:
directory (Path): Base directory to save in
filename (str): Name for the figure file
fig (Figure): Matplotlib figure to save
dpi (int, optional): Resolution in dots per inch. Defaults to 300.
"""
figure_directory = directory / "figures"
figure_directory.mkdir(parents=True, exist_ok=True)
fig.savefig(directory / "figures" / filename, dpi=dpi, bbox_inches="tight")
plt.close(fig)
|
write
write(stats: dict, do_print: bool = True)
Write statistics to CSV file and optionally print them.
Records metrics in a CSV file and optionally displays them in a formatted table.
Automatically initializes CSV headers on first write.
Parameters: |
-
stats
(dict )
–
Dictionary of statistics to record
-
do_print
(bool , default:
True
)
–
Whether to print stats to console. Defaults to true.
|
Source code in src/fdtdx/utils/logger.py
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198 | def write(self, stats: dict, do_print: bool = True):
"""Write statistics to CSV file and optionally print them.
Records metrics in a CSV file and optionally displays them in a formatted table.
Automatically initializes CSV headers on first write.
Args:
stats (dict): Dictionary of statistics to record
do_print (bool, optional): Whether to print stats to console. Defaults to true.
"""
stats = {
k: v.item() if isinstance(v, jax.Array) else v
for k, v in stats.items()
if isinstance(v, (int, float)) or (isinstance(v, jax.Array) and v.size == 1)
}
if self.fieldnames is None:
self.fieldnames = list(stats.keys())
self.writer = csv.DictWriter(self.csvfile, fieldnames=self.fieldnames)
self.writer.writeheader()
assert self.writer is not None
self.writer.writerow(stats)
self.csvfile.flush()
if do_print:
table = Table(box=None)
for k, v in stats.items():
table.add_column(k)
table.add_column(str(v))
self.console.print(table)
|