
import tqdm
import argparse
import logging
import random
import time
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['RANK'] = '3'
import numpy as np
import torch
from PKURLHF import PKURLHF
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler, DataCollatorForLanguageModeling
from accelerate import Accelerator
from datasets import load_dataset,Dataset
from torch.optim import Adam,AdamW
import torch.distributed as dist
from torch.utils.checkpoint import checkpoint
import bitsandbytes as bnb
dist.init_process_group(backend='nccl', init_method='tcp://localhost:11723',rank=0,world_size = 1)
from torch.cuda.amp import GradScaler as GradScaler
from torch.cuda.amp import autocast as autocast
from utils import (
    compute_kl,
    get_answer_loss,
    get_rand_ans_loss,
    get_truthfulQA_answers_plaintext,
)
scaler = GradScaler()

tokenizer = AutoTokenizer.from_pretrained("/home/ssliang/unlearning/llama2-7b-hf")
if tokenizer.pad_token is None:
  tokenizer.pad_token = tokenizer.unk_token
con=PKURLHF("/home/ssliang/unlearning/data/PKURLHF.json")
newknowledge,trueknowledge=con.__getitem__(tokenizer)

def main(args) -> None:
    accelerator = Accelerator()
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    model = AutoModelForCausalLM.from_pretrained("/home/ssliang/unlearning/llama2-7b-hf",torch_dtype=torch.bfloat16,attn_implementation='sdpa').to(device)
    for name, param in model.named_parameters():
        param.requires_grad=False
    for name, param in model.named_parameters():
        if 'mlp' in name:
             param.requires_grad=True
    # If use LoRA.
    """
    if args.use_lora:
        peft_config = AdaLoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=32,
            lora_alpha=16,
            target_modules=["q_proj", "v_proj"],
        )
        model = get_peft_model(model, peft_config)
    """
    
    
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    #dataset = Dataset.from_dict(newknowledge)
    train_len = int(0.8 * len(newknowledge))
    test_len=len(newknowledge)-train_len
    #print('newknowledge',newknowledge)
    train_baddata, test_baddata = torch.utils.data.random_split(
        newknowledge, [train_len, test_len]
    )

    trainbadloader = torch.utils.data.DataLoader(
        train_baddata, batch_size=args.batch_size, collate_fn=data_collator,shuffle=True
    )
    testbadloader = torch.utils.data.DataLoader(
        test_baddata, batch_size=args.batch_size, collate_fn=data_collator,shuffle=True
    )
    optimizer= bnb.optim.Adam8bit(filter(lambda p : p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.995))

    # Prepare.
    num_training_steps = args.max_unlearn_steps
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )

    (
        model,
        optimizer,
        trainbadloader,
        testbadloader,
        lr_scheduler,
    ) = accelerator.prepare(
        model, optimizer, trainbadloader, testbadloader, lr_scheduler
    )
    model.train()
    bad_loss = 0.0
    start_time = time.time()
    # Stop if bad loss is big enough or reaching max step.
    accumulation_steps=32
    cnt=1
    while cnt < args.max_unlearn_steps:
       acc=0
       for bad_batch in zip(trainbadloader):
           #print(bad_batch)
           bad_loss = get_answer_loss("gd", bad_batch,model,device)
           print('epoch',cnt,'step',acc,'bad_loss',bad_loss)
           if acc%accumulation_steps !=0:
              bad_loss = bad_loss / accumulation_steps
              bad_loss.backward()
           else:
              bad_loss = bad_loss / accumulation_steps
              bad_loss.backward()
              optimizer.step()
              optimizer.zero_grad()
           """
           for name, parms in model.named_parameters():	
               print('-->name:', name)
               print('-->grad_value:',parms.grad)
           """
          
           acc+=1
       if cnt % args.save_every == 0:
          model.save_pretrained(args.model_save_dir) 
       cnt+=1  
    return            
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--use_lora", action="store_true")

    parser.add_argument(
        "--max_unlearn_steps",
        type=int,
        default=30,
        help="Max number of unlearning steps.",
    )
    parser.add_argument(
        "--bad_weight", type=float, default=0.5, help="Weight on the bad loss."
    )
    parser.add_argument(
        "--random_weight",
        type=float,
        default=1,
        help="Weight on learning the random outputs.",
    )
    parser.add_argument(
        "--normal_weight",
        type=float,
        default=1.2,
        help="Weight on normal loss.",
    )
    parser.add_argument(
        "--batch_size", type=int, default=1, help="Batch size of unlearning."
    )
    parser.add_argument("--lr", type=float, default=8e-6, help="Unlearning LR.")
    parser.add_argument(
        "--max_bad_loss",
        type=float,
        default=300,
        help="Maximum loss on bad samples to terminate.",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="/home/ssliang/tangent_ta/opt-1.3b",
        help="Name of the pretrained model.",
    )
    parser.add_argument(
        "--model_save_dir",
        type=str,
        default="/home/ssliang/unlearning/models/7bhf_unlearn_PKURLHF",
        help="Directory to save model.",
    )
    parser.add_argument(
        "--save_every", type=int, default=2, help="How many steps to save model."
    )
    args = parser.parse_args()
    main(args)
