Skip to content

Visualization

fdtdx.core.plotting.device_permittivity_index_utils.index_matrix_to_str(indices)

Converts a 2D matrix of indices to a formatted string representation.

Parameters:

Name Type Description Default
indices Array

A 2D JAX array containing numerical indices.

required

Returns:

Type Description
str

A string representation of the matrix where each row is space-separated

str

and rows are separated by newlines.

Source code in src/fdtdx/core/plotting/device_permittivity_index_utils.py
def index_matrix_to_str(indices: jax.Array) -> str:
    """Converts a 2D matrix of indices to a formatted string representation.

    Args:
        indices: A 2D JAX array containing numerical indices.

    Returns:
        A string representation of the matrix where each row is space-separated
        and rows are separated by newlines.
    """
    indices_str = ""
    for i in range(indices.shape[0]):
        for j in range(indices.shape[1]):
            indices_str += str(indices[i, j].squeeze()) + " "
        indices_str += "\n"
    return indices_str

fdtdx.core.plotting.device_permittivity_index_utils.device_matrix_index_figure(device_matrix_indices, permittivity_configs)

Creates a visualization figure of device matrix indices with permittivity configurations.

Parameters:

Name Type Description Default
device_matrix_indices Array

A 3D JAX array containing the device matrix indices. Shape should be (height, width, channels) where channels is typically 1.

required
permittivity_configs tuple[tuple[str, float], ...]

A tuple of (name, value) pairs defining the permittivity configurations, where name is a string identifier (e.g., "Air") and value is the corresponding permittivity value.

required

Returns:

Type Description
Figure

A matplotlib Figure object containing the visualization with:

Figure
  • A heatmap of the device matrix indices
Figure
  • Color-coded regions based on permittivity configurations
Figure
  • Optional text labels showing index values (for smaller matrices)
Figure
  • A legend mapping colors to permittivity configurations
Figure
  • Proper axis labels and grid settings

Raises:

Type Description
AssertionError

If device_matrix_indices is not 3-dimensional.

Source code in src/fdtdx/core/plotting/device_permittivity_index_utils.py
def device_matrix_index_figure(
    device_matrix_indices: jax.Array,
    permittivity_configs: tuple[tuple[str, float], ...],
) -> Figure:
    """Creates a visualization figure of device matrix indices with permittivity configurations.

    Args:
        device_matrix_indices: A 3D JAX array containing the device matrix indices.
            Shape should be (height, width, channels) where channels is typically 1.
        permittivity_configs: A tuple of (name, value) pairs defining the permittivity
            configurations, where name is a string identifier (e.g., "Air") and value
            is the corresponding permittivity value.

    Returns:
        A matplotlib Figure object containing the visualization with:
        - A heatmap of the device matrix indices
        - Color-coded regions based on permittivity configurations
        - Optional text labels showing index values (for smaller matrices)
        - A legend mapping colors to permittivity configurations
        - Proper axis labels and grid settings

    Raises:
        AssertionError: If device_matrix_indices is not 3-dimensional.
    """
    assert device_matrix_indices.ndim == 3
    device_matrix_indices = device_matrix_indices.astype(np.int32)
    fig, ax = cast(tuple[Figure, Axes], plt.subplots(figsize=(12, 12)))
    image_palette = sns.color_palette("YlOrBr", as_cmap=True)
    if device_matrix_indices.shape[-1] == 1:
        device_matrix_indices = device_matrix_indices[..., 0]
        matrix_inverse_permittivity_indices_sorted = device_matrix_indices
        indices = np.unique(device_matrix_indices)
    else:
        air_index = None
        for i, cfg in enumerate(permittivity_configs):
            if cfg[0] == "Air":
                air_index = i
                break
        device_matrix_indices_flat = np.reshape(device_matrix_indices, (-1, device_matrix_indices.shape[-1]))
        indices = np.unique(
            device_matrix_indices_flat,
            axis=0,
        )
        air_count = np.count_nonzero(indices == air_index, axis=-1)
        air_count_argsort = np.argsort(air_count)
        indices_sorted = indices[air_count_argsort]
        matrix_inverse_permittivity_indices_sorted = np.array(
            [
                np.where((indices_sorted == device_matrix_indices_flat[i]).all(axis=1))[0][0]
                for i in range(device_matrix_indices_flat.shape[0])
            ]
        ).reshape(device_matrix_indices.shape[:-1])

    cax = ax.imshow(
        matrix_inverse_permittivity_indices_sorted.T,
        cmap=image_palette,
        aspect="auto",
        origin="lower",
    )
    ax.set_xlabel("X Axis")
    ax.set_ylabel("Y Axis")
    height, width = (
        device_matrix_indices.shape[0],
        device_matrix_indices.shape[1],
    )
    if height * width < 1500:
        for y in range(height):
            for x in range(width):
                value = matrix_inverse_permittivity_indices_sorted[x, y]
                text_color = "w" if cax.norm(value) > 0.5 else "k"  # type: ignore
                ax.text(x, y, str(int(value)), ha="center", va="center", color=text_color)
    assert cax.cmap is not None
    if indices.ndim == 1:
        legend_elements = [
            Patch(
                facecolor=cax.cmap(cax.norm(int(i))),
                label=f"({i}) {permittivity_configs[int(i)][0]}",
            )
            for i in indices
        ]
    else:
        legend_elements = [
            Patch(
                facecolor=cax.cmap(cax.norm(int(i))),
                label=f"({i}) " + "|".join([permittivity_configs[int(e)][0] for e in indices[i]]),
            )
            for i in np.unique(matrix_inverse_permittivity_indices_sorted)
        ]

    legend_cols = max(1, int(len(legend_elements) / height))
    if len(legend_elements) < 100:
        ax.legend(
            handles=legend_elements,
            loc="center left",
            frameon=False,
            bbox_to_anchor=(1, 0.5),
            ncols=legend_cols,
        )
    ax.set_aspect("equal")
    for line in ax.get_xgridlines() + ax.get_ygridlines():
        line.set_alpha(0.0)
    return fig

fdtdx.core.plotting.utils.plot_filled_std_curves(x, mean, color, lighter_color, std=None, upper=None, lower=None, linestyle='-', marker=None, label=None, alpha=0.2, min_val=None, max_val=None)

Plots a curve with filled standard deviation or confidence intervals.

Creates a plot showing a mean curve with a filled region representing either standard deviation bounds or custom upper/lower bounds. The filled region uses a lighter color with transparency.

The function supports two modes: 1. Standard deviation mode: Provide std parameter to create bounds at mean ± std 2. Custom bounds mode: Provide explicit upper and lower bound arrays

The plotted curves can be optionally clipped to minimum/maximum values.

Parameters:

Name Type Description Default
x ndarray

Array of x-axis values.

required
mean ndarray

Array of y-axis values for the mean curve.

required
color Any

Color for the mean curve line.

required
lighter_color Any

Color for the filled standard deviation region.

required
std Optional[ndarray]

Optional standard deviation array. If provided, used to compute upper/lower bounds.

None
upper Optional[ndarray]

Optional array of upper bound values. Must be provided with lower.

None
lower Optional[ndarray]

Optional array of lower bound values. Must be provided with upper.

None
linestyle str

Style of the mean curve line. Defaults to solid line "-".

'-'
marker Optional[str]

Optional marker style for data points on the mean curve.

None
label Optional[str]

Optional label for the plot legend.

None
alpha float

Transparency value for the filled region. Defaults to 0.2.

0.2
min_val Optional[float]

Optional minimum value to clip the curves.

None
max_val Optional[float]

Optional maximum value to clip the curves.

None

Raises:

Type Description
ValueError

If neither std nor both upper/lower bounds are provided, or if only one of upper/lower is provided.

Example

x = np.linspace(0, 10, 100) mean = np.sin(x) std = 0.1 * np.ones_like(x) plot_filled_std_curves(x, mean, 'blue', 'lightblue', std=std)

Source code in src/fdtdx/core/plotting/utils.py
def plot_filled_std_curves(
    x: np.ndarray,
    mean: np.ndarray,
    color: Any,
    lighter_color: Any,
    std: Optional[np.ndarray] = None,
    upper: Optional[np.ndarray] = None,
    lower: Optional[np.ndarray] = None,
    linestyle: str = "-",
    marker: Optional[str] = None,
    label: Optional[str] = None,
    alpha: float = 0.2,
    min_val: Optional[float] = None,
    max_val: Optional[float] = None,
):
    """Plots a curve with filled standard deviation or confidence intervals.

    Creates a plot showing a mean curve with a filled region representing either
    standard deviation bounds or custom upper/lower bounds. The filled region uses
    a lighter color with transparency.

    The function supports two modes:
    1. Standard deviation mode: Provide std parameter to create bounds at mean ± std
    2. Custom bounds mode: Provide explicit upper and lower bound arrays

    The plotted curves can be optionally clipped to minimum/maximum values.

    Args:
        x: Array of x-axis values.
        mean: Array of y-axis values for the mean curve.
        color: Color for the mean curve line.
        lighter_color: Color for the filled standard deviation region.
        std: Optional standard deviation array. If provided, used to compute upper/lower bounds.
        upper: Optional array of upper bound values. Must be provided with lower.
        lower: Optional array of lower bound values. Must be provided with upper.
        linestyle: Style of the mean curve line. Defaults to solid line "-".
        marker: Optional marker style for data points on the mean curve.
        label: Optional label for the plot legend.
        alpha: Transparency value for the filled region. Defaults to 0.2.
        min_val: Optional minimum value to clip the curves.
        max_val: Optional maximum value to clip the curves.

    Raises:
        ValueError: If neither std nor both upper/lower bounds are provided, or if only
            one of upper/lower is provided.

    Example:
        >>> x = np.linspace(0, 10, 100)
        >>> mean = np.sin(x)
        >>> std = 0.1 * np.ones_like(x)
        >>> plot_filled_std_curves(x, mean, 'blue', 'lightblue', std=std)
    """
    if (upper is None) != (lower is None):
        raise ValueError("Need to specify both upper and lower")
    if (std is None) == (upper is None):
        raise ValueError("Need to specify either std or upper/lower")
    if std is not None:
        upper = mean + std
        lower = mean - std
    if min_val is not None and lower is not None and upper is not None:
        mean = np.maximum(mean, min_val)
        lower = np.maximum(lower, min_val)
        upper = np.maximum(upper, min_val)
    if max_val is not None and lower is not None and upper is not None:
        mean = np.minimum(mean, max_val)
        upper = np.minimum(upper, max_val)
        lower = np.minimum(lower, max_val)
    if upper is None or lower is None:
        raise Exception("This should never happen")
    plt.plot(x, upper, color=lighter_color, alpha=alpha)
    plt.plot(x, lower, color=lighter_color, alpha=alpha)
    plt.fill_between(x, lower, upper, color=lighter_color, alpha=alpha)
    plt.plot(x, mean, color=color, label=label, linestyle=linestyle, marker=marker, markersize=4)

fdtdx.core.plotting.debug.generate_unique_filename(prefix='file', extension=None)

Generate a unique filename using timestamp and UUID.

Parameters:

prefix : str, optional Prefix for the filename extension : str, optional File extension (without dot)

Returns:

str : Unique filename

Source code in src/fdtdx/core/plotting/debug.py
def generate_unique_filename(prefix="file", extension=None):
    """
    Generate a unique filename using timestamp and UUID.

    Parameters:
    -----------
    prefix : str, optional
        Prefix for the filename
    extension : str, optional
        File extension (without dot)

    Returns:
    --------
    str : Unique filename
    """
    # Get current timestamp
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    # Add a short UUID segment for extra uniqueness
    unique_id = str(uuid.uuid4())[:8]

    # Combine components
    if extension:
        return f"{prefix}_{timestamp}_{unique_id}.{extension}"
    return f"{prefix}_{timestamp}_{unique_id}"

fdtdx.core.plotting.debug.debug_plot_2d(array, cmap='viridis', show_values=False, tmp_dir='outputs/tmp/debug', filename=None)

Creates a debug visualization of a 2D array and saves it to disk.

This function is useful for debugging array values during development and testing. It creates a heatmap visualization with optional value annotations and automatically saves it to a specified directory.

Parameters:

Name Type Description Default
array ndarray | Array

The 2D array to visualize. Can be either a numpy array or JAX array.

required
cmap str

The matplotlib colormap to use for the visualization. Defaults to "viridis".

'viridis'
show_values bool

If True, overlays the numerical values on each cell. Defaults to False.

False
tmp_dir str | Path

Directory where the plot will be saved. Will be created if it doesn't exist. Defaults to "outputs/tmp/debug".

'outputs/tmp/debug'
filename str | None

Name for the output file. If None, generates a unique name using timestamp. The .png extension will be added automatically.

None
The resulting plot includes
  • A heatmap visualization of the array values
  • A colorbar showing the value scale
  • Grid lines for better readability
  • Axis labels indicating array dimensions
  • Optional numerical value annotations in each cell
Source code in src/fdtdx/core/plotting/debug.py
def debug_plot_2d(
    array: np.ndarray | jax.Array,
    cmap: str = "viridis",
    show_values: bool = False,
    tmp_dir: str | Path = "outputs/tmp/debug",
    filename: str | None = None,
) -> None:
    """Creates a debug visualization of a 2D array and saves it to disk.

    This function is useful for debugging array values during development and testing.
    It creates a heatmap visualization with optional value annotations and automatically
    saves it to a specified directory.

    Args:
        array: The 2D array to visualize. Can be either a numpy array or JAX array.
        cmap: The matplotlib colormap to use for the visualization. Defaults to "viridis".
        show_values: If True, overlays the numerical values on each cell. Defaults to False.
        tmp_dir: Directory where the plot will be saved. Will be created if it doesn't exist.
            Defaults to "outputs/tmp/debug".
        filename: Name for the output file. If None, generates a unique name using timestamp.
            The .png extension will be added automatically.

    The resulting plot includes:
        - A heatmap visualization of the array values
        - A colorbar showing the value scale
        - Grid lines for better readability
        - Axis labels indicating array dimensions
        - Optional numerical value annotations in each cell
    """
    if not isinstance(array, np.ndarray):
        array = np.asarray(array)

    if filename is None:
        filename = generate_unique_filename("debug", "png")

    plt.figure(figsize=(10, 8))

    # Create heatmap
    im = plt.imshow(
        array.T,
        cmap=cmap,
        origin="lower",
        aspect="equal",
    )

    plt.colorbar(im)
    plt.xlabel("First Array axis (x)")
    plt.ylabel("Second Array axis (y)")

    # Show values in cells if requested
    if show_values:
        for i in range(array.shape[0]):
            for j in range(array.shape[1]):
                text_color = "white" if im.norm(array[i, j]) > 0.5 else "black"  # type: ignore
                plt.text(j, i, f"{array[i, j]:.2f}", ha="center", va="center", color=text_color)

    if isinstance(tmp_dir, str):
        tmp_dir = Path(tmp_dir)

    plt.grid(True)

    plt.savefig(tmp_dir / filename, dpi=400, bbox_inches="tight")

fdtdx.core.plotting.colors

Color constants for visualization and plotting.

This module provides a collection of predefined RGB color tuples normalized to the range [0,1]. Colors are organized into categories: primary/bright colors, grayscale, and earth tones. These colors are designed to provide a consistent and visually appealing palette for plotting and visualization tasks throughout the FDTDX framework.

Each color is defined as a tuple of (red, green, blue) values normalized to [0,1]. The normalization is done by dividing 8-bit RGB values (0-255) by 255.

Categories
  • Bright and primary colors: Vibrant colors for emphasis and contrast
  • Grayscale colors: Various shades of gray for backgrounds and subtle elements
  • Earth tones: Natural, warm colors for material representations
Example

import matplotlib.pyplot as plt plt.plot([0, 1], [0, 1], color=GREEN) # Plot a line in vibrant green plt.fill_between([0, 1], [0, 1], color=LIGHT_BLUE, alpha=0.3) # Fill with pale blue