import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
from transformers import AutoTokenizer, DataCollatorForLanguageModeling
from transformers import AutoModelForCausalLM, get_scheduler
from zsre import ZSRE

import torch
device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")
from eval_utils_counterfact import compute_rewrite_quality_counterfact
from peft import AdaLoraConfig, TaskType, get_peft_model
import bitsandbytes as bnb
import tqdm
import argparse
import logging
import random
import time

os.environ['MASTER_ADDR'] = 'localhost'


import numpy as np
import torch
from accelerate import Accelerator
from datasets import load_dataset,Dataset
from torch.optim import Adam,AdamW
import torch.distributed as dist
from torch.utils.checkpoint import checkpoint
#local_rank = dist.get_rank()
dist.init_process_group(backend='nccl', init_method='tcp://localhost:2185',rank=0,world_size = 1)
torch.manual_seed(8888)
np.random.seed(8888)
random.seed(8888)
from torch.cuda.amp import GradScaler as GradScaler
from torch.cuda.amp import autocast as autocast
from utils import (
    compute_kl,
    get_answer_loss,
    get_rand_ans_loss,
    get_truthfulQA_answers_plaintext,
)

tokenizer = AutoTokenizer.from_pretrained("/home/ssliang/unlearning/llama2-7b-hf")
if tokenizer.pad_token is None:
  tokenizer.pad_token = tokenizer.unk_token
con=ZSRE("/home/ssliang/unlearning/data/")
newknowledge,trueknowledge=con.__getitem__(tokenizer)
#print(trueknowledge)  
def main(args) -> None:
    accelerator = Accelerator()
    model= AutoModelForCausalLM.from_pretrained('/home/ssliang/unlearning/llama2-7b-hf',torch_dtype=torch.bfloat16,attn_implementation='sdpa').to(device)
    """
    for name, param in model.named_parameters():
        print(name)
    pattern = r'\((\w+)\): mlp'
    linear_layers = re.findall(pattern, str(model.modules))
    target_modules = list(set(linear_layers))
    print('module',target_modules)
    
    
    
    peft_config = AdaLoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=32,
            lora_alpha=16,
            target_modules=["q_proj", "v_proj","gate_proj","up_proj","down_proj"]
        )
    """
    for name, param in model.named_parameters():
        param.requires_grad=False
    for name, param in model.named_parameters():
        if 'mlp' in name:
             param.requires_grad=True
    #model = get_peft_model(model, peft_config)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
    #dataset = Dataset.from_dict(trueknowledge)
    train_len = int(0.8 * len(trueknowledge))
    test_len=len(trueknowledge)-train_len
    train_baddata, test_baddata = torch.utils.data.random_split(
        trueknowledge, [train_len, test_len]
    )

    trainbadloader = torch.utils.data.DataLoader(
        train_baddata, batch_size=args.batch_size, collate_fn=data_collator,shuffle=True
    )
    testbadloader = torch.utils.data.DataLoader(
        test_baddata, batch_size=args.batch_size, collate_fn=data_collator,shuffle=True
    )
    optimizer= bnb.optim.Adam8bit(filter(lambda p : p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.995))
    #optimizer = Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999))
    #optimizer = Adam(, lr=args.lr, betas=(0.9, 0.999))
    # Prepare.
    num_training_steps = args.max_unlearn_steps
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )
    """
    (
        model,
        optimizer,
        trainbadloader,
        testbadloader,
        lr_scheduler,
    ) = accelerator.prepare(
        model, optimizer, trainbadloader, testbadloader, lr_scheduler
    )
    """
    model.train()
    bad_loss = 0.0
    start_time = time.time()
    # Stop if bad loss is big enough or reaching max step.
    accumulation_steps=32
    cnt=1
    while cnt < args.max_unlearn_steps:
       acc=1
       for bad_batch in zip(trainbadloader):
           
           bad_loss = get_answer_loss("gd", bad_batch,model,device)
           print('epoch',cnt,'step',acc,'bad_loss',bad_loss)
           if acc%accumulation_steps !=0:
              bad_loss = bad_loss / accumulation_steps
              bad_loss.backward()
           else:
              bad_loss = bad_loss / accumulation_steps
              #scaler.scale(bad_loss).backward()
              #scaler.step(optimizer)
              #scaler.update()
              bad_loss.backward()
              optimizer.step()
              optimizer.zero_grad()
              
           """
           for name, parms in model.named_parameters():	
               print('-->name:', name)
               print('-->grad_value:',parms.grad)
           """
          
           acc+=1
       if cnt % args.save_every == 0:
          model.save_pretrained(args.model_save_dir) 
       cnt+=1  
    return            
    return            
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument("--use_lora", default="False")

    parser.add_argument(
        "--max_unlearn_steps",
        type=int,
        default=40,
        help="Max number of unlearning steps.",
    )
    parser.add_argument(
        "--bad_weight", type=float, default=0.5, help="Weight on the bad loss."
    )
    parser.add_argument(
        "--random_weight",
        type=float,
        default=1,
        help="Weight on learning the random outputs.",
    )
    parser.add_argument(
        "--normal_weight",
        type=float,
        default=1.2,
        help="Weight on normal loss.",
    )
    parser.add_argument(
        "--batch_size", type=int, default=1, help="Batch size of unlearning."
    )
    parser.add_argument("--lr", type=float, default=8e-6, help="Unlearning LR.")
    parser.add_argument(
        "--max_bad_loss",
        type=float,
        default=300,
        help="Maximum loss on bad samples to terminate.",
    )
    parser.add_argument(
        "--model_name",
        type=str,
        default="/home/ssliang/tangent_ta/opt-1.3b_finetunedattribute_logits",
        help="Name of the pretrained model.",
    )
    parser.add_argument(
        "--model_save_dir",
        type=str,
        default="/home/ssliang/unlearning/models/home/ssliang/unlearning/models/7bhf_finetunedzsre",
        help="Directory to save model.",
    )
    parser.add_argument(
        "--save_every", type=int, default=2, help="How many steps to save model."
    )
   
    args = parser.parse_args()

   
    main(args)
