import math

import einops
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import torch
import torch as t
import torch.nn.functional as F
from jaxtyping import Float
from plotly.subplots import make_subplots
from torch import Tensor

from adversarial_superposition.constants import DEVICE


def linear_lr(step, steps):
    return 1 - (step / steps)


def constant_lr(*_):
    return 1.0


def cosine_decay_lr(step, steps):
    return np.cos(0.5 * np.pi * step / (steps - 1))


def cast_element_to_nested_list(elem, shape: tuple):
    """
    Creates a nested list of shape `shape`, where every element is `elem`.
    Example: ("a", (2, 2)) -> [["a", "a"], ["a", "a"]]
    """
    if len(shape) == 0:
        return elem
    return [cast_element_to_nested_list(elem, shape[1:])] * shape[0]


def plot_features_in_2d(
    W: Float[Tensor, "*inst d_hidden feats"] | list[Float[Tensor, "d_hidden feats"]],
    colors: Float[Tensor, "inst feats"] | list[str] | list[list[str]] | None = None,
    title: str | None = None,
    subplot_titles: list[str] | None = None,
    allow_different_limits_across_subplots: bool = False,
    n_rows: int | None = None,
    return_fig: bool = False,
    ax: plt.Axes | None = None,
    linewidth: float | None = None,
    markersize: float | None = None,
):
    """
    Visualises superposition in 2D using Matplotlib.

    Args:
        W: Tensor of shape (*inst, d_hidden=2, feats) or list of tensors shape (d_hidden=2, feats).
        colors: Colors for the features. Can be a tensor, list of strings, or nested list.
        title: Overall title for the figure.
        subplot_titles: List of titles for each subplot (instance).
        allow_different_limits_across_subplots: If True, each subplot can have different axis limits.
        n_rows: Number of rows for subplots. If None, plots in a single row.
        return_fig: If True, returns the figure object instead of showing it.
        ax: Matplotlib Axes object to plot on. If provided, `n_rows` and `return_fig` are ignored,
            and the plot is drawn directly onto the provided axes.
        linewidth: Optional override for line width.
        markersize: Optional override for marker size.
    """
    # Handle input W type
    if isinstance(W, list):
        W_list_tensors = W
        W_tensor = t.stack(W_list_tensors)
    else:
        W_tensor = W

    # Ensure W is at least 3D [inst, d_hidden, feats]
    if W_tensor.ndim == 2:
        W_tensor = W_tensor.unsqueeze(0)

    # Transpose to [inst, feats, d_hidden] for easier plotting
    W_tensor = W_tensor.permute(0, 2, 1)
    n_instances, n_feats, d_hidden = W_tensor.shape

    assert d_hidden == 2, f"plot_features_in_2d requires d_hidden=2, got {d_hidden}"

    W_list = (
        W_tensor.detach().cpu().numpy()
    )  # Convert to numpy list of [feats, d_hidden=2]

    # Determine if creating a new figure or using provided ax
    create_figure = ax is None

    # Get plot characteristics
    limits_per_instance_list = (
        [np.abs(w).max() * 1.1 for w in W_list]  # Calculate limits based on W_list
        if allow_different_limits_across_subplots
        # Ensure tensor is on CPU before converting to numpy
        else [
            np.abs(W_tensor.detach().cpu().numpy()).max() * 1.1
            for _ in range(n_instances)
        ]  # Use global max
    )
    # Use provided linewidth/markersize or calculate default
    default_linewidth, default_markersize = (1, 4) if (n_feats >= 25) else (1.5, 6)
    linewidth = linewidth if linewidth is not None else default_linewidth
    markersize = markersize if markersize is not None else default_markersize

    # Set up subplot layout only if creating a new figure
    if create_figure:
        if n_rows is None:
            n_rows_fig, n_cols_fig = 1, n_instances
        else:
            n_rows_fig = n_rows
            n_cols_fig = math.ceil(n_instances / n_rows)

        # Adjust figure size dynamically
        fig_width = 3.5 * n_cols_fig
        fig_height = 3.0 * n_rows_fig  # Slightly shorter height
        fig, axs_array = plt.subplots(
            n_rows_fig, n_cols_fig, figsize=(fig_width, fig_height), squeeze=False
        )
        fig.subplots_adjust(
            bottom=0.1,
            top=(0.85 if title else 0.92),
            left=0.08,
            right=0.95,
            hspace=0.35,
            wspace=0.25,
        )
    else:
        # If ax is provided, ignore layout setup and use the single ax
        axs_array = np.array([[ax]])  # Treat provided ax as a 1x1 grid for consistency
        n_rows_fig, n_cols_fig = 1, 1
        fig = ax.figure
        # Ensure only one instance is plotted if ax is given
        if n_instances > 1:
            print(
                "Warning: Plotting multiple instances onto a single provided axis. Only the first instance will be shown."
            )
            n_instances = 1
            W_list = [W_list[0]]
            limits_per_instance_list = [limits_per_instance_list[0]]
            if subplot_titles:
                subplot_titles = [subplot_titles[0]]

    # --- Color processing --- Needs careful handling
    colors_list_processed = []
    if colors is None:
        colors_list_processed = cast_element_to_nested_list(
            "black", (n_instances, n_feats)
        )
    elif isinstance(colors, str):
        colors_list_processed = cast_element_to_nested_list(
            colors, (n_instances, n_feats)
        )
    elif isinstance(colors, list):
        if not colors:  # Empty list
            colors_list_processed = cast_element_to_nested_list(
                "black", (n_instances, n_feats)
            )
        elif isinstance(colors[0], str):  # List of strings (apply to all instances)
            if len(colors) == n_feats:
                colors_list_processed = [colors for _ in range(n_instances)]
            else:
                print(
                    f"Warning: Provided list of colors length ({len(colors)}) doesn't match n_feats ({n_feats}). Using black."
                )
                colors_list_processed = cast_element_to_nested_list(
                    "black", (n_instances, n_feats)
                )
        elif isinstance(colors[0], list):  # Nested list [instances, feats]
            if len(colors) == n_instances and all(len(c) == n_feats for c in colors):
                colors_list_processed = colors
            else:
                print(
                    f"Warning: Nested color list shape mismatch. Expected ({n_instances}, {n_feats}). Using black."
                )
                colors_list_processed = cast_element_to_nested_list(
                    "black", (n_instances, n_feats)
                )
        else:
            colors_list_processed = cast_element_to_nested_list(
                "black", (n_instances, n_feats)
            )

    elif isinstance(colors, Tensor):
        if colors.shape == (n_instances, n_feats):
            # Assuming tensor contains values to map to a colormap (e.g., Viridis)
            norm = plt.Normalize(vmin=colors.min().item(), vmax=colors.max().item())
            cmap = plt.get_cmap("viridis")
            colors_list_processed = [
                [cmap(norm(v)) for v in color_row.tolist()] for color_row in colors
            ]
        else:
            print(
                f"Warning: Color tensor shape mismatch. Expected ({n_instances}, {n_feats}). Using black."
            )
            colors_list_processed = cast_element_to_nested_list(
                "black", (n_instances, n_feats)
            )
    else:
        colors_list_processed = cast_element_to_nested_list(
            "black", (n_instances, n_feats)
        )

    # Ensure colors_list_processed has the right dimensions even after warnings
    if len(colors_list_processed) != n_instances or (
        n_feats > 0 and len(colors_list_processed[0]) != n_feats
    ):
        print("Error during color processing. Defaulting to black.")
        colors_list_processed = cast_element_to_nested_list(
            "black", (n_instances, n_feats)
        )

    # --- Plotting Loop ---
    for instance_idx in range(n_instances):
        row = instance_idx // n_cols_fig
        col = instance_idx % n_cols_fig

        current_ax = axs_array[row, col]
        limit_value = limits_per_instance_list[instance_idx]

        # Set limits and aspect ratio
        current_ax.set_xlim(-limit_value, limit_value)
        current_ax.set_ylim(-limit_value, limit_value)
        current_ax.set_aspect("equal", adjustable="box")
        current_ax.grid(True, linestyle="--", alpha=0.6)

        # Add features (lines and markers)
        features_data = W_list[instance_idx]  # Shape [n_feats, d_hidden=2]
        instance_colors = colors_list_processed[instance_idx]

        for feature_idx in range(n_feats):
            x, y = features_data[feature_idx, 0], features_data[feature_idx, 1]
            color = instance_colors[feature_idx]
            # Plot line from origin
            current_ax.plot([0, x], [0, y], color=color, lw=linewidth, zorder=2)
            # Plot marker at the end
            current_ax.plot(
                x, y, color=color, marker="o", markersize=markersize, zorder=3
            )

        # Add titles & labels
        if subplot_titles and instance_idx < len(subplot_titles):
            current_ax.set_title(subplot_titles[instance_idx], fontsize=10)
        if row == n_rows_fig - 1:  # Only add x-label to bottom row
            current_ax.set_xlabel("Dimension 1", fontsize=11)
        if col == 0:  # Only add y-label to left column
            current_ax.set_ylabel("Dimension 2", fontsize=11)

        # Hide tick labels for inner plots
        if row < n_rows_fig - 1:
            current_ax.tick_params(axis="x", labelbottom=False)
        if col > 0:
            current_ax.tick_params(axis="y", labelleft=False)

        current_ax.tick_params(axis="both", labelsize=8)

    # --- Final Figure Adjustments ---
    # Hide unused axes if creating a figure
    if create_figure:
        for idx in range(instance_idx + 1, n_rows_fig * n_cols_fig):
            row = idx // n_cols_fig
            col = idx % n_cols_fig
            axs_array[row, col].axis("off")

        if title:
            fig.suptitle(title, fontsize=14)

        # Adjust layout tightly
        # fig.tight_layout(rect=[0, 0.03, 1, 0.95] if title else [0, 0.03, 1, 1])

    # --- Return or Show ---
    if create_figure and return_fig:
        return fig  # Return the figure object
    elif create_figure:
        plt.show()
        plt.close(fig)  # Close the figure to free memory
        return None  # Indicate no figure object is returned
    else:  # If ax was provided, return it
        return ax


def plots(m, neuron_basis=False, classification=True, log_to_wandb=False):
    """
    Plot model weights and activations.

    Args:
        m: Model to plot
        neuron_basis: Whether to plot in neuron basis
        classification: Whether this is a classification model
        outer: Whether to use outer colors
        log_to_wandb: If True, return figures for wandb logging instead of displaying

    Returns:
        If log_to_wandb is True, returns a dict of figures that can be logged to wandb
    """
    figures = {}

    if m.cfg.n_hidden == 2:
        if classification:
            colors = generate_plotly_colors_classification(
                m.cfg.n_classes, m.cfg.n_features_per_class
            )
        else:
            colors = generate_plotly_colors(m.cfg.n_features)

        weights_fig = plot_features_in_2d(
            m.W,
            colors=colors,
            title="Weights of model",
            allow_different_limits_across_subplots=True,
            return_fig=log_to_wandb,  # Return figure instead of displaying if log_to_wandb is True
        )

        if log_to_wandb:
            print("Saved weights_2d")
            figures["weights_2d"] = weights_fig

        with t.inference_mode():
            if classification:
                batch, labels = m.generate_batch(200)
            else:
                batch = m.generate_batch(200)

            hidden = einops.einsum(
                batch,
                m.W,
                "batch_size instances features, instances hidden features -> instances hidden batch_size",
            )

        activations_fig = plot_features_in_2d(
            hidden,
            title="Activations of a random batch of data",
            allow_different_limits_across_subplots=True,
            return_fig=log_to_wandb,
        )

        if log_to_wandb:
            print("Saved activations_2d")
            figures["activations_2d"] = activations_fig

    if log_to_wandb:
        return figures


def plot_w(tensor, log_to_wandb=False):
    """
    Plot weights tensor as heatmaps.

    Args:
        tensor: Tensor to plot
        log_to_wandb: If True, return figure for wandb logging instead of displaying

    Returns:
        If log_to_wandb is True, returns the figure that can be logged to wandb
    """
    # Ensure tensor is 3D
    if len(tensor.shape) != 3:
        raise ValueError("Expected 3D tensor (N x H x W)")

    N, H, W = tensor.shape
    tensor = tensor.detach().cpu()
    # tensor, _ = utils.rearrange_full_tensor(tensor)
    tensor_np = tensor.numpy()

    # Create subplot grid with minimal spacing
    fig = make_subplots(
        rows=1,
        cols=N,
        subplot_titles=[f"Layer {i + 1}" for i in range(N)],
        horizontal_spacing=0.01,
    )

    # Create a custom colorscale with blue-grey-red transition
    colorscale = [
        [0, "rgb(0,0,255)"],  # Dark blue for most negative values
        [0.2, "rgb(100,100,255)"],  # Medium blue
        [0.4, "rgb(180,180,180)"],  # Light grey
        [0.5, "rgb(128,128,128)"],  # Medium grey for zero
        [0.6, "rgb(180,180,180)"],  # Light grey
        [0.8, "rgb(255,100,100)"],  # Medium red
        [1, "rgb(255,0,0)"],  # Dark red for most positive values
    ]

    # Find global min/max for consistent color scaling
    vmin = tensor_np.min()
    vmax = tensor_np.max()

    # Ensure the color scale is centered at zero
    abs_max = max(abs(vmin), abs(vmax))
    vmin = -abs_max
    vmax = abs_max

    # Add each layer as a heatmap
    for i in range(N):
        fig.add_trace(
            go.Heatmap(
                z=tensor_np[i].T,
                colorscale=colorscale,
                zmid=0,
                zmin=vmin,
                zmax=vmax,
                showscale=True if i == N - 1 else False,
                hoverongaps=False,
            ),
            row=1,
            col=i + 1,
        )

        # Update axes for square cells with minimal padding
        fig.update_xaxes(
            constrain="domain",
            scaleanchor=f"y{i + 1}",
            scaleratio=1,
            title="Row Index" if i == 0 else None,
            row=1,
            col=i + 1,
            showgrid=False,
            fixedrange=True,
        )
        fig.update_yaxes(
            constrain="domain",
            title="Column Index" if i == 0 else None,
            row=1,
            col=i + 1,
            showgrid=False,
            fixedrange=True,
        )

    base_width = 300
    total_width = base_width * N

    # Update overall layout
    fig.update_layout(
        title="3D Tensor Heatmaps",
        width=min(total_width, 1100),
        height=base_width,
        showlegend=False,
        margin=dict(l=50, r=50, t=50, b=50),
        plot_bgcolor="white",
    )

    if log_to_wandb:
        return fig
    else:
        fig.show()


def generate_plotly_colors(n):
    """
    Generate a list of Plotly's default colors.

    Args:
        n (int): Number of colors to generate

    Returns:
        list: List of color hex strings
    """
    # Plotly's default color sequence
    base_colors = [
        "#1f77b4",  # muted blue
        "#ff7f0e",  # safety orange
        "#2ca02c",  # cooked asparagus green
        "#d62728",  # brick red
        "#9467bd",  # muted purple
        "#8c564b",  # chestnut brown
        "#e377c2",  # raspberry yogurt pink
        "#7f7f7f",  # middle gray
        "#bcbd22",  # curry yellow-green
        "#17becf",  # blue-teal
    ]

    # For n <= len(base_colors), return just the first n colors
    if n <= len(base_colors):
        return base_colors[:n]

    # For n > len(base_colors), cycle through the colors
    return [base_colors[i % len(base_colors)] for i in range(n)]


def generate_plotly_colors_classification(n_classes, features_per_class):
    """
    Generate a list of Plotly's default colors.

    Args:
        n (int): Number of colors to generate

    Returns:
        list: List of color hex strings
    """
    # Plotly's default color sequence
    base_colors = [
        "#1f77b4",  # muted blue
        "#ff7f0e",  # safety orange
        "#2ca02c",  # cooked asparagus green
        "#d62728",  # brick red
        "#9467bd",  # muted purple
        "#8c564b",  # chestnut brown
        "#e377c2",  # raspberry yogurt pink
        "#7f7f7f",  # middle gray
        "#bcbd22",  # curry yellow-green
        "#17becf",  # blue-teal
    ]

    base_colors_names = [
        "muted blue",
        "safety orange",
        "cooked asparagus green",
        "brick red",
        "muted purple",
        "chestnut brown",
        "raspberry yogurt pink",
        "middle gray",
        "curry yellow-green",
        "blue-teal",
    ]

    # For n > len(base_colors), cycle through the colors
    class_colors = [base_colors[i % len(base_colors)] for i in range(n_classes)]
    class_colors_names = [
        base_colors_names[i % len(base_colors_names)] for i in range(n_classes)
    ]
    for i, c in enumerate(class_colors_names):
        print(f"Class {i+1}: {c}")
    return [c for c in class_colors for _ in range(features_per_class)]


def cosine_similarity_efficient(x1: t.Tensor, x2: t.Tensor) -> t.Tensor:
    """
    Calculate cosine similarity between two sets of vectors efficiently.

    Args:
        x1 (torch.Tensor): First input tensor of shape (batch_size, vector_dim)
        x2 (torch.Tensor): Second input tensor of shape (batch_size, vector_dim)

    Returns:
        torch.Tensor: Cosine similarity scores of shape (batch_size,)
    """
    # Normalize the vectors along the feature dimension
    x1_normalized = F.normalize(x1, p=2, dim=1)
    x2_normalized = F.normalize(x2, p=2, dim=1)

    # Calculate cosine similarity using matrix multiplication
    return t.sum(x1_normalized * x2_normalized, dim=1)


def cosine_similarity_matrix(x1: t.Tensor, x2: t.Tensor) -> t.Tensor:
    """
    Calculate cosine similarity matrix between two sets of vectors efficiently.
    Useful when you need to compute similarities between all pairs of vectors.

    Args:
        x1 (torch.Tensor): First input tensor of shape (n, vector_dim)
        x2 (torch.Tensor): Second input tensor of shape (m, vector_dim)

    Returns:
        torch.Tensor: Cosine similarity matrix of shape (n, m)
    """
    # Normalize the vectors along the feature dimension
    x1_normalized = F.normalize(x1, p=2, dim=1)
    x2_normalized = F.normalize(x2, p=2, dim=1)

    # Calculate similarity matrix using matrix multiplication
    return t.mm(x1_normalized, x2_normalized.t())


def pairwise_cosine_similarity(x: t.Tensor) -> t.Tensor:
    """
    Calculate the pairwise cosine similarity between all vectors in a single matrix.

    Args:
        x (torch.Tensor): Input tensor of shape (n, vector_dim)

    Returns:
        torch.Tensor: Cosine similarity matrix of shape (n, n), where the element
                      (i, j) is the cosine similarity between vector i and vector j.
    """
    return cosine_similarity_matrix(x, x)


def project_input_onto_latent_space(m):
    perfect = t.zeros(1, m.cfg.n_instances, m.cfg.n_features)

    for i in range(m.cfg.n_classes):
        perfect[
            :,
            :,
            i * m.cfg.n_classes : i * m.cfg.n_classes + i * m.cfg.n_features_per_class,
        ] = 1.0
        perfect = perfect.to(DEVICE)

        with t.inference_mode():
            batch = perfect
            h = einops.einsum(
                batch,
                m.W,
                "batch_size instances features, instances hidden features -> instances hidden batch_size",
            )

    return h


def get_average_class_vectors(m):
    """Get the average of all the vectors for each class in latent space"""

    class_latent_features = t.zeros(m.W.shape[0], m.W.shape[1], m.cfg.n_classes)
    for idx, i in enumerate(range(0, m.cfg.n_features, m.cfg.n_features_per_class)):
        class_latent_features[:, :, idx] = m.W[
            :, :, i : i + m.cfg.n_features_per_class
        ].mean(-1)
    return class_latent_features


def randomize_signs_torch(tensor):
    # Generate random signs (-1 or 1) with same shape as input
    random_signs = torch.randint(0, 2, tensor.shape, device=tensor.device) * 2.0 - 1.0
    # Create mask for non-negative values
    mask = tensor >= 0
    # Only apply random signs to non-negative values
    return tensor * torch.where(mask, random_signs, torch.ones_like(tensor))


def plot_2d(vectors, points):
    plt.figure(figsize=(8, 8))

    # Convert tensor to numpy for calculations
    vectors_np = vectors.numpy()

    # Calculate appropriate axis limits with 20% padding
    max_abs_x = np.max(np.abs(vectors_np[0, :])) * 1.2
    max_abs_y = np.max(np.abs(vectors_np[1, :])) * 1.2
    max_abs = max(max_abs_x, max_abs_y)

    # Plot each vector with a different color
    colors = ["r", "g", "b", "c", "m", "y"]
    for i in range(vectors_np.shape[1]):
        # Plot vector
        plt.quiver(
            0,
            0,
            vectors_np[0, i],
            vectors_np[1, i],
            angles="xy",
            scale_units="xy",
            scale=max_abs / 5,
            color=colors[i],
            label=f"Vector {i+1}",
            width=0.008,
        )

        # Add text label next to vector tip
        # Position the label slightly offset from the vector tip
        label_x = vectors_np[0, i] * 1.1  # 10% offset from tip
        label_y = vectors_np[1, i] * 1.1
        plt.text(
            label_x, label_y, f"{i+1}", color=colors[i], fontweight="bold", fontsize=12
        )

    # Set equal limits on both axes to maintain vector proportions
    plt.xlim(-max_abs, max_abs)
    plt.ylim(-max_abs, max_abs)

    # Add grid, legend, and labels
    plt.grid(True)
    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.title("Activation Space")
    plt.axhline(y=0, color="k", linestyle="-", linewidth=0.5)
    plt.axvline(x=0, color="k", linestyle="-", linewidth=0.5)
    plt.gca().set_aspect("equal")

    plt.tight_layout()

    for i, (point, l) in enumerate(points):
        print(f"Plotting {point[0]} {point[1]}")
        plt.scatter(point[0], point[1], color="black", s=100, label=f"Point {i}")
        plt.text(
            point[0] + 0.1, point[1] + 0.1, f"{l}", fontsize=12, fontweight="bold"
        )  # small offset of 0.1 in x and y

    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

    plt.show()


def visualize_3d_weight_matrices(
    weights, figsize=(20, 8), title="Weight Matrix Visualization"
):
    """
    Visualize an 8×3×5 weight matrix as 8 side-by-side 3D plots,
    each showing 5 vectors in 3D space.

    Args:
        weights: numpy array or tensor of shape (8, 3, 5)
        figsize: size of the figure (width, height)
        title: title for the overall figure
    """
    # Enable interactive mode
    plt.ion()

    # Convert to numpy if it's a tensor
    if hasattr(weights, "detach"):
        weights = weights.detach().cpu().numpy()
    elif hasattr(weights, "numpy"):
        weights = weights.numpy()

    n_instances = weights.shape[0]
    n_dims = weights.shape[1]
    n_vectors = weights.shape[2]

    # Create a colormap for the vectors
    colors = plt.cm.rainbow(np.linspace(0, 1, n_vectors))

    # Create a figure with a grid of subplots
    fig = plt.figure(figsize=figsize)
    fig.suptitle(title, fontsize=16)

    # Create a grid with 2 rows and 4 columns
    gs = gridspec.GridSpec(2, 4)

    # Calculate axis limits for consistent scaling across plots
    all_weights = weights.reshape(-1)
    max_val = np.max(np.abs(all_weights))
    axis_lim = max_val * 1.2  # Add 20% margin

    # Create each subplot
    for i in range(n_instances):
        # Calculate row and column position
        row = i // 4
        col = i % 4

        # Create 3D subplot
        ax = fig.add_subplot(gs[row, col], projection="3d")

        # Origin point
        origin = np.zeros(3)

        # Plot each vector
        for j in range(n_vectors):
            vector = weights[i, :, j]
            ax.quiver(
                origin[0],
                origin[1],
                origin[2],
                vector[0],
                vector[1],
                vector[2],
                color=colors[j],
                arrow_length_ratio=0.1,
                label=f"Vector {j+1}",
            )

        # Set consistent axis limits
        ax.set_xlim([-axis_lim, axis_lim])
        ax.set_ylim([-axis_lim, axis_lim])
        ax.set_zlim([-axis_lim, axis_lim])

        # Set labels
        ax.set_xlabel("X")
        ax.set_ylabel("Y")
        ax.set_zlabel("Z")
        ax.set_title(f"Instance {i+1}")

        # Add legend to the first plot only to avoid redundancy
        if i == 0:
            ax.legend(loc="upper right", bbox_to_anchor=(1.1, 1.1))

    plt.tight_layout()
    plt.subplots_adjust(top=0.9)  # Adjust for the main title

    # Show the plot
    plt.show()
