import torch
import matplotlib.pyplot as plt
import numpy as np
import time
import os
import argparse
from total_influence_experiments.gpt2 import gpt2_generate_hooks_and_model
from total_influence_experiments.bert import bert_generate_hooks_and_model
from total_influence_experiments.roberta import roberta_generate_hooks_and_model
from total_influence_experiments.gemma import gemma_generate_hooks_and_model

from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from transformers import BertConfig, BertForMaskedLM
from transformers import RobertaConfig, RobertaForMaskedLM
from transformers import AutoModelForCausalLM

from noise_stability.measure_noise_stability import model_measure_noise_stability

def calculate_influence_for_model(model_name, n, n_samples, device, norm, sampling, verbose):
    """
    Calculates the influence for a single model.
    
    Args:
        model_name: Name of the model to use
        n: Number of tokens in the sequence
        n_samples: Number of samples to average over
        device: Device to run the model on
        norm: Norm to use for influence calculation
        sampling: Sampling method
        verbose: Whether to print verbose output
        
    Returns:
        influence: Tensor of shape (n_layers, n_heads, n)
        n_layers: Number of layers in the model
        n_heads: Number of heads in the model
    """
    # Hook to extract per-head output
    head_outputs = {}

    samples = None
    attention_mask = None
    model = None
    params = None
    
    if model_name == "gpt2":
        model, samples, attention_mask, params = gpt2_generate_hooks_and_model(\
                                                                    n, \
                                                                    device, \
                                                                    head_outputs, \
                                                                    sampling=sampling, \
                                                                    num_samples=n_samples)

        assert samples is not None, "No samples generated. Check the sampling method."
    elif model_name == "bert":
        model, samples, attention_mask, params = bert_generate_hooks_and_model(\
                                                                    n, \
                                                                    device, \
                                                                    head_outputs, \
                                                                    sampling=sampling, \
                                                                    num_samples=n_samples)

        assert samples is not None, "No samples generated. Check the sampling method."
    elif model_name == "roberta":
        model, samples, attention_mask, params = roberta_generate_hooks_and_model(\
                                                                    n, \
                                                                    device, \
                                                                    head_outputs, \
                                                                    sampling=sampling, \
                                                                    num_samples=n_samples)

        assert samples is not None, "No samples generated. Check the sampling method."
    elif model_name == "gemma":
        try:
            model, samples, attention_mask, params = gemma_generate_hooks_and_model(\
                                                                        n, \
                                                                        device, \
                                                                        head_outputs, \
                                                                        sampling=sampling, \
                                                                        num_samples=n_samples)

            assert samples is not None, "No samples generated. Check the sampling method."
        except Exception as e:
            print(f"\nGemma model loading failed: {e}")
            print("\nGemma models require authentication. For immediate testing, try:")
            print("  --model gpt2    (works without authentication)")
            print("  --model bert    (works without authentication)")  
            print("  --model roberta (works without authentication)")
            raise
    else:
        raise ValueError(f"Model {model_name} not supported. Choose from: gpt2, bert, roberta, gemma")

    # Parse out important model parameters
    n_heads = params[0]
    n_layers = params[1]

    if verbose:
        print(f"Model {model_name} has {n_layers} layers and {n_heads} heads.")

    # Influence tensor: layers x heads x sequence
    influence = torch.zeros((n_layers, n_heads, n), device=device)

    for sample in range(n_samples):
        if verbose:
            print(f"Sample {sample + 1}/{n_samples}")
        
        x = samples[sample].to(device).requires_grad_()  # (1, n, d) from the pre-generated samples.

        # Forward pass to compute outputs and gradients.
        # This populates the head_outputs dictionary via the hooks.
        if attention_mask is not None:
            model(inputs_embeds=x, attention_mask=attention_mask)  
        else:
            model(inputs_embeds=x)

        # Enumerate over all layers and heads to compute influence.
        for L in range(n_layers):
            if verbose:
                print(f"Processing Layer {L+1}/{n_layers}")
            
            for h in range(n_heads):
                y = head_outputs[L][0, h]

                if norm == "l1":
                    # Compute L1 norm
                    normed_total = y.abs().sum()
                elif norm == "l2":
                    # Compute L2 norm
                    normed_total = y.pow(2).sum().sqrt()
                elif norm == "avg":
                    # Compute average
                    normed_total = y.mean()
                elif norm == "max":
                    # Compute max
                    normed_total = y.max()
                elif norm == "min":
                    # Compute min
                    normed_total = y.min()
                else:
                    raise ValueError(f"Norm {norm} not supported. Use one of: l1, l2, avg, max, min")

                grad = torch.autograd.grad(normed_total, x, retain_graph=True)[0]  # (1, n, d)
                grad = grad.abs()[0] # (n, d)

                influence[L, h] += grad.mean(dim=-1)  # (n,)

        # Zero out gradients for the next iteration
        x.grad = None

    influence /= n_samples  # average over multiple draws

    return influence, n_layers, n_heads

def calculate_influence(args):
    """
    Calculates the influence of each input coordinate on each attention head for multiple models.

    Args:
        args: Arguments object containing the following:
            models: List of model names to use
            n: Number of tokens in the sequence
            n_samples: Number of samples to average over
            device: Device to run the model on
            norm: Norm to use for influence calculation
            visible_heads: Number of heads to visualize
            sampling: Sampling method

    Returns:
        all_influence: Dictionary mapping model names to influence tensors
        all_params: Dictionary mapping model names to (n_layers, n_heads) tuples
    """
    # Setup
    n = args.n
    n_samples = args.n_samples
    device = args.device
    models = args.models

    all_influence = {}
    all_params = {}

    for model_name in models:
        print(f"\nProcessing model: {model_name}")
        influence, n_layers, n_heads = calculate_influence_for_model(
            model_name, n, n_samples, device, args.norm, args.sampling, args.verbose
        )
        all_influence[model_name] = influence
        all_params[model_name] = (n_layers, n_heads)

    if not all_influence:
        raise ValueError("No models were successfully processed")

    return all_influence, all_params

def calculate_noise_stability_for_model(model_name, r, n, device):
    """
    Calculates the noise stability for a single model.

    Args:
        model_name: Name of the model to use
        r: Correlation coefficient between X and Y (float in [-1, 1]).

    Returns:
        noise_stability: Noise stability value
        l2_norm: L2 norm value
    """
    
    noise_stability = 0
    l2_norm = 0
    model=None
    vocab_size = 0
    if model_name == "gpt2":
        config = GPT2Config.from_pretrained("gpt2", attn_implementation="eager")
        model = GPT2LMHeadModel.from_pretrained("gpt2", config=config).to(device)

        vocab_size = model.config.vocab_size

    elif model_name == "bert":
        config = BertConfig.from_pretrained("bert-base-uncased", output_attentions=True)
        model = BertForMaskedLM.from_pretrained("bert-base-uncased", config=config).to(device)

        vocab_size = model.config.vocab_size

    elif model_name == "roberta":
        config = RobertaConfig.from_pretrained("roberta-base", output_attentions=True)
        model = RobertaForMaskedLM.from_pretrained("roberta-base", config=config).to(device)

        vocab_size = model.config.vocab_size

    elif model_name == "gemma":
        model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b").to(device)

        vocab_size = model.config.vocab_size

    noise_stability, l2_norm = model_measure_noise_stability(model, 
                                                            n, 
                                                            r, 
                                                            vocab_size, 
                                                            num_trials=10, 
                                                            device=device)

    return noise_stability, l2_norm
    


def calculate_noise_stability(args, r, device):
    """
    Calculates the noise stability for a collection of models.

    Args:
        args: Arguments object containing the following:
            models: List of model names to use
            n: Number of tokens in the sequence

        r: Correlation coefficient between X and Y (float in [-1, 1]).
    
    Returns:
        all_noise_stability: Dictionary mapping model names to noise stability values
    """

    all_noise_stability = {}
    all_l2_norm = {}

    for model_name in args.models:
        print(f"Processing model: {model_name}")
        noise_stability, l2_norm = calculate_noise_stability_for_model(model_name, r, args.n, device)

        all_noise_stability[model_name] = noise_stability
        all_l2_norm[model_name] = l2_norm

    return all_noise_stability, all_l2_norm


def visualize_influence(all_influence, all_params, args, folder_name, n, l2_norm, X):
    # Visualization and analysis with multiple models overlaid.
    
    # Define colors for different models
    colors = ['blue', 'red', 'green', 'orange', 'purple', 'brown', 'pink', 'gray']

    # Plot in high resolution
    plt.rcParams['figure.dpi'] = 300
    
    # Analysis 1: Plot total model influence (last layer only, summed across heads)
    plt.figure(figsize=(12, 6))
    for i, (model_name, influence) in enumerate(all_influence.items()):
        data = influence.cpu().numpy()
        n_layers, n_heads = all_params[model_name]
        
        # The last layer captures the total model influence since gradients flow through all layers
        total_influence = data[-1].sum(axis=0)  # Sum across heads for the last layer
        
        # Convert to percentages
        total_influence_sum = total_influence.sum()
        print(f"Total influence sum for {model_name}: {total_influence_sum}")

        if l2_norm is not None:
            print(f"Influence predicted concentration: ({total_influence_sum / (n*X*l2_norm[model_name])}, {X})")

        if total_influence_sum > 0:
            total_influence_pct = (total_influence / total_influence_sum) * 100
        else:
            total_influence_pct = total_influence  # Avoid division by zero
        
        color = colors[i % len(colors)]
        plt.plot(total_influence_pct, color=color, linewidth=2, label=f'{model_name.upper()} Total Influence (%)', alpha=0.8)
    
    plt.title('Total Model Influence Across Token Positions (Last Layer) - Multiple Models', fontsize=16)
    plt.xlabel('Token Position', fontsize=14)
    plt.ylabel('Percentage of Total Influence (%)', fontsize=14)
    plt.legend()
    plt.grid(True, alpha=0.3)
    filename = f"{folder_name}/total_influence_comparison.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    # Analysis 2: Plot influence per layer (summed across heads) for each model
    plt.figure(figsize=(15, 8))
    for model_name, influence in all_influence.items():
        data = influence.cpu().numpy()
        n_layers, n_heads = all_params[model_name]
        layer_influence = data.sum(axis=1)  # Sum across heads for each layer
        
        for layer in range(0, n_layers, 6):
            # Convert each layer's influence to percentages
            layer_sum = layer_influence[layer].sum()
            if layer_sum > 0:
                layer_influence_pct = (layer_influence[layer] / layer_sum) * 100
            else:
                layer_influence_pct = layer_influence[layer]
            
            color = colors[list(all_influence.keys()).index(model_name) % len(colors)]
            plt.plot(layer_influence_pct, color=color, label=f'{model_name.upper()} Layer {layer + 1}', 
                    alpha=0.7, linestyle='-' if layer % 2 == 0 else '--')
    
    plt.title('Layer-wise Influence Across Token Positions - Multiple Models', fontsize=16)
    plt.xlabel('Token Position', fontsize=14)
    plt.ylabel('Percentage of Layer Influence (%)', fontsize=14)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3)
    filename = f"{folder_name}/layer_influence_comparison.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    # Analysis 3: Plot influence per head (summed across layers) for each model
    plt.figure(figsize=(15, 8))
    for model_name, influence in all_influence.items():
        data = influence.cpu().numpy()
        n_layers, n_heads = all_params[model_name]
        head_influence = data.sum(axis=0)  # Sum across layers for each head
        
        for head in range(min(n_heads, args.visible_heads)):
            # Convert each head's influence to percentages
            head_sum = head_influence[head].sum()
            if head_sum > 0:
                head_influence_pct = (head_influence[head] / head_sum) * 100
            else:
                head_influence_pct = head_influence[head]
            
            color = colors[list(all_influence.keys()).index(model_name) % len(colors)]
            plt.plot(head_influence_pct, color=color, label=f'{model_name.upper()} Head {head + 1}', 
                    alpha=0.7, linestyle='-' if head % 2 == 0 else '--')
    
    plt.title('Head-wise Influence Across Token Positions - Multiple Models', fontsize=16)
    plt.xlabel('Token Position', fontsize=14)
    plt.ylabel('Percentage of Head Influence (%)', fontsize=14)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3)
    filename = f"{folder_name}/head_influence_comparison.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    # Analysis 4: Heatmap of influence by layer and head (averaged across positions) for each model
    fig, axes = plt.subplots(1, len(all_influence), figsize=(6*len(all_influence), 8))
    if len(all_influence) == 1:
        axes = [axes]
    
    for i, (model_name, influence) in enumerate(all_influence.items()):
        data = influence.cpu().numpy()
        n_layers, n_heads = all_params[model_name]
        avg_influence_by_layer_head = data.mean(axis=2)  # Average across token positions
        
        # Convert to percentage of total influence across all layers and heads
        total_avg_influence = avg_influence_by_layer_head.sum()
        if total_avg_influence > 0:
            avg_influence_pct = (avg_influence_by_layer_head / total_avg_influence) * 100
        else:
            avg_influence_pct = avg_influence_by_layer_head
        
        im = axes[i].imshow(avg_influence_pct, aspect='auto', cmap='viridis')
        axes[i].set_xlabel('Head')
        axes[i].set_ylabel('Layer')
        axes[i].set_title(f'{model_name.upper()} - Average Influence by Layer and Head (%)', fontsize=14)
        axes[i].set_xticks(range(n_heads))
        axes[i].set_yticks(range(n_layers))
        
        # Add colorbar
        plt.colorbar(im, ax=axes[i], label='Percentage of Total Average Influence (%)')
    
    plt.tight_layout()
    filename = f"{folder_name}/influence_heatmap_comparison.png"
    plt.savefig(filename, dpi=300, bbox_inches='tight')
    plt.close()
    
    if args.verbose:
        print(f"Generated comparison visualizations in {folder_name}:")
        print("- total_influence_comparison.png: Total model influence comparison (last layer)")
        print("- layer_influence_comparison.png: Layer-wise influence comparison")
        print("- head_influence_comparison.png: Head-wise influence comparison")
        print("- influence_heatmap_comparison.png: Average influence comparison by layer/head")

if __name__ == "__main__":

    #
    # Argument parsing
    #
    parse = argparse.ArgumentParser(description="Influence experiments ")
    parse.add_argument("--n", type=int, default=512, \
                        help="Number of tokens in the sequence")
    parse.add_argument("--d", type=int, default=768, \
                        help="Dimension of the model")
    parse.add_argument("--n_samples", type=int, default=10, \
                        help="Number of samples to average over")
    parse.add_argument("--device", type=str, \
                        default='cuda:1' if torch.cuda.is_available() else 'cpu', \
                        help="Device to run the model on")
    parse.add_argument("--verbose", type=bool, default=True, \
                        help="Whether to print verbose output")
    parse.add_argument("--models", type=str, nargs='+', default=["gpt2"], \
                        choices=["gpt2", "bert", "roberta", "gemma"], \
                        help="List of model names to use (default: gpt2)")
    parse.add_argument("--norm", type=str, default="l2", \
                        help="Norm to use for influence calculation (l1, l2, avg, max, min)")
    parse.add_argument("--visible_heads", type=int, default=4, \
                        help="Number of heads to visualize")
    parse.add_argument("--sampling", type=str, default="uniform", \
                        help="Sampling method (uniform, model)")
    parse.add_argument("--r", type=float, default=0.05, \
                        help="Correlation coefficient between X and Y (float in [-1, 1])")
    args = parse.parse_args()

    # Calculate influence for all models
    all_influence, all_params = calculate_influence(args)

    # Clear GPU memory.
    torch.cuda.empty_cache()
    print("GPU memory cleared.")

    # Calculate noise stability
    if args.r > 0:
        device = args.device
        noise_stability, l2_norm = calculate_noise_stability(args, args.r, device) # dictionaries
        
        delta = {}
        for model_name in args.models:
            delta[model_name] = 1- noise_stability[model_name] / l2_norm[model_name]
            print(f"Normalized noise stability for {model_name}: {delta[model_name]}")
            print(f"L2 norm for {model_name}: {l2_norm[model_name]}")

            print(f"Noise Stability Predicted concentration for {model_name}: ({delta[model_name]*1.0001}, {np.log(10000)/np.log(1/args.r)})")
    else:
        delta = None
        l2_norm = None

    # Use timestamps for folder name and include all model names
    folder_name = time.strftime("%Y%m%d-%H%M%S")
    models_str = "_".join(args.models)
    folder_name = f"plots/{folder_name}_{models_str}"

    # Create the folder if it doesn't exist
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)

    # Visualize influence comparison
    visualize_influence(all_influence, 
                        all_params, 
                        args, 
                        folder_name,
                        args.n,
                        l2_norm,
                        X=np.log(10000)/np.log(1/args.r) if l2_norm is not None else None)
