Sets an attribute of this class. In contrast to the classical .at[].set(), this method updates the class
attribute directly and does not only operate on jax pytree leaf nodes. Instead, replaces the full attribute
with the new value.
The attribute can either be the attribute name of this class, or for nested classes it can also be the
attribute name of a class, which itself is an attribute of this class. The syntax for this operation could
look like this: "a->b->[0]->['name']". Here, the current class has an attribute a, which has an attribute b,
which is a list, which we index at index 0, which is an element of type dictionary, which we index using
the dictionary key 'name'.
Note that dictionary keys cannot contain square brackets or single quotes (even if they are escaped).
Parameters: |
-
attr_name
(str )
–
-
val
(Any )
–
Value to set the attribute to
-
create_new_ok
(bool , default:
False
)
–
If false (default), throw an error if the attribute does not exist.
If true, creates a new attribute if the attribute name does not exist yet.
|
Returns: |
-
Self ( Self
) –
Updated instance with new attribute value
|
Source code in src/fdtdx/core/jax/pytrees.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219 | def aset(
self,
attr_name: str,
val: Any,
create_new_ok: bool = False,
) -> Self:
"""Sets an attribute of this class. In contrast to the classical .at[].set(), this method updates the class
attribute directly and does not only operate on jax pytree leaf nodes. Instead, replaces the full attribute
with the new value.
The attribute can either be the attribute name of this class, or for nested classes it can also be the
attribute name of a class, which itself is an attribute of this class. The syntax for this operation could
look like this: "a->b->[0]->['name']". Here, the current class has an attribute a, which has an attribute b,
which is a list, which we index at index 0, which is an element of type dictionary, which we index using
the dictionary key 'name'.
Note that dictionary keys cannot contain square brackets or single quotes (even if they are escaped).
Args:
attr_name (str): Name of attribute to set
val (Any): Value to set the attribute to
create_new_ok (bool, optional): If false (default), throw an error if the attribute does not exist.
If true, creates a new attribute if the attribute name does not exist yet.
Returns:
Self: Updated instance with new attribute value
"""
# parse operations
ops = self._parse_operations(attr_name)
# find final attribute and save intermediate attributes
attr_list = [self]
current_parent = self
for idx, (op, op_type) in enumerate(ops):
if op_type == "attribute":
if not safe_hasattr(current_parent, op):
if idx != len(ops) - 1 or not create_new_ok:
raise Exception(f"Attribute: {op} does not exist for {current_parent.__class__}")
current_parent = None
else:
current_parent = getattr(current_parent, op)
elif op_type == "index":
if "__getitem__" not in dir(current_parent):
raise Exception(f"{current_parent.__class__} does not implement __getitem__")
current_parent = current_parent[int(op)] # type: ignore
elif op_type == "key":
if "__getitem__" not in dir(current_parent):
raise Exception(f"{current_parent.__class__} does not implement __getitem__")
if op not in current_parent: # type: ignore
if idx != len(ops) - 1 or not create_new_ok:
raise Exception(f"Key: {op} does not exist for {current_parent}")
current_parent = None
else:
current_parent = current_parent[op] # type: ignore
else:
raise Exception(f"Invalid operation type: {op_type}. This is an internal bug!")
if idx != len(ops) - 1:
attr_list.append(current_parent) # type: ignore
# from bottom-up set attributes and update
cur_attr = val
for idx in list(range(len(attr_list)))[::-1]:
op, op_type = ops[idx]
current_parent = attr_list[idx]
if op_type == "attribute":
if not isinstance(current_parent, TreeClass):
raise Exception(f"Can only set attribute on ExtendedTreeClass, but got {current_parent.__class__}")
_, cur_attr = current_parent.at["_aset"](op, cur_attr)
elif op_type == "index":
if "__setitem__" not in dir(current_parent):
raise Exception(
f"Can only update by index if __setitem__ is implemented, but got {current_parent.__class__}"
)
cpy = current_parent.copy() # type: ignore
cpy[int(op)] = cur_attr # type: ignore
cur_attr = cpy
elif op_type == "key":
if "__setitem__" not in dir(current_parent):
raise Exception(
f"Can only update by index if __setitem__ is implemented, but got {current_parent.__class__}"
)
cpy = current_parent.copy() # type: ignore
cpy[op] = cur_attr # type: ignore
cur_attr = cpy
else:
raise Exception(f"Invalid operation type: {op_type}. This is an internal bug!")
assert cur_attr.__class__ == self.__class__
return cur_attr
|