Skip to content

Array Operations

fdtdx.core.misc.expand_matrix(matrix, grid_points_per_voxel, add_channels=True)

Expands a matrix by repeating values along spatial dimensions and optionally adding channels.

Used to upsample a coarse grid to a finer simulation grid by repeating values. Can also add vector field components as channels.

Parameters:

Name Type Description Default
matrix Array

Input matrix to expand

required
grid_points_per_voxel tuple[int, ...]

Number of grid points to expand each voxel into along each dimension

required
add_channels bool

If True, adds and tiles 3 channels for vector field components

True

Returns:

Type Description
Array

jax.Array: Expanded matrix with repeated values and optional channels

Source code in src/fdtdx/core/misc.py
def expand_matrix(matrix: jax.Array, grid_points_per_voxel: tuple[int, ...], add_channels: bool = True) -> jax.Array:
    """Expands a matrix by repeating values along spatial dimensions and optionally adding channels.

    Used to upsample a coarse grid to a finer simulation grid by repeating values. Can also add
    vector field components as channels.

    Args:
        matrix: Input matrix to expand
        grid_points_per_voxel: Number of grid points to expand each voxel into along each dimension
        add_channels: If True, adds and tiles 3 channels for vector field components

    Returns:
        jax.Array: Expanded matrix with repeated values and optional channels
    """
    """Expands a matrix by repeating values along spatial dimensions and optionally adding channels.

    Args:
        matrix: Input matrix to expand
        grid_points_per_voxel: Number of grid points to expand each voxel into along each dimension
        add_channels: If True, adds and tiles 3 channels for vector field components

    Returns:
        jax.Array: Expanded matrix with repeated values and optional channels
    """
    if matrix.ndim == 2:
        matrix = jnp.expand_dims(matrix, axis=-1)
    expanded_matrix = jnp.repeat(matrix, grid_points_per_voxel[0], axis=0)
    expanded_matrix = jnp.repeat(expanded_matrix, grid_points_per_voxel[1], axis=1)
    expanded_matrix = jnp.repeat(expanded_matrix, grid_points_per_voxel[2], axis=2)
    if add_channels:
        if matrix.ndim == 3:
            expanded_matrix = jnp.expand_dims(expanded_matrix, axis=-1)
        expanded_matrix = jnp.tile(expanded_matrix, tuple(1 for _ in grid_points_per_voxel) + (3,))
    return expanded_matrix

fdtdx.core.misc.ensure_slice_tuple(t)

Ensures that all elements of the input sequence are converted to slices.

This function takes a sequence of elements that can be slices, integers, or tuples of integers and returns a tuple of slices. Integers are converted to slices that select a single item, and tuples are converted to slices that select a range of items.

Parameters:

Name Type Description Default
t Sequence[slice | int | Tuple[int, int]]

A sequence of elements where each element is either a slice, an integer, or a tuple of two integers representing the start and end of a slice range.

required

Returns:

Type Description
Tuple[slice, ...]

A tuple of slices corresponding to the input sequence.

Source code in src/fdtdx/core/misc.py
def ensure_slice_tuple(t: Sequence[slice | int | Tuple[int, int]]) -> Tuple[slice, ...]:
    """
    Ensures that all elements of the input sequence are converted to slices.

    This function takes a sequence of elements that can be slices, integers,
    or tuples of integers and returns a tuple of slices. Integers are converted
    to slices that select a single item, and tuples are converted to slices
    that select a range of items.

    Args:
        t: A sequence of elements where each element is either a slice, an
            integer, or a tuple of two integers representing the start and end
            of a slice range.

    Returns:
        A tuple of slices corresponding to the input sequence.
    """

    def to_slice(loc):
        if isinstance(loc, int):
            return slice(loc, loc + 1)
        elif isinstance(loc, slice):
            return loc
        elif isinstance(loc, (tuple, list)) and len(loc) == 2 and all(isinstance(i, int) for i in loc):
            return slice(*loc)
        else:
            raise ValueError(f"Invalid location type: {type(loc)}. Expected int, slice, or tuple of ints.")

    return tuple(to_slice(loc) for loc in t)

fdtdx.core.misc.index_1d_array(arr, val)

Finds the first index where a 1D array equals a given value.

Parameters:

Name Type Description Default
arr Array

1D input array to search

required
val Array

Value to find in the array

required

Returns:

Name Type Description
int Array

Index of first occurrence of val in arr

Raises:

Type Description
Exception

If input array is not 1D

Source code in src/fdtdx/core/misc.py
def index_1d_array(arr: jax.Array, val: jax.Array) -> jax.Array:
    """Finds the first index where a 1D array equals a given value.

    Args:
        arr: 1D input array to search
        val: Value to find in the array

    Returns:
        int: Index of first occurrence of val in arr

    Raises:
        Exception: If input array is not 1D
    """
    if len(arr.shape) != 1:
        raise Exception(f"index only works on 1d-array, got shape: {arr.shape}")
    first_idx = jnp.argmax(arr == val)
    return first_idx

fdtdx.core.misc.index_by_slice(arr, start, stop, axis, step=1)

Indexes an array along a specified axis using slice notation.

Parameters:

Name Type Description Default
arr Array

Input array to slice

required
start int | None

Starting index

required
stop int | None

Stopping index

required
axis int

Axis along which to slice

required
step int

Step size between elements

1

Returns:

Type Description
Array

jax.Array: Sliced array

Source code in src/fdtdx/core/misc.py
def index_by_slice(
    arr: jax.Array,
    start: int | None,
    stop: int | None,
    axis: int,
    step: int = 1,
) -> jax.Array:
    """Indexes an array along a specified axis using slice notation.

    Args:
        arr: Input array to slice
        start: Starting index
        stop: Stopping index
        axis: Axis along which to slice
        step: Step size between elements

    Returns:
        jax.Array: Sliced array
    """
    slice_list = [slice(None) for _ in range(arr.ndim)]
    slice_list[axis] = slice(start, stop, step)
    return arr[tuple(slice_list)]

fdtdx.core.misc.index_by_slice_take_1d(arr, slice, axis)

Takes elements from an array along one axis using a slice and JAX's take operation.

Optimized version of array slicing that uses JAX's take operation for better performance when taking elements along a single axis.

Parameters:

Name Type Description Default
arr Array

Input array

required
slice slice

Slice object specifying which elements to take

required
axis int

Axis along which to take elements

required

Returns:

Type Description
Array

jax.Array: Array with selected elements

Raises:

Type Description
Exception

If slice would result in empty array

Source code in src/fdtdx/core/misc.py
def index_by_slice_take_1d(
    arr: jax.Array,
    slice: slice,
    axis: int,
) -> jax.Array:
    """Takes elements from an array along one axis using a slice and JAX's take operation.

    Optimized version of array slicing that uses JAX's take operation for better performance
    when taking elements along a single axis.

    Args:
        arr: Input array
        slice: Slice object specifying which elements to take
        axis: Axis along which to take elements

    Returns:
        jax.Array: Array with selected elements

    Raises:
        Exception: If slice would result in empty array
    """
    """Takes elements from an array along one axis using a slice and JAX's take operation.

    Args:
        arr: Input array
        slice: Slice object specifying which elements to take
        axis: Axis along which to take elements

    Returns:
        jax.Array: Array with selected elements

    Raises:
        Exception: If slice would result in empty array
    """
    start, stop, step = slice.indices(arr.shape[axis])
    if start == 0 and stop == arr.shape[axis] and step == 1:
        return arr
    indices = jnp.arange(start, stop, step)
    if len(indices) == 0:
        raise Exception(f"Invalid slice: {slice}")
    arr = jnp.take(arr, indices, axis=axis, unique_indices=True, indices_are_sorted=True)
    return arr

fdtdx.core.misc.index_by_slice_take(arr, slices)

Takes elements from an array using multiple slices and JAX's take operation.

Optimized version of array slicing that uses JAX's take operation for better performance when taking elements along multiple axes.

Parameters:

Name Type Description Default
arr Array

Input array

required
slices Sequence[slice]

Sequence of slice objects, one for each dimension

required

Returns:

Type Description
Array

jax.Array: Array with selected elements

Raises:

Type Description
Exception

If any slice would result in empty array

Source code in src/fdtdx/core/misc.py
def index_by_slice_take(
    arr: jax.Array,
    slices: Sequence[slice],
) -> jax.Array:
    """Takes elements from an array using multiple slices and JAX's take operation.

    Optimized version of array slicing that uses JAX's take operation for better performance
    when taking elements along multiple axes.

    Args:
        arr: Input array
        slices: Sequence of slice objects, one for each dimension

    Returns:
        jax.Array: Array with selected elements

    Raises:
        Exception: If any slice would result in empty array
    """
    """Takes elements from an array using multiple slices and JAX's take operation.

    Args:
        arr: Input array
        slices: Sequence of slice objects, one for each dimension

    Returns:
        jax.Array: Array with selected elements

    Raises:
        Exception: If any slice would result in empty array
    """
    for axis, s in enumerate(slices):
        start, stop, step = s.indices(arr.shape[axis])
        if start == 0 and stop == arr.shape[axis] and step == 1:
            continue
        indices = jnp.arange(start, stop, step)
        if len(indices) == 0:
            raise Exception(f"Invalid slice: {s}")
        arr = jnp.take(arr, indices, axis=axis, unique_indices=True, indices_are_sorted=True)
    return arr

fdtdx.core.misc.mask_1d_from_slice(s, axis_size)

Creates a boolean mask array from a slice specification.

Parameters:

Name Type Description Default
s slice

Slice object defining which elements should be True

required
axis_size int

Size of the axis to create mask for

required

Returns:

Type Description
Array

jax.Array: Boolean mask array with True values where slice selects elements

Source code in src/fdtdx/core/misc.py
def mask_1d_from_slice(
    s: slice,
    axis_size: int,
) -> jax.Array:
    """Creates a boolean mask array from a slice specification.

    Args:
        s: Slice object defining which elements should be True
        axis_size: Size of the axis to create mask for

    Returns:
        jax.Array: Boolean mask array with True values where slice selects elements
    """
    """Creates a boolean mask array from a slice specification.

    Args:
        s: Slice object defining which elements should be True
        axis_size: Size of the axis to create mask for

    Returns:
        jax.Array: Boolean mask array with True values where slice selects elements
    """
    start, stop, step = s.indices(axis_size)
    mask = jnp.zeros(shape=(axis_size,), dtype=jnp.bool)
    mask = mask.at[start:stop:step].set(1)
    return mask

fdtdx.core.misc.assimilate_shape(arr, ref_arr, ref_axes, repeat_single_dims=False)

Reshapes array to match reference array's dimensions for broadcasting.

Inserts new dimensions of size 1 such that arr has same dimensions as ref_arr and can be broadcasted. Optionally repeats single dimensions to match ref_arr's shape.

Parameters:

Name Type Description Default
arr Array

Array to reshape

required
ref_arr Array

Reference array whose shape to match

required
ref_axes tuple[int, ...]

Tuple mapping arr's axes to ref_arr's axes

required
repeat_single_dims bool

If True, repeats size-1 dimensions to match ref_arr

False

Returns:

Type Description
Array

jax.Array: Reshaped array that can be broadcasted with ref_arr

Raises:

Type Description
Exception

If shapes are incompatible or axes mapping is invalid

Source code in src/fdtdx/core/misc.py
def assimilate_shape(
    arr: jax.Array,
    ref_arr: jax.Array,
    ref_axes: tuple[int, ...],
    repeat_single_dims: bool = False,
) -> jax.Array:
    """Reshapes array to match reference array's dimensions for broadcasting.

    Inserts new dimensions of size 1 such that arr has same dimensions as ref_arr
    and can be broadcasted. Optionally repeats single dimensions to match ref_arr's shape.

    Args:
        arr: Array to reshape
        ref_arr: Reference array whose shape to match
        ref_axes: Tuple mapping arr's axes to ref_arr's axes
        repeat_single_dims: If True, repeats size-1 dimensions to match ref_arr

    Returns:
        jax.Array: Reshaped array that can be broadcasted with ref_arr

    Raises:
        Exception: If shapes are incompatible or axes mapping is invalid
    """
    """
    Inserts new dimensions of size 1 such that to_change has same dimensions
    as the reference arr and can be broadcasted.
    """
    if arr.ndim != len(ref_axes):
        raise Exception(f"Invalid axes: {arr.ndim=}, {ref_axes=}")
    if max(ref_axes) >= ref_arr.ndim:
        raise Exception(f"Invalid axes: {ref_arr.ndim=}, {ref_axes=}")
    for a, ra in enumerate(ref_axes):
        if ref_arr.shape[ra] != arr.shape[a] and arr.shape[a] != 1:
            raise Exception(f"Invalid shapes: {arr.shape=}, {ref_arr.shape=}")
    new_shape = [1] * len(ref_arr.shape)
    for a, ra in enumerate(ref_axes):
        new_shape[ra] = arr.shape[a]
    arr = jnp.reshape(arr, new_shape)
    if not repeat_single_dims:
        return arr
    for ra in ref_axes:
        if arr.shape[ra] == 1:
            arr = jnp.repeat(arr, ref_arr.shape[ra], axis=ra)
    return arr

fdtdx.core.misc.linear_interpolated_indexing(point, arr)

Performs linear interpolation at a point in an array.

Parameters:

Name Type Description Default
point Array

Coordinates at which to interpolate

required
arr Array

Array to interpolate from

required

Returns:

Type Description
Array

jax.Array: Interpolated value at the specified point

Raises:

Type Description
Exception

If point dimensions don't match array dimensions

Source code in src/fdtdx/core/misc.py
def linear_interpolated_indexing(
    point: jax.Array,
    arr: jax.Array,
) -> jax.Array:
    """Performs linear interpolation at a point in an array.

    Args:
        point: Coordinates at which to interpolate
        arr: Array to interpolate from

    Returns:
        jax.Array: Interpolated value at the specified point

    Raises:
        Exception: If point dimensions don't match array dimensions
    """
    if point.ndim != 1 or point.shape[0] != arr.ndim:
        raise Exception(f"Invalid shape of point ({point.shape}) or arr {arr.shape}")
    indices = [[jnp.floor(point[a]), jnp.ceil(point[a])] for a in range(point.shape[0])]
    to_interpolate = jnp.asarray(list(itertools.product(*indices)), dtype=jnp.int32)
    weights = (1 - jnp.abs(to_interpolate - point[None, :])).prod(axis=-1)
    for axis in range(arr.ndim):
        invalid_mask = (to_interpolate[:, axis] < 0) | (to_interpolate[:, axis] >= arr.shape[axis])
        weights = jnp.where(invalid_mask, 0, weights)
        to_interpolate = jnp.where(invalid_mask[:, None], 0, to_interpolate)
    indexed_vals = arr[tuple(to_interpolate.T)]
    result = (weights * indexed_vals).sum() / (weights.sum() + 1e-8)
    return result

fdtdx.core.misc.advanced_padding(arr, padding_cfg)

Pads the input array with configurable padding modes and widths.

Parameters:

Name Type Description Default
arr Array

Input array to pad

required
padding_cfg PaddingConfig

Configuration object containing: - widths: Padding widths for each edge - modes: Padding modes (constant, edge, reflect etc) - values: Values to use for constant padding

required

Returns:

Type Description
tuple[Array, tuple[slice, ...]]

tuple[jax.Array, tuple[slice, ...]]: Tuple containing: - Padded array - Slice tuple to extract original array region

Source code in src/fdtdx/core/misc.py
def advanced_padding(
    arr: jax.Array,
    padding_cfg: PaddingConfig,
) -> tuple[jax.Array, tuple[slice, ...]]:
    """Pads the input array with configurable padding modes and widths.

    Args:
        arr: Input array to pad
        padding_cfg: Configuration object containing:
            - widths: Padding widths for each edge
            - modes: Padding modes (constant, edge, reflect etc)
            - values: Values to use for constant padding

    Returns:
        tuple[jax.Array, tuple[slice, ...]]: Tuple containing:
            - Padded array
            - Slice tuple to extract original array region
    """
    # default values
    if len(padding_cfg.widths) == 1:
        padding_cfg = padding_cfg.aset("widths", [padding_cfg.widths[0] for _ in range(2 * arr.ndim)])
    if len(padding_cfg.modes) == 1:
        padding_cfg = padding_cfg.aset("modes", [padding_cfg.modes[0] for _ in range(2 * arr.ndim)])
    if padding_cfg.values is None:
        padding_cfg = padding_cfg.aset("values", [0 for _ in range(2 * arr.ndim)])
    if len(padding_cfg.values) == 1:
        padding_cfg = padding_cfg.aset("values", [padding_cfg.values[0] for _ in range(2 * arr.ndim)])

    # sanity checks
    if len(padding_cfg.widths) % 2 != 0 or len(padding_cfg.widths) / 2 != arr.ndim:
        raise Exception(f"Invalid padding width: {padding_cfg.widths} for array with {arr.ndim} dimensions")
    if len(padding_cfg.modes) % 2 != 0 or len(padding_cfg.modes) / 2 != arr.ndim:
        raise Exception(f"Invalid padding width: {padding_cfg.modes} for array with {arr.ndim} dimensions")
    if len(padding_cfg.values) % 2 != 0 or len(padding_cfg.values) / 2 != arr.ndim:
        raise Exception(f"Invalid padding width: {padding_cfg.values} for array with {arr.ndim} dimensions")

    slices = [[0, arr.shape[ax]] for ax in range(arr.ndim)]
    for edge in range(2 * arr.ndim):
        is_end = edge % 2 != 0
        axis = math.floor(edge / 2)
        cur_width = padding_cfg.widths[edge]
        cur_mode = padding_cfg.modes[edge]
        cur_value = padding_cfg.values[edge]

        kwargs = {}
        if cur_mode == "constant":
            kwargs["constant_values"] = cur_value
        pad_width_tuple = tuple(
            [(0, 0) if ax != axis else ((0, cur_width) if is_end else (cur_width, 0)) for ax in range(arr.ndim)]
        )
        if not is_end:
            slices[axis][0] = cur_width
            slices[axis][1] += cur_width
        arr = jnp.pad(array=arr, pad_width=pad_width_tuple, mode=cur_mode, **kwargs)
    slices = ensure_slice_tuple(slices)  # type: ignore
    return arr, slices