import einops
import matplotlib
import matplotlib.colors as mcolors
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F

import wandb
from adversarial_superposition.constants import BASE_COLOURS
from adversarial_superposition.toy_models.utils.utils import plot_features_in_2d


def activations(model, features, instance_idx: int):
    out = einops.einsum(
        features,
        model.W[instance_idx, ...],
        "... features, hidden features -> ... hidden",
    )
    print(out.shape)

    if model.bias:
        out = out + model.b_enc[instance_idx, ...]

    if model.privileged:
        activations = F.relu(out)
    else:
        activations = out

    return activations.squeeze(0)


def classify(model, activations, instance_idx: int):
    """Take the model activations at the hidden layer and project them to the output layer."""

    # Ensure activations are on the same device as the model weights
    activations = activations.to(model.W_projection.device)

    classification = einops.einsum(
        activations,
        model.W_projection[instance_idx, ...],
        "... features, n_classes features -> ... n_classes",
    )

    if model.projection_bias:
        classification = classification + model.b_projection[instance_idx, ...]

    if model.privileged_out and classification.numel() > 0:
        classification = F.relu(classification)

    return classification


def plot_decision_boundary(
    model,
    acts,
    batch_labels,
    i_idx,
    log_to_wandb=True,
    tag="decision_boundary",
    additional_points=None,
    graph_axes=(-1.0, 1.0, 1.0, 1.0),
    ax: plt.Axes | None = None,
    title: str | None = None,
    create_legend: bool = True,
    color_palette: list | None = None,  # Optional palette override
):
    """Plot the decision boundary of a classification model in 2D activation space.

    This function visualizes how a model separates different classes in its 2D activation space by:
    1. Creating a fine mesh grid over the activation space
    2. Getting model predictions for each point in the grid
    3. Plotting the decision boundaries between classes as filled contours
    4. Optionally overlaying additional points in the activation space

    Args:
        model: The classification model to analyze
        acts: Tensor of shape (batch_size, 2) containing 2D activation values - used for grid boundaries
        batch_labels: Tensor of shape (batch_size,) containing true class labels
        i_idx: Integer index specifying which instance of the model to analyze
        log_to_wandb (bool, optional): Whether to log the plot to Weights & Biases. Defaults to True.
        tag (str, optional): Tag/name for the generated plot. Defaults to "decision_boundary".
        additional_points: Tensor or numpy array of shape (n_points, 2) containing additional points
                          to plot in activation space. Defaults to None.
        graph_axes (tuple, optional): (x_min, x_max, y_min, y_max) for the grid boundaries. If None, determined from acts.
        ax (matplotlib.axes.Axes, optional): The axes to plot on. If None, a new figure and axes are created. Defaults to None.
        title (str, optional): Title for the plot. If None, a default title is used.
        create_legend (bool, optional): If False and ax is provided, returns legend handles/labels instead of creating legend. Defaults to True.
        color_palette (list, optional): List of hex color codes. If None, uses tab10. Defaults to None.

    Returns:
        matplotlib.axes.Axes or tuple(matplotlib.axes.Axes, list, list):
            The axes object containing the plot. If ax is provided and create_legend is False,
            returns (ax, legend_handles, legend_labels).

    The resulting plot shows:
    - Colored regions indicating which class the model predicts for each point in space
    - A colorbar mapping colors to class indices
    - Any additional points highlighted in the activation space (if provided)

    The plot is either logged to Weights & Biases or saved locally based on log_to_wandb and whether an ax is provided.
    """
    # --- Initial Checks ---
    if acts is None or acts.numel() == 0 or acts.shape[0] < 2:
        print(
            "Warning: 'acts' tensor is empty or too small. Cannot plot decision boundary."
        )
        if ax is None:
            ax = plt.subplots(figsize=(8, 8))[1]
        ax.text(0.5, 0.5, "Insufficient data for boundary", ha="center", va="center")
        if not create_figure and not create_legend:
            return ax, [], []
        else:
            return ax

    if batch_labels is None or batch_labels.numel() == 0:
        print(
            "Warning: 'batch_labels' tensor is empty. Cannot determine classes for boundary."
        )
        if ax is None:
            ax = plt.subplots(figsize=(8, 8))[1]
        ax.text(0.5, 0.5, "Missing labels for boundary", ha="center", va="center")
        if not create_figure and not create_legend:
            return ax, [], []
        else:
            return ax
    # --------------------

    create_figure = ax is None

    if create_figure:
        fig, ax = plt.subplots(figsize=(10, 8))
    else:
        fig = ax.figure  # Get the figure from the axes

    # Step 1: Create a mesh grid in the 2D activation space
    if not graph_axes:
        # Calculate limits from acts
        try:
            # Check for NaNs/Infs in acts first
            if torch.isnan(acts).any() or torch.isinf(acts).any():
                print(
                    "Warning: acts contains NaN/Inf values. Using default axes limits for decision boundary."
                )
                x_min, x_max, y_min, y_max = -2.0, 2.0, -2.0, 2.0  # Default limits
            else:
                x_min = acts[:, 0].min().item() - 0.5
                x_max = acts[:, 0].max().item() + 0.5
                y_min = acts[:, 1].min().item() - 0.5
                y_max = acts[:, 1].max().item() + 0.5

                # Validate calculated limits
                if (
                    not (
                        np.isfinite(x_min)
                        and np.isfinite(x_max)
                        and np.isfinite(y_min)
                        and np.isfinite(y_max)
                    )
                    or x_min >= x_max
                    or y_min >= y_max
                ):
                    print(
                        f"Warning: Invalid calculated axes limits (xmin={x_min}, xmax={x_max}, ymin={y_min}, ymax={y_max}). Using default limits [-2, 2]."
                    )
                    x_min, x_max, y_min, y_max = -2.0, 2.0, -2.0, 2.0  # Default limits
        except Exception as e:
            print(
                f"Warning: Error calculating axes limits from acts: {e}. Using default limits [-2, 2]."
            )
            x_min, x_max, y_min, y_max = -2.0, 2.0, -2.0, 2.0  # Default limits
    else:
        x_min, x_max, y_min, y_max = graph_axes

        # Also validate graph_axes if provided
        if (
            not (
                np.isfinite(x_min)
                and np.isfinite(x_max)
                and np.isfinite(y_min)
                and np.isfinite(y_max)
            )
            or x_min >= x_max
            or y_min >= y_max
        ):
            print(
                f"Warning: Provided graph_axes limits are invalid (xmin={x_min}, xmax={x_max}, ymin={y_min}, ymax={y_max}). Using default limits [-2, 2]."
            )
            x_min, x_max, y_min, y_max = -2.0, 2.0, -2.0, 2.0  # Default limits

    h = 0.01  # grid step size
    print(
        f"[Debug plot_decision_boundary] Meshgrid limits: x=[{x_min:.2f}, {x_max:.2f}], y=[{y_min:.2f}, {y_max:.2f}]"
    )
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    print(
        f"[Debug plot_decision_boundary] Meshgrid shape: xx={xx.shape}, yy={yy.shape}"
    )

    # Step 2: Create all grid points and convert to torch tensor
    grid_points_np = np.c_[xx.ravel(), yy.ravel()]
    grid_points = torch.tensor(grid_points_np, dtype=torch.float32)  # Convert to tensor

    # Move to CPU explicitly to avoid MPS backend issues
    grid_points = grid_points.cpu()

    if grid_points.numel() == 0:
        print("Warning: Grid points tensor is empty. Cannot calculate classifications.")
        # Handle empty grid case gracefully
        if ax is None:
            ax = plt.subplots(figsize=(8, 8))[1]
        ax.text(0.5, 0.5, "Cannot generate boundary grid", ha="center", va="center")
        if not create_figure and not create_legend:
            return ax, [], []
        else:
            return ax

    # Step 3: Get logits for each grid point
    with torch.no_grad():  # Ensure no gradients are computed
        logits = classify(model, grid_points, i_idx)

    # Convert logits to class predictions
    Z = torch.argmax(logits, dim=1).cpu().numpy()
    print(f"[Debug plot_decision_boundary] Z shape before reshape: {Z.shape}")
    # Z += 6 # Removed this mysterious shift - assuming classes are 0-indexed
    # Reshape Z only if it's not empty and grid is valid
    if Z.size > 0 and xx.size > 0 and yy.size > 0:
        Z = Z.reshape(xx.shape)
    else:
        print("Warning: Cannot reshape Z, either Z or grid was empty.")
        # Handle empty Z/grid case gracefully
        if ax is None:
            ax = plt.subplots(figsize=(8, 8))[1]
        ax.text(0.5, 0.5, "Cannot classify grid points", ha="center", va="center")
        if not create_figure and not create_legend:
            return ax, [], []
        else:
            return ax

    # Step 4: Plot decision boundary

    # Create colormaps
    n_classes = len(torch.unique(batch_labels))

    # Use provided palette or default to tab10
    if color_palette is None:
        cmap = plt.get_cmap("tab10")
        eff_palette = [
            matplotlib.colors.rgb2hex(cmap(i)) for i in range(min(n_classes, 10))
        ]
        # Cycle if more than 10 classes needed
        if n_classes > 10:
            eff_palette = [eff_palette[i % 10] for i in range(n_classes)]
    else:
        # Use provided palette, cycling if necessary
        eff_palette = [color_palette[i % len(color_palette)] for i in range(n_classes)]

    cmap_bold = matplotlib.colors.ListedColormap(eff_palette)

    # Plot the decision boundary contours
    contour = ax.contourf(
        xx, yy, Z, alpha=0.4, cmap=cmap_bold, levels=np.arange(n_classes + 1) - 0.5
    )

    # Create legend elements for the classes
    class_handles = []
    class_labels = []
    for class_idx in range(n_classes):
        # Use actual class index, assuming they are 0 to n_classes-1
        color_idx = class_idx % len(eff_palette)
        patch = plt.Rectangle(
            (0, 0), 1, 1, color=eff_palette[color_idx], alpha=0.4
        )  # Match contour alpha
        class_handles.append(patch)
        class_labels.append(f"Class {class_idx + 1}")

    # Plot additional points if provided (but don't add to legend handles here)
    point_handles = []
    point_labels = []
    if additional_points is not None:
        if isinstance(additional_points, torch.Tensor):
            additional_points = additional_points.detach().cpu().numpy()
        if additional_points.ndim == 1:
            additional_points = additional_points.reshape(1, -1)

        # Plot points, maybe with a default label
        points_scatter = ax.scatter(
            additional_points[:, 0],
            additional_points[:, 1],
            c="red",
            marker="*",
            s=200,
            edgecolors="black",
            linewidths=1.5,
            zorder=5,
        )

        # Only add a single legend entry for all additional points if requested
        # point_handles.append(points_scatter)
        # point_labels.append('Additional Points') # Or pass a specific label

    # Combine handles for legend creation if requested
    all_handles = class_handles + point_handles
    all_labels = class_labels + point_labels

    # Create legend only if requested and if we are creating the figure OR if ax was provided
    if create_legend and all_handles:
        legend_title = "Classes" if not point_handles else "Classes & Points"
        ax.legend(
            handles=all_handles,
            labels=all_labels,
            title=legend_title,
            bbox_to_anchor=(1.05, 1),
            loc="upper left",
        )

    ax.set_xlim(xx.min(), xx.max())
    ax.set_ylim(yy.min(), yy.max())
    plot_title = title if title else f"Decision Boundary (Instance {i_idx})"
    ax.set_title(plot_title, fontsize=24)  # Increased from 20
    ax.set_xlabel("Dimension 1", fontsize=22)
    ax.set_ylabel("Dimension 2", fontsize=22)
    ax.tick_params(axis="both", which="major", labelsize=16)  # Increased tick labelsize
    ax.set_aspect("equal", adjustable="box")

    # Adjust layout only if a new figure was created
    if create_figure:
        # Adjust layout only if we also created the legend
        if create_legend and all_handles:
            fig.tight_layout(
                rect=[0, 0, 0.85, 1]
            )  # Adjust rect to make space for legend
        else:
            fig.tight_layout()

    # Logging/Saving behavior: Only log if creating a new figure and log_to_wandb is True
    if create_figure and log_to_wandb:
        print(f"Logging {tag} figure")
        wandb.log({tag: wandb.Image(fig)}, commit=True)
        plt.close(fig)  # Close the figure after logging
    elif create_figure:  # If created figure but not logging, show it
        plt.show()
        plt.close(fig)

    # Return behavior depends on whether ax was provided and legend requested
    if not create_figure and not create_legend:
        # Return axis and the handles/labels needed to create legend externally
        return ax, class_handles, class_labels
    else:
        # Otherwise, just return the axis
        return ax


def visualize_feature_vector_by_class(
    feature_vector,
    n_classes,
    log_to_wandb=False,
    tag="feature_vector_diagram",
    overlay_vectors=None,
    ax=None,
    title=None,
    show_class_label: bool = True,  # New flag
    color_palette=None,  # Optional color palette
    class_feature_deltas: np.ndarray | None = None,  # For displaying deltas on bars
    show_individual_feature_values: bool = True,  # New: Control individual value labels
    y_margin_factor: float = 0.25,  # New: Control vertical margin size
):
    """
    Create a diagram-style visualization of a single feature vector, showing class contributions.

    Args:
        feature_vector (torch.Tensor or np.ndarray): Shape (n_features,)
        n_classes (int): Number of classes.
        log_to_wandb (bool, optional): Log plot to W&B. Defaults to False.
        tag (str, optional): Tag for W&B logging. Defaults to "feature_vector_diagram".
        overlay_vectors (torch.Tensor or np.ndarray, optional): Shape (n_vectors, 2) [index, magnitude].
        ax (matplotlib.axes.Axes, optional): Axes to plot on. If None, creates new figure/axes. Defaults to None.
        title (str, optional): Plot title. Defaults to a generic title.
        show_class_label (bool, optional): Whether to show the 'Class X / Sum: Y' text label. Defaults to True.
        color_palette (list, optional): List of colors to use for classes. Defaults to BASE_COLOURS.
        class_feature_deltas (np.ndarray, optional): Array of shape (n_classes,) containing sum of deltas for each class.
                                                   If provided, these are displayed near class bars. Defaults to None.
        show_individual_feature_values (bool, optional): Whether to show fine-grained labels for each feature value
                                                       within a class bar (if features_per_class is small). Defaults to True.
        y_margin_factor (float, optional): Factor to control vertical margin size. Defaults to 0.25.

    Returns:
        matplotlib.axes.Axes: The axes object containing the plot.

    Assumes features are grouped by class, equal features per class.
    """
    create_figure = ax is None

    # --- Ensure Input is NumPy ---
    if isinstance(feature_vector, torch.Tensor):
        feature_vector_np = feature_vector.detach().cpu().numpy()
    elif isinstance(feature_vector, np.ndarray):
        feature_vector_np = feature_vector
    else:
        raise TypeError("feature_vector must be a torch.Tensor or np.ndarray")

    if overlay_vectors is not None:
        if isinstance(overlay_vectors, torch.Tensor):
            overlay_vectors_np = overlay_vectors.detach().cpu().numpy()
        elif isinstance(overlay_vectors, np.ndarray):
            overlay_vectors_np = overlay_vectors
        else:
            print("Warning: overlay_vectors type not recognized, ignoring.")
            overlay_vectors_np = None
    else:
        overlay_vectors_np = None
    # ----------------------------

    feature_vector_np = feature_vector_np.flatten()
    n_features = len(feature_vector_np)

    if n_features == 0:
        print("Warning: feature_vector is empty. Cannot plot.")
        if create_figure:
            plt.close(fig)  # Close if we created one
        return ax  # Return the (potentially empty) axes

    if n_classes <= 0:
        print("Warning: n_classes must be positive. Cannot plot.")
        if create_figure:
            plt.close(fig)
        return ax

    if n_features % n_classes != 0:
        print(
            f"Warning: n_features ({n_features}) is not divisible by n_classes ({n_classes})"
        )

    features_per_class = n_features // n_classes

    if create_figure:
        fig_width = max(8, n_features * 0.5)
        fig, ax = plt.subplots(figsize=(fig_width, 4))
    else:
        fig = ax.figure

    # Use provided palette or default to BASE_COLOURS
    if color_palette is None:
        color_palette = BASE_COLOURS
    # Ensure we have enough colors, cycle if necessary
    class_colors = [color_palette[i % len(color_palette)] for i in range(n_classes)]

    max_abs_sum = 0  # Use max absolute sum for centering considerations
    min_val = np.inf
    max_val = -np.inf
    vector_max_abs_mag = 0

    # Calculate scale bounds first using NumPy array
    for i in range(n_classes):
        start_idx = i * features_per_class
        end_idx = min((i + 1) * features_per_class, n_features)
        if start_idx >= n_features:
            break
        class_values_np = feature_vector_np[start_idx:end_idx]
        if class_values_np.size > 0:
            class_sum = np.sum(class_values_np)
            max_abs_sum = max(max_abs_sum, abs(class_sum))
            min_val = min(min_val, np.min(class_values_np))
            max_val = max(max_val, np.max(class_values_np))

    # Check overlay vectors for max magnitude - Keep for y-limit calculation for now
    if (
        overlay_vectors_np is not None
        and overlay_vectors_np.ndim == 2
        and overlay_vectors_np.shape[1] == 2
        and overlay_vectors_np.size > 0
    ):
        vector_magnitudes = overlay_vectors_np[:, 1]
        vector_max_abs_mag = max(vector_max_abs_mag, np.max(np.abs(vector_magnitudes)))
        # Also consider the final position after overlay for y-limits
        overlay_indices = overlay_vectors_np[:, 0].astype(int)
        valid_indices = (overlay_indices >= 0) & (overlay_indices < n_features)
        if np.any(valid_indices):
            original_vals = feature_vector_np[overlay_indices[valid_indices]]
            final_vals = original_vals + vector_magnitudes[valid_indices]
            min_val = min(min_val, np.min(final_vals))
            max_val = max(max_val, np.max(final_vals))

    # Check if min/max are still infinite (e.g., empty feature vector)
    if min_val == np.inf or max_val == -np.inf:
        min_val, max_val = -1.0, 1.0  # Default limits
        print(
            "Warning: Could not determine data range for y-axis. Using default [-1, 1]."
        )

    # Plotting loop
    for i in range(n_classes):
        start_idx = i * features_per_class
        end_idx = min((i + 1) * features_per_class, n_features)
        if start_idx >= n_features:
            break

        class_values_np = feature_vector_np[start_idx:end_idx]
        x_positions = np.arange(start_idx, end_idx)

        if class_values_np.size == 0:
            continue

        class_sum = np.sum(class_values_np)

        x_min_rect, x_max_rect = start_idx - 0.5, end_idx - 0.5
        y_margin_for_text = max(
            0.1, max_abs_sum * y_margin_factor
        )  # Use y_margin_factor here

        zero_bar_thickness_factor = 0.20
        zero_bar_half_height = (y_margin_for_text * zero_bar_thickness_factor) / 2.0

        # Define bar extents (y_min_bar, y_max_bar) for the actual colored rectangle
        if abs(class_sum) < 1e-5:  # Threshold for "zero" sum - Center bar around 0
            y_min_bar = -zero_bar_half_height
            y_max_bar = zero_bar_half_height
        elif class_sum > 0:  # Strictly positive sum
            y_min_bar = -zero_bar_half_height
            y_max_bar = class_sum
        else:  # class_sum < 0, strictly negative sum
            y_min_bar = class_sum
            y_max_bar = zero_bar_half_height

        rect_color_idx = i % len(class_colors)
        rect = plt.Rectangle(
            (x_min_rect, y_min_bar),
            x_max_rect - x_min_rect,
            y_max_bar - y_min_bar,
            fill=True,
            color=class_colors[rect_color_idx],
            alpha=0.6,
            zorder=1,
            linewidth=1.5,
            edgecolor=class_colors[rect_color_idx],
        )
        ax.add_patch(rect)

        # --- Text Label Positioning ---
        sum_label_offset = (
            y_margin_for_text * 0.15
        )  # Offset for sum label from bar edge
        delta_label_stack_offset = (
            y_margin_for_text * 0.20
        )  # Additional offset for delta if sum label exists

        primary_label_y_pos = 0
        primary_va_align = "bottom"

        # Determine base position for the primary label (either sum or delta if sum is hidden)
        if abs(class_sum) < 1e-5:  # Centered zero bar
            primary_label_y_pos = y_max_bar + sum_label_offset
            primary_va_align = "bottom"
        elif class_sum > 0:  # Positive sum, label above bar's top (class_sum)
            primary_label_y_pos = y_max_bar + sum_label_offset
            primary_va_align = "bottom"
        else:  # Negative sum, label below bar's bottom (class_sum)
            primary_label_y_pos = y_min_bar - sum_label_offset
            primary_va_align = "top"

        # Conditionally add class/sum label if requested AND if deltas are not meant to replace it
        if show_class_label:
            sum_text_obj = ax.text(
                (x_min_rect + x_max_rect) / 2,
                primary_label_y_pos,
                f"Class {i + 1}\\nSum: {class_sum:.2f}",
                horizontalalignment="center",
                verticalalignment=primary_va_align,
                fontsize=16,
                fontweight="bold",
                color="black",
                zorder=5,
            )

        # Display class feature deltas if provided
        if class_feature_deltas is not None and i < len(class_feature_deltas):
            delta_val = class_feature_deltas[i]
            delta_text_y = primary_label_y_pos  # Default to sum label position
            va_delta = primary_va_align

            if (
                show_class_label
            ):  # If sum label is also shown, stack delta on top/bottom of it
                if (
                    primary_va_align == "bottom"
                ):  # sum label is above bar or at bottom of text block
                    delta_text_y = primary_label_y_pos + delta_label_stack_offset
                else:  # sum label is below bar or at top of text block
                    delta_text_y = primary_label_y_pos - delta_label_stack_offset
            # If show_class_label is False, delta_text_y and va_delta are already set to primary positions

            delta_text_obj = ax.text(
                (x_min_rect + x_max_rect) / 2,
                delta_text_y,
                f"{delta_val:+.2f}",
                horizontalalignment="center",
                verticalalignment=va_delta,
                fontsize=23,
                fontweight="normal",
                color="black",
                zorder=6,
            )  # Increased fontsize from 22 to 26

        scatter_color_idx = i % len(class_colors)
        ax.scatter(
            x_positions,
            class_values_np,
            color=class_colors[scatter_color_idx],
            s=60,
            zorder=3,
            edgecolors="black",
            linewidth=0.5,
            alpha=0.85,
        )

        if show_individual_feature_values and features_per_class <= 8:
            for x, y in zip(x_positions, class_values_np):
                ax.text(
                    x,
                    y + (0.05 if y >= 0 else -0.05),
                    f"{y:.2f}",
                    horizontalalignment="center",
                    verticalalignment="bottom" if y >= 0 else "top",
                    fontsize=8,
                    zorder=4,
                )

    # --- Customization ---
    ax.grid(True, linestyle="--", alpha=0.5, zorder=0)
    ax.set_xticks(np.arange(n_features))
    ax.set_xticklabels(
        [f"{i}" for i in range(n_features)], fontsize=16
    )  # Increased tick label size
    ax.tick_params(
        axis="y", which="major", labelsize=16
    )  # Increased y-axis tick labelsize
    ax.set_xlabel("Feature Index", fontsize=28)  # Increased x-axis label fontsize
    ax.set_ylabel("Feature Value", fontsize=28)  # Increased y-axis label fontsize
    plot_title = title if title else "Feature Vector Values by Class"
    ax.set_title(plot_title, fontsize=34)  # Increased from 20

    ax.axhline(y=0, color="black", linestyle="-", linewidth=0.7, alpha=0.5, zorder=0)

    # Set y-axis limits dynamically based on content
    y_range = max_val - min_val
    y_margin = max(y_range * y_margin_factor, 0.1)  # Use y_margin_factor here

    # Determine min/max extent needed for text labels relative to class_sum
    max_text_y = -np.inf
    min_text_y = np.inf
    text_total_padding_factor = (
        0.5  # Estimate of y_margin_for_text factor needed for all text
    )

    for i in range(n_classes):
        start_idx = i * features_per_class
        end_idx = min((i + 1) * features_per_class, n_features)
        if start_idx >= n_features:
            break
        current_class_sum = np.sum(feature_vector_np[start_idx:end_idx])

        # Estimate positive extent for labels
        est_pos_label_extent = (
            current_class_sum + y_margin_for_text * text_total_padding_factor
        )
        max_text_y = max(max_text_y, est_pos_label_extent)
        # Estimate negative extent for labels
        est_neg_label_extent = (
            current_class_sum - y_margin_for_text * text_total_padding_factor
        )
        min_text_y = min(min_text_y, est_neg_label_extent)

    # If no classes, use default text extents to avoid issues with -inf/inf
    if n_classes == 0:
        max_text_y = 0.5
        min_text_y = -0.5

    # Determine bounds based on data points AND text label extents
    # Max of (scatter_max, text_max_y_estimate) + some_margin
    # Min of (scatter_min, text_min_y_estimate) - some_margin
    upper_bound = max(max_val, max_text_y) + y_margin
    lower_bound = min(min_val, min_text_y) - y_margin

    # Ensure the range is not collapsed
    min_range = 0.2  # Minimum y-axis range
    current_range = upper_bound - lower_bound
    if current_range < min_range:
        center = (upper_bound + lower_bound) / 2
        upper_bound = center + min_range / 2
        lower_bound = center - min_range / 2

    ax.set_ylim(lower_bound, upper_bound)

    # --- Layout & Output ---
    if create_figure:
        plt.tight_layout()

    if create_figure and log_to_wandb:
        print(f"Logging {tag} figure")
        wandb.log({tag: wandb.Image(fig)}, commit=True)
    elif create_figure:
        plt.show()
        plt.close(fig)

    return ax


def visualize_feature_matrix_by_class(
    feature_matrix,
    n_classes,
    max_samples=5,
    log_to_wandb=False,
    tag="feature_matrix_diagram",
    overlay_vectors=None,
):
    """
    Create a diagram-style visualization of multiple feature vectors, clearly showing which values belong to which class.

    Args:
        feature_matrix (torch.Tensor or np.ndarray): Multiple feature vectors to visualize, shape (batch_size, n_features)
        n_classes (int): Number of classes in the data
        max_samples (int, optional): Maximum number of samples to visualize. Defaults to 5.
        log_to_wandb (bool, optional): Whether to log the plot to Weights & Biases. Defaults to True.
        tag (str, optional): Tag/name for the generated plot. Defaults to "feature_matrix_diagram".
        overlay_vectors (torch.Tensor or np.ndarray, optional): 2D vectors to overlay on each sample's plot,
                                                               shape (n_vectors, 2) or list of such tensors/arrays.
                                                               Each tensor/array has columns [x, y] where x is the feature index
                                                               and y is the vector magnitude (change).

    Assumes the features are grouped by class, with equal features per class.
    """
    if isinstance(feature_matrix, torch.Tensor):
        feature_matrix = feature_matrix.detach().cpu().numpy()

    n_samples, n_features = feature_matrix.shape
    n_samples_to_plot = min(
        n_samples, max_samples
    )  # Limit the number of samples to visualize

    if n_samples_to_plot == 0:
        print("No samples to plot.")
        return

    # Calculate features per class
    if n_features % n_classes != 0:
        print(
            f"Warning: n_features ({n_features}) is not divisible by n_classes ({n_classes})"
        )

    features_per_class = n_features // n_classes

    # Create figure with shared y-axis for comparison
    fig_width = max(10, n_features * 0.5)
    fig, axes = plt.subplots(
        n_samples_to_plot,
        1,
        figsize=(fig_width, 3.5 * n_samples_to_plot),
        squeeze=False,
        sharex=True,
        sharey=True,
    )

    class_colors = BASE_COLOURS[:n_classes]

    # Process overlay_vectors: Ensure it's a list of numpy arrays (or None) of the correct length
    processed_overlays = [None] * n_samples_to_plot
    if overlay_vectors is not None:
        is_single_overlay = False
        if isinstance(overlay_vectors, (torch.Tensor, np.ndarray)):
            # Check if it's a single set of vectors meant for all samples or vectors for the first sample
            if overlay_vectors.ndim == 2 and overlay_vectors.shape[1] == 2:
                is_single_overlay = True
            elif (
                overlay_vectors.ndim == 3
                and overlay_vectors.shape[0] == n_samples
                and overlay_vectors.shape[2] == 2
            ):
                # It's already per-sample
                pass
            else:
                print("Warning: overlay_vectors has unexpected shape. Ignoring.")
                overlay_vectors = None

        if overlay_vectors is not None:
            for i in range(n_samples_to_plot):
                overlay_data = None
                if is_single_overlay:
                    overlay_data = overlay_vectors
                elif (
                    isinstance(overlay_vectors, list)
                    and i < len(overlay_vectors)
                    and overlay_vectors[i] is not None
                ):
                    overlay_data = overlay_vectors[i]
                elif (
                    isinstance(overlay_vectors, (torch.Tensor, np.ndarray))
                    and overlay_vectors.ndim == 3
                ):
                    overlay_data = overlay_vectors[i]

                if overlay_data is not None:
                    if isinstance(overlay_data, torch.Tensor):
                        overlay_data = overlay_data.detach().cpu().numpy()
                    if (
                        overlay_data.ndim == 2
                        and overlay_data.shape[1] == 2
                        and overlay_data.size > 0
                    ):
                        processed_overlays[i] = overlay_data

    # Find global max sum and min/max values for consistent y-axis scaling
    global_max_sum = 0
    global_min_value = float("inf")
    global_max_value = float("-inf")
    global_vector_max_abs = 0

    for sample_idx in range(n_samples_to_plot):
        feature_vector = feature_matrix[sample_idx]
        global_min_value = min(global_min_value, np.min(feature_vector))
        global_max_value = max(global_max_value, np.max(feature_vector))

        current_max_sum = 0
        for i in range(n_classes):
            start_idx = i * features_per_class
            end_idx = min((i + 1) * features_per_class, n_features)
            if start_idx >= n_features:
                break
            class_values = feature_vector[start_idx:end_idx]
            if class_values.size > 0:
                class_sum = np.sum(class_values)
                current_max_sum = max(current_max_sum, abs(class_sum))
        global_max_sum = max(global_max_sum, current_max_sum)

        # Consider vector magnitudes if provided for this sample
        sample_overlay = processed_overlays[sample_idx]
        if sample_overlay is not None:
            global_vector_max_abs = max(
                global_vector_max_abs, np.max(np.abs(sample_overlay[:, 1]))
            )

    # --- Determine Y-limits based on global values ---
    y_value_margin = (global_max_value - global_min_value) * 0.15
    y_sum_margin = global_max_sum * 0.25
    y_vec_margin = global_vector_max_abs * 0.2

    upper_bound = max(
        global_max_value + y_value_margin,
        global_max_sum + y_sum_margin,
        (global_max_value + global_vector_max_abs) + y_vec_margin,
        0.5,
    )
    lower_bound = min(
        global_min_value - y_value_margin,
        -global_max_sum - y_sum_margin,
        (global_min_value - global_vector_max_abs) - y_vec_margin,
        -0.5,
    )

    if abs(upper_bound - lower_bound) < 0.1:  # Handle near-zero case
        upper_bound = max(upper_bound, 0.5)
        lower_bound = min(lower_bound, -0.5)

    # --- Plot each sample ---
    for sample_idx in range(n_samples_to_plot):
        ax = axes[sample_idx, 0]
        feature_vector = feature_matrix[sample_idx]
        sample_overlay = processed_overlays[
            sample_idx
        ]  # Get the processed overlay for this sample

        # Use the single-vector plotting function logic (adapted for the loop)
        for i in range(n_classes):
            start_idx = i * features_per_class
            end_idx = min((i + 1) * features_per_class, n_features)
            if start_idx >= n_features:
                break

            class_values = feature_vector[start_idx:end_idx]
            x_positions = np.arange(start_idx, end_idx)
            if class_values.size == 0:
                continue

            class_sum = np.sum(class_values)

            x_min_rect, x_max_rect = start_idx - 0.5, end_idx - 0.5
            y_margin_for_text = max(
                0.1, global_max_sum * 0.25
            )  # Margin specifically for placing text labels outside bars

            zero_bar_thickness_factor = 0.20
            zero_bar_half_height = (y_margin_for_text * zero_bar_thickness_factor) / 2.0

            # Define bar extents (y_min_bar, y_max_bar) for the actual colored rectangle
            if abs(class_sum) < 1e-5:  # Threshold for "zero" sum - Center bar around 0
                y_min_bar = -zero_bar_half_height
                y_max_bar = zero_bar_half_height
            elif class_sum > 0:  # Strictly positive sum
                y_min_bar = -zero_bar_half_height
                y_max_bar = class_sum
            else:  # class_sum < 0, strictly negative sum
                y_min_bar = class_sum
                y_max_bar = zero_bar_half_height

            rect = plt.Rectangle(
                (x_min_rect, y_min_bar),
                x_max_rect - x_min_rect,
                y_max_bar - y_min_bar,
                fill=True,
                color=class_colors[i],
                alpha=0.6,
                zorder=1,
                linewidth=1.5,
                edgecolor=class_colors[i],
            )
            ax.add_patch(rect)

            # --- Text Label Positioning ---
            sum_label_offset = (
                y_margin_for_text * 0.15
            )  # Offset for sum label from bar edge
            delta_label_stack_offset = (
                y_margin_for_text * 0.20
            )  # Additional offset for delta if sum label exists

            primary_label_y_pos = 0
            primary_va_align = "bottom"

            # Determine base position for the primary label (either sum or delta if sum is hidden)
            if abs(class_sum) < 1e-5:  # Centered zero bar
                primary_label_y_pos = y_max_bar + sum_label_offset
                primary_va_align = "bottom"
            elif class_sum > 0:  # Positive sum, label above bar's top (class_sum)
                primary_label_y_pos = y_max_bar + sum_label_offset
                primary_va_align = "bottom"
            else:  # Negative sum, label below bar's bottom (class_sum)
                primary_label_y_pos = y_min_bar - sum_label_offset
                primary_va_align = "top"

            # Conditionally add class/sum label if requested AND if deltas are not meant to replace it
            if show_class_label:
                sum_text_obj = ax.text(
                    (x_min_rect + x_max_rect) / 2,
                    primary_label_y_pos,
                    f"Class {i + 1}\\nSum: {class_sum:.2f}",
                    horizontalalignment="center",
                    verticalalignment=primary_va_align,
                    fontsize=10,
                    fontweight="bold",
                    color="black",
                    zorder=5,
                )

            # Display class feature deltas if provided
            if class_feature_deltas is not None and i < len(class_feature_deltas):
                delta_val = class_feature_deltas[i]
                delta_text_y = primary_label_y_pos  # Default to sum label position
                va_delta = primary_va_align

                if (
                    show_class_label
                ):  # If sum label is also shown, stack delta on top/bottom of it
                    if (
                        primary_va_align == "bottom"
                    ):  # sum label is above bar or at bottom of text block
                        delta_text_y = primary_label_y_pos + delta_label_stack_offset
                    else:  # sum label is below bar or at top of text block
                        delta_text_y = primary_label_y_pos - delta_label_stack_offset
                # If show_class_label is False, delta_text_y and va_delta are already set to primary positions

                delta_text_obj = ax.text(
                    (x_min_rect + x_max_rect) / 2,
                    delta_text_y,
                    f"{delta_val:+.2f}",
                    horizontalalignment="center",
                    verticalalignment=va_delta,
                    fontsize=22,
                    fontweight="normal",
                    color="black",
                    zorder=6,
                )  # Increased fontsize from 16 to 22

            scatter_color_idx = i % len(class_colors)
            ax.scatter(
                x_positions,
                class_values,
                color=class_colors[scatter_color_idx],
                s=60,
                zorder=3,
                edgecolors="black",
                linewidth=0.5,
                alpha=0.85,
            )

            if features_per_class <= 10:
                for x, y in zip(x_positions, class_values):
                    ax.text(
                        x,
                        y + (0.05 if y >= 0 else -0.05),
                        f"{y:.2f}",
                        horizontalalignment="center",
                        verticalalignment="bottom" if y >= 0 else "top",
                        fontsize=8,
                        zorder=4,
                    )

        # Overlay vectors for this sample
        if sample_overlay is not None:
            positions = sample_overlay[:, 0]
            values = sample_overlay[:, 1]
            original_y_values = feature_vector[positions.astype(int)]

            ax.quiver(
                positions,
                original_y_values,
                np.zeros_like(values),
                values,
                angles="xy",
                scale_units="xy",
                scale=1,
                width=0.006,
                color="darkred",
                alpha=0.9,
                zorder=6,
            )

            for x, y_start, delta_y in zip(positions, original_y_values, values):
                label_y = y_start + delta_y
                ax.text(
                    x + 0.1,
                    label_y,
                    f"{delta_y:.2f}",
                    color="darkred",
                    fontweight="bold",
                    fontsize=8,
                    horizontalalignment="left",
                    verticalalignment=(
                        "center"
                        if abs(delta_y) > 0.1
                        else ("bottom" if delta_y > 0 else "top")
                    ),
                    zorder=7,
                )

        # --- Customization for each subplot ---
        ax.grid(True, linestyle="--", alpha=0.5, zorder=0)
        # ax.set_xticks(np.arange(n_features)) # Xticks are shared
        # ax.set_xticklabels([f'{i}' for i in range(n_features)], fontsize=8) # Xticklabels are shared
        if sample_idx == n_samples_to_plot - 1:  # Only label bottom x-axis
            ax.set_xlabel("Feature Index", fontsize=10)
        ax.set_ylabel("Value", fontsize=10)  # Y-label for each
        ax.set_title(f"Sample {sample_idx + 1}", fontsize=13)  # Simple title
        ax.axhline(
            y=0, color="black", linestyle="-", linewidth=0.7, alpha=0.5, zorder=0
        )
        ax.set_ylim(lower_bound, upper_bound)  # Apply consistent y-limits

    # Set shared x-axis ticks and labels once
    axes[-1, 0].set_xticks(np.arange(n_features))
    axes[-1, 0].set_xticklabels([f"{i}" for i in range(n_features)], fontsize=8)

    # Add a main title
    fig.suptitle(
        f"Feature Values by Class (Top {n_samples_to_plot} Samples)",
        fontsize=14,
        y=1.02,
    )

    plt.tight_layout(rect=[0, 0, 1, 0.98])  # Adjust layout slightly for suptitle

    # Log to wandb or show the plot
    if log_to_wandb:
        print(f"Logging {tag} figure")
        wandb.log({tag: wandb.Image(fig)}, commit=True)
    else:
        plt.show()

    plt.close(fig)


def print_features_by_class_config(features, cfg, max_samples=20, precision=2):
    """
    Lightweight method to print feature values to the console, clearly showing which values belong to which class.
    Uses a ClassificationConfig object to determine the number of classes.

    Args:
        features (torch.Tensor or np.ndarray): Feature vectors to display, shape (batch_size, n_features) or (n_features,)
        cfg (ClassificationConfig): Configuration object containing n_classes
        max_samples (int, optional): Maximum number of samples to display. Defaults to 5.
        precision (int, optional): Number of decimal places to display. Defaults to 2.

    Assumes the features are grouped by class, with equal features per class.
    """
    # Get n_classes from the config
    n_classes = cfg.n_classes

    # Convert to numpy if it's a torch tensor
    if isinstance(features, torch.Tensor):
        features = features.detach().cpu().numpy()

    # Handle both single vectors and batches
    if features.ndim == 1:
        features = features.reshape(1, -1)

    n_samples, n_features = features.shape
    n_samples = min(n_samples, max_samples)  # Limit number of samples to display

    # Calculate features per class
    features_per_class = n_features // n_classes
    remainder = n_features % n_classes

    if remainder != 0:
        print(
            f"Warning: n_features ({n_features}) is not divisible by n_classes ({n_classes})"
        )

    # Print header with config info
    print("\n" + "=" * 80)
    print(
        f"Feature Values by Class (showing {n_samples} samples, n_classes={n_classes})"
    )
    print("=" * 80)

    # Loop through each sample
    for sample_idx in range(n_samples):
        feature_vector = features[sample_idx]

        print(f"\nSample {sample_idx + 1}:")
        print("-" * 80)

        # Create a formatted string for each class
        for class_idx in range(n_classes):
            start_idx = class_idx * features_per_class
            end_idx = min((class_idx + 1) * features_per_class, n_features)

            if start_idx >= n_features:
                break

            class_values = feature_vector[start_idx:end_idx]

            # Format values with specified precision
            formatted_values = [f"{val:.{precision}f}" for val in class_values]

            # Print class label and values
            print(f"Class {class_idx + 1}: [", end="")

            # Use a different formatting approach based on the number of values
            if len(formatted_values) > 10:
                # For many values, print in a compact form
                print(" ".join(formatted_values), end="")
            else:
                # For fewer values, add more spacing for better readability
                print("  ".join(formatted_values), end="")

            print("]")

        print("-" * 80)

    print("\n")


def print_feature_by_class_config(features, cfg, max_samples=20, precision=4):
    """
    Lightweight method to print feature values to the console, clearly showing which values belong to which class.
    Uses a ClassificationConfig object to determine the number of classes.

    Args:
        features (torch.Tensor or np.ndarray): Feature vectors to display, shape (batch_size, n_features) or (n_features,)
        cfg (ClassificationConfig): Configuration object containing n_classes
        max_samples (int, optional): Maximum number of samples to display. Defaults to 5.
        precision (int, optional): Number of decimal places to display. Defaults to 2.

    Assumes the features are grouped by class, with equal features per class.
    """
    # Get n_classes from the config
    n_classes = cfg.n_classes

    # Convert to numpy if it's a torch tensor
    if isinstance(features, torch.Tensor):
        features = features.detach().cpu().numpy()

    # Handle both single vectors and batches
    if features.ndim == 1:
        features = features.reshape(1, -1)

    n_samples, n_features = features.shape
    n_samples = min(n_samples, max_samples)  # Limit number of samples to display

    # Calculate features per class
    features_per_class = n_features // n_classes
    remainder = n_features % n_classes

    if remainder != 0:
        print(
            f"Warning: n_features ({n_features}) is not divisible by n_classes ({n_classes})"
        )

    # Loop through each sample
    for sample_idx in range(1):
        feature_vector = features[sample_idx]

        # Create a formatted string for each class
        for class_idx in range(n_classes):
            start_idx = class_idx * features_per_class
            end_idx = min((class_idx + 1) * features_per_class, n_features)

            if start_idx >= n_features:
                break

            class_values = feature_vector[start_idx:end_idx]

            # Format values with specified precision
            formatted_values = [f"{val:.{precision}f}" for val in class_values]

            # Print class label and values
            print(f"Class {class_idx + 1}: [", end="")

            # Use a different formatting approach based on the number of values
            if len(formatted_values) > 10:
                # For many values, print in a compact form
                print(" ".join(formatted_values), end="")
            else:
                # For fewer values, add more spacing for better readability
                print("  ".join(formatted_values), end="")

            print("]")

    print("\n")


def desaturate_hex(hex_color, factor=0.8):  # Factor < 1 desaturates
    rgb = mcolors.hex2color(hex_color)
    hsv = mcolors.rgb_to_hsv(rgb)
    hsv[1] *= factor  # Reduce saturation
    new_rgb = mcolors.hsv_to_rgb(hsv)
    # Ensure RGB values are clipped between 0 and 1
    new_rgb = np.clip(new_rgb, 0, 1)
    return mcolors.rgb2hex(new_rgb)


def plot_adversarial_attack_mechanism(
    model,
    original_input=None,
    adversarial_input=None,
    instance_idx=0,
    batch_for_boundary=None,  # Need a batch to determine boundary grid size & axes limits
    batch_labels_for_boundary=None,  # Corresponding labels for coloring boundary
    fig_title="Adversarial Attack Mechanism",  # New default title
    save_path=None,
    log_to_wandb=False,
    tag="adversarial_mechanism_plot",
    save_plot_data_to: str | None = None,  # New: Path to save plot data
):
    """
    Creates a composite plot visualizing the mechanism of an adversarial attack.

    Combines:
    1. Decision boundary in activation space with original/adversarial activations highlighted.
       Also shows the feature directions (weight vectors) from the model's first layer.
    2. Feature breakdown diagram for the original input (if provided).
    3. Feature breakdown diagram for the adversarial input (if provided), overlaying the perturbation.

    Args:
        model: The classification model. Assumed to have model.cfg access.
        original_input (torch.Tensor, optional): The original input vector, shape (n_features,). If None, original input plots are skipped.
        adversarial_input (torch.Tensor, optional): The adversarial input vector, shape (n_features,). If None, adversarial input plots are skipped.
        instance_idx (int): The model instance index to analyze.
        batch_for_boundary (torch.Tensor): Batch of inputs to determine activation space bounds, shape (batch_size, n_features).
        batch_labels_for_boundary (torch.Tensor): Labels for the boundary batch, shape (batch_size,).
        fig_title (str, optional): Overall title for the figure.
        save_path (str, optional): If provided, saves the figure to this path.
        log_to_wandb (bool, optional): Whether to log the plot to Weights & Biases. Defaults to False.
        tag (str, optional): Tag/name for wandb logging. Defaults to "adversarial_mechanism_plot".
        save_plot_data_to (str, optional): If provided, saves all necessary data to reconstruct this plot to the given filepath.

    Returns:
        matplotlib.figure.Figure: The generated Matplotlib figure object.
    """
    # --- Input Validation ---
    if original_input is None and adversarial_input is None:
        raise ValueError(
            "At least one of original_input or adversarial_input must be provided"
        )

    if original_input is not None:
        assert original_input.ndim == 1, "original_input must be a 1D tensor"

    if adversarial_input is not None:
        assert adversarial_input.ndim == 1, "adversarial_input must be a 1D tensor"

    if original_input is not None and adversarial_input is not None:
        assert (
            original_input.shape == adversarial_input.shape
        ), "Inputs must have the same shape"

    assert (
        model.cfg.n_hidden == 2
    ), "This visualization requires a model with n_hidden=2"

    # --- Setup Figure and Grid ---
    fig = plt.figure(figsize=(30, 8))  # Made wider and shorter
    # Create a 5-column grid to achieve different spacings
    gs = gridspec.GridSpec(
        1,
        5,
        width_ratios=[
            1.2,
            0.15,
            1,
            0.25,
            1,
        ],  # Increased middle spacer from 0.08 to 0.15
        wspace=0,  # No spacing between columns, we control it with width_ratios
        left=0.05,
        right=0.78,
        bottom=0.25,
        top=0.85,
    )  # Adjusted for more space around plots

    # Create subplots for feature breakdowns
    ax_boundary = fig.add_subplot(gs[0])  # First column for decision boundary

    if original_input is not None and adversarial_input is not None:
        ax_original_features = fig.add_subplot(gs[2])  # Third column
        ax_adv_features = fig.add_subplot(gs[4])  # Fifth column
    elif original_input is not None:
        ax_original_features = fig.add_subplot(gs[2:])  # Span from third column to end
    elif adversarial_input is not None:
        ax_adv_features = fig.add_subplot(gs[2:])  # Span from third column to end

    # --- Calculations ---
    with torch.no_grad():
        # Calculate activations only for provided inputs
        h_original = None
        h_adversarial = None

        if original_input is not None:
            h_original = activations(
                model, original_input.unsqueeze(0), instance_idx
            ).squeeze()

        if adversarial_input is not None:
            h_adversarial = activations(
                model, adversarial_input.unsqueeze(0), instance_idx
            ).squeeze()

        # Stack only the available activation points
        h_points = []
        if h_original is not None:
            h_points.append(h_original)
        if h_adversarial is not None:
            h_points.append(h_adversarial)

        h_points = torch.stack(h_points).cpu() if h_points else None

        acts_boundary = activations(model, batch_for_boundary, instance_idx).cpu()

        # Calculate feature change only if both inputs are provided
        feature_change = None
        overlay_data = None
        if original_input is not None and adversarial_input is not None:
            feature_change = (adversarial_input - original_input).cpu()
            overlay_indices = torch.arange(
                len(feature_change), device=feature_change.device
            )
            overlay_data = torch.stack([overlay_indices.float(), feature_change], dim=1)

    # --- Define Color Palette ---
    n_classes = model.cfg.n_classes
    # Get standard tab10 colors
    cmap = plt.get_cmap("tab10")
    tab10_colors = [mcolors.rgb2hex(cmap(i)) for i in range(10)]

    # Create the slightly desaturated palette
    base_palette = [
        desaturate_hex(c, 0.75) for c in tab10_colors
    ]  # Adjust 0.75 factor as needed

    # Ensure palette has enough colours, cycle if needed
    color_palette = [base_palette[i % len(base_palette)] for i in range(n_classes)]

    # --- Plotting ---
    # 1. Plot Decision Boundary & Get Class Handles/Labels
    class_handles, class_labels = [], []  # Initialize
    ax_boundary, class_handles, class_labels = plot_decision_boundary(
        model,
        acts_boundary,
        batch_labels_for_boundary,
        instance_idx,
        additional_points=None,
        ax=ax_boundary,
        log_to_wandb=False,
        graph_axes=(
            -5,
            5,
            -5,
            5,
        ),
        title="Decision Boundary, Activations &\nLatent Directions",  # Split title into two lines
        create_legend=False,  # Get handles/labels back
        color_palette=color_palette,  # Pass the consistent palette
    )
    original_xlim = ax_boundary.get_xlim()
    original_ylim = ax_boundary.get_ylim()

    # 2. Overlay Feature Directions (Weight Vectors) on the decision boundary plot
    weights = model.W[instance_idx].cpu()
    n_features = model.cfg.n_features
    features_per_class = n_features // n_classes
    weight_colors = []
    for c_idx in range(n_classes):
        color = color_palette[c_idx % len(color_palette)]
        weight_colors.extend([color] * features_per_class)
    weight_colors.extend(["darkgrey"] * (n_features - len(weight_colors)))
    weight_colors_list = [weight_colors]

    # Increase linewidth/markersize for feature vectors
    plot_features_in_2d(
        W=weights.unsqueeze(0),
        colors=weight_colors_list,
        ax=ax_boundary,  # Plot on the decision boundary axes
        linewidth=2.5,  # Increased from 2.0
        markersize=7,
    )
    ax_boundary.set_xlim(original_xlim)
    ax_boundary.set_ylim(original_ylim)
    # Re-apply axis labels and tick sizes after plot_features_in_2d
    ax_boundary.set_xlabel("Dimension 1", fontsize=22)
    ax_boundary.set_ylabel("Dimension 2", fontsize=22)
    ax_boundary.tick_params(axis="both", which="major", labelsize=16)

    # 3. Plot Activation Points & Get Handles (only for provided inputs)
    activation_handles = []
    activation_labels = []

    if h_original is not None:
        scatter_orig = ax_boundary.scatter(
            h_original[0].item(),
            h_original[1].item(),
            marker="o",
            s=180,
            color="deepskyblue",
            label="Original Activation",
            zorder=10,
            edgecolors="black",
            linewidth=1.5,
        )
        activation_handles.append(scatter_orig)
        activation_labels.append("Original Activation")

    if h_adversarial is not None:
        scatter_adv = ax_boundary.scatter(
            h_adversarial[0].item(),
            h_adversarial[1].item(),
            marker="X",
            s=180,
            color="orangered",
            label="Adversarial Activation",
            zorder=10,
            edgecolors="black",
            linewidth=1.5,
        )
        activation_handles.append(scatter_adv)
        activation_labels.append("Adversarial Activation")

    # 4. Plot Feature Breakdowns (only for provided inputs)
    if original_input is not None and ax_original_features is not None:
        visualize_feature_vector_by_class(
            original_input.cpu(),
            n_classes=model.cfg.n_classes,
            ax=ax_original_features,
            log_to_wandb=False,
            title="Original Input",
            show_class_label=False,
            color_palette=color_palette,
            overlay_vectors=None,
            class_feature_deltas=None,
            show_individual_feature_values=False,  # Do not show individual values on original plot
            y_margin_factor=0.1,  # Reduced margin for more compact plot
        )

    if adversarial_input is not None and ax_adv_features is not None:
        class_deltas_for_adv_plot = None
        if adversarial_input is not None:
            class_deltas_for_adv_plot = (
                (adversarial_input - original_input).cpu().numpy()
            )
        visualize_feature_vector_by_class(
            adversarial_input.cpu(),
            n_classes=model.cfg.n_classes,
            ax=ax_adv_features,
            log_to_wandb=False,
            title="Adversarial Input",
            show_class_label=False,
            color_palette=color_palette,
            overlay_vectors=None,
            class_feature_deltas=class_deltas_for_adv_plot,
            show_individual_feature_values=False,  # Show individual values on adversarial plot (default)
            y_margin_factor=0.1,  # Reduced margin for more compact plot
        )

    # --- Final Touches ---
    # Create two separate legends on the right (only if there are items to include)
    leg1 = fig.legend(
        handles=class_handles,
        labels=class_labels,
        loc="upper left",
        bbox_to_anchor=(0.81, 0.95),  # Moved up from 0.90 to 0.95
        title="Classes",
        fontsize=22,
        title_fontsize=24,
    )  # Increased font sizes

    if activation_handles:
        leg2 = fig.legend(
            handles=activation_handles,
            labels=activation_labels,
            loc="upper left",
            bbox_to_anchor=(0.81, 0.40),  # Moved down from 0.50 to 0.40
            fontsize=18,
            labelspacing=1.2,
        )  # Increased font size from 16 to 18

    # Add more space between titles and plots
    for ax in [ax_boundary, ax_original_features, ax_adv_features]:
        if ax is not None:
            ax.set_title(
                ax.get_title(), pad=40, fontsize=24
            )  # Increased pad from 20 to 30
            # Add more space between x-axis labels and plot
            ax.xaxis.set_label_coords(0.5, -0.15)  # Move x-axis label down

    # Ensure x-axes align for bar charts
    if original_input is not None and adversarial_input is not None:
        ax_original_features.set_ylim(ax_adv_features.get_ylim())

    fig.suptitle(fig_title, fontsize=28, y=0.95)  # Increased main title font size

    # --- Save/Log/Show ---
    if save_path:
        print(f"Saving figure to {save_path}")
        fig.savefig(save_path, bbox_inches="tight", dpi=300)
    if log_to_wandb:
        print(f"Logging {tag} figure to W&B")
        wandb.log({tag: wandb.Image(fig)}, commit=True)
        plt.close(fig)
    if not save_path and not log_to_wandb:
        plt.show()
        plt.close(fig)

    # --- Save plot data if requested ---
    if save_plot_data_to:
        print(f"Saving plot data to {save_plot_data_to}")
        # Prepare model reconstruction info
        model_reconstruction_info = {
            "cfg": {
                "n_features": model.cfg.n_features,
                "n_hidden": model.cfg.n_hidden,
                "n_classes": model.cfg.n_classes,
                # Removed flags from here, they are direct attributes of the model
            },
            # Store flags directly from the model object
            "bias": model.bias if hasattr(model, "bias") else None,
            "privileged": model.privileged if hasattr(model, "privileged") else None,
            "projection_bias": (
                model.projection_bias if hasattr(model, "projection_bias") else None
            ),
            "privileged_out": (
                model.privileged_out if hasattr(model, "privileged_out") else None
            ),
            # Tensors:
            "W": model.W.detach().cpu().numpy() if hasattr(model, "W") else None,
            "b_enc": (
                model.b_enc.detach().cpu().numpy()
                if hasattr(model, "b_enc") and model.b_enc is not None
                else None
            ),
            "W_projection": (
                model.W_projection.detach().cpu().numpy()
                if hasattr(model, "W_projection")
                else None
            ),
            "b_projection": (
                model.b_projection.detach().cpu().numpy()
                if hasattr(model, "b_projection") and model.b_projection is not None
                else None
            ),
        }

        plot_data = {
            "plot_function_name": "plot_adversarial_attack_mechanism",
            "plot_args": {
                "model_reconstruction_info": model_reconstruction_info,
                "original_input": (
                    original_input.detach().cpu().numpy()
                    if original_input is not None
                    else None
                ),
                "adversarial_input": (
                    adversarial_input.detach().cpu().numpy()
                    if adversarial_input is not None
                    else None
                ),
                "instance_idx": instance_idx,
                "batch_for_boundary": (
                    batch_for_boundary.detach().cpu().numpy()
                    if batch_for_boundary is not None
                    else None
                ),
                "batch_labels_for_boundary": (
                    batch_labels_for_boundary.detach().cpu().numpy()
                    if batch_labels_for_boundary is not None
                    else None
                ),
                "fig_title": fig_title,
                # Note: color_palette is not saved as it's derived dynamically or uses defaults.
                # log_to_wandb, tag, save_path are runtime options, not intrinsic data for plot reconstruction.
            },
        }
        torch.save(plot_data, save_plot_data_to)

    return fig


def plot_adversarial_attack_mechanism_3d(
    model,
    original_input=None,
    adversarial_input=None,
    instance_idx=0,
    batch_for_boundary=None,  # Need a batch to determine boundary grid size & axes limits
    batch_labels_for_boundary=None,  # Corresponding labels for coloring boundary
    fig_title="Adversarial Attack Mechanism (3D)",  # New default title
    save_path=None,
    log_to_wandb=False,
    tag="adversarial_mechanism_plot_3d",
    save_plot_data_to: str | None = None,  # New: Path to save plot data
):
    """
    Creates a composite plot visualizing the mechanism of an adversarial attack in 3D space.

    Combines:
    1. Feature vectors in 3D space.
    2. Feature breakdown diagram for the original input (if provided).
    3. Feature breakdown diagram for the adversarial input (if provided), overlaying the perturbation.

    Args:
        model: The classification model. Assumed to have model.cfg access.
        original_input (torch.Tensor, optional): The original input vector, shape (n_features,). If None, original input plots are skipped.
        adversarial_input (torch.Tensor, optional): The adversarial input vector, shape (n_features,). If None, adversarial input plots are skipped.
        instance_idx (int): The model instance index to analyze.
        batch_for_boundary (torch.Tensor): Batch of inputs to determine activation space bounds, shape (batch_size, n_features).
        batch_labels_for_boundary (torch.Tensor): Labels for the boundary batch, shape (batch_size,).
        fig_title (str, optional): Overall title for the figure.
        save_path (str, optional): If provided, saves the figure to this path.
        log_to_wandb (bool, optional): Whether to log the plot to Weights & Biases. Defaults to False.
        tag (str, optional): Tag/name for wandb logging. Defaults to "adversarial_mechanism_plot_3d".
        save_plot_data_to (str, optional): If provided, saves all necessary data (model weights, inputs, plotting parameters, etc.) to reconstruct or modify this plot to the given filepath.

    Returns:
        matplotlib.figure.Figure: The generated Matplotlib figure object.
    """
    # --- Input Validation ---
    if original_input is None and adversarial_input is None:
        raise ValueError(
            "At least one of original_input or adversarial_input must be provided"
        )

    if original_input is not None:
        assert original_input.ndim == 1, "original_input must be a 1D tensor"

    if adversarial_input is not None:
        assert adversarial_input.ndim == 1, "adversarial_input must be a 1D tensor"

    if original_input is not None and adversarial_input is not None:
        assert (
            original_input.shape == adversarial_input.shape
        ), "Inputs must have the same shape"

    assert (
        model.cfg.n_hidden == 3
    ), "This visualization requires a model with n_hidden=3"

    # --- Setup Figure and Grid ---
    fig = plt.figure(figsize=(26, 13))  # Made larger
    # Increase wspace, hspace. Reserve space on right for legends.
    gs = gridspec.GridSpec(
        2,
        2,
        width_ratios=[1.2, 1],
        height_ratios=[1, 1],
        wspace=0.22,  # Adjusted
        hspace=0.5,  # Adjusted
        left=0.05,
        right=0.78,
        bottom=0.1,
        top=0.9,
    )  # Adjusted for legend

    ax_boundary = fig.add_subplot(gs[:, 0], projection="3d")  # Spans rows 0-1, column 0

    # Only create subplot for original features if original_input is provided
    ax_original_features = (
        fig.add_subplot(gs[0, 1]) if original_input is not None else None
    )

    # Only create subplot for adversarial features if adversarial_input is provided
    ax_adv_features = (
        fig.add_subplot(gs[1, 1]) if adversarial_input is not None else None
    )

    # If only one input is provided, make its subplot span both rows
    if original_input is not None and adversarial_input is None:
        ax_original_features = fig.add_subplot(gs[:, 1])
    elif original_input is None and adversarial_input is not None:
        ax_adv_features = fig.add_subplot(gs[:, 1])

    # --- Define Color Palette ---
    n_classes = model.cfg.n_classes
    # Get standard tab10 colors
    cmap = plt.get_cmap("tab10")
    tab10_colors = [mcolors.rgb2hex(cmap(i)) for i in range(10)]

    # Create the slightly desaturated palette
    base_palette = [
        desaturate_hex(c, 0.75) for c in tab10_colors
    ]  # Adjust 0.75 factor as needed

    # Ensure palette has enough colours, cycle if needed
    color_palette = [base_palette[i % len(base_palette)] for i in range(n_classes)]

    # --- Plotting ---
    # 1. Plot Feature Vectors
    weights = model.W[instance_idx].cpu()
    n_features = model.cfg.n_features
    features_per_class = n_features // n_classes

    # Create class handles for legend
    class_handles = []
    class_labels = []
    for c_idx in range(n_classes):
        color = color_palette[c_idx % len(color_palette)]
        patch = plt.Rectangle((0, 0), 1, 1, color=color, alpha=0.7)
        class_handles.append(patch)
        class_labels.append(f"Class {c_idx + 1}")  # Changed to 1-based indexing

    # Find the maximum extent of the vectors for axis scaling
    max_extent = 0
    vectors = []
    for i in range(n_features):
        vector = weights[:, i].detach().numpy()
        vectors.append(vector)
        max_extent = max(max_extent, np.max(np.abs(vector)))

    # Add a small margin to the max extent
    max_extent *= 1.1

    # Plot weight vectors
    origin = np.zeros(3)
    for i, vector in enumerate(vectors):
        class_idx = i // features_per_class
        color = color_palette[class_idx % len(color_palette)]

        ax_boundary.quiver(
            origin[0],
            origin[1],
            origin[2],
            vector[0],
            vector[1],
            vector[2],
            color=color,
            arrow_length_ratio=0.1,
            alpha=0.7,
            linewidth=7,
        )

    # Set axis labels and limits
    ax_boundary.set_xlabel(
        "Dimension 1", fontsize=27, labelpad=12
    )  # Increased fontsize and labelpad
    ax_boundary.set_ylabel(
        "Dimension 2", fontsize=27, labelpad=12
    )  # Increased fontsize and labelpad
    ax_boundary.set_zlabel(
        "Dimension 3", fontsize=27, labelpad=12
    )  # Increased fontsize and labelpad
    ax_boundary.set_title(
        "Latent Directions", fontsize=34, pad=30
    )  # Added pad parameter
    ax_boundary.tick_params(
        axis="both", which="major", labelsize=16
    )  # Increased tick labelsize

    # Set equal limits for all axes
    ax_boundary.set_xlim([-max_extent, max_extent])
    ax_boundary.set_ylim([-max_extent, max_extent])
    ax_boundary.set_zlim([-max_extent, max_extent])

    # 2. Plot Feature Breakdowns (only for provided inputs)
    if original_input is not None and ax_original_features is not None:
        visualize_feature_vector_by_class(
            original_input.cpu(),
            n_classes=model.cfg.n_classes,
            ax=ax_original_features,
            log_to_wandb=False,
            title="Original Input",
            show_class_label=False,
            color_palette=color_palette,
            overlay_vectors=None,
            class_feature_deltas=None,
            show_individual_feature_values=False,  # Do not show individual values on original plot
        )
        ax_original_features.set_title(
            ax_original_features.get_title(), pad=20, fontsize=32
        )  # Added pad parameter

    if adversarial_input is not None and ax_adv_features is not None:
        class_deltas_for_adv_plot = None
        if adversarial_input is not None:
            class_deltas_for_adv_plot = (
                (adversarial_input - original_input).cpu().numpy()
            )
        visualize_feature_vector_by_class(
            adversarial_input.cpu(),
            n_classes=model.cfg.n_classes,
            ax=ax_adv_features,
            log_to_wandb=False,
            title="Adversarial Input",
            show_class_label=False,
            color_palette=color_palette,
            overlay_vectors=None,  # Original overlay_vectors might not be what we need for bar deltas
            class_feature_deltas=class_deltas_for_adv_plot,  # Corrected: Pass the calculated per-class deltas
            show_individual_feature_values=False,  # Show individual values on adversarial plot (default)
        )
        ax_adv_features.set_title(
            ax_adv_features.get_title(), pad=20, fontsize=32
        )  # Added pad parameter

    # --- Final Touches ---
    # Create class legend
    leg = fig.legend(
        handles=class_handles,
        labels=class_labels,
        loc="upper left",
        bbox_to_anchor=(0.81, 0.90),  # Moved further right
        title="Classes",
        fontsize=24,
        title_fontsize=26,
    )  # Increased font sizes

    fig.suptitle(fig_title, fontsize=24)

    # --- Save/Log/Show ---
    if save_path:
        print(f"Saving figure to {save_path}")
        fig.savefig(save_path, bbox_inches="tight", dpi=300)
    if log_to_wandb:
        print(f"Logging {tag} figure to W&B")
        wandb.log({tag: wandb.Image(fig)}, commit=True)
        plt.close(fig)
    if not save_path and not log_to_wandb:
        plt.show()
        plt.close(fig)

    # --- Save plot data if requested ---
    if save_plot_data_to:
        print(f"Saving plot data to {save_plot_data_to}")
        model_reconstruction_info = {
            "cfg": {
                "n_features": model.cfg.n_features,
                "n_hidden": model.cfg.n_hidden,
                "n_classes": model.cfg.n_classes,
            },
            "bias": model.bias if hasattr(model, "bias") else None,
            "privileged": model.privileged if hasattr(model, "privileged") else None,
            "projection_bias": (
                model.projection_bias if hasattr(model, "projection_bias") else None
            ),
            "privileged_out": (
                model.privileged_out if hasattr(model, "privileged_out") else None
            ),
            "W": model.W.detach().cpu().numpy() if hasattr(model, "W") else None,
            "b_enc": (
                model.b_enc.detach().cpu().numpy()
                if hasattr(model, "b_enc") and model.b_enc is not None
                else None
            ),
            "W_projection": (
                model.W_projection.detach().cpu().numpy()
                if hasattr(model, "W_projection")
                else None
            ),
            "b_projection": (
                model.b_projection.detach().cpu().numpy()
                if hasattr(model, "b_projection") and model.b_projection is not None
                else None
            ),
        }

        # Save all plotting parameters needed for exact recreation
        plot_params = {
            "fig_title": fig_title,
            "instance_idx": instance_idx,
            "axis_label_fontsize": 22,
            "axis_labelpad": 12,
            "vector_linewidth": 4,
            "arrow_length_ratio": 0.1,
            "vector_alpha": 0.7,
            # Add more params here if you add more customizations
        }

        plot_data = {
            "plot_function_name": "plot_adversarial_attack_mechanism_3d",
            "plot_args": {
                "model_reconstruction_info": model_reconstruction_info,
                "original_input": (
                    original_input.detach().cpu().numpy()
                    if original_input is not None
                    else None
                ),
                "adversarial_input": (
                    adversarial_input.detach().cpu().numpy()
                    if adversarial_input is not None
                    else None
                ),
                "plot_params": plot_params,
            },
        }
        torch.save(plot_data, save_plot_data_to)

    return fig


def plot_attack_transitions(
    model,
    attack_lookup,
    instance_idx=0,
    batch_for_boundary=None,
    batch_labels_for_boundary=None,
    fig_title="Attack Transitions",
    save_path=None,
    log_to_wandb=False,
    tag="attack_transitions_plot",
    save_plot_data_to: str | None = None,  # New: Path to save plot data
):
    """
    Creates a composite plot showing the decision boundary and a heatmap of attack transitions.

    Combines:
    1. Decision boundary in 2D activation space.
    2. Heatmap showing the number of attacks from each class to each other class.

    Args:
        model: The classification model. Assumed to have model.cfg access.
        attack_lookup: Dictionary mapping (from_class, to_class) pairs to lists of Attack objects.
        instance_idx (int): The model instance index to analyze.
        batch_for_boundary (torch.Tensor): Batch of inputs to determine activation space bounds.
        batch_labels_for_boundary (torch.Tensor): Labels for the boundary batch.
        fig_title (str, optional): Overall title for the figure.
        save_path (str, optional): If provided, saves the figure to this path.
        log_to_wandb (bool, optional): Whether to log the plot to Weights & Biases. Defaults to False.
        tag (str, optional): Tag/name for wandb logging. Defaults to "attack_transitions_plot".
        save_plot_data_to (str, optional): If provided, saves all necessary data to reconstruct this plot to the given filepath.

    Returns:
        matplotlib.figure.Figure: The generated Matplotlib figure object.
    """
    assert (
        model.cfg.n_hidden == 2
    ), "This visualization requires a model with n_hidden=2"

    # --- Setup Figure and Grid ---
    # Calculate figure size to make heatmap square while matching decision boundary height
    decision_boundary_height = 11  # Height of decision boundary plot
    decision_boundary_width = 15.6  # Width of decision boundary plot (1.2 * 13)
    heatmap_size = decision_boundary_height  # Make heatmap square
    total_width = (
        decision_boundary_width + heatmap_size + 2.5
    )  # Add some space for wspace

    fig = plt.figure(figsize=(total_width, decision_boundary_height))
    gs = gridspec.GridSpec(
        1,
        2,
        width_ratios=[decision_boundary_width, heatmap_size],
        wspace=0.25,
        left=0.05,
        right=0.78,
        bottom=0.1,
        top=0.95,
    )

    ax_boundary = fig.add_subplot(gs[0])
    ax_heatmap = fig.add_subplot(gs[1])

    # --- Define Color Palette ---
    n_classes = model.cfg.n_classes
    cmap = plt.get_cmap("tab10")
    tab10_colors = [mcolors.rgb2hex(cmap(i)) for i in range(10)]
    base_palette = [desaturate_hex(c, 0.75) for c in tab10_colors]
    color_palette = [base_palette[i % len(base_palette)] for i in range(n_classes)]

    # --- Plot Decision Boundary ---
    ax_boundary, class_handles, class_labels = plot_decision_boundary(
        model,
        batch_for_boundary,
        batch_labels_for_boundary,
        instance_idx,
        additional_points=None,
        ax=ax_boundary,
        log_to_wandb=False,
        graph_axes=(
            -10,
            10,
            -10,
            10,
        ),  # Increased axes limits to show more of the vectors
        title="Decision Boundary & Feature Directions",
        create_legend=False,
        color_palette=color_palette,
    )
    original_xlim = ax_boundary.get_xlim()
    original_ylim = ax_boundary.get_ylim()

    # Plot feature vectors directly on the decision boundary plot
    weights = model.W[instance_idx].cpu()
    n_features = model.cfg.n_features
    features_per_class = n_features // n_classes
    weight_colors = []
    for c_idx in range(n_classes):
        color = color_palette[c_idx % len(color_palette)]
        weight_colors.extend([color] * features_per_class)
    weight_colors.extend(["darkgrey"] * (n_features - len(weight_colors)))
    weight_colors_list = [weight_colors]

    # Plot feature vectors with increased linewidth/markersize
    plot_features_in_2d(
        W=weights.unsqueeze(0),
        colors=weight_colors_list,
        ax=ax_boundary,
        linewidth=4.0,  # Increased from 2.5
        markersize=10,  # Increased from 7
    )
    ax_boundary.set_xlim(original_xlim)
    ax_boundary.set_ylim(original_ylim)
    # Re-apply axis labels and tick sizes after plot_features_in_2d
    ax_boundary.set_xlabel("Dimension 1", fontsize=28)  # Increased from 22
    ax_boundary.set_ylabel("Dimension 2", fontsize=28)  # Increased from 22
    ax_boundary.tick_params(
        axis="both", which="major", labelsize=20
    )  # Increased from 16

    # --- Create and Plot Heatmap ---
    # Initialize transition matrix
    transition_matrix = np.zeros((n_classes, n_classes))

    # Fill transition matrix
    for (from_class, to_class), attacks in attack_lookup.items():
        transition_matrix[from_class, to_class] = len(attacks)

    # Create custom colormap from white to red
    cmap = mcolors.LinearSegmentedColormap.from_list("custom", ["#ffffff", "#d73027"])

    # Create heatmap with equal aspect ratio
    im = ax_heatmap.imshow(transition_matrix, cmap=cmap, aspect="equal")

    # Add colorbar
    cbar = plt.colorbar(im, ax=ax_heatmap)
    cbar.set_label("Number of Attacks", fontsize=24)  # Increased fontsize
    cbar.ax.tick_params(labelsize=20)  # Increased tick labelsize

    # Customize heatmap
    ax_heatmap.set_title("Attack Transitions", fontsize=32)  # Increased from 24
    ax_heatmap.set_xlabel("Target Class", fontsize=28)  # Increased from 18
    ax_heatmap.set_ylabel("Source Class", fontsize=28)  # Increased from 18
    ax_heatmap.set_xticks(np.arange(n_classes))
    ax_heatmap.set_yticks(np.arange(n_classes))
    ax_heatmap.set_xticklabels(
        [str(i + 1) for i in range(n_classes)], fontsize=20
    )  # Increased from 12
    ax_heatmap.set_yticklabels(
        [str(i + 1) for i in range(n_classes)], fontsize=20
    )  # Increased from 12

    # Add text annotations
    for i in range(n_classes):
        for j in range(n_classes):
            count = int(transition_matrix[i, j])
            ax_heatmap.text(
                j,
                i,
                str(count),
                ha="center",
                va="center",
                color="black" if count < np.max(transition_matrix) / 2 else "white",
                fontsize=16,
                fontweight="bold",
            )  # Increased from 10

    # --- Final Touches ---
    # Create class legend
    leg = fig.legend(
        handles=class_handles,
        labels=class_labels,
        loc="upper left",
        bbox_to_anchor=(0.83, 0.90),
        title="Classes",
        fontsize=24,
        title_fontsize=26,
    )  # Increased font sizes

    # Remove the main title
    # fig.suptitle(fig_title, fontsize=24)  # Commented out to remove main title

    # --- Save/Log/Show ---
    if save_path:
        print(f"Saving figure to {save_path}")
        fig.savefig(save_path, bbox_inches="tight", dpi=300)
    if log_to_wandb:
        print(f"Logging {tag} figure to W&B")
        wandb.log({tag: wandb.Image(fig)}, commit=True)
        plt.close(fig)
    if not save_path and not log_to_wandb:
        plt.show()
        plt.close(fig)

    # --- Save plot data if requested ---
    if save_plot_data_to:
        print(f"Saving plot data to {save_plot_data_to}")
        model_reconstruction_info = {
            "cfg": {
                "n_features": model.cfg.n_features,
                "n_hidden": model.cfg.n_hidden,
                "n_classes": model.cfg.n_classes,
                # Removed flags from here, they are direct attributes of the model
            },
            # Store flags directly from the model object
            "bias": model.bias if hasattr(model, "bias") else None,
            "privileged": model.privileged if hasattr(model, "privileged") else None,
            "projection_bias": (
                model.projection_bias if hasattr(model, "projection_bias") else None
            ),
            "privileged_out": (
                model.privileged_out if hasattr(model, "privileged_out") else None
            ),
            # Tensors:
            "W": model.W.detach().cpu().numpy() if hasattr(model, "W") else None,
            "b_enc": (
                model.b_enc.detach().cpu().numpy()
                if hasattr(model, "b_enc") and model.b_enc is not None
                else None
            ),
            "W_projection": (
                model.W_projection.detach().cpu().numpy()
                if hasattr(model, "W_projection")
                else None
            ),
            "b_projection": (
                model.b_projection.detach().cpu().numpy()
                if hasattr(model, "b_projection") and model.b_projection is not None
                else None
            ),
        }

        # Serialize attack_lookup
        serialized_attack_lookup = {}
        for key, attack_list in attack_lookup.items():
            serialized_attacks = []
            for attack in attack_list:
                serialized_attack = {
                    "placeholder_info": f"Attack from {key[0]} to {key[1]}"
                }
            serialized_attack_lookup[key] = len(
                attack_list
            )  # Storing counts for heatmap reconstruction

        plot_data = {
            "plot_function_name": "plot_attack_transitions",
            "plot_args": {
                "model_reconstruction_info": model_reconstruction_info,
                "attack_lookup_counts": serialized_attack_lookup,  # Saving counts for heatmap
                "instance_idx": instance_idx,
                "batch_for_boundary": (
                    batch_for_boundary.detach().cpu().numpy()
                    if batch_for_boundary is not None
                    else None
                ),
                "batch_labels_for_boundary": (
                    batch_labels_for_boundary.detach().cpu().numpy()
                    if batch_labels_for_boundary is not None
                    else None
                ),
                "fig_title": fig_title,
            },
        }
        torch.save(plot_data, save_plot_data_to)

    return fig
