fdtdx.TreeClass

Bases: TreeClass

Extended tree class with improved attribute setting functionality.

Extends TreeClass to provide more flexible attribute setting capabilities, particularly for handling non-recursive attribute updates.

at property

at: ExtendedTreeClassIndexer

Gets the extended indexer for this tree.

Returns:
  • ExtendedTreeClassIndexer( ExtendedTreeClassIndexer ) –

    Indexer that preserves type information

aset

aset(
    attr_name: str, val: Any, create_new_ok: bool = False
) -> Self

Sets an attribute of this class. In contrast to the classical .at[].set(), this method updates the class attribute directly and does not only operate on jax pytree leaf nodes. Instead, replaces the full attribute with the new value.

The attribute can either be the attribute name of this class, or for nested classes it can also be the attribute name of a class, which itself is an attribute of this class. The syntax for this operation could look like this: "a->b->[0]->['name']". Here, the current class has an attribute a, which has an attribute b, which is a list, which we index at index 0, which is an element of type dictionary, which we index using the dictionary key 'name'.

Note that dictionary keys cannot contain square brackets or single quotes (even if they are escaped).

Parameters:
  • attr_name (str) –

    Name of attribute to set

  • val (Any) –

    Value to set the attribute to

  • create_new_ok (bool, default: False ) –

    If false (default), throw an error if the attribute does not exist. If true, creates a new attribute if the attribute name does not exist yet.

Returns:
  • Self( Self ) –

    Updated instance with new attribute value

Source code in src/fdtdx/core/jax/pytrees.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
def aset(
    self,
    attr_name: str,
    val: Any,
    create_new_ok: bool = False,
) -> Self:
    """Sets an attribute of this class. In contrast to the classical .at[].set(), this method updates the class
    attribute directly and does not only operate on jax pytree leaf nodes. Instead, replaces the full attribute
    with the new value.

    The attribute can either be the attribute name of this class, or for nested classes it can also be the
    attribute name of a class, which itself is an attribute of this class. The syntax for this operation could
    look like this: "a->b->[0]->['name']". Here, the current class has an attribute a, which has an attribute b,
    which is a list, which we index at index 0, which is an element of type dictionary, which we index using
    the dictionary key 'name'.

    Note that dictionary keys cannot contain square brackets or single quotes (even if they are escaped).

    Args:
        attr_name (str): Name of attribute to set
        val (Any): Value to set the attribute to
        create_new_ok (bool, optional): If false (default), throw an error if the attribute does not exist.
            If true, creates a new attribute if the attribute name does not exist yet.

    Returns:
        Self: Updated instance with new attribute value
    """
    # parse operations
    ops = self._parse_operations(attr_name)

    # find final attribute and save intermediate attributes
    attr_list = [self]
    current_parent = self
    for idx, (op, op_type) in enumerate(ops):
        if op_type == "attribute":
            if not safe_hasattr(current_parent, op):
                if idx != len(ops) - 1 or not create_new_ok:
                    raise Exception(f"Attribute: {op} does not exist for {current_parent.__class__}")
                current_parent = None
            else:
                current_parent = getattr(current_parent, op)
        elif op_type == "index":
            if "__getitem__" not in dir(current_parent):
                raise Exception(f"{current_parent.__class__} does not implement __getitem__")
            current_parent = current_parent[int(op)]  # type: ignore
        elif op_type == "key":
            if "__getitem__" not in dir(current_parent):
                raise Exception(f"{current_parent.__class__} does not implement __getitem__")
            if op not in current_parent:  # type: ignore
                if idx != len(ops) - 1 or not create_new_ok:
                    raise Exception(f"Key: {op} does not exist for {current_parent}")
                current_parent = None
            else:
                current_parent = current_parent[op]  # type: ignore
        else:
            raise Exception(f"Invalid operation type: {op_type}. This is an internal bug!")
        if idx != len(ops) - 1:
            attr_list.append(current_parent)  # type: ignore

    # from bottom-up set attributes and update
    cur_attr = val
    for idx in list(range(len(attr_list)))[::-1]:
        op, op_type = ops[idx]
        current_parent = attr_list[idx]
        if op_type == "attribute":
            if not isinstance(current_parent, TreeClass):
                raise Exception(f"Can only set attribute on ExtendedTreeClass, but got {current_parent.__class__}")
            _, cur_attr = current_parent.at["_aset"](op, cur_attr)
        elif op_type == "index":
            if "__setitem__" not in dir(current_parent):
                raise Exception(
                    f"Can only update by index if __setitem__ is implemented, but got {current_parent.__class__}"
                )
            cpy = current_parent.copy()  # type: ignore
            cpy[int(op)] = cur_attr  # type: ignore
            cur_attr = cpy
        elif op_type == "key":
            if "__setitem__" not in dir(current_parent):
                raise Exception(
                    f"Can only update by index if __setitem__ is implemented, but got {current_parent.__class__}"
                )
            cpy = current_parent.copy()  # type: ignore
            cpy[op] = cur_attr  # type: ignore
            cur_attr = cpy
        else:
            raise Exception(f"Invalid operation type: {op_type}. This is an internal bug!")

    assert cur_attr.__class__ == self.__class__
    return cur_attr