import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import matplotlib.pyplot as plt
import numpy as np

# Set device
base_dir = './'
device = "cuda" if torch.cuda.is_available() else "cpu"

# Loop through different model variants
for name in ['baseline', 'gate_elementwise', 'gate_headwise']:
    
    # Load model and tokenizer
    model_name_or_path = f"{base_dir}/1B_{name}"
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True).to(device)

    # Input text
    prompt = "Sparse gating mechanism mitigates attention sink."
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    # Forward pass with output_attentions=True to retrieve attention scores
    with torch.no_grad():
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            output_attentions=True  # Retrieve attention scores
        )

    # Extract attention scores
    attentions = outputs.attentions  # tuple of tensors: (layer) -> (batch, head, seq_len, seq_len)

    # Function to average attention scores across all heads for each layer
    def average_heads(attentions):
        averaged = []
        for layer_attn in attentions:
            # layer_attn: (batch, head, seq_len, seq_len)
            avg_attn = layer_attn.mean(dim=1).cpu().numpy()  # (batch, seq_len, seq_len)
            averaged.append(avg_attn[0])  # Take the first sample
        return averaged

    averaged_attentions = average_heads(attentions)

    # Get tokens for axis labels
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    # Visualize attention maps of selected layers
    layers_to_visualize = [0, 6, 20, 27]  # Python indices start at 0, corresponds to 1st, 7th, 21st, 28th layers
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    axes = axes.flatten()

    for idx, layer_idx in enumerate(layers_to_visualize):
        attn_map = averaged_attentions[layer_idx]

        # Plot attention map
        ax = axes[idx]
        im = ax.imshow(attn_map, cmap="viridis")

        # Add colorbar
        fig.colorbar(im, ax=ax)

        # Set title
        ax.set_title(f"Layer {layer_idx + 1}")

        # Set ticks and labels
        ax.set_xticks(np.arange(len(tokens)))
        ax.set_yticks(np.arange(len(tokens)))
        ax.set_xticklabels(tokens, rotation=90)
        ax.set_yticklabels(tokens)

        # Hide tick marks
        ax.tick_params(axis='both', which='both', length=0)

    plt.tight_layout()
    plt.savefig(f"{name}_selected_layer_attention_maps.png")
    plt.show()