from collections import Counter
from xml.parsers.expat import model
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import os
import utilities.data_generation as dg
import json
from torch.optim.lr_scheduler import LambdaLR

####################################################################################
##################### NEW INFLUENCE EXPERIMENTS
####################################################################################

def measure_total_influence(model, n, vocab_size, num_trials=100, device='cpu'):
    """
    Measures the total influence of a transformer model.
    
    Total Influence = sum_i E[(f(X)-E[f(X_i)])^2]

    Args: 
        model: The transformer model (should be in eval mode).
        n: Input sequence length
        vocab_size: Number of unique tokens in the vocabulary.
        num_trials: Number of samples to average over.
        device: 'cpu' or 'cuda'.

    Returns:
        Estimated total influence (float).
    """

    model.eval()

    X = torch.randint(0, vocab_size, size=(n, num_trials), device=device)

    # Create attention masks (all ones)
    attention_mask_X = torch.ones(1, n, device=device)
    attention_mask_Y = torch.ones(1, n, device=device)

    total_influence = 0
    for i in range(n):
        infl_i = 0

        # 1. Estimate E[f(X_i)]
        E_f_Xi = 0
        for j in range(num_trials):
            # Change the i-th coordinate of the j-th sample.
            Y = X[:, j].clone()
            Y[i] = np.random.randint(0, vocab_size)

            with torch.no_grad():
                out_Y = model(Y, attention_mask=attention_mask_Y)
                f_Y = torch.argmax(out_Y, dim=1).to(torch.float32)

            E_f_Xi += f_Y.to('cpu').item()

        E_f_Xi /= num_trials

        # 2. Estimate E[(f(X)-E[f(X_i)])^2]
        for j in range(num_trials):
            with torch.no_grad():
                out_X = model(X, attention_mask=attention_mask_X)
                f_X = torch.argmax(out_X, dim=1).to(torch.float32).to('cpu').item()

            infl_i += (f_X - E_f_Xi)**2

        infl_i /= num_trials
        total_influence += infl_i

    return total_influence


####################################################################################
##################### OLD INFLUENCE EXPERIMENTS
####################################################################################

def plot_boolean_function_influence(learn_function, seq_length, num_samples, \
                                        folder_name=None, relevant_coords=None):
    """
    Plots the influence of each coordinate for a given boolean function.
    
    Args:
        learn_function: The boolean function to analyze (majority, parity, tribes)
        seq_length: Length of the input sequences
        num_samples: Number of random samples to use for estimation
        save_path: Optional path to save the plot instead of displaying
        
    Returns:
        influence: NumPy array containing the influence of each coordinate
    """
    
    # Initialize influence array
    influence = np.zeros(seq_length)
    
    # Generate random samples
    if learn_function.__name__ == "tribes":
        # For tribes, use 0/1 inputs
        samples = torch.randint(0, 2, (num_samples, seq_length)).float()
    else:
        # For other functions, use -1/+1 inputs
        samples = (2 * torch.randint(0, 2, (num_samples, seq_length)) - 1).float()
    
    # Calculate influence for each coordinate
    for i in range(seq_length):
        for j in range(num_samples):
            # Original sample
            if learn_function.__name__ != "junta":
                original_output = learn_function(samples[j])
            else:
                original_output = learn_function(samples[j], relevant_coords)
            
            # Create modified sample with bit i flipped
            modified_sample = samples[j].clone()
            if learn_function.__name__ == "tribes":
                modified_sample[i] = 1 - modified_sample[i]  # Flip 0/1
            else:
                modified_sample[i] = -modified_sample[i]  # Flip -1/+1
            
            # Modified output
            if learn_function.__name__ != "junta":
                modified_output = learn_function(modified_sample)
            else:
                modified_output = learn_function(modified_sample, relevant_coords)
            
            # Record if output changed
            influence[i] += (original_output != modified_output)
    
    # Average across samples
    influence /= num_samples
    
    # Plot the results
    plt.figure(figsize=(12, 6))
    plt.bar(range(seq_length), influence)
    plt.xlabel('Coordinate Index')
    plt.ylabel('Influence')
    plt.title(f'Coordinate Influence for {learn_function.__name__} function (n={seq_length})')
    plt.grid(True, alpha=0.3)
    
    # Add total influence as text
    total_influence = influence.sum()
    plt.text(0.02, 0.95, f'Total Influence: {total_influence:.4f}', 
             transform=plt.gca().transAxes, bbox=dict(facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    
    if folder_name:
        plt.savefig(f"{folder_name}/influence_{learn_function.__name__}.png")
    else:
        plt.show()
    
    return influence

def calculate_influence(model, learn_function, seq_length, device, num_samples=30, \
                            relevant_coords = None):
    """
        Calculate the total influence of a trained Transformer.
    """

    model.eval()
    samples_x, _ = dg.generate_data(num_samples, seq_length, learn_function, relevant_coords)

    total_influence = 0
    original_total_influence = 0
    for idx in range(seq_length):
        inf_idx = 0
        original_inf_idx = 0
        for i, sample in enumerate(samples_x):
            sample = sample.unsqueeze(0).to(device).float()
            sample_prime = sample.clone()
            if learn_function.__name__ == "tribes":
                sample_prime[0, idx] = 1 - sample_prime[0, idx]
            else:
                sample_prime[0, idx] *= (-1)

            with torch.no_grad():
                output = model(sample)
                probs = F.softmax(output, dim=-1)
                predicted = torch.argmax(probs).item()
                if learn_function.__name__ != "junta": 
                    true = learn_function(sample)
                else:
                    true = learn_function(sample, relevant_coords)

                output_prime = model(sample_prime)
                probs_prime = F.softmax(output_prime, dim=-1)
                predicted_prime = torch.argmax(probs_prime).item()
                if learn_function.__name__ != "junta":
                    true_prime = learn_function(sample_prime)
                else:
                    true_prime = learn_function(sample_prime, relevant_coords)

                inf_idx += (predicted != predicted_prime)
                original_inf_idx += (true != true_prime)

        inf_idx = inf_idx / num_samples
        original_inf_idx = original_inf_idx / num_samples
        total_influence += inf_idx
        original_total_influence += original_inf_idx

    print(f"Total influence of model for {learn_function.__name__} is {total_influence}")
    print(f"Total influence of {learn_function.__name__} is {original_total_influence}")

    return total_influence

def calculate_head_influence(model, seq_length, device, learn_function, num_samples=30, relevant_coords=None):
    """
    Calculate the influence of each input coordinate on each attention head.
    
    Returns:
        influence: tensor of shape (n_layers, n_heads, seq_length)
                  influence[l,h,i] = influence of coordinate i on head h in layer l
    """
    # Set up hooks to capture head outputs
    head_outputs = {}
    
    def get_hook(layer_idx):
        def hook(module, input, output):
            # Store the per-head output before it's combined
            # Output shape is (batch, seq_len, n_heads * head_dim)
            B, seq_len, _ = output.shape
            reshaped = output.view(B, seq_len, model.n_heads, model.d_model // model.n_heads)
            reshaped.permute(0, 2, 1, 3)
            head_outputs[layer_idx] = reshaped # (B, H, n, d_H)
            return output
        return hook
    
    # Register hooks on each transformer layer
    handles = []
    for i, layer in enumerate(model.layers):
        handle = layer.register_forward_hook(get_hook(i))
        handles.append(handle)
    
    # Prepare storage for influence values
    influence = torch.zeros(model.n_layers, model.n_heads, seq_length)
    
    # Generate test samples
    X, _ = dg.generate_data(num_samples, seq_length, learn_function, relevant_coords)
    
    model.eval()
    with torch.no_grad():
        # For each input coordinate
        for coord_idx in range(seq_length):
            # Process each sample
            for sample_idx, x in enumerate(X):
                # Original sample
                x_orig = x.unsqueeze(0).to(device).float()
                
                # Forward pass with original sample
                _ = model(x_orig)
                
                # Store original head outputs
                orig_outputs = {}
                for layer_idx in range(model.n_layers):
                    orig_outputs[layer_idx] = head_outputs[layer_idx].clone()
                
                # Create sample with flipped coordinate
                x_flipped = x_orig.clone()
                if learn_function.__name__ == "tribes":
                    x_flipped[0, coord_idx] = 1 - x_flipped[0, coord_idx]
                else:
                    x_flipped[0, coord_idx] *= (-1)
                
                # Forward pass with flipped sample
                _ = model(x_flipped)
                
                # Calculate influence for each layer and head
                for layer_idx in range(model.n_layers):
                    for head_idx in range(model.n_heads):
                        # Get outputs for this head
                        orig_output = orig_outputs[layer_idx][0, head_idx]  # (seq_len, head_dim)
                        flipped_output = head_outputs[layer_idx][0, head_idx]  # (seq_len, head_dim)
                        
                        # Calculate squared difference and sum across all output dimensions
                        diff = orig_output - flipped_output  # (seq_len, head_dim)
                        diff_norm = (diff**2).sum()
                        influence_sample = diff_norm  # scalar
                        
                        # Accumulate the influence
                        influence[layer_idx, head_idx, coord_idx] += influence_sample.item()
    
    # Remove hooks
    for handle in handles:
        handle.remove()
    
    # Average over all samples
    influence = influence / num_samples
    
    return influence

####################################################################################
##################### OLD INFLUENCE EXPERIMENTS
####################################################################################

#  # Calculate influence
#     calculate_influence(model, learn_function, seq_length, device=device, \
#                             num_samples=30, relevant_coords=relevant_coords)
    
#     print("\nCalculating per-head influence...")
#     influence = calculate_head_influence(model, seq_length, device, learn_function, num_samples=30, \
#                                             relevant_coords=relevant_coords)

#     # Visualize learn_function influence
#     bf.plot_boolean_function_influence(learn_function, seq_length, 1000, folder_name, relevant_coords)

#     # Save the influence tensor to a file for later visualization
#     influence_file = f"{folder_name}/influence_data.pt"
#     torch.save({
#         'influence': influence,
#         'model_config': {
#             'n_layers': model.n_layers,
#             'n_heads': model.n_heads,
#             'seq_length': seq_length,
#             'function': args.function
#         }
#     }, influence_file)
#     print(f"Influence data saved to {influence_file}")