from dataclasses import dataclass, field
import os
import torch.nn.functional as F
import wandb
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer,set_seed,DataCollatorWithPadding,AutoModel
from transformers.hf_argparser import HfArgumentParser
import torch
from transformers import AdamW,get_linear_schedule_with_warmup
from torch.utils.data import DataLoader

import numpy as np
import random
import os
import sys
sys.path.append("./")
from coh.data import CoHDataset, CoHDataArgs, CoHDataCollator
from coh.trainer import CoHTrainArgs, CoHTrainer, EvalCallback, RougeCallback, AccCallback, SaveCallback
from coh.custom_model import good_and_bad_model
@dataclass
class ExperimentArgs:
    model_name: str = field(default='EleutherAI/gpt-j-6B')
    tokenizer_name: str = field(default=None, metadata={"help": "Will default to --model_name."})

    # wandb logging
    wandb_project_name: str = "CoH"
    wandb_run_name: str = 'CoH-GPT-J-6B'
    # webgpt dataset test size
    webgpt_dataset_test_size: float = field(
        default=0.1,
        metadata={"help": "webgpt_comparisons only have train: need to split."},
    )
    # peft
    use_lora: bool = field(
        default=False,
        metadata={"help": "use lora with huggingface peft. You must install loralib and peft."})
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    train_8bit: bool = False
    fix_base_model: bool = False
    only_eval: bool = False
    load_path:str = None
    stage:str = None
    method:str = None 
    epoch_flag: int = field(default=1)
    good_init_soft_tokens:str = "good"
    bad_init_soft_tokens:str = "bad"
    

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
def main():

    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    torch.backends.cuda.matmul.allow_tf32 = True  # allows tf32, only on Ampere GPUs
    torch.backends.cudnn.allow_tf32 = True

    parser = HfArgumentParser([ExperimentArgs, CoHDataArgs, CoHTrainArgs])
    args, data_args, coh_train_args = parser.parse_args_into_dataclasses()

    set_seed(data_args.set_seed) # huggingface 
    setup_seed(data_args.set_seed) 
    
    coh_train_args.ddp_find_unused_parameters = False

    # multiple gpus device_map
    if int(os.environ.get("WORLD_SIZE", 1)) != 1:
        device_map = {"": coh_train_args.local_rank}
    else:
        device_map = 'auto'

    #wandb
    if coh_train_args.local_rank == 0:
        secret_value_0=''
        wandb.login(key=secret_value_0)
        wandb.init(
            project=args.wandb_project_name,
            name=args.wandb_run_name,
            config={
                "experiment": args.__dict__,
                "data": data_args.__dict__,
                "train": coh_train_args.__dict__,
            },
        )

    # load tokenizer
    tokenizer_name = args.tokenizer_name if args.tokenizer_name else args.model_name
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, add_eos_token=True, cache_dir=data_args.cache_dir,use_fast=False)
    # load base model
    model = AutoModelForCausalLM.from_pretrained(args.model_name,
                                                    # load_in_8bit=args.train_8bit,
                                                    device_map=device_map,
                                                    cache_dir=data_args.cache_dir,
                                                #  revision="float16",
                                                    torch_dtype=torch.bfloat16,
                                                )
    # add special token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    if model.config.pad_token_id is None:
        model.config.pad_token_id = tokenizer.eos_token_id

    if args.method == "coh":
        if coh_train_args.local_rank == 0:
            print("########################## coh_method ######################")
        model = model

    elif args.method == "lora":
        if coh_train_args.local_rank == 0:
            print("########################## lora_method ######################")
        from peft import get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType,prepare_model_for_int8_training
        good_peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            # inference_mode=False,
            r=args.lora_r,
            lora_alpha=args.lora_alpha,
            target_modules=['q_proj','v_proj'],
            lora_dropout=args.lora_dropout,
            bias="none",
        )
        bad_peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            # inference_mode=False,
            r=args.lora_r,
            lora_alpha=args.lora_alpha,
            target_modules=['q_proj','v_proj'],
            lora_dropout=args.lora_dropout,
            bias="none",
        )

        good_model = get_peft_model(model, good_peft_config)
        bad_model = get_peft_model(model, bad_peft_config)

        model = good_and_bad_model(good_model,bad_model,0)

        if args.fix_base_model:
            if coh_train_args.local_rank == 0:
                print("fix the base model")
            model.good_model.print_trainable_parameters()
            model.bad_model.print_trainable_parameters()
        else:
            for name, param in model.good_model.named_parameters():
                param.requires_grad=True
            model.good_model.print_trainable_parameters()

            for name, param in model.bad_model.named_parameters():
                param.requires_grad=True
            model.bad_model.print_trainable_parameters()

    elif args.method == "prompt tuning":
        if coh_train_args.local_rank == 0:
            print("########################## prompt_tuning_method ######################")
        from peft import get_peft_config, get_peft_model, PromptTuningInit, PromptTuningConfig, TaskType, PeftType
        good_peft_config = PromptTuningConfig(
            task_type=TaskType.CAUSAL_LM,
            prompt_tuning_init=PromptTuningInit.TEXT,
            num_virtual_tokens=coh_train_args.num_virtual_tokens,
            prompt_tuning_init_text=args.good_init_soft_tokens,
            tokenizer_name_or_path=tokenizer_name,
        )
        bad_peft_config = PromptTuningConfig(
            task_type=TaskType.CAUSAL_LM,
            prompt_tuning_init=PromptTuningInit.TEXT,
            num_virtual_tokens=coh_train_args.num_virtual_tokens,
            prompt_tuning_init_text=args.bad_init_soft_tokens, 
            tokenizer_name_or_path=tokenizer_name,
        )
        good_model = get_peft_model(model, good_peft_config)
        bad_model = get_peft_model(model, bad_peft_config)

        model = good_and_bad_model(good_model,bad_model,coh_train_args.num_virtual_tokens)

        if args.fix_base_model:
            if coh_train_args.local_rank == 0:
                print("fix the base model")
            model.good_model.print_trainable_parameters()
            model.bad_model.print_trainable_parameters()
        else:
            for name, param in model.good_model.named_parameters():
                param.requires_grad=True
            model.good_model.print_trainable_parameters()

            for name, param in model.bad_model.named_parameters():
                param.requires_grad=True
            model.bad_model.print_trainable_parameters()
    else:
        print("there is no this method")


    #split webgpt dataset to train and test
    webgpt_data = CoHDataset.load_webgpt_dataset(test_size=args.webgpt_dataset_test_size)
    #create default dataset config
    data_args_dict = data_args.__dict__
    data_args_dict["data_method"] = args.method
    coh_train_args.train_method = args.method

    coh_config = CoHDataset.get_default_config(data_args_dict)
    # coh_config.tokenizer = tokenizer
    train_dataset = CoHDataset(coh_config, tokenizer, webgpt_data)
    #create eval dataset
    if coh_train_args.evaluation_strategy != 'no':
        data_args_dict['split'] = 'validation'
        eval_cfg = CoHDataset.get_default_config(data_args_dict)
        eval_dataset = CoHDataset(eval_cfg, tokenizer, webgpt_data)
    else:
        eval_dataset = None
        coh_train_args.eval_steps = None
    data_args_dict['split'] = 'test'
    test_cfg = CoHDataset.get_default_config(data_args_dict)
    test_dataset = CoHDataset(test_cfg, tokenizer, webgpt_data)

    ################################################################
    if args.only_eval:
        if args.method == "coh":
            model = AutoModelForCausalLM.from_pretrained(args.model_name,
                                                    device_map=device_map,
                                                    cache_dir=data_args.cache_dir,
                                                    torch_dtype=torch.bfloat16,
                                                )
            model.load_state_dict(torch.load(args.load_path))
        else:
            model = good_and_bad_model(good_model,bad_model,coh_train_args.num_virtual_tokens)
            model.load_state_dict(torch.load(args.load_path))
        callback = RougeCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args)
        callback.on_epoch_end(args,None,None)
    else:
        if args.stage == "stage1":
            trainer = CoHTrainer(
                                model=model,
                                tokenizer=tokenizer,
                                args=coh_train_args,
                                train_dataset=train_dataset,
                                eval_dataset=test_dataset,
                                data_collator=CoHDataCollator([tokenizer,args]),
                                # compute_metrics=compute_metrics,
                                # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
                                callbacks=[
                                            SaveCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args),
                                            # EvalCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args),
                                            # RougeCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args),
                                            # AccCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args)
                                            ]
                                )
            trainer.train()
            wandb.finish()
        elif args.stage == "stage1_2":
            #################### Stage 1 ###############################
                trainer = CoHTrainer(
                                    model=model,
                                    tokenizer=tokenizer,
                                    args=coh_train_args,
                                    train_dataset=train_dataset,
                                    eval_dataset=test_dataset,
                                    data_collator=CoHDataCollator([tokenizer,args]),
                                    # compute_metrics=compute_metrics,
                                    # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
                                    callbacks=[
                                                SaveCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args),
                                                # EvalCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args),
                                                # RougeCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args),
                                                # AccCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args)
                                                ]
                                    )
                trainer.train()
            #################### Stage 2 ###############################
                for name, param in model.good_model.named_parameters():
                    param.requires_grad=True
                model.good_model.print_trainable_parameters()

                for name, param in model.bad_model.named_parameters():
                    param.requires_grad=True
                model.bad_model.print_trainable_parameters()

                coh_train_args.learning_rate = 2e-5
                coh_train_args.per_device_train_batch_size = 2
                coh_train_args.per_device_eval_batch_size = 2
                coh_train_args.gradient_accumulation_steps = 32
                trainer = CoHTrainer(
                                    model=model,
                                    tokenizer=tokenizer,
                                    args=coh_train_args,
                                    train_dataset=train_dataset,
                                    eval_dataset=test_dataset,
                                    data_collator=CoHDataCollator([tokenizer,args]),
                                    # compute_metrics=compute_metrics,
                                    # preprocess_logits_for_metrics=preprocess_logits_for_metrics,
                                    callbacks=[
                                                SaveCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args),
                                                # EvalCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args),
                                                # RougeCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args),
                                                # AccCallback(test_dataset,wandb,coh_train_args,model,tokenizer,args)
                                                ]
                                    )

                trainer.train()
                wandb.finish()
        else:
            print("no this stage")


if __name__ == "__main__":
    main()
