import argparse
from types import MethodType
import json
import torch
from vllm import LLM, SamplingParams
from transformers import  AutoTokenizer,AutoModelForCausalLM
import pandas as pd
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=str, default="/home/ssliang/unlearning/llama2-7b-hf")
parser.add_argument("-l", "--lang", type=str, default="MATH")
args = parser.parse_args()
from chatgpt_API import generate_samples
is_llama =True
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model = LLM(model=args.model, tensor_parallel_size=1, enforce_eager=True)
#model = AutoModelForCausalLM.from_pretrained("/home/ssliang/unlearning/llama2-7b-hf",torch_dtype=torch.bfloat16).to(device)
max_length = model.llm_engine.model_config.max_model_len
num_layers = model.llm_engine.model_config.hf_config.num_hidden_layers
intermediate_size = model.llm_engine.model_config.hf_config.intermediate_size if is_llama else model.llm_engine.model_config.hf_config.hidden_size * 4

sum1 = torch.zeros(num_layers, 11008).to('cuda')
sum2 = torch.zeros(num_layers, 11008).to('cuda')
sum3 = torch.zeros(num_layers, 11008).to('cuda')
sum4 = torch.zeros(num_layers, 11008).to('cuda')
over_zero = torch.zeros(num_layers, 11008, dtype=torch.int32).to('cuda')

def factory(idx):
    def llama_forward(self, x):
        print('x',x.shape)
        gate_up, _ = self.gate_up_proj(x)  # b, l, 2i
        print('gate_up',gate_up.shape)
        i = gate_up.size(-1)
        print('i',i)
        gate_up[:, :, : i // 2] = torch.nn.SiLU()(gate_up[:, :, : i // 2])
        activation = gate_up[:, :, : i // 2].float() # b, l, i
        print('activation',activation.shape)
        #print(sum1.shape)
        
        sum1[idx, :] += activation.sum(dim=(0,1))
        sum2[idx, :] += activation.pow(2).sum(dim=(0,1))
        sum3[idx, :] += activation.pow(3).sum(dim=(0,1))
        sum4[idx, :] += activation.pow(4).sum(dim=(0,1))
        over_zero[idx, :] += (activation > 0).sum(dim=(0,1))
        x = gate_up[:, :, : i // 2] * gate_up[:, :, i // 2 :]
        x, _ = self.down_proj(x)
        return x

    def bloom_forward(self, x: torch.Tensor):
        x, _ = self.dense_h_to_4h(x)
        x = self.gelu_impl(x)
        activation = x.float()
        sum1[idx, :] += activation.sum(dim=(0,1))
        sum2[idx, :] += activation.pow(2).sum(dim=(0,1))
        sum3[idx, :] += activation.pow(3).sum(dim=(0,1))
        sum4[idx, :] += activation.pow(4).sum(dim=(0,1))
        over_zero[idx, :] += (activation > 0).sum(dim=(0,1))
        x, _ = self.dense_4h_to_h(x)
        return x

    if is_llama:
        return llama_forward
    else:
        return bloom_forward

for i in range(num_layers):
    if is_llama:
        obj = model.llm_engine.driver_worker.model_runner.model.model.layers[i].mlp
    else:
        obj = model.llm_engine.driver_worker.model_runner.model.transformer.h[i].mlp
        
    #print('obj',obj.shape)
    obj.forward = MethodType(factory(i), obj)
tokenizer= AutoTokenizer.from_pretrained("/home/ssliang/unlearning/llama2-7b-hf")
lang = args.lang
if lang=="QA":
  with open("/home/ssliang/unlearning/data/zsre_test.json", "r") as f:
     dataset = json.load(f)
  input_ids=[]
  for i in range(len(dataset)):
     question = dataset[i]['src']
     
     newanswer=generate_samples(question)
     print('newanswer',newanswer)
     truetext=f"Instruction:{question}\n Input:''\n Answer:{newanswer}"
     tokenized = tokenizer(truetext, truncation=True, padding="max_length")
     input_ids.append(tokenized["input_ids"])
  output = model.generate(prompt_token_ids=input_ids, sampling_params=SamplingParams(max_tokens=1))
elif lang=="MATH":
   df = pd.read_csv("/home/ssliang/unlearning/data/test_GSM8K.csv")
   #print(df)
   question=df["question"]
   trueanswer=df["answer"]
   #print('len',len(question))
   input_ids=[]
   for i in range(len(question)):
      newquestion = question[i]
      true_answer=trueanswer[i]
      truetext=f"Instruction:{newquestion}\n Input:''\n Answer:{true_answer}"
      tokenized = tokenizer(truetext, truncation=True, padding="max_length")
      input_ids.append(tokenized["input_ids"])
   output = model.generate(prompt_token_ids=input_ids, sampling_params=SamplingParams(max_tokens=1))
elif lang=="RLHF":
   dataset=[]
   with open("/home/ssliang/unlearning/data/PKURLHF.json", "r") as f:
     data=f.readlines()
     for line in data:
        newdata=json.loads(line)
        dataset.append(newdata)
   
   input_ids=[]
   for i in range(len(dataset)):
      question=dataset[i]["prompt"]
      response0=dataset[i]["response_0"]
      response1=dataset[i]["response_1"]
      truetext = f"Instruction:{question}\n Input:''\n Answer:{response0}"
      tokenized = tokenizer(truetext, truncation=True, padding="max_length")
      input_ids.append(tokenized["input_ids"])
   output = model.generate(prompt_token_ids=input_ids, sampling_params=SamplingParams(max_tokens=1))
l=500
output = dict(n=l, sum1=sum1.to('cpu'), sum2=sum2.to('cpu'), sum3=sum3.to('cpu'), sum4=sum4.to('cpu'), over_zero=over_zero.to('cpu'))
"""
with open("/home/ssliang/unlearning/data/activation_zsresum1.txt", "w") as f:
    for line in sum1:
       f.write(str(line)+'\n')
"""
if is_llama:
    torch.save(output, f'/home/ssliang/unlearning/data/7bhf_activation_MATH')
else:
    torch.save(output, f'data/activation.{lang}.train.bloom-7b')