Distributed Computing¶
fdtdx.core.jax.sharding.get_named_sharding_from_shape(shape, sharding_axis)
¶
Creates a NamedSharding object for distributing an array across devices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
shape
|
tuple[int, ...]
|
Shape of the array to be sharded |
required |
sharding_axis
|
int
|
Which axis to shard the array along |
required |
Returns:
Type | Description |
---|---|
NamedSharding
|
NamedSharding object specifying how to distribute the array across available devices |
Raises:
Type | Description |
---|---|
ValueError
|
If shape[sharding_axis] is not divisible by number of devices |
Source code in src/fdtdx/core/jax/sharding.py
fdtdx.core.jax.sharding.get_dtype_bytes(dtype)
¶
Get the size in bytes of a JAX dtype.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
dtype
|
dtype
|
JAX dtype to get size of |
required |
Returns:
Type | Description |
---|---|
int
|
Number of bytes used by the dtype |
fdtdx.core.jax.sharding.pretty_print_sharding(sharding)
¶
Returns a human-readable string representation of a sharding specification.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
sharding
|
Sharding
|
JAX sharding object to format |
required |
Returns:
Type | Description |
---|---|
str
|
String representation showing the sharding type and configuration |
Source code in src/fdtdx/core/jax/sharding.py
fdtdx.core.jax.sharding.create_named_sharded_matrix(shape, value, sharding_axis, dtype, backend)
¶
Creates a sharded matrix distributed across available devices.
Creates a matrix of the given shape filled with the specified value, sharded across available devices along the specified axis.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
shape
|
tuple[int, ...]
|
Shape of the matrix to create |
required |
value
|
float
|
Value to fill the matrix with |
required |
sharding_axis
|
int
|
Which axis to shard along |
required |
dtype
|
dtype
|
Data type of the matrix elements |
required |
backend
|
Literal['gpu', 'tpu', 'cpu']
|
Which device backend to use ("gpu", "tpu", or "cpu") |
required |
Returns:
Type | Description |
---|---|
Array
|
Sharded matrix distributed across devices |
Raises:
Type | Description |
---|---|
ValueError
|
If shape[sharding_axis] is not divisible by number of devices |