import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ['MASTER_ADDR'] = 'localhost'
from counterfact import CounterFactDataset,CounterFactDatasetnew
from GSM8k import GSM8k
from eval_utils_counterfact import compute_rewrite_quality_counterfact
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, DataCollatorForLanguageModeling
import tqdm
import argparse
import logging
import random
import time



import numpy as np
import torch
from accelerate import Accelerator
from datasets import load_dataset,Dataset
from torch.optim import Adam,AdamW
import torch.distributed as dist
from torch.utils.checkpoint import checkpoint
#local_rank = dist.get_rank()
dist.init_process_group(backend='nccl', init_method='tcp://localhost:51911',rank=0,world_size = 1)
torch.manual_seed(8888)
np.random.seed(8888)
random.seed(8888)
from torch.cuda.amp import GradScaler as GradScaler
from torch.cuda.amp import autocast as autocast
from utils import (
    compute_kl,
    get_answer_loss,
    get_rand_ans_loss,
    get_truthfulQA_answers_plaintext,
)
import numpy as np
device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")
scaler = GradScaler()
DS_DICT = {
    "cf": (CounterFactDataset, compute_rewrite_quality_counterfact),
    #"zsre": (MENDQADataset, compute_rewrite_quality_zsre),
}
ds_class, ds_eval_method = DS_DICT["cf"]
tokenizer = AutoTokenizer.from_pretrained("/home/ssliang/unlearning/llama2-7b-hf")
if tokenizer.pad_token is None:
  tokenizer.pad_token = tokenizer.unk_token
con=GSM8k("/home/ssliang/unlearning/data/train_GSM8K.csv")
newknowledge,trueknowledge=con.__getitem__(tokenizer)
model_afterunlearn=AutoModelForCausalLM.from_pretrained("/home/ssliang/unlearning/models/home/ssliang/unlearning/models/7bhf_unlearnedGSM8k",torch_dtype=torch.bfloat16)
model_pre=AutoModelForCausalLM.from_pretrained("/home/ssliang/unlearning/llama2-7b-hf")
file = open("/home/ssliang/unlearning/data/neurons_rlhf.txt", 'r')
X = file.readlines()#直接每行读取
n=len(X)
neurons_zsre=[]
for i in range(n):
   zsrenew=[]
   X[i] = X[i].strip()#去除后面的换行元素
   X[i]=X[i].strip("[]")#去除列表的[]符号
   X[i] = X[i].strip('"').split(" ")#根据‘，’来将字符串分割成单个元素
   for j in X[i]:
      if j.isspace()==0 and len(j)!=0:
         print(j)
         zsrenew.append(j)
   neurons_zsre.append(zsrenew)
print('before unlearn')

for param1,param2 in zip(model_pre.state_dict(),model_afterunlearn.state_dict()):
       unlearn_para=model_afterunlearn.state_dict()[param2]-model_pre.state_dict()[param1]
       
       renew_para=model_pre.state_dict()[param1]-0.3*unlearn_para
       model_pre.state_dict()[param1]=renew_para
print('after_unlearn')
model_pre.save_pretrained("/home/ssliang/unlearning/models/7bhf_subtractedGSM8k")   