from collections import Counter
from xml.parsers.expat import model
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import os
import utilities.data_generation as dg
import json
from torch.optim.lr_scheduler import LambdaLR

def visualize_attention_matrices_for_sample(model, seq_length, learn_function, relevant_coords, \
                                                device, folder_name="attention_matrices_2"):
    """
    Visualize the attention matrices (Q*K^T) for each head and layer for a given sample.
    """
    os.makedirs(folder_name, exist_ok=True)
    attn_matrices = {}

    sample, _ = dg.generate_data(1, seq_length, learn_function, relevant_coords)
    
    def get_attention_hook(layer_idx):
        def hook(module, input, output):
            # input[0]: (batch, seq_len, d_model)
            x = input[0]
            batch_size, seq_len, d_model = x.shape
            n_heads = module.num_heads
            head_dim = d_model // n_heads
            
            # Compute Q, K
            Q = module.W_Q(x).view(batch_size, seq_len, n_heads, head_dim).transpose(1, 2)  # (B, H, n, d)
            # K = module.W_K(x).view(batch_size, seq_len, n_heads, head_dim).transpose(1, 2)  # (B, H, n, d)
            K = Q
            
            # Compute attention matrices: Q * K^T / sqrt(d_k)
            attention_matrices = torch.matmul(Q, K.transpose(-2, -1)) / (head_dim ** 0.5)  # (B, H, n, n)
            
            # Store before softmax to see raw attention scores
            attn_matrices[layer_idx] = attention_matrices.detach().cpu()
            
            # Also store after softmax to see actual attention weights
            attn_weights = F.softmax(attention_matrices, dim=-1)
            attn_matrices[f"{layer_idx}_softmax"] = attn_weights.detach().cpu()
        return hook
    
    # Register hooks
    handles = []
    for i, layer in enumerate(model.layers):
        handles.append(layer.register_forward_hook(get_attention_hook(i)))
    
    # Run a forward pass with the sample
    model.eval()
    with torch.no_grad():
        sample = sample.to(device).float()
        print(sample.shape)
        _ = model(sample)
    
    # Remove hooks. # This is important to avoid memory leaks.
    for h in handles:
        h.remove()

    # Save the attention matrices to a file.
    torch.save(attn_matrices, f"{folder_name}/attention_matrices.pt")
    
    # Plot and save attention matrices
    for layer_idx in range(len(model.layers)):
        # Raw attention scores (before softmax)
        raw_attn = attn_matrices[layer_idx][0]  # Shape: (n_heads, seq_length, seq_length)
        n_heads = raw_attn.shape[0]
        
        for head in range(n_heads):
            # 1. Raw attention scores
            plt.figure(figsize=(10, 8))
            plt.imshow(raw_attn[head].numpy(), aspect='auto', cmap='viridis')
            plt.colorbar()
            plt.title(f"Layer {layer_idx+1} Head {head+1} Raw Attention Scores")
            plt.xlabel("Key Position")
            plt.ylabel("Query Position")
            plt.tight_layout()
            plt.savefig(f"{folder_name}/layer{layer_idx+1}_head{head+1}_raw_attention.png")
            plt.close()
            
            # 2. Softmax attention weights
            soft_attn = attn_matrices[f"{layer_idx}_softmax"][0]
            plt.figure(figsize=(10, 8))
            plt.imshow(soft_attn[head].numpy(), aspect='auto', cmap='viridis')
            plt.colorbar()
            plt.title(f"Layer {layer_idx+1} Head {head+1} Attention Weights")
            plt.xlabel("Key Position")
            plt.ylabel("Query Position")
            plt.tight_layout()
            plt.savefig(f"{folder_name}/layer{layer_idx+1}_head{head+1}_attention_weights.png")
            plt.close()

def visualize_attention_matrices(model, folder_name="attention_matrices"):
    """
    Visualize the W_Q, W_K, W_V matrices and their eigenvalue spectra for each layer.
    """
    
    os.makedirs(folder_name, exist_ok=True)

    # Visualize the embedding matrix if it is a nn.Linear
    if isinstance(model.embedding, nn.Linear):
        W_emb = model.embedding.weight.detach().cpu().numpy()  # (d_model, vocab_size)
        plt.figure(figsize=(10, 8))
        plt.imshow(W_emb, aspect='auto', cmap='viridis')
        plt.colorbar()
        plt.title("Embedding Matrix (nn.Linear)")
        nrows, ncols = W_emb.shape
        for i in range(nrows):
            for j in range(ncols):
                plt.text(j, i, f"{W_emb[i, j]:.2f}", ha='center', va='center',
                         color='white' if W_emb[i, j] < (W_emb.max()/2) else 'black', fontsize=6)
        plt.tight_layout()
        plt.savefig(f"{folder_name}/embedding_matrix_linear.png")
        plt.close()
    
    # For each layer, extract and visualize matrices
    for layer_idx, layer in enumerate(model.layers):
        # Get weight matrices
        W_Q = layer.W_Q.weight.detach().cpu().numpy()  # (d_model, d_model)
        # W_K = layer.W_K.weight.detach().cpu().numpy()
        W_K = W_Q
        W_V = layer.W_V.weight.detach().cpu().numpy()
        
        # For each matrix (W_Q, W_K, W_V)
        for name, W in [("W_Q", W_Q), ("W_K", W_K), ("W_V", W_V)]:
            # 1. Visualize the matrix itself with values in each cell
            plt.figure(figsize=(10, 8))
            plt.imshow(W, aspect='auto', cmap='viridis')
            plt.colorbar()
            plt.title(f"Layer {layer_idx+1} {name}")
            
            # Print values in each cell
            nrows, ncols = W.shape
            for i in range(nrows):
                for j in range(ncols):
                    plt.text(j, i, f"{W[i, j]:.2f}", ha='center', va='center', color='white' if W[i, j] < (W.max()/2) else 'black', fontsize=6)
            
            plt.tight_layout()
            plt.savefig(f"{folder_name}/layer{layer_idx+1}_{name}_matrix.png")
            plt.close()
            
            # 2. Compute eigenvalues
            eigvals = np.linalg.eigvals(W)
            
            # 3. Plot eigenvalue spectrum in complex plane
            plt.figure(figsize=(10, 8))
            plt.scatter(eigvals.real, eigvals.imag, alpha=0.7)
            plt.grid(True)
            plt.axhline(y=0, color='k', linestyle='--', alpha=0.3)
            plt.axvline(x=0, color='k', linestyle='--', alpha=0.3)
            plt.xlabel("Real part")
            plt.ylabel("Imaginary part")
            plt.title(f"Layer {layer_idx+1} {name} Eigenvalue Spectrum")
            plt.tight_layout()
            plt.savefig(f"{folder_name}/layer{layer_idx+1}_{name}_spectrum.png")
            plt.close()
            
            # 4. Plot eigenvalue magnitude distribution
            plt.figure(figsize=(10, 6))
            magnitudes = np.abs(eigvals)
            plt.hist(magnitudes, bins=20, alpha=0.7)
            plt.xlabel("Magnitude")
            plt.ylabel("Count")
            plt.title(f"Layer {layer_idx+1} {name} Eigenvalue Magnitudes")
            plt.grid(True)
            plt.tight_layout()
            plt.savefig(f"{folder_name}/layer{layer_idx+1}_{name}_magnitudes.png")
            plt.close()
            
            # 5. Calculate and print important spectral statistics
            with open(f"{folder_name}/spectral_stats.txt", "a") as f:
                f.write(f"Layer {layer_idx+1} {name}:\n")
                f.write(f"  Spectral norm: {np.max(magnitudes):.4f}\n")
                f.write(f"  Condition number: {np.max(magnitudes)/np.min(magnitudes):.4f}\n")
                f.write(f"  Mean eigenvalue magnitude: {np.mean(magnitudes):.4f}\n")
                f.write(f"  Max real part: {np.max(eigvals.real):.4f}\n")
                f.write(f"  Min real part: {np.min(eigvals.real):.4f}\n")
                f.write(f"  Rank: {np.linalg.matrix_rank(W)}\n\n")

########################################################################
##################### OLD VISUALIZATION CODE
########################################################################

# if args.visualize_matrices:
#     print("\nVisualizing attention matrices and their spectra...")
#     tf.visualize_attention_matrices(model, folder_name=folder_name)
#     print(f"Matrix visualizations saved.")
    
#     print("\nVisualizing attention matrices for a sample...")
#     # Test with a single example from each class if possible
#     sample = dg.generate_data(1, seq_length, learn_function, relevant_coords)
    
#     tf.visualize_attention_matrices_for_sample(
#         model, seq_length, learn_function, relevant_coords,\
#              device, folder_name=f"{folder_name}/attention_class0"
#     )
#     print("Attention matrix visualizations saved.")