import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import argparse
import deepspeed
import os
import json
import copy
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, Trainer, TrainingArguments,EarlyStoppingCallback
from transformers.modeling_outputs import SequenceClassifierOutput
import numpy as np
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score
from torch.utils.data import Subset

def loss1(model):
    loss = 0
    for name, param in model.named_parameters():
        if "self_attn.q_proj.weight" in name or "self_attn.k_proj.weight" in name or "self_attn.v_proj.weight" in name or "self_attn.o_proj.weight" in name  or \
            "mlp.gate_proj.weight" in name or "mlp.up_proj.weight" in name or "mlp.down_proj.weight" in name:
            loss += torch.sum(param.float() ** 2)
    return loss

def loss2(model, pre_model):
    loss = 0
    for name, param in model.named_parameters():
        if "self_attn.q_proj.weight" in name or "self_attn.k_proj.weight" in name or "self_attn.v_proj.weight" in name or "self_attn.o_proj.weight" in name or \
            "mlp.gate_proj.weight" in name or "mlp.up_proj.weight" in name or "mlp.down_proj.weight" in name:
            name = name.replace("module.", "")
            pre_data = pre_model.state_dict()[name]
            # pre_data = pre_data.to(param.device)
            loss += torch.sum((param - pre_data) ** 2)
    return torch.sqrt(loss+1e-8)

def loss1_lora(model, lora_modules):
    loss = 0.0
    model = model.module if hasattr(model, 'module') else model
    merged_model = copy.deepcopy(model).merge_and_unload()
    for name, param in merged_model.named_parameters():
        for lora_module in lora_modules:
            if lora_module in name:
                loss += torch.sum(param.float() ** 2)
    return loss

def loss2_lora(model, lora_modules):
    loss = 0.0
    model = model.module if hasattr(model, 'module') else model
    merged_model = copy.deepcopy(model).merge_and_unload()
    for name, param in merged_model.named_parameters():
        for lora_module in lora_modules:
            if lora_module in name:
                name = "base_model.model." + name.replace("module.", "").replace(".weight", ".base_layer.weight")
                pre_data = model.state_dict()[name]
                # pre_data = pre_data.to(param.device)
                loss += torch.sum((param - pre_data) ** 2)
    return torch.sqrt(loss+1e-8)

class CustomTrainer(Trainer):
    def __init__(self, pre_model, is_lora, lora_modules, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.pre_model = pre_model
        self.is_lora = is_lora
        self.lora_modules = lora_modules
        self.loss_fn = nn.BCEWithLogitsLoss()
        
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        outputs = model(**inputs)
        logits = outputs.logits
        labels = inputs["labels"]
        loss_1 = 1e-6*loss1_lora(model, self.lora_modules) if self.is_lora else 1e-6*loss1(model)
        loss_2 = 1e-4*loss2_lora(model, self.lora_modules) if self.is_lora else 1e-4*loss2(model, self.pre_model)
        loss = self.loss_fn(logits, labels)+loss_1-loss_2
        return (loss, outputs) if return_outputs else loss

class MultilabelDataCollator(DataCollatorWithPadding):
    def __call__(self, features):
        batch = super().__call__(features)
        
        if 'labels' in batch:
            if isinstance(batch['labels'], list):
                batch['labels'] = torch.tensor(batch['labels'], dtype=torch.float32)
            else:
                batch['labels'] = batch['labels'].to(dtype=torch.float32)
        return batch
    
def parse_args():
    parser = argparse.ArgumentParser(description='DeepSpeed ZeRO')

    # which model you tend to finetuing
    parser.add_argument('--model_name_or_path', type=str, required=True, help='model name or path, you can also pass the path of model you want to finetune')
    parser.add_argument('--src_len', type=int, default=512, help='max source sentence length')
    parser.add_argument('--tgt_len', type=int, default=128, help='max target sentence length')

    # dataset params
    parser.add_argument('--data_path', type=str, required=True, help='Path to the training dataset.')

    # typical params
    parser.add_argument('--train_micro_batch_size_per_gpu', type=int, default=8, help='batch size per gpu')
    parser.add_argument('--gradient_accumulation_steps', type=int, default=32, help='gradient accumulation steps')
    parser.add_argument('--max_lr', type=float, default=1e-3, help='max learning rate')
    parser.add_argument('--initial_lr', type=float, default=1e-6, help='initial learning rate')
    parser.add_argument('--min_lr', type=float, default=1e-8, help='min learning rate')
    parser.add_argument('--weight_decay', type=float, default=0.01, help='weight decay')
    parser.add_argument('--adam_beta1', type=float, default=0.9, help='adam beta1')
    parser.add_argument('--adam_beta2', type=float, default=0.999, help='adam beta2')
    parser.add_argument('--fused', action='store_true', help='whether to use fused optimizer, if you can load all prameters of the model on a single gpu.')
    parser.add_argument('--epochs', type=int, default=100, help='number of epochs')
    parser.add_argument('--output_dir', type=str, default='./checkpoints/',help='save dir')

    # log dir of tensorboard
    parser.add_argument('--log_dir', type=str, default='./logs', help='log dir')

    # which fintuning method to use
    parser.add_argument('--finetune_method', type=str, default='lora', help='finetune method, support parameters lora, freeze, full-tuning')

    # freeze modules
    parser.add_argument('--freeze_modules', type=str, default='dense_h_to_4h', help='the layer of model you wanna freeze')

    # LoRA params
    parser.add_argument('--lora_alpha', type=int, default=32, help='alpha for LoRA')
    parser.add_argument('--lora_dropout', type=float, default=0.05, help='dropout probability for LoRA')
    parser.add_argument('--lora_target_modules', type=str, default='query_key_value', help='target modules for LoRA, the name of layer in model you wanna use LoRA')
    parser.add_argument('--lora_r', type=int, default=8, help='r for LoRA')

    # deepspeed params
    parser.add_argument('--ds_config_path', type=str, default='./config/ds_config.json', help='path to deepspeed config file')
    parser.add_argument('--offload_device', type=str, default='cpu', help='offload device, cpu or nvme, which mean you want to offload the model to cpu memory or nvme ssd')
    parser.add_argument('--nvme_path', type=str, default='./mnt/nvme', help='path to nvme ssd')
    parser.add_argument('--local_rank', type=int, default=-1, help='local rank passed from distributed launcher')
    parser.add_argument('--global_rank', type=int, default=-1, help='global rank passed from distributed launcher')
    parser = deepspeed.add_config_arguments(parser)

    # tsqp params
    parser.add_argument('--tsqp', default='false', type=str, help='Whether to use TSQP')
    
    args = parser.parse_args()

    return args

def main():
    args = parse_args()

    # init deepspeed
    if args.local_rank == -1:
        r"""
        when you don't use deepspeed to train your own model with a single gpu, you need to modify the 
        train loop according to pytorch traditional grammer. when arg.local_rank == -1, mean you are using 
        pytorch to train your model with a single gpu.
        """
        device = torch.device("cuda")
    else:
        device = torch.device("cuda", args.local_rank)
        torch.cuda.set_device(args.local_rank)
        deepspeed.init_distributed()
        torch.distributed.barrier()
    master_process = (args.local_rank == 0)

    # get model, and you can add other finetuning methods here, eg: prefix tuning, P-tuning, etc.
    if args.finetune_method == "full-tuning":
        model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, num_labels=28, problem_type="multi_label_classification")
    elif args.finetune_method == "lora":
        model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, num_labels=28, problem_type="multi_label_classification")
        lora_target_modules = args.lora_target_modules.split(",")
        print(lora_target_modules)
        lora_config = LoraConfig(
            # task_type=TaskType.CAUSAL_LM,
            task_type=TaskType.SEQ_CLS,
            r=args.lora_r, 
            lora_alpha=args.lora_alpha, 
            lora_dropout=args.lora_dropout,
            target_modules=lora_target_modules,
            inference_mode=False
            )
        model = get_peft_model(model, lora_config)
        # print(model)
    elif args.finetune_method == "freeze":
        freeze_modules = args.freeze_modules.split(",")
        model = AutoModelForSequenceClassification.from_pretrained(args.model_name_or_path, num_labels=28, problem_type="multi_label_classification")
        for name, param in model.named_parameters():
            if any(item in name for item in freeze_modules): 
                param.requires_grad = False
    else:
        raise ValueError("Invalid finetune method")
    
    if args.fused:
        model = model.to(device)

    # print the number of the model parameters
    if master_process:
        print(model)
        print(f"Local rank: {args.local_rank} \n Total number of parameters: {sum(p.numel() for p in model.parameters())}")

    # get the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.pad_token_id
    def tokenize_function(examples):
        tokenized = tokenizer(
            examples['text'],
            truncation=True,
            padding=True,
            max_length=args.src_len,
            return_tensors="pt"
        )
        
        labels = torch.zeros((len(examples['labels']), 28), dtype=torch.float32)
        for i, label_indices in enumerate(examples['labels']):
            labels[i, label_indices] = 1.0
        
        tokenized['labels'] = labels

        return tokenized
    
    dataset = load_dataset(args.data_path)
    train_dataset = dataset['train']
    eval_dataset = dataset["validation"]
    # preprocess dataset
    tokenized_train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)
    # tokenized_train_dataset = Subset(tokenized_train_dataset, torch.randperm(len(dataset['train'])).tolist()[:len(dataset['train'])//100])
    tokenized_eval_dataset = eval_dataset.map(tokenize_function, batched=True, remove_columns=eval_dataset.column_names)
    # get the dataloader
    data_collator = MultilabelDataCollator(tokenizer=tokenizer,
        padding='max_length',
        max_length=args.src_len, 
        return_tensors="pt"
    )

    # load the deepspeed config json file
    ds_config = json.load(open(args.ds_config_path))
    ds_config['train_micro_batch_size_per_gpu'] = args.train_micro_batch_size_per_gpu
    ds_config['gradient_accumulation_steps'] = args.gradient_accumulation_steps
    ds_config['zero_optimization']['offload_param']['device'] = args.offload_device
    ds_config['zero_optimization']['offload_optimizer']['device'] = args.offload_device
    if args.offload_device == "nvme":
        # here don't use os.path.mkdir, because it may cover the original directory.
        if not os.path.exists(args.nvme_path):
            raise ValueError(f"nvme path does not exist, please make directory {args.nvme_path} by yourself.")
        else:
            ds_config['zero_optimization']['offload_param']['nvme_path'] = args.nvme_path
            ds_config['zero_optimization']['offload_optimizer']['nvme_path'] = args.nvme_path

    if master_process:        
        print(f"DeepSpeed config: {ds_config}")
    
    # train
    training_args = TrainingArguments(
        output_dir=args.output_dir, 
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.train_micro_batch_size_per_gpu,
        per_device_eval_batch_size=args.train_micro_batch_size_per_gpu,  
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        learning_rate=args.max_lr, 
        weight_decay=args.weight_decay,
        adam_beta1=args.adam_beta1,
        adam_beta2=args.adam_beta2,
        warmup_ratio=0.1 if args.initial_lr < args.max_lr else 0,
        logging_dir=args.log_dir,
        logging_steps=10,  
        label_names=["labels"],
        save_strategy="steps",
        save_steps=50,
        # eval_strategy="epoch",
        eval_strategy="steps",
        eval_steps=50,
        deepspeed=ds_config,
        fp16=ds_config.get("fp16", {}).get("enabled", False),
        bf16=ds_config.get("bf16", {}).get("enabled", False),
        # fp16=False
        metric_for_best_model="macro_f1",
        load_best_model_at_end=True,
        greater_is_better=True,
        save_total_limit=4
    )

    def compute_metrics(p):
        """计算多标签分类的评估指标"""
        try:
            preds = p.predictions
            if isinstance(preds, tuple):
                preds = preds[0]
            
            probs = torch.sigmoid(torch.tensor(preds)).numpy()
            y_pred = (probs >= 0.5).astype(np.int32)  
            
            y_true = p.label_ids.astype(np.int32) 
            
            micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
            macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
            samples_f1 = f1_score(y_true, y_pred, average='samples', zero_division=0)
            
            precision = precision_score(y_true, y_pred, average='samples', zero_division=0)
            recall = recall_score(y_true, y_pred, average='samples', zero_division=0)
            accuracy = accuracy_score(y_true, y_pred)
            
            return {
                'micro_f1': micro_f1,
                'macro_f1': macro_f1,
                'samples_f1': samples_f1,
                'precision': precision,
                'recall': recall,
                'accuracy': accuracy,
            }
        except Exception as e:
            print(f"计算评估指标时出错: {e}")
            return {
                'micro_f1': 0.0,
                'macro_f1': 0.0,
                'samples_f1': 0.0,
                'precision': 0.0,
                'recall': 0.0,
                'accuracy': 0.0,
            }

    callbacks = [
        EarlyStoppingCallback(
            early_stopping_patience=3,
            early_stopping_threshold=0
        )
    ]

    if args.tsqp == "true":
        is_lora = args.finetune_method == 'lora'
        if is_lora:
            pre_model = None
        else:
            pre_model = copy.deepcopy(model)
            pre_model = pre_model.to(device)
        trainer = CustomTrainer(
                pre_model = pre_model,
                is_lora = is_lora,
                lora_modules=lora_target_modules if is_lora else None,
                model=model,
                args=training_args,
                train_dataset=tokenized_train_dataset, 
                eval_dataset=tokenized_eval_dataset,
                data_collator=data_collator, 
                compute_metrics=compute_metrics,
                callbacks=callbacks,
            )
    else:
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=tokenized_train_dataset, 
            eval_dataset=tokenized_eval_dataset,
            data_collator=data_collator, 
            compute_metrics=compute_metrics,
            callbacks=callbacks,
        )

    trainer.train()

    if args.finetune_method == 'lora':
        merged_model = model.merge_and_unload()
        merged_model.save_pretrained(os.path.join(args.output_dir, "final_merged_model"))
        tokenizer.save_pretrained(os.path.join(args.output_dir, "final_merged_model"))
    else:
        trainer.save_model(os.path.join(args.output_dir, "final_model"))
    
if __name__ == "__main__":
    main()