"""
多模态模型微调脚本，支持 LoRA 和 QLoRA，适用于 NPU 加速。
主要功能：加载模型与数据，配置训练参数，进行高效微调，并保存结果。
"""

import os
import sys
current_file_path = os.path.dirname(os.path.abspath(__file__))
module_path = os.path.join(current_file_path, "../")
sys.path.append(module_path)  # 添加自定义模块路径

from dataclasses import asdict
import math
from pathlib import Path
from typing import List, Optional
import yaml  # 用于 YAML 配置文件处理

# 深度学习相关库
from accelerate.utils import DistributedType
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training  # 参数高效微调工具
import torch
import torch_npu  # 适配 NPU
from torch_npu.contrib import transfer_to_npu  # 将张量转移到 NPU

# Transformers库组件
import transformers
from transformers import Trainer, deepspeed
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from torch.utils.data import DataLoader

# 自定义模块
from arguments import ModelArguments, DataArguments, TrainingArguments, LoraArguments  # 训练参数配置
from collators.mbeir_rerank import MbeirRerankDataCollator                             # 数据整理器
from dataset.datasets_mbeir_rerank import LazySupervisedDataset  # 惰性加载数据集
from loaders.qwen2_vl_raw import Qwen2VLModelLoader              # 模型加载器
from supported_models import MODULE_KEYWORDS                     # 模型结构关键字
from utils import (
    rank0_print, find_all_linear_names, safe_save_model_for_hf_trainer,
    get_peft_state_maybe_zero_3                                  # 工具函数
)
# deepseek 自定义一个 Trainer 用来实现多机采样
class CustomTrainer(Trainer):
    def __init__(self, world_size, global_rank, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.world_size = world_size
        self.global_rank = global_rank

    def get_train_dataloader(self):
        sampler = DistributedSampler(
            self.train_dataset,
            num_replicas=self.world_size,
            rank=self.global_rank,
            shuffle=True
        )
        return DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=sampler,
            collate_fn=self.data_collator,
            drop_last=self.args.dataloader_drop_last,
        )

# # chatgpt 的 定义方法
# class CustomTrainer(Trainer):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.world_size = self.args.world_size
#         self.global_rank = self.args.process_index

#     def get_train_dataloader(self):
#         if self.train_dataset is None:
#             raise ValueError("Trainer: training requires a train_dataset.")

#         # 根据是否分布式训练选择采样器
#         if self.args.local_rank != -1:
#             sampler = DistributedSampler(
#                 self.train_dataset,
#                 num_replicas=self.world_size,
#                 rank=self.global_rank,
#                 shuffle=True
#             )
#         else:
#             sampler = RandomSampler(self.train_dataset)

#         return DataLoader(
#             self.train_dataset,
#             batch_size=self.args.per_device_train_batch_size,
#             sampler=sampler,
#             collate_fn=self.data_collator,
#             drop_last=self.args.dataloader_drop_last,
#             num_workers=self.args.dataloader_num_workers,
#             pin_memory=self.args.dataloader_pin_memory,
#         )

def train():
        
    """主训练流程，包含参数解析、模型加载、数据准备、训练配置等步骤"""
    # 解析命令行参数
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
    )
    model_args, data_args, training_args, lora_args = parser.parse_args_into_dataclasses()
    
    # 在 train() 函数中获取全局 rank 和 world size -------------------------------
    if dist.is_initialized():
        # local_rank = training_args.local_rank
        global global_rank, world_size, local_rank
        global_rank = int(os.environ.get("RANK", 0))
        world_size = int(os.environ.get("WORLD_SIZE", 1))
        local_rank = int(os.environ.get("LOCAL_RANK", 0))

    # 参数持久化 ----------------------------------------------------------------
    output_dir = getattr(training_args, 'output_dir', None)
    assert output_dir is not None, "必须指定输出目录"
    args_dir = Path(output_dir) / "arguments"
    args_dir.mkdir(parents=True, exist_ok=True)
    # 将各参数类转为字典并保存为YAML文件
    yaml.dump(asdict(model_args), open(args_dir / "model.yaml", "w"))
    yaml.dump(asdict(data_args), open(args_dir / "data.yaml", "w"))
    yaml.dump(asdict(training_args), open(args_dir / "training.yaml", "w"))
    yaml.dump(asdict(lora_args), open(args_dir / "lora.yaml", "w"))

    # 计算精度配置 --------------------------------------------------------------
    # 根据训练参数中的 fp16 和 bf16 选项确定计算时使用的数据类型
    compute_dtype = (
        torch.float16 if training_args.fp16 
        else (torch.bfloat16 if training_args.bf16 else torch.float32)
    )
    
    # 分布式训练配置检查
     # 如果使用了 deepspeed 且启用了 QLoRA
    if getattr(training_args, 'deepspeed', None) and getattr(lora_args, 'q_lora', False):
        # 将分布式训练类型设置为 DEEPSPEED
        training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED

    # 设备映射配置（ QLoRA 模式需要特殊处理）
    device_map = None
    if lora_args.q_lora:
        # 如果是多机多卡训练，根据 LOCAL_RANK 确定设备映射
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if int(os.environ.get("WORLD_SIZE", 1)) != 1 else None
        # FSDP 或 ZeRO3 与 QLoRA 不兼容，若使用则抛出异常
        if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
            raise ValueError("FSDP 或 ZeRO3 与 QLoRA 不兼容")

    # QLoRA 量化配置 ------------------------------------------------------------
    bnb_config = None
    if lora_args.use_lora and lora_args.q_lora:
        from transformers import BitsAndBytesConfig
        rank0_print("启用 LLM 量化配置...")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,                     # 4bit量化加载
            bnb_4bit_compute_dtype=compute_dtype,  # 计算时数据类型
            bnb_4bit_quant_type="nf4",             # 量化类型
        )
    
    # 加载模型/分词器/处理器 ----------------------------------------------------
    rank0_print("加载模型、分词器、处理器...")
    loader = Qwen2VLModelLoader(
        model_hf_path=model_args.model_hf_path,        # HuggingFace模型路径
        model_local_path=model_args.model_local_path,  # 本地模型路径
        compute_dtype=compute_dtype,                   # 计算精度
        bnb_config=bnb_config,                         # 量化配置
        use_flash_attn=training_args.use_flash_attn,   # 是否使用 FlashAttention
        device_map=device_map,                         # 设备映射
    )
    model, tokenizer, processor = loader.load()        # 加载三件套
    # model = model.to(f"npu:{local_rank}") # 显示设置 NPU 设备
    tokenizer.model_max_length = training_args.model_max_length  # 设置分词器最大长度

    # 梯度检查点配置
    if training_args.gradient_checkpointing:
        model.enable_input_require_grads()  # 启用模型的梯度检查点功能

    # 冻结视觉编码器参数 --------------------------------------------------------
    # 获取当前模型家族对应的视觉编码器模块关键字列表
    vision_encoder_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_encoder"]
    if not training_args.train_vision_encoder:
        rank0_print(f"冻结视觉编码器参数，包含模块:")
        for module in vision_encoder_keys:
            rank0_print(f"\t{module}")
            eval(f"model.{module}").requires_grad_(False)  # 动态执行冻结操作

    # 冻结视觉投影器参数 --------------------------------------------------------
    vision_projector_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_projector"]
    if not training_args.train_vision_projector:
        rank0_print(f"冻结视觉投影器参数，包含模块:")
        for module in vision_projector_keys:
            rank0_print(f"\t{module}")
            eval(f"model.{module}").requires_grad_(False) # 动态执行冻结操作，将指定模块的参数设置为不可训练

    # 冻结其他多模态组件 --------------------------------------------------------
    if "others" in MODULE_KEYWORDS[model_args.model_family_id]:
        rank0_print(f"冻结其他多模态组件:")
        for other_key in MODULE_KEYWORDS[model_args.model_family_id]["others"]:
            rank0_print(f"\t{other_key}")
            eval(f"model.{other_key}").requires_grad_(False)

    # LoRA配置部分 -------------------------------------------------------------
    # 获取当前模型家族对应的大语言模型模块关键字列表
    llm_keys = MODULE_KEYWORDS[model_args.model_family_id]["llm"]
    if not (lora_args.use_lora or (training_args.train_vision_encoder and lora_args.use_vision_lora)):
        rank0_print("未启用LoRA...")        
    else:
        # 获取所有命名模块
        named_modules = {n: m for n, m in model.named_modules()}  
        lora_modules = [] # 应用 LoRA 的模块列表
        full_modules = [] # 全参数训练的模块列表

        # 视觉编码器LoRA配置
        if training_args.train_vision_encoder and lora_args.use_vision_lora:
            rank0_print("启用视觉编码器LoRA...")
            
            # 找到视觉编码器中所有需要应用 LoRA 的线性层
            lora_modules.extend(find_all_linear_names(named_modules, vision_encoder_keys))
            
        elif training_args.train_vision_encoder:
            rank0_print("视觉编码器将全参数训练...")
            
            # 将视觉编码器模块添加到全参数训练列表
            full_modules.extend(vision_encoder_keys)
        
        # LLM 部分 LoRA 配置
        if lora_args.use_lora:
            rank0_print("启用 LLM 的 LoRA...")
            # 找到 LLM 中所有需要应用 LoRA 的线性层
            lora_modules.extend(find_all_linear_names(named_modules, llm_keys))
        else:
            rank0_print("LLM 将全参数训练...")
            # 将 LLM 模块添加到全参数训练列表
            full_modules.extend(llm_keys)
        
        # 视觉投影器配置
        if training_args.train_vision_projector:
            rank0_print("视觉投影器将全参数训练...")
            # 将视觉投影器模块添加到全参数训练列表
            full_modules.extend(vision_projector_keys)
        
        # 创建 LoRA 配置对象
        lora_config = LoraConfig(
            r=lora_args.lora_r,               # LoRA秩
            lora_alpha=lora_args.lora_alpha,  # 缩放系数
            target_modules=lora_modules,      # 目标模块列表
            modules_to_save=full_modules,         # 全参数训练模块
            lora_dropout=lora_args.lora_dropout,  # Dropout率
            bias=lora_args.lora_bias,  # 偏置项处理方式
            task_type="CAUSAL_LM",     # 任务类型
        )

        # QLoRA 特殊处理
        if lora_args.q_lora:
            model = prepare_model_for_kbit_training(
                model, 
                use_gradient_checkpointing=training_args.gradient_checkpointing
            )
        # 应用 LoRA 配置到模型    
        model = get_peft_model(model, lora_config)
        
    # # 在模型加载后，进行同步，这个也不确定是否需要--------------
    # if dist.is_available() and dist.is_initialized():
    #     dist.barrier()
    
    # 打印可训练参数 ------------------------------------------------------------
    rank0_print("可训练参数列表:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            rank0_print(f"\t{name}")

    # 数据加载部分 --------------------------------------------------------------
    rank0_print("加载训练数据...")
    # 使用 rank 和 world_size 在单机 8 卡上面训练导致数据量变成原来的 1/8 -- 2603（total = 20825）
    # 不知到使用 rank 和 world_szie 究竟问题出在哪里了，最有可能的是脚本已经自动把数据集拆分好了，又做了一次拆分
    # 将这个脚本恢复成 train_rerank.py 的样子 就是跑不出来结果，老是报错 ！！！
    train_dataset = LazySupervisedDataset(
        query_data_path=data_args.query_data_path,  # 查询数据路径
        cand_pool_path=data_args.cand_pool_path,    # 候选池路径
        instructions_path=data_args.instructions_path,  # 指令文件路径
        rerank_data_path=data_args.rerank_data_path,    # 重排数据路径
        image_path_prefix=data_args.image_path_prefix,  # 图像路径前缀
        tokenizer=tokenizer,            # 分词器
        # rank = global_rank,           # 适配多机多卡训练
        # world_size = world_size       # 适配多机多卡训练
    )
    print("数据加载完成————————————————————————————————————————————————")
    # 数据整理器配置
    data_collator = MbeirRerankDataCollator(
        tokenizer=tokenizer,
        processor=processor,  # 图像处理器
    )
    
    # 新增：设置分布式采样器，分布式采样器-----------------------------------
    # 如果指定了 rank 和 word_size 那么不需要使用分布式采样器
    # 如果没有指定 rank 和 word_size 那么需要指定 DistributedSampler
    model.floating_point_ops = lambda s: 0  # 禁用浮点运算计数（节省资源）
    trainer = Trainer(
        # world_size=world_size,
        # global_rank=global_rank,
        model=model,
        args=training_args,           # 训练参数
        data_collator=data_collator,  # 数据整理器
        train_dataset=train_dataset,  # 训练数据集
        eval_dataset=None,            # 无验证集
    )
    # 训练流程
    trainer.train()
    # 保存训练状态
    trainer.save_state()

    # 模型保存：safe_save_model_for_hf_trainer(trainer=trainer, output_dir=output_dir)
    # 修改 safe_save_model_for_hf_trainer 函数或直接使用 Trainer 的保存方法
    # if global_rank == 0: # 仅在 rank == 0 时 保存模型：

    safe_save_model_for_hf_trainer(trainer=trainer, output_dir=output_dir)
    # trainer.save_model(output_dir)  # 自动处理分布式保存，不确定是否需要进行这样的保存

    
if __name__ == "__main__":
    train()