import torch
import matplotlib.pyplot as plt
import argparse
import os

def load_attention_matrices(file_path):
    """Load attention matrices from a file."""
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")
        
    print(f"Loading attention matrices from {file_path}")
    return torch.load(file_path)

def display_matrix(matrix, title, save_path=None):
    """Display a matrix as a heatmap."""
    plt.figure(figsize=(10, 8))
    plt.imshow(matrix.numpy(), aspect='auto', cmap='viridis')
    plt.colorbar()
    plt.title(title)
    plt.xlabel("Key Position")
    plt.ylabel("Query Position")
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

def analyze_matrix_properties(attn_matrices, file_prefix="matrix_analysis"):
    """Analyze rank and spectral properties of attention matrices."""
    import torch
    import matplotlib.pyplot as plt
    import numpy as np
    import os
    
    os.makedirs(os.path.dirname(file_prefix), exist_ok=True)
    
    # Open a summary file
    with open(f"{file_prefix}_summary.txt", "w") as summary_file:
        # For each layer and type of attention matrix
        for key in attn_matrices.keys():
            if isinstance(key, int):  # Raw attention matrices
                matrix_type = "Raw Attention"
                layer_idx = key
            elif isinstance(key, str) and "_softmax" in key:  # Softmax attention
                matrix_type = "Softmax Attention"
                layer_idx = int(key.split("_")[0])
            else:
                continue
                
            matrices = attn_matrices[key][0]  # Shape: (n_heads, seq_len, seq_len)
            n_heads = matrices.shape[0]
            
            summary_file.write(f"=== {matrix_type} Layer {layer_idx+1} ===\n")
            
            # Analyze each head
            for head_idx in range(n_heads):
                matrix = matrices[head_idx]
                
                # 1. Compute SVD for rank analysis
                U, S, V = torch.linalg.svd(matrix)
                
                # 2. Estimate effective rank (singular values above threshold)
                threshold = 1e-5
                effective_rank = torch.sum(S > threshold).item()
                
                # 3. Calculate percentage of variance explained by top k singular values
                total_variance = torch.sum(S**2).item()
                var_explained_1 = (S[0]**2 / total_variance * 100).item()
                var_explained_2 = (torch.sum(S[:2]**2) / total_variance * 100).item()
                var_explained_5 = (torch.sum(S[:5]**2) / total_variance * 100).item()
                
                # 4. Write summary statistics
                summary_file.write(f"\nHead {head_idx+1}:\n")
                summary_file.write(f"  Matrix shape: {matrix.shape}\n")
                summary_file.write(f"  Effective rank: {effective_rank}\n")
                summary_file.write(f"  Top 1 singular value explains: {var_explained_1:.2f}%\n")
                summary_file.write(f"  Top 2 singular values explain: {var_explained_2:.2f}%\n")
                summary_file.write(f"  Top 5 singular values explain: {var_explained_5:.2f}%\n")
                
                # 5. Plot singular value spectrum
                plt.figure(figsize=(10, 6))
                plt.bar(range(len(S)), S.cpu().numpy())
                plt.title(f"{matrix_type} Layer {layer_idx+1} Head {head_idx+1} Singular Values")
                plt.xlabel("Index")
                plt.ylabel("Singular Value")
                plt.yscale('log')  # Log scale to better see the distribution
                plt.grid(True, which="both", ls="--", alpha=0.5)
                plt.tight_layout()
                plt.savefig(f"{file_prefix}_L{layer_idx+1}_H{head_idx+1}_singular_values.png")
                plt.close()
                
                # 6. Plot low-rank approximation error
                approx_errors = []
                ranks = list(range(1, min(20, len(S)) + 1))
                for r in ranks:
                    # Reconstruct with only top r singular values
                    S_r = torch.zeros_like(S)
                    S_r[:r] = S[:r]
                    matrix_r = torch.matmul(U, torch.matmul(torch.diag(S_r), V))
                    error = torch.norm(matrix - matrix_r) / torch.norm(matrix)
                    approx_errors.append(error.item())
                
                plt.figure(figsize=(10, 6))
                plt.plot(ranks, approx_errors, 'o-')
                plt.title(f"{matrix_type} Layer {layer_idx+1} Head {head_idx+1} Approximation Error")
                plt.xlabel("Rank")
                plt.ylabel("Relative Error")
                plt.grid(True)
                plt.tight_layout()
                plt.savefig(f"{file_prefix}_L{layer_idx+1}_H{head_idx+1}_approx_error.png")
                plt.close()
                
                # 7. Visualize rank-1 approximation vs original
                plt.figure(figsize=(15, 6))
                
                # Original matrix
                plt.subplot(1, 2, 1)
                plt.imshow(matrix.cpu().numpy(), aspect='auto', cmap='viridis')
                plt.colorbar()
                plt.title("Original Matrix")
                
                # Rank-1 approximation
                S_1 = torch.zeros_like(S)
                S_1[0] = S[0]
                matrix_1 = torch.matmul(U, torch.matmul(torch.diag(S_1), V))
                
                plt.subplot(1, 2, 2)
                plt.imshow(matrix_1.cpu().numpy(), aspect='auto', cmap='viridis')
                plt.colorbar()
                plt.title(f"Rank-1 Approximation ({var_explained_1:.1f}% variance)")
                
                plt.tight_layout()
                plt.savefig(f"{file_prefix}_L{layer_idx+1}_H{head_idx+1}_rank1_comparison.png")
                plt.close()
    
    print(f"Matrix analysis complete. Results saved with prefix: {file_prefix}")

# Add this to your reading script
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Display saved attention matrices")
    parser.add_argument("--file", type=str, required=True, help="Path to the saved attention matrices file")
    parser.add_argument("--layer", type=int, default=0, help="Layer index to display")
    parser.add_argument("--head", type=int, default=0, help="Head index to display")
    parser.add_argument("--save", type=str, default=None, help="Directory to save plots instead of displaying")
    args = parser.parse_args()
    
    # Load the matrices
    attn_matrices = load_attention_matrices(args.file)
    
    # Get the specified layer's raw attention scores
    if args.layer not in attn_matrices:
        print(f"Layer {args.layer} not found. Available keys: {list(attn_matrices.keys())}")
        exit(1)
        
    raw_attn = attn_matrices[args.layer][0]  # Shape: (n_heads, seq_len, seq_len)
    n_heads = raw_attn.shape[0]
    
    if args.head >= n_heads:
        print(f"Head index {args.head} out of range. There are {n_heads} heads.")
        exit(1)
    
    # Get the specified head's attention scores
    head_attn = raw_attn[args.head]

    print(f"Raw Attention Scores Layer {args.layer} and Head {args.head}")
    print(head_attn)

    softmax_key = f"{args.layer}_softmax"
    soft_attn = attn_matrices[softmax_key][0][args.head]
    print(f"Softmax Attention Scores Layer {args.layer} and Head {args.head}")
    print(soft_attn)
    
    # Display the matrix
    title = f"Layer {args.layer} Head {args.head} Raw Attention Scores"
    save_path = f"{args.save}/layer{args.layer}_head{args.head}_raw.png" if args.save else None
    display_matrix(head_attn, title, save_path)
    
    # Also display the softmax version
    softmax_key = f"{args.layer}_softmax"
    if softmax_key in attn_matrices:
        soft_attn = attn_matrices[softmax_key][0][args.head]
        title = f"Layer {args.layer} Head {args.head} Attention Weights (Softmax)"
        save_path = f"{args.save}/layer{args.layer}_head{args.head}_softmax.png" if args.save else None
        display_matrix(soft_attn, title, save_path)
    
    # Analyze matrix properties
    analyze_matrix_properties(attn_matrices, file_prefix=f"{args.save}/matrix_analysis")