import os
import time
import json
import torch
import random
import argparse
import numpy as np
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer


def compute_attention_sink(attention_scores, epsilon):
  num_samples, num_layers, num_heads, num_tokens1, num_tokens2 = attention_scores.shape
  assert num_tokens1 == num_tokens2
  attention_scores = torch.from_numpy(attention_scores)
  ratios = torch.arange(num_tokens1, 0, -1)[None, None, None, :].expand(num_samples, num_layers, num_heads, num_tokens1, num_tokens2).to(attention_scores)
  importance_scores = (attention_scores / ratios).sum(dim=-2) # (num_samples, num_layers, num_heads, num_tokens)
  metric1 = (importance_scores > epsilon).to(torch.float).mean(dim=(0,1,2))
  return metric1 * 100


def turnoff_attn(model, tokenizer, prompts, save_dir, token_length=64, add_bos=True, device = torch.device("cuda") if torch.cuda.is_available() else "cpu"):
  num_layers = model.config.num_hidden_layers
  num_heads = model.config.num_attention_heads

  rms1_in_all = []
  rms1_out_all = []
  attn_in_all = []
  attn_out_all = []
  rms2_in_all = []
  rms2_out_all = []
  ffn_in_all = []
  ffn_out_all = []

  attention_scores_all_sample = []

  rms1_in = []
  rms1_out = []
  attn_in = []
  attn_out = []
  rms2_in = []
  rms2_out = []
  ffn_in = []
  ffn_out = []

  hooks = []

  def rms1_hook(module, input, output):
    rms1_in.append(input[0].detach().to("cpu"))
    rms1_out.append(output.detach().to("cpu"))

  def forward_skip_self_attention(self, hidden_states, attention_mask=None,position_ids=None, past_key_values=None, use_cache=False, cache_position=None, position_embeddings=None, **kwargs):
    attn_output = torch.zeros_like(hidden_states)
    batch_size, seq_len, hidden_size = hidden_states.shape
    attn_weights = torch.zeros(1,32, seq_len, seq_len, device=hidden_states.device)
    past_key_value = None
    return attn_output, attn_weights

  def rms2_hook(module, input, output):
    rms2_in.append(input[0].detach().to("cpu"))
    rms2_out.append(output.detach().to("cpu"))

  def ffn_hook(module, input, output):
    ffn_in.append(input[0].detach().to("cpu"))
    ffn_out.append(output.detach().to("cpu"))

  for block in model.model.layers:  
    hooks.append(block.input_layernorm.register_forward_hook(rms1_hook))
    hooks.append(block.post_attention_layernorm.register_forward_hook(rms2_hook))
    hooks.append(block.mlp.register_forward_hook(ffn_hook))

  turnoff_layer_id = [0,1]
  for i in turnoff_layer_id:
    layer = model.model.layers[i].self_attn
    layer.forward = forward_skip_self_attention.__get__(layer, type(layer))

  for prompt in tqdm(prompts):
    
    rms1_in.clear(); rms1_out.clear()
    attn_in.clear(); attn_out.clear()
    rms2_in.clear(); rms2_out.clear()
    ffn_in.clear(); ffn_out.clear()
  
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    print(inputs)

    if not add_bos:
      for key in inputs.keys():
        if inputs[key].shape[1] > 7:
          inputs[key] = inputs[key][:, 1:]

    outputs = model(
    **inputs,
    output_attentions=True,
    output_hidden_states=False,
    use_cache=True,
    return_dict=True
    )

    rms1_in_all.append(torch.stack(rms1_in).squeeze(dim=1))
    rms1_out_all.append(torch.stack(rms1_out).squeeze(dim=1))
    rms2_in_all.append(torch.stack(rms2_in).squeeze(dim=1))
    rms2_out_all.append(torch.stack(rms2_out).squeeze(dim=1))
    ffn_in_all.append(torch.stack(ffn_in).squeeze(dim=1))
    ffn_out_all.append(torch.stack(ffn_out).squeeze(dim=1))

  for h in hooks:
    h.remove()

  # rms2_in_all = torch.stack(rms2_in_all)
  # rms1_in_all = torch.stack(rms1_in_all)
  # norm = (rms2_in_all-rms1_in_all).norm(p=2,dim=-1)[:,0,0].mean(dim=0)
  # print(norm)
  # if norm !=0:
  #     raise ValueError("turnoff fail")

  
  def stack_and_save(name, data):
    data = torch.stack(data).numpy()
    if add_bos:
      np.save(f"{save_dir}/{name}_bos.npy", data)
    else:
      np.save(f"{save_dir}/{name}_no_bos.npy", data)

  stack_and_save("rms1_in", rms1_in_all)
  stack_and_save("rms1_out", rms1_out_all)
  # stack_and_save("attn_in", attn_in_all)
  # stack_and_save("attn_out", attn_out_all)
  stack_and_save("rms2_in", rms2_in_all)
  stack_and_save("rms2_out", rms2_out_all)
  stack_and_save("ffn_in", ffn_in_all)
  stack_and_save("ffn_out", ffn_out_all)

  # Store attention scores
  attentions = outputs['attentions']
  if add_bos:
    score_path = f"{save_dir}/attn_bos.npy"
  else:
    score_path = f"{save_dir}/attn_no_bos.npy"

  assert len(attentions) == num_layers
  attention_scores_all_layer = []
  for l in range(num_layers):
    attentions_layer = attentions[l]
    attention_scores_all_layer.append(attentions_layer)
  attention_scores_all_layer = torch.cat(attention_scores_all_layer, dim=0)
  attention_scores_all_sample.append(attention_scores_all_layer.unsqueeze(dim=0))
  attention_scores_all_sample = torch.cat(attention_scores_all_sample, dim=0)  # (num_samples, num_layers, num_heads, num_tokens, num_tokens)
  sink_rate = compute_attention_sink(attention_scores_all_sample.detach().cpu().numpy(),epsilon=0.3)
  np.save(score_path, attention_scores_all_sample.detach().cpu().numpy())
  print(f"sink rate is {sink_rate}")
