from typing import List

import matplotlib.pyplot as plt
import numpy as np
from torch import Tensor

plt.ion()


def plot_heads(hidden_states,
               title: str = None,
               elevation: int = 15,
               azimuthal: int = -135,
               attention_weights: Tensor = None,
               mask: Tensor = None,
               head_dims: int = 3,
               save_dir: str = None,
               show: bool = True,
               axes_labels: List[str] = None,
               max_linewidth: float = 2.0,
               min_linewidth: float = 0.1,
               attention_threshold: float = 0.1,
               plot_attention_lines: bool = False,
               ticks: bool = False) -> None:
    """

    :param hidden_states: Tensor of shape (batch, length, dim)
    :param title: Title of plot (default set as iter_number t)
    :param elevation: Elevation angle for 3D plots
    :param azimuthal: Azimuthal angle for 3D plots
    :param attention_weights: Tensor of shape (batch, heads, length, length)
    :param mask: Attention mask (0: masked, 1: not-masked, 10000: global tokens)
    :param head_dims: Dimension of each head. must be 3 for 3d visualization
    :param save_dir: Save path for .png plots
    :param show: Displays plots if True
    :param axes_labels: List of axes labels
    :param max_linewidth: Maximum linewidth for attention lines
    :param min_linewidth: Minimum linewidth for attention lines
    :param attention_threshold: Show attention lines above this threshold
    :param plot_attention_lines: Plot attention lines if True
    :param ticks: Show axes ticks if True
    :return: None
    """

    hidden_states = hidden_states.detach().cpu()
    b, s, c = hidden_states.shape
    assert c % head_dims == 0, "dim not a multiple of head_dims (3)"
    num_heads = c // head_dims
    heads = hidden_states.reshape(b, s, num_heads, head_dims)
    fig = plt.figure(figsize=(8 * b, 6 * num_heads))

    if axes_labels is not None:
        if len(axes_labels) != c:
            print('length of axes labels should be equal to last dimension of hidden_states. '
                  'axes labels will be set to None')
            axes_labels = None

    if attention_weights is not None:
        if attention_weights.shape != (b, num_heads, s, s):
            print('NOT ABLE TO PLOT ATTENTION LINES')
            print(f'RESHAPE ATTENTION WEIGHTS TO STANDARD SHAPE {(b, num_heads, s, s)}')
            attention_weights = None

    for i in range(b):
        for j in range(num_heads):
            ax = fig.add_subplot(b, num_heads, i * num_heads + j + 1, projection='3d')
            head_j = heads[i, :, j, :]

            # Handle global tokens if any
            if mask is not None:
                global_mask = (mask[i] == 10000).to(head_j.device)
                global_tokens = head_j[global_mask]
                regular_tokens = head_j[~global_mask]
            else:
                global_tokens = np.empty((0, 3))
                regular_tokens = head_j

            xr, yr, zr = regular_tokens.T
            ax.scatter(xr, yr, zr, alpha=.8, s=10, linewidths=0.2, depthshade=False)
            if len(global_tokens) > 0:
                xg, yg, zg = global_tokens.T
                ax.scatter(xg, yg, zg, c='red', marker='*', s=200, edgecolors='b', linewidth=1.0, depthshade=False)

            # Draw attention-weighted connections
            if attention_weights is not None and plot_attention_lines is True:
                print('plotting attention lines............')
                attn = attention_weights[i, j]
                p, q = attn.shape
                attn = (attn - attn.min()) / (attn.max() - attn.min())  # Normalize 0-1

                for k in range(p):
                    for m in range(k + 1, q):  # Avoid duplicate connections
                        if attn[k, m] > attention_threshold:
                            linewidth = min_linewidth + (max_linewidth - min_linewidth) * attn[k, m]
                            ax.plot([head_j[k, 0], head_j[m, 0]],
                                    [head_j[k, 1], head_j[m, 1]],
                                    [head_j[k, 2], head_j[m, 2]],
                                    color='black',
                                    alpha=0.5 * attn[k, m].item(),  # Alpha scales with attention
                                    linewidth=linewidth)

            # Plot formatting
            ax.set_title(f'batch {i} | head {j} | {title}', fontsize=10)
            ax.grid(False)
            if axes_labels is not None:
                head_labels = axes_labels[j * num_heads: j * num_heads + 3]
                ax.set_xlabel(head_labels[0])
                ax.set_ylabel(head_labels[1])
                ax.set_zlabel(head_labels[2])

            if ticks:
                # set symmetric integer limits and ticks for less clutter
                x_min, x_max = ax.get_xlim()
                y_min, y_max = ax.get_ylim()
                z_min, z_max = ax.get_zlim()
                max_range = max(abs(x_max - x_min), abs(y_max - y_min), abs(z_max - z_min)) / 2.0
                max_range_int = int(np.round(max_range))
                ax.set_xlim(-max_range_int, max_range_int)
                ax.set_ylim(-max_range_int, max_range_int)
                ax.set_zlim(-max_range_int, max_range_int)
                tick_values = np.linspace(-max_range_int, max_range_int, num=3, dtype=int)
                ax.set_xticks(tick_values)
                ax.set_yticks(tick_values)
                ax.set_zticks(tick_values)
            else:
                ax.set_xticks([])
                ax.set_yticks([])
                ax.set_zticks([])

            ax.view_init(elev=elevation, azim=azimuthal)

    plt.tight_layout()
    if save_dir:
        plt.savefig(save_dir, dpi=300, bbox_inches='tight', pad_inches=.5)
    if show:
        plt.draw()
        plt.pause(0.001)
        input("PRESS ENTER TO CLOSE ALL FIGURES AND CONTINUE")
        plt.close()
    return fig


def plot_eigenvalues(eigenvalues_data: List[Tensor],
                     n: int = 3,
                     save_path: str = None) -> None:
    """
    Plots eigenvalues and spectral ratio between the leading and trailing eigenvalues
    :param eigenvalues_data: List of eigenvalues
    :param n: Number of top eigenvalues to plot (default: 3)
    :param save_path: Path to save .png plots
    :return: None
    """
    iterations = np.arange(len(eigenvalues_data))
    eigenvalues = np.array(eigenvalues_data)  # (num_iter, batch, dim)
    num_iter, batch, dim = eigenvalues.shape
    n = min(n, dim)

    for i in range(batch):
        widest_gap = eigenvalues[:, i,  0] / eigenvalues[:, i, -1]  # spectral gap b/w leading and trailing eigenvalues
        plt.subplot(batch, 2, (i * 2)+1)
        plt.plot(iterations, widest_gap, linewidth=1, markersize=6)
        plt.xlabel('Iteration', fontsize=10)
        plt.ylabel('Spectral Gap Ratio', fontsize=10)
        plt.title('Spectral gap between leading and trailing eigenvalue', fontsize=10)
        plt.grid(True)
        plt.yscale('log')  # Use log scale to better see all trends

        plt.subplot(batch, 2, (i * 2)+2)
        colors = plt.cm.viridis(np.linspace(0, 1, n))
        for k in range(n):
            plt.plot(iterations, eigenvalues[:, i,  k], linewidth=1, markersize=4,
                     label=f'λ{k + 1}', color=colors[k])

        plt.xlabel('Iteration', fontsize=10)
        plt.ylabel('Eigenvalue Magnitude', fontsize=10)
        plt.title(f'Top {n} Eigenvalue Evolution', fontsize=10)
        plt.grid(False)
        plt.legend()
        plt.yscale('log')  # Use log scale due to large values

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(f'{save_path}/spectral_analysis.png', dpi=300, bbox_inches='tight', pad_inches=.5)
    plt.draw()
    plt.pause(0.001)
    input("PRESS ENTER TO CLOSE ALL FIGURES AND CONTINUE")
    plt.close()
