metric_efficiency(
detector_states: dict[str, dict[str, Array]],
in_names: Sequence[str],
out_names: Sequence[str],
metric_name: str,
) -> tuple[jax.Array, dict[str, Any]]
Calculate efficiency metrics between input and output detectors.
Computes efficiency ratios between input and output detectors by comparing their
metric values (e.g. energy, power). For each input-output detector pair, calculates
the ratio of output/input metric values.
Parameters: |
-
detector_states
(dict[str, dict[str, Array]] )
–
Dictionary mapping detector names to their state dictionaries,
which contain metric values as JAX arrays
-
in_names
(Sequence[str] )
–
Names of input detectors to use as reference
-
out_names
(Sequence[str] )
–
Names of output detectors to compare against inputs
-
metric_name
(str )
–
Name of the metric to compare between detectors (e.g. "energy")
|
Returns: |
-
tuple[Array, dict[str, Any]]
–
tuple[jax.Array, dict[str, Any]]: tuple containing:
- jax.Array: Mean efficiency across all input-output pairs
- dict: Additional info including individual metric values and efficiencies
with keys like:
"{detector}{metric}" for raw metric values
"{out}{by}_{in}_efficiency" for individual efficiency ratios
|
Source code in src/fdtdx/core/physics/losses.py
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45 | def metric_efficiency(
detector_states: dict[str, dict[str, jax.Array]],
in_names: Sequence[str],
out_names: Sequence[str],
metric_name: str,
) -> tuple[jax.Array, dict[str, Any]]:
"""Calculate efficiency metrics between input and output detectors.
Computes efficiency ratios between input and output detectors by comparing their
metric values (e.g. energy, power). For each input-output detector pair, calculates
the ratio of output/input metric values.
Args:
detector_states (dict[str, dict[str, jax.Array]]): Dictionary mapping detector names to their state dictionaries,
which contain metric values as JAX arrays
in_names (Sequence[str]): Names of input detectors to use as reference
out_names (Sequence[str]): Names of output detectors to compare against inputs
metric_name (str): Name of the metric to compare between detectors (e.g. "energy")
Returns:
tuple[jax.Array, dict[str, Any]]: tuple containing:
- jax.Array: Mean efficiency across all input-output pairs
- dict: Additional info including individual metric values and efficiencies
with keys like:
"{detector}_{metric}" for raw metric values
"{out}_{by}_{in}_efficiency" for individual efficiency ratios
"""
efficiencies, info = [], {}
for in_name in in_names:
in_value = jax.lax.stop_gradient(detector_states[in_name][metric_name].mean())
info[f"{in_name}_{metric_name}"] = in_value
for out_name in out_names:
out_value = detector_states[out_name][metric_name].mean()
eff = jnp.where(in_value == 0, 0, out_value / in_value)
efficiencies.append(eff)
info[f"{out_name}_{metric_name}"] = out_value
info[f"{out_name}_by_{in_name}_efficiency"] = eff
objective = jnp.mean(jnp.asarray(efficiencies))
return objective, info
|