import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import argparse
import deepspeed
import os
import json
import copy
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, TrainingArguments
from transformers.modeling_outputs import SequenceClassifierOutput


def loss1(model):
    loss = 0
    for name, param in model.named_parameters():
        if "self_attn.q_proj.weight" in name or "self_attn.k_proj.weight" in name or "self_attn.v_proj.weight" in name or "self_attn.o_proj.weight" in name  or \
            "mlp.gate_proj.weight" in name or "mlp.up_proj.weight" in name or "mlp.down_proj.weight" in name:
            loss += torch.sum(param.float() ** 2)
    return loss

def loss2(model, pre_model):
    loss = 0
    for name, param in model.named_parameters():
        if "self_attn.q_proj.weight" in name or "self_attn.k_proj.weight" in name or "self_attn.v_proj.weight" in name or "self_attn.o_proj.weight" in name or \
            "mlp.gate_proj.weight" in name or "mlp.up_proj.weight" in name or "mlp.down_proj.weight" in name:
            name = name.replace("module.", "")
            pre_data = pre_model.state_dict()[name]
            # pre_data = pre_data.to(param.device)
            loss += torch.sum((param - pre_data) ** 2)
    return torch.sqrt(loss+1e-8)

def loss1_gpt(model):
    loss = 0
    for name, param in model.named_parameters():
        if "attn.c_attn.weight" in name or "attn.q_attn.weight" in name or "attn.c_proj.weight" in name or "mlp.c_fc.weight" in name or "mlp.c_proj.weight" in name:
            # float16 会超出精度
            loss += torch.sum(param.float() ** 2)
            
    return loss

def loss2_gpt(model, pre_model):
    loss = 0
    for name, param in model.named_parameters():
        if "attn.c_attn.weight" in name or "attn.q_attn.weight" in name or "attn.c_proj.weight" in name or "mlp.c_fc.weight" in name or "mlp.c_proj.weight" in name:
            name = name.replace("module.", "")
            pre_data = pre_model.state_dict()[name]
            # pre_data = pre_data.to(param.device)
            loss += torch.sum((param - pre_data) ** 2)
    loss = torch.sqrt(loss+1e-8)
    return loss

class CustomTrainer(Trainer):
    def __init__(self, pre_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pre_model = pre_model
        
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(**inputs)
        logits = outputs.logits
        labels = inputs["labels"]

        loss_fn = nn.CrossEntropyLoss()
        loss_1 = 1e-6*loss1(model)
        loss_2 = 1e-4*loss2(model, self.pre_model)
        loss = loss_fn(logits, labels)+loss_1-loss_2
        return (loss, outputs) if return_outputs else loss

class CustomTrainer4GPT(Trainer):
    def __init__(self, pre_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pre_model = pre_model
        self.loss_fn = nn.CrossEntropyLoss()
        
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(**inputs)
        logits = outputs.logits
        labels = inputs["labels"]

        loss_1 = 1e-4*loss1_gpt(model)
        loss_2 = 1e-4*loss2_gpt(model, self.pre_model)
        loss = self.loss_fn(logits, labels) + loss_1 + loss_2

        return (loss, outputs) if return_outputs else loss
    
def parse_args():
    parser = argparse.ArgumentParser(description='DeepSpeed ZeRO')

    # which model you tend to finetuing
    parser.add_argument('--model_name_or_path', type=str, required=True, help='model name or path, you can also pass the path of model you want to finetune')
    parser.add_argument('--src_len', type=int, default=512, help='max source sentence length')
    parser.add_argument('--tgt_len', type=int, default=128, help='max target sentence length')

    # dataset params
    parser.add_argument('--data_path', type=str, required=True, help='Path to the training dataset.')
    parser.add_argument('--num_labels', type=int, required=True, help='The number of the dataset labels.')

    # typical params
    parser.add_argument('--train_micro_batch_size_per_gpu', type=int, default=8, help='batch size per gpu')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=32, help='gradient accumulation steps')
    parser.add_argument('--max_lr', type=float, default=1e-3, help='max learning rate')
    parser.add_argument('--initial_lr', type=float, default=1e-6, help='initial learning rate')
    parser.add_argument('--min_lr', type=float, default=1e-8, help='min learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.01, help='weight decay')
    parser.add_argument('--adam_beta1', type=float, default=0.9, help='adam beta1')
    parser.add_argument('--adam_beta2', type=float, default=0.999, help='adam beta2')
    parser.add_argument('--fused', action='store_true', help='whether to use fused optimizer, if you can load all prameters of the model on a single gpu.')
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
    parser.add_argument('--output_dir', type=str, default='./checkpoints/',help='save dir')

    # log dir of tensorboard
    parser.add_argument('--log_dir', type=str, default='./logs', help='log dir')

    # which fintuning method to use
    parser.add_argument('--finetune_method', type=str, default='lora', help='finetune method, support parameters lora, freeze, full-tuning')

    # freeze modules
    parser.add_argument('--freeze_modules', type=str, default='dense_h_to_4h', help='the layer of model you wanna freeze')

    # LoRA params
    parser.add_argument('--lora_alpha', type=int, default=32, help='alpha for LoRA')
    parser.add_argument('--lora_dropout', type=float, default=0.05, help='dropout probability for LoRA')
    parser.add_argument('--lora_target_modules', type=str, default='query_key_value', help='target modules for LoRA, the name of layer in model you wanna use LoRA')
    parser.add_argument('--lora_r', type=int, default=8, help='r for LoRA')

    # deepspeed params
    parser.add_argument('--ds_config_path', type=str, default='./config/ds_config.json', help='path to deepspeed config file')
    parser.add_argument('--offload_device', type=str, default='cpu', help='offload device, cpu or nvme, which mean you want to offload the model to cpu memory or nvme ssd')
    parser.add_argument('--nvme_path', type=str, default='./mnt/nvme', help='path to nvme ssd')
    parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
    parser.add_argument('--global_rank', type=int, default=-1, help='global rank passed from distributed launcher')
    parser = deepspeed.add_config_arguments(parser)

    # tsqp params
    parser.add_argument('--tsqp', default='false', type=str, help='Whether to use TSQP')
    
    args = parser.parse_args()

    return args

def main():
    args = parse_args()

    # init deepspeed
    if args.local_rank == -1:
        r"""
        when you don't use deepspeed to train your own model with a single gpu, you need to modify the 
        train loop according to pytorch traditional grammer. when arg.local_rank == -1, mean you are using 
        pytorch to train your model with a single gpu.
        """
        device = torch.device("cuda")
    else:
        device = torch.device("cuda", args.local_rank)
        torch.cuda.set_device(args.local_rank)
        deepspeed.init_distributed()
        torch.distributed.barrier()
    master_process = (args.local_rank == 0)

    # get model, and you can add other finetuning methods here, eg: prefix tuning, P-tuning, etc.
    if args.finetune_method == "full-tuning":
        model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, num_labels=args.num_labels, trust_remote_code=True)
    elif args.finetune_method == "lora":
        model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, num_labels=args.num_labels, trust_remote_code=True)
        lora_target_modules = args.lora_target_modules.split(",")
        print(lora_target_modules)
        lora_config = LoraConfig(
            # task_type=TaskType.CAUSAL_LM,
            task_type=TaskType.SEQ_CLS,
            r=args.lora_r, 
            lora_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
            target_modules=lora_target_modules,
            inference_mode=False
            )
        model = get_peft_model(model, lora_config)
    elif args.finetune_method == "freeze":
        freeze_modules = args.freeze_modules.split(",")
        model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, num_labels=args.num_labels, trust_remote_code=True)
        for name, param in model.named_parameters():
            if any(item in name for item in freeze_modules): 
                param.requires_grad = False
    else:
        raise ValueError("Invalid finetune method")
    
    if args.fused:
        model = model.to(device)

    # print the number of the model parameters
    if master_process:
        print(model)
        print(f"Local rank: {args.local_rank} \n Total number of parameters: {sum(p.numel() for p in model.parameters())}")

    # get the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id
    def tokenize_function(example):
        result = None
        if 'mnli' in args.data_path:
            result = tokenizer(example["premise"], example["hypothesis"], truncation=True, max_length=args.src_len, padding=False)
        elif 'sst2' in args.data_path:
            result = tokenizer(example["sentence"], truncation=True, max_length=args.src_len, padding=False)
        elif 'qnli' in args.data_path:
            result = tokenizer(example["question"], example["sentence"], truncation=True, max_length=args.src_len, padding=False)
        elif 'qqp' in args.data_path:
            result = tokenizer(example["question1"], example["question2"], truncation=True, max_length=args.src_len, padding=False)
        if "label" in example:
            result["labels"] = torch.tensor(example["label"], dtype=torch.long)
        return result
    
    dataset = load_dataset(args.data_path)
    train_dataset = dataset['train']
    eval_dataset = dataset["validation"]
    # preprocess dataset
    tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)
    tokenized_eval_dataset = eval_dataset.map(tokenize_function, batched=True, remove_columns=eval_dataset.column_names)
    # get the dataloader
    data_collator = DataCollatorWithPadding(
        tokenizer=tokenizer,
        padding='max_length',
        max_length=args.src_len, 
        return_tensors="pt"
        )

    # load the deepspeed config json file
    ds_config = json.load(open(args.ds_config_path))
    ds_config['train_micro_batch_size_per_gpu'] = args.train_micro_batch_size_per_gpu
    ds_config['gradient_accumulation_steps'] = args.gradient_accumulation_steps
    ds_config['zero_optimization']['offload_param']['device'] = args.offload_device
    ds_config['zero_optimization']['offload_optimizer']['device'] = args.offload_device
    if args.offload_device == "nvme":
        # here don't use os.path.mkdir, because it may cover the original directory.
        if not os.path.exists(args.nvme_path):
            raise ValueError(f"nvme path does not exist, please make directory {args.nvme_path} by yourself.")
        else:
            ds_config['zero_optimization']['offload_param']['nvme_path'] = args.nvme_path
            ds_config['zero_optimization']['offload_optimizer']['nvme_path'] = args.nvme_path

    if master_process:        
        print(f"DeepSpeed config: {ds_config}")
    
    # train
    training_args = TrainingArguments(
        output_dir=args.output_dir, 
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.train_micro_batch_size_per_gpu,
        per_device_eval_batch_size=args.train_micro_batch_size_per_gpu,  
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.max_lr, 
        weight_decay=args.weight_decay,
        adam_beta1=args.adam_beta1,
        adam_beta2=args.adam_beta2,
        warmup_ratio=0.1 if args.initial_lr < args.max_lr else 0,
        logging_dir=args.log_dir,
        logging_steps=10,  
        label_names=["labels"],
        save_strategy="steps",
        save_steps=6000,
        eval_strategy="steps",
        eval_steps=60,
        deepspeed=ds_config,
        fp16=ds_config.get("fp16", {}).get("enabled", False),
    )

    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = predictions.argmax(axis=-1)
        return {"accuracy": (predictions == labels).mean()}

    if args.tsqp == "true":
        pre_model = copy.deepcopy(model)
        pre_model = pre_model.to(device)
        if "gpt2" in args.model_name_or_path:
            pre_model.config.pad_token_id = tokenizer.pad_token_id
            trainer = CustomTrainer4GPT(
                    model=model,
                    pre_model = pre_model,
                    args=training_args,
                    train_dataset=tokenized_train_dataset,
                    eval_dataset=tokenized_eval_dataset,
                    tokenizer=tokenizer,
                    compute_metrics=compute_metrics,
                )  
        else:
            trainer = CustomTrainer(
                    model=model,
                    pre_model = pre_model,
                    args=training_args,
                    train_dataset=tokenized_train_dataset,
                    eval_dataset=tokenized_eval_dataset,
                    tokenizer=tokenizer,
                    compute_metrics=compute_metrics
                )
    else:
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_train_dataset, 
            eval_dataset=tokenized_eval_dataset,
            data_collator=data_collator, 
            compute_metrics=compute_metrics
        )

    trainer.train()

    if args.finetune_method == 'lora':
        merged_model = model.merge_and_unload()
        merged_model.save_pretrained(os.path.join(args.output_dir, "final_merged_model"))
        tokenizer.save_pretrained(os.path.join(args.output_dir, "final_merged_model"))
    else:
        trainer.save_model(os.path.join(args.output_dir, "final_model"))
    
if __name__ == "__main__":
    main()