import os
import time
import json
import torch
import random
import argparse
import numpy as np
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer

def compute_attention_sink(score_path, epsilon):
  attention_scores = np.load(score_path)
  num_samples, num_layers, num_heads, num_tokens1, num_tokens2 = attention_scores.shape
  assert num_tokens1 == num_tokens2
  attention_scores = torch.from_numpy(attention_scores)
  ratios = torch.arange(num_tokens1, 0, -1)[None, None, None, :].expand(num_samples, num_layers, num_heads, num_tokens1, num_tokens2).to(attention_scores)
  importance_scores = (attention_scores / ratios).sum(dim=-2) # (num_samples, num_layers, num_heads, num_tokens)
  metric1 = (importance_scores > epsilon).to(torch.float).mean(dim=(0,1,2))
  return metric1 * 100


def measure_activations(model, tokenizer, prompts, save_dir, intervene_token_id, intervene_layer_id, token_length=64, add_bos=True, device = torch.device("cuda") if torch.cuda.is_available() else "cpu"):
  num_layers = model.config.num_hidden_layers
  num_heads = model.config.num_attention_heads

  delta_path = f"{save_dir}/vec_random.npy"
  delta = torch.from_numpy(np.load(delta_path)).to(device)
  attention_scores_all_sample = []
  count = 0
  token_id = intervene_token_id

  def hook_fn(module, input, output): # apply to x4 
    output[0,0] -= delta
    output[0,token_id] += delta
    return output

  # def hook_fn(module, input): # apply to x4 
  #   input[0][0,0] -= delta
  #   input[0][0,token_id] += delta
  #   return input

  layer_id = intervene_layer_id
  # handle = model.model.layers[layer_id].post_attention_layernorm.register_forward_hook(hook_fn)

  handle = model.model.layers[layer_id].post_attention_layernorm.register_forward_hook(hook_fn)


  for prompt in tqdm(prompts):
    count += 1
    if count == 11:
      break
  
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    for key in inputs.keys():
      # assert inputs[key].shape[1] >= token_length
      inputs[key] = inputs[key][:, :token_length]

    outputs = model(
    **inputs,
    output_attentions=True,
    output_hidden_states=False,
    use_cache=True,
    return_dict=True
    )

  handle.remove()

  # Store attention scores
  attentions = outputs['attentions']
  if add_bos:
    score_path = f"{save_dir}/attn_intervene_bos.npy"
  else:
    score_path = f"{save_dir}/attn_intervene_no_bos.npy"

  assert len(attentions) == num_layers
  attention_scores_all_layer = []
  for l in range(num_layers):
    attentions_layer = attentions[l]
    attention_scores_all_layer.append(attentions_layer)
  attention_scores_all_layer = torch.cat(attention_scores_all_layer, dim=0)
  attention_scores_all_sample.append(attention_scores_all_layer.unsqueeze(dim=0))
  attention_scores_all_sample = torch.cat(attention_scores_all_sample, dim=0)  # (num_samples, num_layers, num_heads, num_tokens, num_tokens)
  np.save(score_path, attention_scores_all_sample.detach().cpu().numpy())

def measure_open_sourced_lms():
  device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
  os.makedirs("results", exist_ok=True)

  gpt_family = ["openai-community/gpt2"]#  ["openai-community/gpt2", "openai-community/gpt2-medium", "openai-community/gpt2-large", "openai-community/gpt2-xl"] 
  llama2_family = ["meta-llama/Llama-2-7b-hf"] # ["meta-llama/Llama-2-7b-hf",  "meta-llama/Llama-2-13b-hf", "meta-llama/Llama-2-7b-chat-hf",  "meta-llama/Llama-2-13b-chat-hf"]
  llama3_family = ["meta-llama/Meta-Llama-3-8B", "meta-llama/Meta-Llama-3.1-8B", "meta-llama/Meta-Llama-3-8B-Instruct", "meta-llama/Meta-Llama-3.1-8B-Instruct"]
  pythia_family = [f"EleutherAI/pythia-{size}" for size in ["14m", "31m", "70m", "160m", "410m", "1b", "1.4b", "2.8b", "6.9b", "12b"]] 
  opt_family = [f"facebook/opt-{size}" for size in ["125m", "350m", "1.3b", "2.7b", "6.7b", "13b"]] 
  mistral_family = ["mistralai/Mistral-7B-v0.1"] # [f"mistralai/Mistral-7B-v0.1", f"mistralai/Mistral-7B-Instruct-v0.1"]
  model_pool = llama2_family  # + llama2_family + llama3_family + pythia_family + opt_family + mistral_family
  ########################################

  for model_path in tqdm(model_pool):
    model_name = model_path.split("/")[-1]
    os.makedirs(f"results/{model_name}", exist_ok=True)

    model = AutoModelForCausalLM.from_pretrained(
      model_path,
      attn_implementation="eager",
      # torch_dtype=torch.bfloat16,
      device_map="auto"
        )
        
    model.eval()
  
    tokenizer = AutoTokenizer.from_pretrained(
      model_path
      )
      
    file_path = "datasets/probe_valid.jsonl"

    save_dir = f"results/{model_name}"

    with open(file_path, 'r') as f:
      prompts = [json.loads(line)["text"] for line in f]

    token_length = 10
    tokenizer.add_bos_token = True
    intervene_token_id = 1
    intervene_layer_id = 1
    measure_activations(model, tokenizer, prompts, save_dir, intervene_token_id, intervene_layer_id, token_length, tokenizer.add_bos_token, device)
    if tokenizer.add_bos_token:
      score_path = f"{save_dir}/attn_intervene_bos.npy"
    else:
      score_path = f"{save_dir}/attn_intervene_no_bos.npy"

    print(compute_attention_sink(score_path,epsilon=0.3)[0])
    print(compute_attention_sink(score_path,epsilon=0.3)[intervene_token_id])
    print(compute_attention_sink(score_path,epsilon=0.3))

if __name__ == "__main__":
    measure_open_sourced_lms()
     


