Skip to content

PyTree Utilities

Core Utilities

fdtdx.core.jax.pytrees.extended_autoinit(klass)

Wrapper around tc.autoinit that preserves parameter requirement information

Source code in src/fdtdx/core/jax/pytrees.py
@dataclass_transform(field_specifiers=(Field, tc_field, frozen_field, frozen_private_field))
def extended_autoinit(klass: type[T]) -> type[T]:
    """Wrapper around tc.autoinit that preserves parameter requirement information"""
    return (
        klass
        # if the class already has a user-defined __init__ method
        # then return the class as is without any modification
        if "__init__" in vars(klass)
        # first convert the current class hints to fields
        # then build the __init__ method from the fields of the current class
        # and any base classes that are decorated with `autoinit`
        else build_init_method(convert_hints_to_fields(klass))
    )

fdtdx.core.jax.pytrees.field(*, default=NULL, init=True, repr=True, kind='POS_OR_KW', metadata=None, on_setattr=(), on_getattr=(), alias=None)

field(
    *,
    default: T,
    init: bool = True,
    repr: bool = True,
    kind: ArgKindType = "POS_OR_KW",
    metadata: dict[str, Any] | None = None,
    on_setattr: Sequence[Any] = (),
    on_getattr: Sequence[Any] = (),
    alias: str | None = None
) -> T
field(
    *,
    init: bool = True,
    repr: bool = True,
    kind: ArgKindType = "POS_OR_KW",
    metadata: dict[str, Any] | None = None,
    on_setattr: Sequence[Any] = (),
    on_getattr: Sequence[Any] = (),
    alias: str | None = None
) -> Any

Creates a field that automatically freezes on set and unfreezes on get.

This field behaves like a regular pytreeclass field but ensures values are frozen when stored and unfrozen when accessed.

Parameters:

Name Type Description Default
default Any

The default value for the field

NULL
init bool

Whether to include the field in init

True
repr bool

Whether to include the field in repr

True
kind ArgKindType

The argument kind (POS_ONLY, POS_OR_KW, etc.)

'POS_OR_KW'
metadata dict[str, Any] | None

Additional metadata for the field

None
on_setattr Sequence[Any]

Additional setattr callbacks (applied after freezing)

()
on_getattr Sequence[Any]

Additional getattr callbacks (applied after unfreezing)

()
alias str | None

Alternative name for the field in init

None

Returns:

Type Description
Any

A Field instance configured with freeze/unfreeze behavior

Source code in src/fdtdx/core/jax/pytrees.py
def field(
    *,
    default: Any = NULL,
    init: bool = True,
    repr: bool = True,
    kind: ArgKindType = "POS_OR_KW",
    metadata: dict[str, Any] | None = None,
    on_setattr: Sequence[Any] = (),
    on_getattr: Sequence[Any] = (),
    alias: str | None = None,
) -> Any:
    """Creates a field that automatically freezes on set and unfreezes on get.

    This field behaves like a regular pytreeclass field but ensures values are
    frozen when stored and unfrozen when accessed.

    Args:
        default: The default value for the field
        init: Whether to include the field in __init__
        repr: Whether to include the field in __repr__
        kind: The argument kind (POS_ONLY, POS_OR_KW, etc.)
        metadata: Additional metadata for the field
        on_setattr: Additional setattr callbacks (applied after freezing)
        on_getattr: Additional getattr callbacks (applied after unfreezing)
        alias: Alternative name for the field in __init__

    Returns:
        A Field instance configured with freeze/unfreeze behavior
    """
    return tc_field(
        default=default,
        init=init,
        repr=repr,
        kind=kind,
        metadata=metadata,
        on_setattr=on_setattr,
        on_getattr=on_getattr,
        alias=alias,
    )

fdtdx.core.jax.pytrees.frozen_field(*, default=NULL, init=True, repr=True, kind='POS_OR_KW', metadata=None, on_setattr=(), on_getattr=(), alias=None)

frozen_field(
    *,
    default: T,
    init: bool = True,
    repr: bool = True,
    kind: ArgKindType = "POS_OR_KW",
    metadata: dict[str, Any] | None = None,
    on_setattr: Sequence[Any] = (),
    on_getattr: Sequence[Any] = (),
    alias: str | None = None
) -> T
frozen_field(
    *,
    init: bool = True,
    repr: bool = True,
    kind: ArgKindType = "POS_OR_KW",
    metadata: dict[str, Any] | None = None,
    on_setattr: Sequence[Any] = (),
    on_getattr: Sequence[Any] = (),
    alias: str | None = None
) -> Any

Creates a field that automatically freezes on set and unfreezes on get.

This field behaves like a regular pytreeclass field but ensures values are frozen when stored and unfrozen when accessed.

Parameters:

Name Type Description Default
default Any

The default value for the field

NULL
init bool

Whether to include the field in init

True
repr bool

Whether to include the field in repr

True
kind ArgKindType

The argument kind (POS_ONLY, POS_OR_KW, etc.)

'POS_OR_KW'
metadata dict[str, Any] | None

Additional metadata for the field

None
on_setattr Sequence[Any]

Additional setattr callbacks (applied after freezing)

()
on_getattr Sequence[Any]

Additional getattr callbacks (applied after unfreezing)

()
alias str | None

Alternative name for the field in init

None

Returns:

Type Description
Any

A Field instance configured with freeze/unfreeze behavior

Source code in src/fdtdx/core/jax/pytrees.py
def frozen_field(
    *,
    default: Any = NULL,
    init: bool = True,
    repr: bool = True,
    kind: ArgKindType = "POS_OR_KW",
    metadata: dict[str, Any] | None = None,
    on_setattr: Sequence[Any] = (),
    on_getattr: Sequence[Any] = (),
    alias: str | None = None,
) -> Any:
    """Creates a field that automatically freezes on set and unfreezes on get.

    This field behaves like a regular pytreeclass field but ensures values are
    frozen when stored and unfrozen when accessed.

    Args:
        default: The default value for the field
        init: Whether to include the field in __init__
        repr: Whether to include the field in __repr__
        kind: The argument kind (POS_ONLY, POS_OR_KW, etc.)
        metadata: Additional metadata for the field
        on_setattr: Additional setattr callbacks (applied after freezing)
        on_getattr: Additional getattr callbacks (applied after unfreezing)
        alias: Alternative name for the field in __init__

    Returns:
        A Field instance configured with freeze/unfreeze behavior
    """
    return tc_field(
        default=default,
        init=init,
        repr=repr,
        kind=kind,
        metadata=metadata,
        on_setattr=list(on_setattr) + [tc.freeze],
        on_getattr=[tc.unfreeze] + list(on_getattr),
        alias=alias,
    )

fdtdx.core.jax.pytrees.frozen_private_field(*, default=None, init=False, repr=True, kind='POS_OR_KW', metadata=None, on_setattr=(), on_getattr=(), alias=None)

frozen_private_field(
    *,
    default: T,
    init: bool = False,
    repr: bool = True,
    kind: ArgKindType = "POS_OR_KW",
    metadata: dict[str, Any] | None = None,
    on_setattr: Sequence[Any] = (),
    on_getattr: Sequence[Any] = (),
    alias: str | None = None
) -> T
frozen_private_field(
    *,
    init: bool = False,
    repr: bool = True,
    kind: ArgKindType = "POS_OR_KW",
    metadata: dict[str, Any] | None = None,
    on_setattr: Sequence[Any] = (),
    on_getattr: Sequence[Any] = (),
    alias: str | None = None
) -> Any

Creates a field that automatically freezes on set and unfreezes on get, sets the default to None and init to False.

This field behaves like a regular pytreeclass field but ensures values are frozen when stored and unfrozen when accessed.

Parameters:

Name Type Description Default
default Any

The default value for the field

None
init bool

Whether to include the field in init

False
repr bool

Whether to include the field in repr

True
kind ArgKindType

The argument kind (POS_ONLY, POS_OR_KW, etc.)

'POS_OR_KW'
metadata dict[str, Any] | None

Additional metadata for the field

None
on_setattr Sequence[Any]

Additional setattr callbacks (applied after freezing)

()
on_getattr Sequence[Any]

Additional getattr callbacks (applied after unfreezing)

()
alias str | None

Alternative name for the field in init

None

Returns:

Type Description
Any

A Field instance configured with freeze/unfreeze behavior

Source code in src/fdtdx/core/jax/pytrees.py
def frozen_private_field(
    *,
    default: Any = None,
    init: bool = False,
    repr: bool = True,
    kind: ArgKindType = "POS_OR_KW",
    metadata: dict[str, Any] | None = None,
    on_setattr: Sequence[Any] = (),
    on_getattr: Sequence[Any] = (),
    alias: str | None = None,
) -> Any:
    """Creates a field that automatically freezes on set and unfreezes on get,
    sets the default to None and init to False.

    This field behaves like a regular pytreeclass field but ensures values are
    frozen when stored and unfrozen when accessed.

    Args:
        default: The default value for the field
        init: Whether to include the field in __init__
        repr: Whether to include the field in __repr__
        kind: The argument kind (POS_ONLY, POS_OR_KW, etc.)
        metadata: Additional metadata for the field
        on_setattr: Additional setattr callbacks (applied after freezing)
        on_getattr: Additional getattr callbacks (applied after unfreezing)
        alias: Alternative name for the field in __init__

    Returns:
        A Field instance configured with freeze/unfreeze behavior
    """
    return frozen_field(
        default=default,
        init=init,
        repr=repr,
        kind=kind,
        metadata=metadata,
        on_setattr=on_setattr,
        on_getattr=on_getattr,
        alias=alias,
    )

Tree Classes and Fields

fdtdx.core.jax.pytrees.ExtendedTreeClassIndexer

Bases: TreeClassIndexer

Extended indexer for tree class that preserves type information.

Extends TreeClassIndexer to properly handle type hints and return Self type.

Source code in src/fdtdx/core/jax/pytrees.py
class ExtendedTreeClassIndexer(TreeClassIndexer):
    """Extended indexer for tree class that preserves type information.

    Extends TreeClassIndexer to properly handle type hints and return Self type.
    """

    def __getitem__(self, where: Any) -> Self:
        """Gets item at specified index while preserving type information.

        Args:
            where: Index or key to access

        Returns:
            Self: The indexed item with proper type information preserved
        """
        return super().__getitem__(where)  # type: ignore

__getitem__(where)

Gets item at specified index while preserving type information.

Parameters:

Name Type Description Default
where Any

Index or key to access

required

Returns:

Name Type Description
Self Self

The indexed item with proper type information preserved

Source code in src/fdtdx/core/jax/pytrees.py
def __getitem__(self, where: Any) -> Self:
    """Gets item at specified index while preserving type information.

    Args:
        where: Index or key to access

    Returns:
        Self: The indexed item with proper type information preserved
    """
    return super().__getitem__(where)  # type: ignore

fdtdx.core.jax.pytrees.ExtendedTreeClass

Bases: TreeClass

Extended tree class with improved attribute setting functionality.

Extends TreeClass to provide more flexible attribute setting capabilities, particularly for handling non-recursive attribute updates.

Source code in src/fdtdx/core/jax/pytrees.py
class ExtendedTreeClass(tc.TreeClass):
    """Extended tree class with improved attribute setting functionality.

    Extends TreeClass to provide more flexible attribute setting capabilities,
    particularly for handling non-recursive attribute updates.
    """

    @property
    def at(self) -> ExtendedTreeClassIndexer:
        """Gets the extended indexer for this tree.

        Returns:
            ExtendedTreeClassIndexer: Indexer that preserves type information
        """
        return super().at  # type: ignore

    def _aset(
        self,
        attr_name: str,
        val: Any,
    ):
        """Internal helper for setting attributes directly.

        Args:
            attr_name: Name of attribute to set
            val: Value to set the attribute to
        """
        setattr(self, attr_name, val)

    def aset(
        self,
        attr_name: str,
        val: Any,
    ) -> Self:
        """Sets an attribute directly without recursive application.

        Similar to Self.at[attr_name].set(val), but without recursively
        applying to each tree leaf. Instead, replaces the full attribute
        with the new value.

        Args:
            attr_name: Name of attribute to set
            val: Value to set the attribute to

        Returns:
            Self: Updated instance with new attribute value
        """
        _, self = self.at["_aset"](attr_name, val)
        return self

at: ExtendedTreeClassIndexer property

Gets the extended indexer for this tree.

Returns:

Name Type Description
ExtendedTreeClassIndexer ExtendedTreeClassIndexer

Indexer that preserves type information

aset(attr_name, val)

Sets an attribute directly without recursive application.

Similar to Self.at[attr_name].set(val), but without recursively applying to each tree leaf. Instead, replaces the full attribute with the new value.

Parameters:

Name Type Description Default
attr_name str

Name of attribute to set

required
val Any

Value to set the attribute to

required

Returns:

Name Type Description
Self Self

Updated instance with new attribute value

Source code in src/fdtdx/core/jax/pytrees.py
def aset(
    self,
    attr_name: str,
    val: Any,
) -> Self:
    """Sets an attribute directly without recursive application.

    Similar to Self.at[attr_name].set(val), but without recursively
    applying to each tree leaf. Instead, replaces the full attribute
    with the new value.

    Args:
        attr_name: Name of attribute to set
        val: Value to set the attribute to

    Returns:
        Self: Updated instance with new attribute value
    """
    _, self = self.at["_aset"](attr_name, val)
    return self