import pickle
import numpy as np
import torch
import os
import logging

from vlmeval.vlm.eagle.save_attention_features import *
import warnings
warnings.simplefilter('ignore', UserWarning)

# Load 5 encoder data
attn_all = load_vision_attention_scores('attention_analysis/eagle_x4_7b_mme_all_merged.pkl')
attn_0 = load_vision_attention_scores('attention_analysis/eagle_x4_7b_mme_0_merged.pkl')
attn_1 = load_vision_attention_scores('attention_analysis/eagle_x4_7b_mme_1_merged.pkl')
attn_2 = load_vision_attention_scores('attention_analysis/eagle_x4_7b_mme_2_merged.pkl')
attn_3 = load_vision_attention_scores('attention_analysis/eagle_x4_7b_mme_3_merged.pkl')



import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import pearsonr, spearmanr
from scipy.spatial.distance import jensenshannon
import seaborn as sns

def simple_attention_analysis_5enc(attn_all, attn_0, attn_1, attn_2, attn_3):
    """
    Simple analysis for 1D attention vectors (600 vision tokens) - 4 encoder version.
    """
    n_samples = min(len(attn_all), len(attn_0), len(attn_1), len(attn_2), len(attn_3))
    print(f"Analyzing {n_samples} samples")
    
    results = {
        'pearson_corr': [[], [], [], []],  # One list per encoder (5 encoders)
        'spearman_corr': [[], [], [], []],
        'mse': [[], [], [], []],  # Mean Squared Error
        'mae': [[], [], [], []],  # Mean Absolute Error
        'js_divergence': [[], [], [], []],  # Jensen-Shannon divergence
        'kl_divergence': [[], [], [], []],  # KL divergence
    }

    total_nan_all = 0
    total_nan_enc = [0, 0, 0, 0]
    samples_with_nan = 0

    for i in range(n_samples):
        # Get normalized attention vectors

        # all_attn = attn_all[i]['vision_attention_raw'].numpy().flatten()
        # encoders = [
        #     attn_0[i]['vision_attention_raw'].numpy().flatten(),
        #     attn_1[i]['vision_attention_raw'].numpy().flatten(), 
        #     attn_2[i]['vision_attention_raw'].numpy().flatten(),
        #     attn_3[i]['vision_attention_raw'].numpy().flatten()
        # ]

        all_attn = attn_all[i]['vision_attention_raw'].to(torch.float32).numpy().flatten()
        encoders = [
            attn_0[i]['vision_attention_raw'].to(torch.float32).numpy().flatten(),
            attn_1[i]['vision_attention_raw'].to(torch.float32).numpy().flatten(), 
            attn_2[i]['vision_attention_raw'].to(torch.float32).numpy().flatten(),
            attn_3[i]['vision_attention_raw'].to(torch.float32).numpy().flatten()
        ]

        # Check and count NaNs
        cnt_all_nan = int(np.isnan(all_attn).sum())
        cnt_enc_nan = [int(np.isnan(enc).sum()) for enc in encoders]
        if cnt_all_nan > 0 or any(c > 0 for c in cnt_enc_nan):
            samples_with_nan += 1
            print(f"[NaN] sample {i}: all_attn={cnt_all_nan}, encoders={cnt_enc_nan}")
        total_nan_all += cnt_all_nan
        total_nan_enc = [a + b for a, b in zip(total_nan_enc, cnt_enc_nan)]

        # Normalize to probability distributions
        all_attn_norm = all_attn / (all_attn.sum() + 1e-8)
        encoders_norm = [enc / (enc.sum() + 1e-8) for enc in encoders]
        
        for j, enc_attn in enumerate(encoders):

            min_attn_len = min([len(all_attn)] + [len(enc) for enc in encoders])
            all_attn = all_attn[:min_attn_len]
            all_attn_norm = all_attn_norm[:min_attn_len]
            encoders = [enc[:min_attn_len] for enc in encoders]
            encoders_norm = [enc[:min_attn_len] for enc in encoders_norm]

            enc_norm = encoders_norm[j]
            enc_attn = encoders[j]
            

            # 1. Pearson correlation
            corr, _ = pearsonr(all_attn, enc_attn)
            results['pearson_corr'][j].append(corr)
            
            # 2. Spearman correlation (rank-based)
            spear_corr, _ = spearmanr(all_attn, enc_attn)
            results['spearman_corr'][j].append(spear_corr)
            
            # 3. Mean Squared Error
            mse = np.mean((all_attn_norm - enc_norm) ** 2)
            results['mse'][j].append(mse)
            
            # 4. Mean Absolute Error
            mae = np.mean(np.abs(all_attn_norm - enc_norm))
            results['mae'][j].append(mae)
            
            # 5. Jensen-Shannon divergence (symmetric version of KL)
            js_div = jensenshannon(all_attn_norm, enc_norm)
            results['js_divergence'][j].append(js_div)
            
            # 6. KL divergence (add small epsilon to avoid log(0))
            eps = 1e-8
            kl_div = np.sum(all_attn_norm * np.log((all_attn_norm + eps) / (enc_norm + eps)))
            results['kl_divergence'][j].append(kl_div)

    print(f"NaN summary — samples_with_any_NaN: {samples_with_nan}/{n_samples}, "
          f"total NaNs in all_attn: {total_nan_all}, per-encoder totals: {total_nan_enc}")
    

    return results

def plot_similarity_metrics_5enc(results):
    """
    Plot all similarity metrics in a comprehensive way - 5 encoder version.
    """
    metrics = ['pearson_corr', 'spearman_corr', 'mse', 'mae', 'js_divergence', 'kl_divergence']
    metric_names = ['Pearson Correlation', 'Spearman Correlation', 'MSE', 'MAE', 'JS Divergence', 'KL Divergence']
    
    fig, axes = plt.subplots(2, 3, figsize=(20, 12))
    axes = axes.flatten()
    
    for i, (metric, name) in enumerate(zip(metrics, metric_names)):
        data = [results[metric][j] for j in range(4)]  # 5 encoders
        
        # Box plot
        box_plot = axes[i].boxplot(data, labels=['Enc 0', 'Enc 1', 'Enc 2', 'Enc 3'])
        axes[i].set_title(f'{name}')
        axes[i].grid(True, alpha=0.3)
        
        # Add mean values as text
        means = [np.mean(data[j]) for j in range(4)]
        for j, mean_val in enumerate(means):
            axes[i].text(j+1, mean_val, f'{mean_val:.4f}', 
                        ha='center', va='bottom', fontweight='bold', fontsize=8)
    
    plt.tight_layout()
    plt.savefig('attention_similarity_metrics_5enc.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print summary
    print("\n=== SIMILARITY ANALYSIS SUMMARY (5 ENCODERS) ===")
    for metric, name in zip(metrics, metric_names):
        means = [np.mean(results[metric][j]) for j in range(4)]
        best_encoder = np.argmax(means) if 'corr' in metric else np.argmin(means)  # Higher is better for correlation, lower for distance
        
        print(f"\n{name}:")
        for j, mean_val in enumerate(means):
            marker = " ⭐" if j == best_encoder else ""
            print(f"  Encoder {j}: {mean_val:.6f}{marker}")

def analyze_attention_patterns_5enc(attn_all, attn_0, attn_1, attn_2, attn_3, sample_indices=None):
    """
    Visualize specific attention patterns for comparison - 5 encoder version.
    """
    if sample_indices is None:
        sample_indices = [0, 100, 200, 300, 400]  # Sample a few examples
    
    fig, axes = plt.subplots(len(sample_indices), 6, figsize=(24, 4*len(sample_indices)))
    if len(sample_indices) == 1:
        axes = axes.reshape(1, -1)
    
    encoders_data = [attn_0, attn_1, attn_2, attn_3, attn_all]
    encoder_names = ['Encoder 0', 'Encoder 1', 'Encoder 2', 'Encoder 3', 'All Encoders']
    
    for row, sample_idx in enumerate(sample_indices):
        for col, (data, name) in enumerate(zip(encoders_data, encoder_names)):
            # attn = data[sample_idx]['vision_attention_raw'].numpy().flatten()
            attn = data[sample_idx]['vision_attention_raw'].to(torch.float32).numpy().flatten()
            attn_norm = attn / (attn.sum() + 1e-8)
            
            # Plot as line chart
            axes[row, col].plot(attn_norm, linewidth=1)
            axes[row, col].set_title(f'{name}\n(Sample {sample_idx})')
            axes[row, col].set_xlabel('Vision Token Index')
            axes[row, col].set_ylabel('Attention Weight')
            axes[row, col].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('cambrian_8b_attention_patterns_comparison_x4.png', dpi=300, bbox_inches='tight')
    plt.show()

def calculate_contribution_weights_5enc(attn_all, attn_0, attn_1, attn_2, attn_3):
    """
    Calculate how much each encoder contributes to the combined attention - 5 encoder version.
    Using linear regression approach.
    """
    from sklearn.linear_model import LinearRegression
    
    n_samples = min(len(attn_all), len(attn_0), len(attn_1), len(attn_2), len(attn_3))
    
    # Prepare data for regression
    X = []  # Features: individual encoder attentions
    y = []  # Target: combined attention
    
    for i in range(n_samples):
        # all_attn = attn_all[i]['vision_attention_raw'].numpy().flatten()
        # enc_attns = [
        #     attn_0[i]['vision_attention_raw'].numpy().flatten(),
        #     attn_1[i]['vision_attention_raw'].numpy().flatten(),
        #     attn_2[i]['vision_attention_raw'].numpy().flatten(),
        #     attn_3[i]['vision_attention_raw'].numpy().flatten()
        # ]

        all_attn = attn_all[i]['vision_attention_raw'].to(torch.float32).numpy().flatten()
        enc_attns = [
            attn_0[i]['vision_attention_raw'].to(torch.float32).numpy().flatten(),
            attn_1[i]['vision_attention_raw'].to(torch.float32).numpy().flatten(),
            attn_2[i]['vision_attention_raw'].to(torch.float32).numpy().flatten(),
            attn_3[i]['vision_attention_raw'].to(torch.float32).numpy().flatten()
        ]
        
        # Stack encoder attentions as features
        X.append(np.column_stack(enc_attns).flatten())  # Shape: (5120,) = 5 * 600
        y.append(all_attn)
    
    X = np.array(X)  # Shape: (n_samples, 5120)
    y = np.array(y)  # Shape: (n_samples, 600)
    
    # Fit linear regression for each vision token position
    contribution_weights = np.zeros((4, 600))  # 5 encoders, 600 positions
    
    for pos in range(600):
        X_pos = X[:, pos::600]  # Get position 'pos' from all 5 encoders
        y_pos = y[:, pos]        # Target for position 'pos'
        
        reg = LinearRegression(fit_intercept=False)  # Force through origin
        reg.fit(X_pos, y_pos)
        contribution_weights[:, pos] = reg.coef_
    
    # Calculate overall contribution (average across all positions)
    overall_contributions = np.mean(np.abs(contribution_weights), axis=1)
    
    print("\n=== ENCODER CONTRIBUTION ANALYSIS (5 ENCODERS) ===")
    for i, contrib in enumerate(overall_contributions):
        percentage = (contrib / overall_contributions.sum()) * 100
        print(f"Encoder {i}: {contrib:.6f} ({percentage:.1f}%)")
    
    return contribution_weights, overall_contributions

def plot_contribution_comparison(overall_contrib):
    """
    Plot encoder contributions as a bar chart.
    """
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    
    encoders = [f'Encoder {i}' for i in range(4)]
    percentages = (overall_contrib / overall_contrib.sum()) * 100
    
    bars = ax.bar(encoders, percentages, alpha=0.7, color=['blue', 'orange', 'green', 'red', 'purple'])
    
    # Add percentage labels on bars
    for bar, pct in zip(bars, percentages):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{pct:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    ax.set_title('Encoder Contribution to Combined Attention', fontsize=14, fontweight='bold')
    ax.set_ylabel('Contribution Percentage (%)')
    ax.set_xlabel('Encoder')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('cambrian_8b_encoder_contributions_x4.png', dpi=300, bbox_inches='tight')
    plt.show()

# Run the analysis
print("Starting 5-encoder attention analysis...")
results = simple_attention_analysis_5enc(attn_all, attn_0, attn_1, attn_2, attn_3)
plot_similarity_metrics_5enc(results)

# Visualize some attention patterns
print("Generating attention pattern visualizations...")
analyze_attention_patterns_5enc(attn_all, attn_0, attn_1, attn_2, attn_3)

# Calculate contribution weights
print("Calculating encoder contributions...")
contrib_weights, overall_contrib = calculate_contribution_weights_5enc(attn_all, attn_0, attn_1, attn_2, attn_3)

# Plot contribution comparison
plot_contribution_comparison(overall_contrib)

print("Analysis complete! Check the generated plots for results.")