import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import argparse
import deepspeed
import os
import json
import copy
from transformers import AutoTokenizer, AutoModelForCausalLM, EarlyStoppingCallback
from trl import SFTTrainer, SFTConfig
from trl.trainer.sft_trainer import dft_loss

from torch.utils.data import Subset

def build_prompt(example):
    question = example['question']
    context = example['context']
    long_answer = example['long_answer']
    final_decision = example['final_decision']
    # Clear, minimal formatting
    system_message = (
        "You are a helpful AI assistant designed to answer biomedical research questions based on provided abstracts from scientific papers.\n"
        "Your task is to carefully read the 'CONTEXT' (which is an abstract from a PubMed article) and then answer the 'QUESTION' regarding that context.\n"
        "Instead of jumping to conclusions, you must carefully "
        "think through the evidence step by step before giving the final answer.\n"
        "And the final answer must be 'yes' or 'no' or 'maybe'"
    )
    user_message = (   
        "## CONTEXT (PubMed Abstract):\n"
        f"{context}\n\n"
        "## QUESTION:\n"
        f"{question}"
    )
    assistant_messgae = (
        f"{long_answer}\n\n"
        f"Answer: {final_decision}\n"
    )

    return {'messages': [
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message},
        {"role": "assistant", "content": assistant_messgae}
    ]}
    
def compute_custom_losses(model, pre_model):
    device = next(model.parameters()).device
    loss1 = torch.tensor(0.0, device=device, requires_grad=True)
    loss2 = torch.tensor(0.0, device=device, requires_grad=True)
    
    pre_state_dict = pre_model.state_dict()
    
    for name, param in model.named_parameters():
        if any(keyword in name for keyword in [
            "self_attn.q_proj.weight", "self_attn.k_proj.weight", 
            "self_attn.v_proj.weight", "self_attn.o_proj.weight",
            "mlp.gate_proj.weight", "mlp.up_proj.weight", "mlp.down_proj.weight"
        ]):
            if param.numel() == 0:
                continue
                
            loss1 = loss1 + torch.sum(param.float() ** 2)
            
            pre_param_name = name.replace("module.", "")
            if pre_param_name in pre_state_dict:
                pre_data = pre_state_dict[pre_param_name].to(device)
                loss2 = loss2 + torch.sum((param - pre_data) ** 2)
    
    return loss1, torch.sqrt(loss2 + 1e-8)

class CustomTrainer(SFTTrainer):
    def __init__(self, pre_model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pre_model = pre_model
        
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        loss_1, loss_2 = compute_custom_losses(model, self.pre_model)
        if return_outputs:
            loss, outputs = super().compute_loss(model, inputs, return_outputs=return_outputs, num_items_in_batch=num_items_in_batch)
        else:
            loss = super().compute_loss(model, inputs, return_outputs=return_outputs, num_items_in_batch=num_items_in_batch)
        loss = loss + 1e-6*loss_1 - 1e-6*loss_2
        return (loss, outputs) if return_outputs else loss

    
def parse_args():
    parser = argparse.ArgumentParser(description='DeepSpeed ZeRO')

    # which model you tend to finetuing
    parser.add_argument('--model_name_or_path', type=str, required=True, help='model name or path, you can also pass the path of model you want to finetune')
    parser.add_argument('--src_len', type=int, default=512, help='max source sentence length')
    parser.add_argument('--tgt_len', type=int, default=128, help='max target sentence length')

    # dataset params
    parser.add_argument('--data_path', type=str, default='data/pubmedqa/split', help='Path to the training dataset.')

    # typical params
    parser.add_argument('--train_micro_batch_size_per_gpu', type=int, default=8, help='batch size per gpu')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=32, help='gradient accumulation steps')
    parser.add_argument('--max_lr', type=float, default=1e-3, help='max learning rate')
    parser.add_argument('--initial_lr', type=float, default=1e-6, help='initial learning rate')
    parser.add_argument('--min_lr', type=float, default=1e-8, help='min learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.01, help='weight decay')
    parser.add_argument('--adam_beta1', type=float, default=0.9, help='adam beta1')
    parser.add_argument('--adam_beta2', type=float, default=0.999, help='adam beta2')
    parser.add_argument('--fused', action='store_true', help='whether to use fused optimizer, if you can load all prameters of the model on a single gpu.')
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
    parser.add_argument('--output_dir', type=str, default='./checkpoints/',help='save dir')

    # log dir of tensorboard
    parser.add_argument('--log_dir', type=str, default='./logs', help='log dir')

    # which fintuning method to use
    parser.add_argument('--finetune_method', type=str, default='lora', help='finetune method, support parameters lora, freeze, full-tuning')

    # freeze modules
    parser.add_argument('--freeze_modules', type=str, default='dense_h_to_4h', help='the layer of model you wanna freeze')

    # LoRA params
    parser.add_argument('--lora_alpha', type=int, default=32, help='alpha for LoRA')
    parser.add_argument('--lora_dropout', type=float, default=0.05, help='dropout probability for LoRA')
    parser.add_argument('--lora_target_modules', type=str, default='query_key_value', help='target modules for LoRA, the name of layer in model you wanna use LoRA')
    parser.add_argument('--lora_r', type=int, default=8, help='r for LoRA')

    # deepspeed params
    parser.add_argument('--ds_config_path', type=str, default='./config/ds_config.json', help='path to deepspeed config file')
    parser.add_argument('--offload_device', type=str, default='cpu', help='offload device, cpu or nvme, which mean you want to offload the model to cpu memory or nvme ssd')
    parser.add_argument('--nvme_path', type=str, default='./mnt/nvme', help='path to nvme ssd')
    parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
    parser.add_argument('--global_rank', type=int, default=-1, help='global rank passed from distributed launcher')
    parser.add_argument('--stage', type=int, default=2, help='the stage of zero_optimization')
    parser = deepspeed.add_config_arguments(parser)

    # tsqp params
    parser.add_argument('--tsqp', default='false', type=str, help='Whether to use TSQP')
    
    args = parser.parse_args()

    return args

def main():
    args = parse_args()

    # init deepspeed
    if args.local_rank == -1:
        r"""
        when you don't use deepspeed to train your own model with a single gpu, you need to modify the 
        train loop according to pytorch traditional grammer. when arg.local_rank == -1, mean you are using 
        pytorch to train your model with a single gpu.
        """
        device = torch.device("cuda")
    else:
        device = torch.device("cuda", args.local_rank)
        torch.cuda.set_device(args.local_rank)
        deepspeed.init_distributed()
        torch.distributed.barrier()
    master_process = (args.local_rank == 0)

    # get model, and you can add other finetuning methods here, eg: prefix tuning, P-tuning, etc.
    if args.finetune_method == "full-tuning":
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    elif args.finetune_method == "lora":
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True)
        lora_target_modules = args.lora_target_modules.split(",")
        print(lora_target_modules)
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            # task_type=TaskType.SEQ_CLS,
            r=args.lora_r, 
            lora_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
            target_modules=lora_target_modules,
            inference_mode=False
            )
        model = get_peft_model(model, lora_config)
    elif args.finetune_method == "freeze":
        freeze_modules = args.freeze_modules.split(",")
        model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True)
        for name, param in model.named_parameters():
            if any(item in name for item in freeze_modules): 
                param.requires_grad = False
    else:
        raise ValueError("Invalid finetune method")
    
    if args.fused:
        model = model.to(device)

    # print the number of the model parameters
    if master_process:
        print(model)
        print(f"Local rank: {args.local_rank} \n Total number of parameters: {sum(p.numel() for p in model.parameters())}")    

    # get the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id
    
    datasets = load_dataset("json", data_dir=args.data_path)
    train_ds = datasets["train"]
    # train_ds = datasets["train"].train_test_split(test_size=0.99, seed=42)['train']
    dev_ds = datasets["validation"]
    
    train_ds = train_ds.map(build_prompt, remove_columns=train_ds.column_names)
    dev_ds = dev_ds.map(build_prompt, remove_columns=dev_ds.column_names)
    

    # load the deepspeed config json file
    ds_config = json.load(open(args.ds_config_path))
    ds_config['train_micro_batch_size_per_gpu'] = args.train_micro_batch_size_per_gpu
    ds_config['gradient_accumulation_steps'] = args.gradient_accumulation_steps
    ds_config['zero_optimization']['offload_param']['device'] = args.offload_device
    ds_config['zero_optimization']['offload_optimizer']['device'] = args.offload_device
    ds_config['zero_optimization']['stage'] = args.stage
    if args.offload_device == "nvme":
        # here don't use os.path.mkdir, because it may cover the original directory.
        if not os.path.exists(args.nvme_path):
            raise ValueError(f"nvme path does not exist, please make directory {args.nvme_path} by yourself.")
        else:
            ds_config['zero_optimization']['offload_param']['nvme_path'] = args.nvme_path
            ds_config['zero_optimization']['offload_optimizer']['nvme_path'] = args.nvme_path

    if master_process:        
        print(f"DeepSpeed config: {ds_config}")
    
    # train
    training_args = SFTConfig(
        output_dir=args.output_dir, 
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.train_micro_batch_size_per_gpu,
        per_device_eval_batch_size=args.train_micro_batch_size_per_gpu,  
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        max_length=args.src_len,
        learning_rate=args.max_lr, 
        weight_decay=args.weight_decay,
        adam_beta1=args.adam_beta1,
        adam_beta2=args.adam_beta2,
        warmup_ratio=0.1 if args.initial_lr < args.max_lr else 0,
        logging_dir=args.log_dir,
        logging_steps=1,  
        label_names=["labels"],
        save_strategy="epoch",
        save_total_limit=2,
        # eval_strategy="steps",
        eval_strategy="epoch",
        metric_for_best_model="eval_loss",
        deepspeed=ds_config,
        fp16=ds_config.get("fp16", {}).get("enabled", False),
        bf16=ds_config.get("bf16", {}).get("enabled", False),
    )

    if args.tsqp == "true":
        pre_model = copy.deepcopy(model)
        pre_model = pre_model.to(device)
        trainer = CustomTrainer(
                    model=model,
                    pre_model=pre_model,
                    processing_class=tokenizer,
                    args=training_args,
                    train_dataset=train_ds, 
                    eval_dataset=dev_ds,
                )
    else:
        trainer = SFTTrainer(
            model=model,
            processing_class=tokenizer,
            args=training_args,
            train_dataset=train_ds, 
            eval_dataset=dev_ds,
        )

    trainer.train()

    if args.finetune_method == 'lora':
        merged_model = model.merge_and_unload()
        merged_model.save_pretrained(os.path.join(args.output_dir, "final_merged_model"))
        tokenizer.save_pretrained(os.path.join(args.output_dir, "final_merged_model"))
    else:
        trainer.save_model(os.path.join(args.output_dir, "final_model"))
    
if __name__ == "__main__":
    main()