Skip to content

JAX in FDTDX

Introduction to JAX

JAX is a high-performance numerical computing library developed by Google that brings together the familiar NumPy API with powerful features like automatic differentiation, just-in-time (JIT) compilation, and seamless GPU/TPU acceleration. Originally designed for machine learning research, JAX has become popular across scientific computing applications due to its speed and flexibility.

Jax itself provides a good introduction here and here. Otherwise, the following is a small crash course.

Functional Programming Paradigm

JAX operates exclusively in a functional programming style, which means it requires you to write pure functions without side effects. This functional approach has several important implications:

Pure Functions Only

JAX functions cannot modify variables in-place or maintain internal state. Instead of operations like array[0] = 5, you must use functional equivalents like array.at[0].set(5) that return new arrays.

# This won't work in JAX
def bad_function(x):
    x[0] = x[0] + 1  # In-place modification
    return x

# This is the JAX way
def good_function(x):
    return x.at[0].add(1)  # Returns new array

No Side Effects

Functions should not print to console, write to files, or modify global variables during compilation. JAX's JIT compiler optimizes based on the assumption that functions are deterministic and side-effect free.

Immutable Data

Arrays and other data structures are treated as immutable. Operations create new objects rather than modifying existing ones, similar to how NumPy handles broadcasting operations.

This functional constraint enables JAX's powerful transformations like jit (compilation), grad (automatic differentiation), vmap (vectorization), and pmap (parallelization). While the functional style requires some adjustment if you're used to imperative programming, it unlocks JAX's ability to automatically optimize and transform your numerical code in ways that would be impossible with stateful operations.

TreeClass Objects in FDTDX

FDTDX leverages JAX's functional programming paradigm through a specialized TreeClass system that makes it easy to work with complex hierarchical data structures while maintaining JAX compatibility. The TreeClass provides a clean, object-oriented interface that automatically integrates with JAX's pytree system, allowing for seamless use with JAX transformations.

TreeClass Structure

The TreeClass system uses dataclass-like syntax with the @fdtdx.autoinit decorator to automatically generate initialization methods. Here's how it works:

import fdtdx

@fdtdx.autoinit
class A(fdtdx.TreeClass):
    a: float = 2
    x: int = 5

@fdtdx.autoinit
class B(fdtdx.TreeClass):
    a1: A
    z: int = 7

@fdtdx.autoinit
class C(fdtdx.TreeClass):
    b_list: list[B]
    c: float = 2
These classes can be nested arbitrarily deep and contain lists, dictionaries, or other complex data structures. The @fdtdx.autoinit decorator automatically generates init methods that handle default values and type checking.

Working with TreeClass Instances

# Create instances with default or custom values
b = B(a1=A())  # Uses defaults: A(a=2, x=5), z=7
b = b.aset("z", 7)  # Functional update

# Create more complex nested structures
b2 = B(a1=A(a=10, x=11), z=12)
b3 = B(a1=A(a=20, x=21), z=22)

# Collections of TreeClass instances
c = C(b_list=[b, b2])

# Deep nested updates using path syntax
c2 = c.aset("b_list->[0]->a1->a", 100)

The aset Method: Functional Updates Made Easy

The aset method is the cornerstone of FDTDX's functional approach. Unlike JAX's standard .at[].set() which only works on pytree leaf nodes (typically arrays), aset can update any attribute at any level of nesting within a TreeClass hierarchy.

Path Syntax: The method uses an intuitive string-based path syntax to navigate nested structures:

  • "attribute" - Direct attribute access
  • "a->b" - Nested attribute access (a.b)
  • "a->[0]" - List indexing
  • "a->['key']" - Dictionary key access
  • "b_list->[0]->a1->a" - Complex nested path

In the example c2 = c.aset("b_list->[0]->a1->a", 100), this path means: - Access the b_list attribute of c - Get the first element [0] of that list - Access the a1 attribute of that element - Access the a attribute of a1 - Set that value to 100

The method returns a completely new instance with the updated value, maintaining JAX's functional programming requirements. This allows FDTDX data structures to be used seamlessly with JAX transformations like jit, grad, and vmap, while providing a much more intuitive interface than manually reconstructing nested data structures. This approach bridges the gap between JAX's powerful functional capabilities and the practical need for complex, hierarchical data management in scientific computing applications.

How JAX is used in FDTDX

For a full example on how to use JAX with fdtdx, check out this example or this example. The script demonstrates FDTDX's seamless integration with JAX's jit transformation. The core simulation function sim_fn takes FDTDX TreeClass structures as arguments and is JIT-compiled:

def sim_fn(
    params: fdtdx.ParameterContainer,
    arrays: fdtdx.ArrayContainer, 
    key: jax.Array,
):
    # Complex FDTD simulation logic with TreeClass structures
    arrays, new_objects, info = fdtdx.apply_params(arrays, objects, params, key)
    final_state = fdtdx.run_fdtd(arrays=arrays, objects=new_objects, config=config, key=key)
    # ... more operations
    return arrays, new_info
jitted_loss = jax.jit(sim_fn, donate_argnames=["arrays"]).lower(params, arrays, key).compile()

JIT compilation with TreeClass arguments

Key Features:

  • TreeClass Compatibility: The ParameterContainer and ArrayContainer are FDTDX TreeClass structures that work seamlessly with jit. JAX automatically handles the pytree registration, allowing these complex nested structures to be compiled efficiently.
  • Memory Optimization: The donate_argnames=["arrays"] parameter tells JAX it can reuse the memory of the arrays argument, which is crucial for large electromagnetic field arrays in FDTD simulations.
  • Compilation Pipeline: The script uses .lower().compile() to explicitly control the compilation process, providing timing information for performance analysis.

While this specific example focuses on forward simulation, FDTDX is designed for gradient-based optimization. The GradientConfig setup shows how gradients would be computed:

gradient_config = fdtdx.GradientConfig(
    recorder=fdtdx.Recorder(
        modules=[fdtdx.DtypeConversion(dtype=jnp.bfloat16)]
    )
)
For gradient computation, you would typically use:
# Hypothetical gradient computation
grad_fn = jax.grad(sim_fn, argnums=0)  # Gradient w.r.t. params
gradients = grad_fn(params, arrays, key)