Skip to content

Image Processing

fdtdx.core.gaussian_blur.gaussian_blur_3d(image, sigma, kernel_size, *, padding='SAME', channel_axis=-1)

Applies 3D Gaussian blur by convolving the image with separable Gaussian kernels.

This function implements an efficient 3D Gaussian blur by decomposing the 3D Gaussian kernel into three 1D kernels and applying them sequentially along each axis. This is mathematically equivalent to a full 3D convolution but much more computationally efficient.

Parameters:

Name Type Description Default
image Array

The input image as a float tensor with values in [0,1]. Should have 4 or 5 dimensions: [batch?, depth, height, width, channels] or [batch?, channels, depth, height, width]. The batch dimension is optional.

required
sigma float

Standard deviation (in pixels) of the Gaussian kernel. Controls the amount of blurring - larger values produce more blurring.

required
kernel_size float

Size (in pixels) of the cubic Gaussian kernel. Will be rounded up to the nearest odd integer to ensure the kernel is symmetric. Should be at least 2 * ceil(3 * sigma) + 1 to avoid truncating the Gaussian significantly.

required
padding str

Either "SAME" or "VALID". With "SAME" padding the output has the same spatial dimensions as the input. With "VALID" padding the output is smaller due to no padding being added.

'SAME'
channel_axis int

The axis containing the channels in the input tensor. Use -1 for channels-last format (default) or 1 for channels-first format.

-1

Returns:

Type Description
Array

jax.Array: The blurred image tensor with the same shape and data format as the input. With "SAME" padding the spatial dimensions match the input. With "VALID" padding they are reduced by (kernel_size - 1) in each dimension.

Source code in src/fdtdx/core/gaussian_blur.py
def gaussian_blur_3d(
    image: jax.Array,
    sigma: float,
    kernel_size: float,
    *,
    padding: str = "SAME",
    channel_axis: int = -1,
) -> jax.Array:
    """Applies 3D Gaussian blur by convolving the image with separable Gaussian kernels.

    This function implements an efficient 3D Gaussian blur by decomposing the 3D Gaussian
    kernel into three 1D kernels and applying them sequentially along each axis. This is
    mathematically equivalent to a full 3D convolution but much more computationally efficient.

    Args:
        image: The input image as a float tensor with values in [0,1]. Should have 4 or 5
            dimensions: [batch?, depth, height, width, channels] or [batch?, channels, depth,
            height, width]. The batch dimension is optional.
        sigma: Standard deviation (in pixels) of the Gaussian kernel. Controls the amount
            of blurring - larger values produce more blurring.
        kernel_size: Size (in pixels) of the cubic Gaussian kernel. Will be rounded up to
            the nearest odd integer to ensure the kernel is symmetric. Should be at least
            2 * ceil(3 * sigma) + 1 to avoid truncating the Gaussian significantly.
        padding: Either "SAME" or "VALID". With "SAME" padding the output has the same
            spatial dimensions as the input. With "VALID" padding the output is smaller
            due to no padding being added.
        channel_axis: The axis containing the channels in the input tensor. Use -1 for
            channels-last format (default) or 1 for channels-first format.

    Returns:
        jax.Array: The blurred image tensor with the same shape and data format as the input.
            With "SAME" padding the spatial dimensions match the input. With "VALID" padding
            they are reduced by (kernel_size - 1) in each dimension.
    """
    chex.assert_rank(image, {4, 5})
    data_format = "NDHWC" if _channels_last(image, channel_axis) else "NCDHW"
    dimension_numbers = (data_format, "DHWIO", data_format)
    num_channels = image.shape[channel_axis]
    radius = int(kernel_size / 2)
    kernel_size_ = 2 * radius + 1
    x = jnp.arange(-radius, radius + 1).astype(image.dtype)
    blur_filter = jnp.exp(-(x**2) / (2.0 * sigma**2))
    blur_filter = blur_filter / jnp.sum(blur_filter)
    blur_d = jnp.reshape(blur_filter, [kernel_size_, 1, 1, 1, 1])
    blur_h = jnp.reshape(blur_filter, [1, kernel_size_, 1, 1, 1])
    blur_w = jnp.reshape(blur_filter, [1, 1, kernel_size_, 1, 1])
    blur_h = jnp.tile(blur_h, [1, 1, 1, 1, num_channels])
    blur_w = jnp.tile(blur_w, [1, 1, 1, 1, num_channels])
    blur_d = jnp.tile(blur_d, [1, 1, 1, 1, num_channels])

    expand_batch_dim = image.ndim == 4
    if expand_batch_dim:
        image = image[jnp.newaxis, ...]
    blurred = _depthwise_conv3d(
        image,
        kernel=blur_h,
        strides=(1, 1, 1),
        padding=padding,
        channel_axis=channel_axis,
        dimension_numbers=dimension_numbers,
    )
    blurred = _depthwise_conv3d(
        blurred,
        kernel=blur_w,
        strides=(1, 1, 1),
        padding=padding,
        channel_axis=channel_axis,
        dimension_numbers=dimension_numbers,
    )
    blurred = _depthwise_conv3d(
        blurred,
        kernel=blur_d,
        strides=(1, 1, 1),
        padding=padding,
        channel_axis=channel_axis,
        dimension_numbers=dimension_numbers,
    )
    if expand_batch_dim:
        blurred = jnp.squeeze(blurred, axis=0)
    return blurred