fdtdx.apply_params

apply_params(
    arrays: ArrayContainer,
    objects: ObjectContainer,
    params: ParameterContainer,
    key: Array,
    **transform_kwargs
) -> tuple[ArrayContainer, ObjectContainer, dict[str, Any]]

Applies parameters to devices and updates source states.

Parameters:
  • arrays (ArrayContainer) –

    Container with field arrays

  • objects (ObjectContainer) –

    Container with simulation objects

  • params (ParameterContainer) –

    Container with device parameters

  • key (Array) –

    JAX random key for source updates

  • **transform_kwargs

    Keyword arguments passed to the parameter transformation.

Returns: tuple[ArrayContainer, ObjectContainer, dict[str, Any]]: A tuple containing: - Updated ArrayContainer with applied device parameters - Updated ObjectContainer with new source states - Dictionary with parameter application info

Source code in src/fdtdx/fdtd/initialization.py
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def apply_params(
    arrays: ArrayContainer,
    objects: ObjectContainer,
    params: ParameterContainer,
    key: jax.Array,
    **transform_kwargs,
) -> tuple[ArrayContainer, ObjectContainer, dict[str, Any]]:
    """Applies parameters to devices and updates source states.

    Args:
        arrays (ArrayContainer): Container with field arrays
        objects (ObjectContainer): Container with simulation objects
        params (ParameterContainer): Container with device parameters
        key (jax.Array): JAX random key for source updates
        **transform_kwargs: Keyword arguments passed to the parameter transformation.
    Returns:
        tuple[ArrayContainer, ObjectContainer, dict[str, Any]]: A tuple containing:
            - Updated ArrayContainer with applied device parameters
            - Updated ObjectContainer with new source states
            - Dictionary with parameter application info
    """
    info = {}
    # apply parameter to devices
    for device in objects.devices:
        cur_material_indices = device(params[device.name], expand_to_sim_grid=True, **transform_kwargs)
        allowed_perm_list = compute_allowed_permittivities(device.materials)
        if device.output_type == ParameterType.CONTINUOUS:
            first_term = (1 - cur_material_indices) * (1 / allowed_perm_list[0])
            second_term = cur_material_indices * (1 / allowed_perm_list[1])
            new_perm_slice = first_term + second_term
        else:
            new_perm_slice = jnp.asarray(allowed_perm_list)[cur_material_indices.astype(jnp.int32)]
            new_perm_slice = straight_through_estimator(cur_material_indices, new_perm_slice)
            new_perm_slice = 1 / new_perm_slice
        new_perm = arrays.inv_permittivities.at[*device.grid_slice].set(new_perm_slice)
        arrays = arrays.at["inv_permittivities"].set(new_perm)

    # apply random key to sources
    new_objects = []
    for obj in objects.object_list:
        key, subkey = jax.random.split(key)
        new_obj = obj.apply(
            key=subkey,
            inv_permittivities=jax.lax.stop_gradient(arrays.inv_permittivities),
            inv_permeabilities=jax.lax.stop_gradient(arrays.inv_permeabilities),
        )
        new_objects.append(new_obj)
    new_objects = ObjectContainer(
        object_list=new_objects,
        volume_idx=objects.volume_idx,
    )

    return arrays, new_objects, info