JAX in FDTDX
Introduction to JAX¶
JAX is a high-performance numerical computing library developed by Google that brings together the familiar NumPy API with powerful features like automatic differentiation, just-in-time (JIT) compilation, and seamless GPU/TPU acceleration. Originally designed for machine learning research, JAX has become popular across scientific computing applications due to its speed and flexibility.
Jax itself provides a good introduction here and here. Otherwise, the following is a small crash course.
Functional Programming Paradigm¶
JAX operates exclusively in a functional programming style, which means it requires you to write pure functions without side effects. This functional approach has several important implications:
Pure Functions Only¶
JAX functions cannot modify variables in-place or maintain internal state. Instead of operations like array[0] = 5, you must use functional equivalents like array.at[0].set(5) that return new arrays.
# This won't work in JAX
def bad_function(x):
x[0] = x[0] + 1 # In-place modification
return x
# This is the JAX way
def good_function(x):
return x.at[0].add(1) # Returns new array
No Side Effects¶
Functions should not print to console, write to files, or modify global variables during compilation. JAX's JIT compiler optimizes based on the assumption that functions are deterministic and side-effect free.
Immutable Data¶
Arrays and other data structures are treated as immutable. Operations create new objects rather than modifying existing ones, similar to how NumPy handles broadcasting operations.
This functional constraint enables JAX's powerful transformations like jit (compilation), grad (automatic differentiation), vmap (vectorization), and pmap (parallelization). While the functional style requires some adjustment if you're used to imperative programming, it unlocks JAX's ability to automatically optimize and transform your numerical code in ways that would be impossible with stateful operations.
TreeClass Objects in FDTDX¶
FDTDX leverages JAX's functional programming paradigm through a specialized TreeClass system that makes it easy to work with complex hierarchical data structures while maintaining JAX compatibility. The TreeClass provides a clean, object-oriented interface that automatically integrates with JAX's pytree system, allowing for seamless use with JAX transformations.
TreeClass Structure¶
The TreeClass system uses dataclass-like syntax with the @fdtdx.autoinit decorator to automatically generate initialization methods. Here's how it works:
import fdtdx
@fdtdx.autoinit
class A(fdtdx.TreeClass):
a: float = 2
x: int = 5
@fdtdx.autoinit
class B(fdtdx.TreeClass):
a1: A
z: int = 7
@fdtdx.autoinit
class C(fdtdx.TreeClass):
b_list: list[B]
c: float = 2
Working with TreeClass Instances¶
# Create instances with default or custom values
b = B(a1=A()) # Uses defaults: A(a=2, x=5), z=7
b = b.aset("z", 7) # Functional update
# Create more complex nested structures
b2 = B(a1=A(a=10, x=11), z=12)
b3 = B(a1=A(a=20, x=21), z=22)
# Collections of TreeClass instances
c = C(b_list=[b, b2])
# Deep nested updates using path syntax
c2 = c.aset("b_list->[0]->a1->a", 100)
The aset Method: Functional Updates Made Easy¶
The aset method is the cornerstone of FDTDX's functional approach. Unlike JAX's standard .at[].set() which only works on pytree leaf nodes (typically arrays), aset can update any attribute at any level of nesting within a TreeClass hierarchy.
Path Syntax: The method uses an intuitive string-based path syntax to navigate nested structures:¶
- "attribute" - Direct attribute access
- "a->b" - Nested attribute access (a.b)
- "a->[0]" - List indexing
- "a->['key']" - Dictionary key access
- "b_list->[0]->a1->a" - Complex nested path
In the example c2 = c.aset("b_list->[0]->a1->a", 100), this path means: - Access the b_list attribute of c - Get the first element [0] of that list - Access the a1 attribute of that element - Access the a attribute of a1 - Set that value to 100
The method returns a completely new instance with the updated value, maintaining JAX's functional programming requirements. This allows FDTDX data structures to be used seamlessly with JAX transformations like jit, grad, and vmap, while providing a much more intuitive interface than manually reconstructing nested data structures. This approach bridges the gap between JAX's powerful functional capabilities and the practical need for complex, hierarchical data management in scientific computing applications.
How JAX is used in FDTDX¶
For a full example on how to use JAX with fdtdx, check out this example or this example. The script demonstrates FDTDX's seamless integration with JAX's jit transformation. The core simulation function sim_fn takes FDTDX TreeClass structures as arguments and is JIT-compiled:
def sim_fn(
params: fdtdx.ParameterContainer,
arrays: fdtdx.ArrayContainer,
key: jax.Array,
):
# Complex FDTD simulation logic with TreeClass structures
arrays, new_objects, info = fdtdx.apply_params(arrays, objects, params, key)
final_state = fdtdx.run_fdtd(arrays=arrays, objects=new_objects, config=config, key=key)
# ... more operations
return arrays, new_info
jitted_loss = jax.jit(sim_fn, donate_argnames=["arrays"]).lower(params, arrays, key).compile()
JIT compilation with TreeClass arguments¶
Key Features:
- TreeClass Compatibility: The ParameterContainer and ArrayContainer are FDTDX TreeClass structures that work seamlessly with jit. JAX automatically handles the pytree registration, allowing these complex nested structures to be compiled efficiently.
- Memory Optimization: The donate_argnames=["arrays"] parameter tells JAX it can reuse the memory of the arrays argument, which is crucial for large electromagnetic field arrays in FDTD simulations.
- Compilation Pipeline: The script uses .lower().compile() to explicitly control the compilation process, providing timing information for performance analysis.
While this specific example focuses on forward simulation, FDTDX is designed for gradient-based optimization. The GradientConfig setup shows how gradients would be computed:
For gradient computation, you would typically use: