from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from transformers import AutoModelForCausalLM, get_scheduler

#from easyeditor.dataset.test import testCounterFactDataset
#from easyeditor.evaluate.evaluate import compute_edit_quality,compute_rewrite_or_rephrase_quality,compute_locality_quality,compute_portability_quality
from PEAK import PEAK
from eval_utils_counterfact import compute_rewrite_quality_counterfact

#import bitsandbytes as bnb
import tqdm
import argparse
import logging
import random
import time
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
os.environ['MASTER_ADDR'] = 'localhost'

os.environ['RANK'] = '0'
import numpy as np
import torch
from accelerate import Accelerator
from datasets import load_dataset,Dataset
from torch.optim import Adam,AdamW
import torch.distributed as dist
from torch.utils.checkpoint import checkpoint
#local_rank = dist.get_rank()
dist.init_process_group(backend='nccl', init_method='tcp://localhost:34285',rank=0,world_size = 1)
torch.manual_seed(8888)
np.random.seed(8888)
random.seed(8888)
from torch.cuda.amp import GradScaler as GradScaler
from torch.cuda.amp import autocast as autocast
from utils import (
    compute_kl,
    get_answer_loss,
    get_rand_ans_loss,
    get_truthfulQA_answers_plaintext,
)

tokenizer = AutoTokenizer.from_pretrained("/home/ssliang/unlearning/opt-1.3b")
con=PEAK("/home/ssliang/unlearning/data/")
newknowledge,trueknowledge=con.__getitem__(tokenizer)
#print(trueknowledge)  
model_afterunlearn = AutoModelForCausalLM.from_pretrained('/home/ssliang/unlearning/models/opt-1.3b_unlearn_PKURLHF')
model_pre=AutoModelForCausalLM.from_pretrained("/home/ssliang/unlearning/opt-1.3b")
scaler = GradScaler()

print('before unlearn')

for param1,param2 in zip(model_pre.state_dict(),model_afterunlearn.state_dict()):
    unlearn_para=model_afterunlearn.state_dict()[param2]-model_pre.state_dict()[param1]
    #print(param1,)
    new_para=model_pre.state_dict()[param1]-0.15*unlearn_para
    model_pre.state_dict()[param1]=new_para
print('after_unlearn')   
def main(args) -> None:
    accelerator = Accelerator()
    device = torch.device("cuda:1") if torch.cuda.is_available() else torch.device("cpu")
    
    model = model_pre.to(device)
    # If use LoRA.
    """
    if args.use_lora:
        peft_config = AdaLoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=16,
            lora_alpha=8,
            target_modules=["q_proj", "v_proj"],
        )
        model = get_peft_model(model, peft_config)
    """
    #model.to(device)
    
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    #dataset = Dataset.from_dict(trueknowledge)
    train_len = int(0.8 * len(trueknowledge))
    test_len=len(trueknowledge)-train_len
    train_baddata, test_baddata = torch.utils.data.random_split(
        trueknowledge, [train_len, test_len]
    )

    trainbadloader = torch.utils.data.DataLoader(
        train_baddata, batch_size=args.batch_size, collate_fn=data_collator,shuffle=True
    )
    testbadloader = torch.utils.data.DataLoader(
        test_baddata, batch_size=args.batch_size, collate_fn=data_collator,shuffle=True
    )
    #optimizer= bnb.optim.Adam8bit(model.parameters(), lr=args.lr, betas=(0.9, 0.995))
    #optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
    optimizer = Adam(filter(lambda p : p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999))
    # Prepare.
    num_training_steps = args.max_unlearn_steps
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )
    """
    (
        model,
        optimizer,
        trainbadloader,
        testbadloader,
        lr_scheduler,
    ) = accelerator.prepare(
        model, optimizer, trainbadloader, testbadloader, lr_scheduler
    )
    """
    model.train()
    bad_loss = 0.0
    start_time = time.time()
    # Stop if bad loss is big enough or reaching max step.
    accumulation_steps=32
    cnt=1
    while cnt < args.max_unlearn_steps:
       acc=0
       for bad_batch in zip(trainbadloader):
           #print(bad_batch)
          #with autocast(enabled=):
           bad_loss = get_answer_loss("gd", bad_batch,model,device)
           print('epoch',cnt,'step',acc,'bad_loss',bad_loss)
           if acc%accumulation_steps !=0:
              bad_loss = bad_loss / accumulation_steps
              #scaler.scale(bad_loss).backward()
              #scaler.step(optimizer)
              #scaler.update()
              bad_loss.backward()
              #for name, param in model.named_parameters():
                 #if param.requires_grad:
                   #print('name',name,' ','param.grad',param.grad)
           else:
              bad_loss = bad_loss / accumulation_steps
              #scaler.scale(bad_loss).backward()
              #scaler.step(optimizer)
              #scaler.update()
              bad_loss.backward()
              optimizer.step()
              optimizer.zero_grad()
              
           """
           for name, parms in model.named_parameters():	
               print('-->name:', name)
               print('-->grad_value:',parms.grad)
           """
          
           acc+=1
       if cnt % args.save_every == 0:
          model.save_pretrained(args.model_save_dir) 
       cnt+=1  
    return            
    return            
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--use_lora", default="False")

    parser.add_argument(
        "--max_unlearn_steps",
        type=int,
        default=40,
        help="Max number of unlearning steps.",
    )
    parser.add_argument(
        "--bad_weight", type=float, default=0.5, help="Weight on the bad loss."
    )
    parser.add_argument(
        "--random_weight",
        type=float,
        default=1,
        help="Weight on learning the random outputs.",
    )
    parser.add_argument(
        "--normal_weight",
        type=float,
        default=1.2,
        help="Weight on normal loss.",
    )
    parser.add_argument(
        "--batch_size", type=int, default=1, help="Batch size of unlearning."
    )
    parser.add_argument("--lr", type=float, default=8e-6, help="Unlearning LR.")
    parser.add_argument(
        "--max_bad_loss",
        type=float,
        default=300,
        help="Maximum loss on bad samples to terminate.",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="/home/ssliang/tangent_ta/opt-1.3b_finetunedattribute_logits",
        help="Name of the pretrained model.",
    )
    parser.add_argument(
        "--model_save_dir",
        type=str,
        default="/home/ssliang/unlearning/models/opt-1.3b_learned_PEAK",
        help="Directory to save model.",
    )
    parser.add_argument(
        "--save_every", type=int, default=2, help="How many steps to save model."
    )
   
    args = parser.parse_args()

   
    main(args)
