import matplotlib.pyplot as plt
import numpy as np
import torch
import os
from tqdm import tqdm


def visualize_attention_grid(
    attn_tensor,
    attn_mask=None,
    save_path="attention_plots/attention_grid.png",
    tokens=None,
    cmap="Reds",
):
    """
    Visualizes a (L, H, S, S) attention tensor as a single figure with L rows and H columns of heatmaps.

    Args:
        attn_tensor (torch.Tensor or np.ndarray): Attention tensor of shape (L, H, S, S)
        save_path (str, optional): Path to save the figure (default: "attention_plots/attention_grid.png")
        tokens (list, optional): List of tokens to label axes (default: None)
        cmap (str, optional): Color map for heatmaps (default: "viridis")
    """
    if attn_mask is not None:
        attn_tensor = attn_tensor.masked_fill(attn_mask.logical_not(), float("nan"))
    if isinstance(attn_tensor, torch.Tensor):
        attn_tensor = attn_tensor.detach().cpu().to(torch.float32).numpy()

    L, H, seq_len, _ = attn_tensor.shape  # Layers, Heads, Seq_len, Seq_len

    os.makedirs(
        os.path.dirname(save_path), exist_ok=True
    )  # Ensure save directory exists

    fig, axes = plt.subplots(
        L, H, figsize=(H * 3, L * 3)
    )  # Create L x H grid of subplots

    # Progress bar tracking
    total_plots = L * H
    with tqdm(total=total_plots, desc="Plotting Attention Heads") as pbar:
        for layer in range(L):
            for head in range(H):
                ax = (
                    axes[layer, head] if L > 1 and H > 1 else axes[max(layer, head)]
                )  # Handle edge cases
                attn_map = attn_tensor[layer, head]
                im = ax.imshow(attn_map, cmap=cmap, aspect="auto")  # Plot the heatmap

                # Add numerical values to heatmap
                for i in range(seq_len):
                    for j in range(seq_len):
                        text_color = "white" if attn_map[i, j] > 0.5 else "black"
                        ax.text(
                            j,
                            i,
                            f"{attn_map[i, j]:.2f}",
                            ha="center",
                            va="center",
                            color=text_color,
                            fontsize=6,
                        )

                ax.set_title(f"L{layer+1} H{head+1}", fontsize=8)

                if tokens and layer == L - 1:  # Only label x-axis on last row
                    ax.set_xticks(range(len(tokens)))
                    ax.set_xticklabels(tokens, rotation=90, fontsize=6)
                else:
                    ax.set_xticks([])

                if tokens and head == 0:  # Only label y-axis on first column
                    ax.set_yticks(range(len(tokens)))
                    ax.set_yticklabels(tokens, fontsize=6)
                else:
                    ax.set_yticks([])
                pbar.update(1)  # Update progress bar

    plt.tight_layout()
    plt.savefig(save_path, dpi=100)
    plt.close()
    print(f"Attention grid saved at: {save_path}")
