import torch
import matplotlib.pyplot as plt
import numpy as np
import argparse
import os

def visualize_head_influence(influence_data, output_folder="influence_plots"):
    """
    Visualize the influence of each input coordinate on each attention head.
    
    Args:
        influence_data: Dictionary containing 'influence' tensor and 'model_config'
    """
    influence = influence_data['influence']
    config = influence_data['model_config']
    
    n_layers, n_heads, seq_length = influence.shape
    
    # Create folder for plots if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)
    
    # Plot influence for each layer and head
    for layer in range(n_layers):
        for head in range(n_heads):
            plt.figure(figsize=(12, 8))
            plt.plot(influence[layer][head])
            plt.xlabel('Input Coordinate')
            plt.ylabel('Influence')
            plt.title(f'Layer {layer+1} Attention Head {head+1} Influence')
            plt.savefig(f"{output_folder}/layer_{layer+1}_head_{head+1}_influence.png")
            plt.close()
    
    # Plot overall influence of each input coordinate (summed across all heads and layers)
    total_influence = influence.sum(dim=(0,1))
    plt.figure(figsize=(12, 6))
    plt.bar(range(seq_length), total_influence)
    plt.xlabel('Input Coordinate')
    plt.ylabel('Total Influence')
    plt.title(f'Total Influence of Each Input Coordinate ({config["function"]})')
    plt.savefig(f"{output_folder}/total_influence.png")
    plt.close()
    
    # Plot average influence per layer
    layer_influence = influence.mean(dim=(1,2))
    plt.figure(figsize=(10, 6))
    plt.bar(range(n_layers), layer_influence)
    plt.xlabel('Layer')
    plt.ylabel('Average Influence')
    plt.title('Average Influence by Layer')
    plt.savefig(f"{output_folder}/layer_influence.png")
    plt.close()
    
    # Plot average influence per head for each layer
    for layer in range(n_layers):
        head_influence = influence[layer].mean(dim=1)
        plt.figure(figsize=(10, 6))
        plt.bar(range(n_heads), head_influence)
        plt.xlabel('Head')
        plt.ylabel('Average Influence')
        plt.title(f'Layer {layer+1}: Average Influence by Head')
        plt.savefig(f"{output_folder}/layer_{layer+1}_head_influence.png")
        plt.close()
    
    # Generate a heatmap of influence by layer and head
    avg_influence_by_head = influence.mean(dim=2)
    plt.figure(figsize=(10, 8))
    plt.imshow(avg_influence_by_head, aspect='auto', cmap='viridis')
    plt.colorbar(label='Average Influence')
    plt.xlabel('Head')
    plt.ylabel('Layer')
    plt.title('Average Influence by Layer and Head')
    plt.xticks(range(n_heads), [f'Head {i+1}' for i in range(n_heads)])
    plt.yticks(range(n_layers), [f'Layer {i+1}' for i in range(n_layers)])
    plt.savefig(f"{output_folder}/influence_heatmap.png")
    plt.close()

def analyze_influence_statistics(influence_data, output_folder="influence_plots"):
    """
    Analyze statistical properties of the influence data.
    """
    influence = influence_data['influence']
    config = influence_data['model_config']
    
    # Create statistics file
    os.makedirs(output_folder, exist_ok=True)
    stats_file = f"{output_folder}/influence_statistics.txt"
    
    with open(stats_file, 'w') as f:
        f.write(f"Influence Statistics for {config['function']} function\n")
        f.write("="*50 + "\n\n")
        
        total_inf = influence.sum().item()
        f.write(f"Total Influence: {total_inf:.4f}\n")
        f.write(f"Average Influence per coordinate: {total_inf/config['seq_length']:.4f}\n\n")
        
        # Layer statistics
        f.write("Layer Statistics:\n")
        for layer in range(influence.shape[0]):
            layer_inf = influence[layer].sum().item()
            f.write(f"  Layer {layer+1}: {layer_inf:.4f} ({layer_inf/total_inf*100:.2f}%)\n")
        f.write("\n")
        
        # Find coordinates with highest influence
        total_by_coord = influence.sum(dim=(0,1))
        top_k = 5
        values, indices = torch.topk(total_by_coord, top_k)
        f.write(f"Top {top_k} Most Influential Coordinates:\n")
        for i in range(top_k):
            f.write(f"  Coordinate {indices[i].item()}: {values[i].item():.4f}\n")
        
        # Find coordinates with lowest influence
        values, indices = torch.topk(total_by_coord, top_k, largest=False)
        f.write(f"\nTop {top_k} Least Influential Coordinates:\n")
        for i in range(top_k):
            f.write(f"  Coordinate {indices[i].item()}: {values[i].item():.4f}\n")
    
    print(f"Statistics written to {stats_file}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Visualize influence data")
    parser.add_argument("--file", type=str, required=True, help="Path to the influence data file")
    parser.add_argument("--output", type=str, default="influence_visualizations", help="Output folder for visualizations")
    args = parser.parse_args()
    
    # Load the influence data
    if not os.path.exists(args.file):
        print(f"Error: File {args.file} not found")
        exit(1)
        
    print(f"Loading influence data from {args.file}")
    influence_data = torch.load(args.file)
    
    # Print basic information
    config = influence_data['model_config']
    print(f"Function: {config['function']}")
    print(f"Model configuration: {config['n_layers']} layers, {config['n_heads']} heads")
    print(f"Sequence length: {config['seq_length']}")
    print(f"Adversarial training: {config.get('adversarial', False)}")
    
    # Create output folder
    os.makedirs(args.output, exist_ok=True)
    
    # Generate visualizations
    print("Generating visualizations...")
    visualize_head_influence(influence_data, args.output)
    
    # Analyze statistics
    print("Analyzing influence statistics...")
    analyze_influence_statistics(influence_data, args.output)
    
    print(f"Visualization complete! Results saved to {args.output}")