import torch
import os
import matplotlib.pyplot as plt
from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM
from transformers import AutoTokenizer
import math
import re 
import numpy as np

# Load Model & Tokenizer (without .to(device) due to Accelerate handling)
model_path = "/path/to/model/weights" # NOTE: Please update with weights to Phi-3-medium-128k-instruct before running
dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Phi3ForCausalLM.from_pretrained(
    model_path, torch_dtype=dtype, local_files_only=True,
    attn_implementation="eager", device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, local_files_only=True)

####################################################

## Input Prompts
# # Repetition prompts
# prompts = ["Hello hello hello how how how are are are you you you repeat repeat repeat stop stop stop."] # Exclude "the" as that is a common sink
# prompts = ["Hello hello hello the the the how how how are are are you you you repeat repeat repeat stop stop stop."]

# # COT prompts
# prompts = ["The derivative of x² is 2x. What is the derivative of x³?"]
# prompts = ["I had 4 oranges and bought 3 more. 4 + 3 = 7. Now, I have 3 apples and buy 2 more. How many apples do I have?"]

# # Zero Shot COT prompts
# prompts = ["What is the derivative of x³? Let's think step by step."]
prompts = ["I have 3 apples. I buy 2 more. How many do I have? Let's think step by step."]

# # Referrential prompts
# prompts = ["John met Mary at the park. She smiled. He waved back. They talked about it."]

# # Streaming numbers/numerical prompts
# prompts = ["Input: 12323212. Label: 2 Input: 7887689. Label: 8 Input: 32454323 Label: 3"]

# # Sentiment
# prompts = ["I laughed with friends until my stomach hurt. I wiped away tears of sadness and said nothing. I slammed my fist in anger on the table."]

## Other
# prompts = ["Classify each sentence as happy or sad. The sun shone brightly, warming my face as I walked through the park. The flowers I had planted last spring never bloomed. Children laughed and played on the swings. A cold wind swept through the empty streets, reminding me of nights spent alone. The rain poured relentlessly, drenching the wilted petals of forgotten flowers."]
# prompts = ["Read the following paragraph and determine if the hypothesis is true. \n \n Premise: A: Oh, oh yeah, and every time you see one hit on the side of the road you say is that my cat. B: Uh-huh. A: And you go crazy thinking it might be yours. B: Right, well I didn’t realize my husband was such a sucker for animals until I brought one home one night. Hypothesis: her husband was such a sucker for animals. Answer:"]

prompts = [f"{tokenizer.bos_token} {prompt}" for prompt in prompts]  # Adding BOS token
inputs = tokenizer(prompts, return_tensors="pt").to(device)

####################################################

# Create Output Directories
base_output_dir = "phi3_visualizations"
value_embeddings_dir = os.path.join(base_output_dir, "value_embeddings")
input_embeddings_dir = os.path.join(base_output_dir, "input_embeddings")
final_embeddings_dir = os.path.join(base_output_dir, "final_embeddings")
attention_sink_dir = os.path.join(base_output_dir, "attention_sink")

os.makedirs(value_embeddings_dir, exist_ok=True)
os.makedirs(input_embeddings_dir, exist_ok=True)
os.makedirs(final_embeddings_dir, exist_ok=True)
os.makedirs(attention_sink_dir, exist_ok=True)

####################################################

def tokenize_and_map(prompts):
    """Tokenizes prompts and ensures proper token-to-word mapping.
       - Keeps subword splits separate for attention visualization.
       - Removes SentencePiece underscores for clean formatting.
       - Keeps punctuation and delimiters as separate tokens.
       - Excludes BOS token from visualization.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, local_files_only=True)
    inputs = tokenizer(prompts, return_tensors="pt")
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze().tolist())

    words = []
    word_to_tokens = []
    current_word_tokens = []
    current_word = ""

    skip_first = False  # Skip BOS token in visualization

    for idx, token in enumerate(tokens):
        if skip_first:  # Skip BOS token
            skip_first = False
            continue

        if token in tokenizer.all_special_tokens:  # Handle special tokens separately
            if current_word_tokens:
                words.append(current_word)
                word_to_tokens.append(current_word_tokens)
            words.append(token)  # Store special token separately
            word_to_tokens.append([idx])
            current_word = ""
            current_word_tokens = []
        elif token.startswith("▁"):  # SentencePiece encoding (used in Phi-3)
            if current_word_tokens:
                words.append(current_word)
                word_to_tokens.append(current_word_tokens)

            current_word = token[1:]  # Remove leading "▁"
            current_word_tokens = [idx]
        elif re.fullmatch(r"[.,!?;:(){}\[\]<>/\\|\"'`~@#$%^&*_+=-]", token):  
            # Punctuation/delimiters are separate tokens
            if current_word_tokens:
                words.append(current_word)
                word_to_tokens.append(current_word_tokens)
            words.append(token)  # Store punctuation separately
            word_to_tokens.append([idx])
            current_word = ""
            current_word_tokens = []
        else:
            current_word += token  # Append subword tokens
            current_word_tokens.append(idx)

    if current_word_tokens:
        words.append(current_word)
        word_to_tokens.append(current_word_tokens)

    # **Ensure subwords remain separate instead of merging back**
    mapped_words = []
    mapped_token_indices = []

    for word, token_indices in zip(words, word_to_tokens):
        if len(token_indices) == 1:  
            # If a token maps to a whole word, keep it as-is
            mapped_words.append(word)
            mapped_token_indices.append(token_indices)
        else:
            # Keep subword pieces separate instead of merging back
            for idx in token_indices:
                mapped_words.append(tokens[idx].replace("▁", ""))  # Keep separate
                mapped_token_indices.append([idx])

    return mapped_words, mapped_token_indices


def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """Expands KV cache from num_kv_heads → num_attention_heads"""
    batch, num_kv_heads, seq_len, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    return hidden_states[:, :, None, :, :].expand(batch, num_kv_heads, n_rep, seq_len, head_dim).reshape(
        batch, num_kv_heads * n_rep, seq_len, head_dim
    )

def extract_value_embeddings(model, tokenizer, prompts):
    """Extracts Value Embeddings for all layers and heads, saving each separately"""
    print("Extracting Value Embeddings for All Layers & Heads")

    inputs = tokenizer(prompts, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True, output_hidden_states=True, use_cache=True)

    attentions = outputs['attentions']
    past_key_values = outputs.past_key_values
    num_layers = len(past_key_values)

    for layer_idx in range(num_layers):
        attentions_layer = attentions[layer_idx].cpu().to(torch.float32)
        v_values = past_key_values[layer_idx][1]  # Extract V values: [batch, num_heads, seq_len, head_dim]
        
        # Convert KV heads to full attention heads
        num_kv_heads = model.config.num_key_value_heads
        num_attn_heads = model.config.num_attention_heads
        kv_repeats = num_attn_heads // num_kv_heads
        v_values = repeat_kv(v_values, kv_repeats)  # Shape now: (batch, num_attn_heads, seq_len, head_dim)
        v_values = v_values.cpu().to(torch.float32)  # Remove batch dim: (num_heads, seq_len, head_dim)
        head_attn_value = torch.bmm(attentions_layer.squeeze(0), v_values.squeeze(0))

        for head_idx in range(head_attn_value.shape[0]):
            head_embedding = head_attn_value[head_idx]  # Shape: (seq_len, head_dim)
            save_path = os.path.join(value_embeddings_dir, f"Phi-3-layer-{layer_idx+1}-head-{head_idx+1}-value.pth")
            torch.save(head_embedding, save_path)
            print(f"Saved: {save_path} | Shape: {head_embedding.shape}")

    print(f"All Value Embeddings Saved in {value_embeddings_dir}")

####################################################

extract_value_embeddings(model, tokenizer, prompts)

# Extracting X + PXW_vW_o for All Layers Using Hooks
extracted_layer_embeddings = {}

def register_hooks(model):
    """Register forward hooks to capture pre-attention input for Layer 0 and post-attention outputs for all other layers."""
    extracted_pre_attention = {}  # Store after RMSNorm, pre-attention
    extracted_layer_embeddings = {}  # Store post-attention outputs

    def input_hook(layer_idx):
        """Input hook"""
        def inner_hook(module, input, output):
            extracted_pre_attention[f"Layer_{layer_idx}"] = input[0].detach().cpu().to(torch.float32)
        return inner_hook

    def output_hook(layer_idx):
        """Output hook"""
        def inner_hook(module, input, output):
            extracted_layer_embeddings[f"Layer_{layer_idx}"] = input[0].detach().cpu().to(torch.float32)
        return inner_hook
    
    # Input Hook
    for i, layer in enumerate(model.model.layers):
        attention_hook = layer.self_attn.qkv_proj.register_forward_hook(input_hook(i))

    # Output Hook
    for i, layer in enumerate(model.model.layers):
        hook = layer.mlp.register_forward_hook(output_hook(i))

    return extracted_layer_embeddings, extracted_pre_attention

# Register Hooks
extracted_layer_embeddings, extracted_pre_attention = register_hooks(model)

# Run Model Forward Pass to Trigger Hooks
with torch.no_grad():
    _ = model(**inputs)

# Save Extracted Embeddings
os.makedirs(input_embeddings_dir, exist_ok=True)
os.makedirs(final_embeddings_dir, exist_ok=True)

# Save pre-attention outputs for all layers
for layer_name, embedding in extracted_pre_attention.items():
    torch.save(embedding, os.path.join(input_embeddings_dir, f"Phi-3-{layer_name}-input.pth"))

# Save post-attention outputs for Layers 1+
for layer_name, embedding in extracted_layer_embeddings.items():
    torch.save(embedding, os.path.join(final_embeddings_dir, f"Phi-3-{layer_name}-final.pth"))

# Attention Sink Measurement & Visualization
def measure_attention_sink(model, tokenizer, prompts, score_path, token_length=50, device=torch.device("cuda")):
    num_layers = model.config.num_hidden_layers
    num_heads = model.config.num_attention_heads
    attention_scores_all_sample = []

    os.makedirs(score_path, exist_ok=True)  # Ensure base directory exists

    for prompt in prompts:
        inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=False).to(device)

        # Get properly mapped tokens using tokenize_and_map
        mapped_tokens, _ = tokenize_and_map(prompt)  

        outputs = model.generate(
            **inputs,
            output_attentions=True,
            return_dict_in_generate=True,
            max_new_tokens=1
        )

        attentions = outputs['attentions']
        assert len(attentions) == 1
        attention_scores_all_layer = []
        
        for l in range(num_layers):
            attentions_layer = attentions[0][l]
            attention_scores_all_layer.append(attentions_layer)
        
        attention_scores_all_layer = torch.cat(attention_scores_all_layer, dim=0)
        attention_scores_all_sample.append(attention_scores_all_layer.unsqueeze(dim=0))

    attention_scores_all_sample = torch.cat(attention_scores_all_sample, dim=0)  # (num_samples, num_layers, num_heads, num_tokens, num_tokens)
    
    avg_attention_scores = attention_scores_all_sample.mean(dim=0)  # (num_layers, num_heads, num_tokens, num_tokens)

    for layer_idx in range(num_layers):
        layer_folder = os.path.join(score_path, f"layer_{layer_idx}")  # Subfolder per layer
        os.makedirs(layer_folder, exist_ok=True)

        num_cols = 5
        num_rows = math.ceil(num_heads / num_cols)

        fig, axes = plt.subplots(num_rows, num_cols, figsize=(num_cols * 3, num_rows * 3))  
        fig.suptitle(f"Layer {layer_idx} - Attention Heads", fontsize=16)

        for head_idx in range(num_heads):
            layer_head_path = os.path.join(layer_folder, f"head_{head_idx}.png")
            layer_head_attn_score = avg_attention_scores[layer_idx, head_idx, :, :].clone().type(torch.float32).cpu().numpy().squeeze()

            # Apply log scaling
            epsilon = 1e-10
            log_attention_scores = np.log(layer_head_attn_score + epsilon)  # log(x + ε) to prevent log(0)

            # Save individual attention map with proper token labels
            fig_width = max(8, len(mapped_tokens) * 0.5)
            fig_height = max(6, len(mapped_tokens) * 0.5)

            plt.figure(figsize=(fig_width, fig_height))
            plt.imshow(layer_head_attn_score, cmap="inferno", interpolation='nearest')

            # Set the colorbar and increase its label font size
            cbar = plt.colorbar(shrink=0.75)  # Shrink the colorbar
            cbar.set_label("Attention Score", fontsize=18)  # Increase colorbar label font size
            cbar.ax.tick_params(labelsize=18)

            # Set the title with font size 16
            plt.title(f"Layer {layer_idx}, Head {head_idx}", fontsize=18)

            plt.xticks(
                ticks=range(len(mapped_tokens)), 
                labels=mapped_tokens, 
                rotation=90, 
                ha="right", 
                fontsize=18,
            )

            plt.yticks(
                ticks=range(len(mapped_tokens)), 
                labels=mapped_tokens, 
                fontsize=18,
            )

            plt.subplots_adjust(bottom=0.25, left=0.25)  # Ensure space for labels
            plt.savefig(layer_head_path, bbox_inches="tight", dpi=200)
            plt.close()

            # Add to large figure
            row, col = divmod(head_idx, num_cols)
            ax = axes[row, col] if num_rows > 1 else axes[col]  # Handle 1D case
            im = ax.imshow(layer_head_attn_score, cmap="inferno", interpolation='nearest')
            ax.set_title(f"Head {head_idx}", fontsize=10)
            ax.set_xticks(ticks=range(len(mapped_tokens)))
            ax.set_xticklabels(mapped_tokens, rotation=90, fontsize=6)
            ax.set_yticks(ticks=range(len(mapped_tokens)))
            ax.set_yticklabels(mapped_tokens, fontsize=6)

        # Hide unused subplots
        for head_idx in range(num_heads, num_rows * num_cols):
            row, col = divmod(head_idx, num_cols)
            fig.delaxes(axes[row, col] if num_rows > 1 else axes[col])

        fig.colorbar(im, ax=axes.ravel().tolist(), label="Attention Score", shrink=0.6)
        large_figure_path = os.path.join(score_path, f"layer_{layer_idx}_all_heads.png")
        plt.savefig(large_figure_path, bbox_inches="tight", dpi=150)
        plt.close()

    print(f"All Layer & Head Attention Visualizations Saved in {score_path}")

measure_attention_sink(model, tokenizer, prompts, attention_sink_dir, 88)

print("All Processing Complete")

# NOTE: Please read the following before running this script.
# * This script is meant to be run first to extract features from Phi-3-Medium
# * Update path to model weights in L11
# * Update the prompt on L24-50
# * Update base_output_dir on L56 (this is the save directory)

####################################################
