apply_params(
arrays: ArrayContainer,
objects: ObjectContainer,
params: ParameterContainer,
key: Array,
**transform_kwargs
) -> tuple[ArrayContainer, ObjectContainer, dict[str, Any]]
Applies parameters to devices and updates source states.
Parameters: |
-
arrays
(ArrayContainer )
–
Container with field arrays
-
objects
(ObjectContainer )
–
Container with simulation objects
-
params
(ParameterContainer )
–
Container with device parameters
-
key
(Array )
–
JAX random key for source updates
-
**transform_kwargs
–
Keyword arguments passed to the parameter transformation.
|
Returns:
tuple[ArrayContainer, ObjectContainer, dict[str, Any]]: A tuple containing:
- Updated ArrayContainer with applied device parameters
- Updated ObjectContainer with new source states
- Dictionary with parameter application info
Source code in src/fdtdx/fdtd/initialization.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175 | def apply_params(
arrays: ArrayContainer,
objects: ObjectContainer,
params: ParameterContainer,
key: jax.Array,
**transform_kwargs,
) -> tuple[ArrayContainer, ObjectContainer, dict[str, Any]]:
"""Applies parameters to devices and updates source states.
Args:
arrays (ArrayContainer): Container with field arrays
objects (ObjectContainer): Container with simulation objects
params (ParameterContainer): Container with device parameters
key (jax.Array): JAX random key for source updates
**transform_kwargs: Keyword arguments passed to the parameter transformation.
Returns:
tuple[ArrayContainer, ObjectContainer, dict[str, Any]]: A tuple containing:
- Updated ArrayContainer with applied device parameters
- Updated ObjectContainer with new source states
- Dictionary with parameter application info
"""
info = {}
# apply parameter to devices
for device in objects.devices:
cur_material_indices = device(params[device.name], expand_to_sim_grid=True, **transform_kwargs)
allowed_perm_list = compute_allowed_permittivities(device.materials)
if device.output_type == ParameterType.CONTINUOUS:
first_term = (1 - cur_material_indices) * (1 / allowed_perm_list[0])
second_term = cur_material_indices * (1 / allowed_perm_list[1])
new_perm_slice = first_term + second_term
else:
new_perm_slice = jnp.asarray(allowed_perm_list)[cur_material_indices.astype(jnp.int32)]
new_perm_slice = straight_through_estimator(cur_material_indices, new_perm_slice)
new_perm_slice = 1 / new_perm_slice
new_perm = arrays.inv_permittivities.at[*device.grid_slice].set(new_perm_slice)
arrays = arrays.at["inv_permittivities"].set(new_perm)
# apply random key to sources
new_objects = []
for obj in objects.object_list:
key, subkey = jax.random.split(key)
new_obj = obj.apply(
key=subkey,
inv_permittivities=jax.lax.stop_gradient(arrays.inv_permittivities),
inv_permeabilities=jax.lax.stop_gradient(arrays.inv_permeabilities),
)
new_objects.append(new_obj)
new_objects = ObjectContainer(
object_list=new_objects,
volume_idx=objects.volume_idx,
)
return arrays, new_objects, info
|