import sys
from pathlib import Path

project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))

import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_from_disk
import argparse
import deepspeed
import os
import json
import copy
from transformers import AutoTokenizer, AutoModelForCausalLM, EarlyStoppingCallback
from trl import SFTTrainer, SFTConfig

from torch.utils.data import Subset
from training.utils.create_finqa_dataset import load_finqa_dataset


def compute_custom_losses(model, pre_model):
    device = next(model.parameters()).device
    loss1 = torch.tensor(0.0, device=device, requires_grad=True)
    loss2 = torch.tensor(0.0, device=device, requires_grad=True)
    
    pre_state_dict = pre_model.state_dict()
    
    for name, param in model.named_parameters():
        if any(keyword in name for keyword in [
            "self_attn.q_proj.weight", "self_attn.k_proj.weight", 
            "self_attn.v_proj.weight", "self_attn.o_proj.weight",
            "mlp.gate_proj.weight", "mlp.up_proj.weight", "mlp.down_proj.weight"
        ]):
            if param.numel() == 0:
                continue
                
            loss1 = loss1 + torch.sum(param.float() ** 2)
            
            pre_param_name = name.replace("module.", "")
            if pre_param_name in pre_state_dict:
                pre_data = pre_state_dict[pre_param_name].to(device)
                loss2 = loss2 + torch.sum((param - pre_data) ** 2)
    
    return loss1, torch.sqrt(loss2 + 1e-8)

class CustomTrainer(SFTTrainer):
    def __init__(self, pre_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pre_model = pre_model
        
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        loss_1, loss_2 = compute_custom_losses(model, self.pre_model)
        if return_outputs:
            loss, outputs = super().compute_loss(model, inputs, return_outputs=return_outputs, num_items_in_batch=num_items_in_batch)
        else:
            loss = super().compute_loss(model, inputs, return_outputs=return_outputs, num_items_in_batch=num_items_in_batch)
        loss = loss + 1e-6*loss_1 - 1e-6*loss_2
        return (loss, outputs) if return_outputs else loss

    
def parse_args():
    parser = argparse.ArgumentParser(description='DeepSpeed ZeRO')

    # which model you tend to finetuing
    parser.add_argument('--model_name_or_path', type=str, required=True, help='model name or path, you can also pass the path of model you want to finetune')
    parser.add_argument('--src_len', type=int, default=512, help='max source sentence length')
    parser.add_argument('--tgt_len', type=int, default=128, help='max target sentence length')

    # dataset params
    parser.add_argument('--data_path', type=str, required=True, help='Path to the training dataset.')

    # typical params
    parser.add_argument('--train_micro_batch_size_per_gpu', type=int, default=8, help='batch size per gpu')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=32, help='gradient accumulation steps')
    parser.add_argument('--max_lr', type=float, default=1e-3, help='max learning rate')
    parser.add_argument('--initial_lr', type=float, default=1e-6, help='initial learning rate')
    parser.add_argument('--min_lr', type=float, default=1e-8, help='min learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.01, help='weight decay')
    parser.add_argument('--adam_beta1', type=float, default=0.9, help='adam beta1')
    parser.add_argument('--adam_beta2', type=float, default=0.999, help='adam beta2')
    parser.add_argument('--fused', action='store_true', help='whether to use fused optimizer, if you can load all prameters of the model on a single gpu.')
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
    parser.add_argument('--output_dir', type=str, default='./checkpoints/',help='save dir')

    # log dir of tensorboard
    parser.add_argument('--log_dir', type=str, default='./logs', help='log dir')

    # which fintuning method to use
    parser.add_argument('--finetune_method', type=str, default='lora', help='finetune method, support parameters lora, freeze, full-tuning')

    # freeze modules
    parser.add_argument('--freeze_modules', type=str, default='dense_h_to_4h', help='the layer of model you wanna freeze')

    # LoRA params
    parser.add_argument('--lora_alpha', type=int, default=32, help='alpha for LoRA')
    parser.add_argument('--lora_dropout', type=float, default=0.05, help='dropout probability for LoRA')
    parser.add_argument('--lora_target_modules', type=str, default='query_key_value', help='target modules for LoRA, the name of layer in model you wanna use LoRA')
    parser.add_argument('--lora_r', type=int, default=8, help='r for LoRA')

    # deepspeed params
    parser.add_argument('--ds_config_path', type=str, default='./config/ds_config.json', help='path to deepspeed config file')
    parser.add_argument('--offload_device', type=str, default='cpu', help='offload device, cpu or nvme, which mean you want to offload the model to cpu memory or nvme ssd')
    parser.add_argument('--nvme_path', type=str, default='./mnt/nvme', help='path to nvme ssd')
    parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
    parser.add_argument('--global_rank', type=int, default=-1, help='global rank passed from distributed launcher')
    parser.add_argument('--stage', type=int, default=2, help='the stage of zero_optimization')
    parser = deepspeed.add_config_arguments(parser)

    # tsqp params
    parser.add_argument('--tsqp', default='false', type=str, help='Whether to use TSQP')
    
    args = parser.parse_args()

    return args

def main():
    args = parse_args()

    # init deepspeed
    if args.local_rank == -1:
        r"""
        when you don't use deepspeed to train your own model with a single gpu, you need to modify the 
        train loop according to pytorch traditional grammer. when arg.local_rank == -1, mean you are using 
        pytorch to train your model with a single gpu.
        """
        device = torch.device("cuda")
    else:
        device = torch.device("cuda", args.local_rank)
        torch.cuda.set_device(args.local_rank)
        deepspeed.init_distributed()
        torch.distributed.barrier()
    master_process = (args.local_rank == 0)

    # get model, and you can add other finetuning methods here, eg: prefix tuning, P-tuning, etc.
    if args.finetune_method == "full-tuning":
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    elif args.finetune_method == "lora":
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True)
        lora_target_modules = args.lora_target_modules.split(",")
        print(lora_target_modules)
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            # task_type=TaskType.SEQ_CLS,
            r=args.lora_r, 
            lora_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
            target_modules=lora_target_modules,
            inference_mode=False
            )
        model = get_peft_model(model, lora_config)
    elif args.finetune_method == "freeze":
        freeze_modules = args.freeze_modules.split(",")
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True)
        for name, param in model.named_parameters():
            if any(item in name for item in freeze_modules): 
                param.requires_grad = False
    else:
        raise ValueError("Invalid finetune method")
    
    if args.fused:
        model = model.to(device)

    # print the number of the model parameters
    if master_process:
        print(model)
        print(f"Local rank: {args.local_rank} \n Total number of parameters: {sum(p.numel() for p in model.parameters())}")    

    # get the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id
    
    train_ds = load_finqa_dataset(os.path.join(args.data_path, "train.json"))
    # train_ds = Subset(train_ds, torch.randperm(len(train_ds)).tolist()[:len(train_ds)//100])
    dev_ds   = load_finqa_dataset(os.path.join(args.data_path, "dev.json"))

    # load the deepspeed config json file
    ds_config = json.load(open(args.ds_config_path))
    ds_config['train_micro_batch_size_per_gpu'] = args.train_micro_batch_size_per_gpu
    ds_config['gradient_accumulation_steps'] = args.gradient_accumulation_steps
    ds_config['zero_optimization']['offload_param']['device'] = args.offload_device
    ds_config['zero_optimization']['offload_optimizer']['device'] = args.offload_device
    ds_config['zero_optimization']['stage'] = args.stage
    if args.offload_device == "nvme":
        # here don't use os.path.mkdir, because it may cover the original directory.
        if not os.path.exists(args.nvme_path):
            raise ValueError(f"nvme path does not exist, please make directory {args.nvme_path} by yourself.")
        else:
            ds_config['zero_optimization']['offload_param']['nvme_path'] = args.nvme_path
            ds_config['zero_optimization']['offload_optimizer']['nvme_path'] = args.nvme_path

    if master_process:        
        print(f"DeepSpeed config: {ds_config}")
    
    # train
    training_args = SFTConfig(
        output_dir=args.output_dir, 
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.train_micro_batch_size_per_gpu,
        per_device_eval_batch_size=args.train_micro_batch_size_per_gpu,  
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        max_length=args.src_len,
        learning_rate=args.max_lr, 
        weight_decay=args.weight_decay,
        adam_beta1=args.adam_beta1,
        adam_beta2=args.adam_beta2,
        warmup_ratio=0.1 if args.initial_lr < args.max_lr else 0,
        logging_dir=args.log_dir,
        logging_steps=10,  
        label_names=["labels"],
        save_strategy="steps",
        save_steps=30,
        eval_strategy="steps",
        # eval_strategy="epoch",
        eval_steps=30,
        metric_for_best_model="eval_loss",
        load_best_model_at_end=True,
        save_total_limit=2,
        deepspeed=ds_config,
        fp16=ds_config.get("fp16", {}).get("enabled", False),
        bf16=ds_config.get("bf16", {}).get("enabled", False),
    )

    callbacks = [
        EarlyStoppingCallback(
            early_stopping_patience=2,
            early_stopping_threshold=0.001
        )
    ]

    if args.tsqp == "true":
        pre_model = copy.deepcopy(model)
        pre_model = pre_model.to(device)
        trainer = CustomTrainer(
                    pre_model=pre_model,
                    model=model,
                    processing_class=tokenizer,
                    args=training_args,
                    train_dataset=train_ds, 
                    eval_dataset=dev_ds,
                    callbacks=callbacks,
                )
    else:
        trainer = SFTTrainer(
            model=model,
            processing_class=tokenizer,
            args=training_args,
            train_dataset=train_ds, 
            eval_dataset=dev_ds,
            callbacks=callbacks,
        )

    trainer.train()

    if args.finetune_method == 'lora':
        merged_model = model.merge_and_unload()
        merged_model.save_pretrained(os.path.join(args.output_dir, "final_merged_model"))
        tokenizer.save_pretrained(os.path.join(args.output_dir, "final_merged_model"))
    else:
        trainer.save_model(os.path.join(args.output_dir, "final_model"))
    
if __name__ == "__main__":
    main()