import os
import sys
# 获取当前文件的绝对路径
current_file_path = os.path.dirname(os.path.abspath(__file__))
# 计算上级目录的路径，用于将其添加到系统路径中，以便可以导入上级目录中的模块
module_path = os.path.join(current_file_path, "../")
# 将上级目录添加到系统路径中
sys.path.append(module_path)

from dataclasses import asdict
import math
from pathlib import Path
from typing import List, Optional
import yaml

from accelerate.utils import DistributedType
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

import torch
# 导入华为昇腾NPU相关的PyTorch扩展，用于适配NPU设备
import torch_npu                             
# 从torch_npu.contrib中导入transfer_to_npu，用于将数据转移到NPU设备
from torch_npu.contrib import transfer_to_npu 

import transformers
from transformers import Trainer, deepspeed

from arguments import ModelArguments, DataArguments, TrainingArguments, LoraArguments # 从arguments模块中导入自定义的参数类
from collators import COLLATORS                           # 从collators模块中导入数据整理器字典
from dataset.datasets_mbeir import LazySupervisedDataset  # 从 dataset.datasets_mbeir 模块中导入自定义的数据集类
from loaders import LOADERS                               # 从 loaders 模块中导入模型加载器字典
from supported_models import MODULE_KEYWORDS              # 从 supported_models 模块中导入支持的模型的模块关键字字典
from utils import (                                       # 从utils模块中导入一些工具函数和自定义的训练器类
    rank0_print, find_all_linear_names, safe_save_model_for_hf_trainer,
    get_peft_state_maybe_zero_3, TrainerWithCustomSampler
)


def train():
    # 创建一个 HfArgumentParser 对象，用于解析命令行参数
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
    )
    
    # 解析命令行参数并将其转换为对应的 dataclass 实例
    model_args, data_args, training_args, lora_args = parser.parse_args_into_dataclasses()

    # 将参数保存到 yaml 文件中，方便后续查看和复现
    output_dir = getattr(training_args, 'output_dir', None)  # 获取训练输出目录
    assert output_dir is not None, "output_dir is required"  # 确保输出目录存在
    
    args_dir = Path(output_dir) / "arguments"                # 创建用于保存参数的目录
    args_dir.mkdir(parents=True, exist_ok=True)              # 创建目录，若父目录不存在则一并创建，若目录已存在则不报错
    
    yaml.dump(asdict(model_args), open(args_dir / "model.yaml", "w"))  # 将模型参数保存到model.yaml文件中
    yaml.dump(asdict(data_args), open(args_dir / "data.yaml", "w"))    # 将数据参数保存到data.yaml文件中
    yaml.dump(asdict(training_args), open(args_dir / "training.yaml", "w")) # 将训练参数保存到training.yaml文件中
    yaml.dump(asdict(lora_args), open(args_dir / "lora.yaml", "w"))  # 将LoRA参数保存到lora.yaml文件中

    # 根据训练参数中的混合精度设置，确定计算数据类型
    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
    
    # 如果使用了 DeepSpeed 且启用了 QLoRA，则将分布式训练类型设置为 DeepSpeed
    if getattr(training_args, 'deepspeed', None) and getattr(lora_args, 'q_lora', False):
        training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED

    # 初始化设备映射，用于指定模型在不同设备上的分布
    device_map = None
    
    if lora_args.q_lora: # 如果启用了 QLoRA 根据环境变量确定设备映射
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if int(os.environ.get("WORLD_SIZE", 1)) != 1 else None
        # 如果使用了 FSDP 或 DeepSpeed ZeRO3，抛出错误，因为它们与 QLoRA 不兼容
        if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
            raise ValueError("FSDP or ZeRO3 are not incompatible with QLoRA.")

    # 初始化大语言模型量化配置
    bnb_config = None
    if lora_args.use_lora and lora_args.q_lora:        # 如果启用了 LoRA 且启用了 QLoRA
        from transformers import BitsAndBytesConfig
        rank0_print("Quantization for LLM enabled...") # 打印信息，表示启用了大语言模型量化
        bnb_config = BitsAndBytesConfig(               # 配置 BitsAndBytesConfig，用于 4 位量化
            load_in_4bit=True,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_quant_type="nf4", 
        )
    
    rank0_print("Loading model, tokenizer, processor...") # 加载模型、分词器和处理器
    # 根据模型家族 ID 选择对应的加载器
    loader = LOADERS[model_args.model_family_id](
        model_hf_path=model_args.model_hf_path,
        model_local_path=model_args.model_local_path,
        compute_dtype=compute_dtype,
        bnb_config=bnb_config,
        use_flash_attn=training_args.use_flash_attn,
        device_map=device_map,
    )
    # 调用加载器的load方法加载模型、分词器和处理器
    model, tokenizer, processor = loader.load(pretrain=False)
    # 设置分词器的最大输入长度
    tokenizer.model_max_length = training_args.model_max_length

    # 如果启用了梯度检查点，则启用模型的输入梯度计算
    if training_args.gradient_checkpointing:
        model.enable_input_require_grads()

    # 冻结某些参数
    # 获取视觉编码器的模块关键字
    vision_encoder_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_encoder"]
    # 如果不训练视觉编码器
    if not training_args.train_vision_encoder:
        rank0_print(f"Vision encoder is freezed... including:")
        # 遍历视觉编码器的模块关键字
        for module in vision_encoder_keys:
            rank0_print(f"\t{module}")
            # 冻结对应的模块参数
            eval(f"model.{module}").requires_grad_(False)

    # 获取视觉投影器的模块关键字
    vision_projector_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_projector"]
    # 如果不训练视觉投影器
    if not training_args.train_vision_projector:
        rank0_print(f"Vision projector is freezed... including:")
        # 遍历视觉投影器的模块关键字
        for module in vision_projector_keys:
            rank0_print(f"\t{module}")
            # 冻结对应的模块参数
            eval(f"model.{module}").requires_grad_(False)

    # 处理其他多模态组件，将其参数冻结
    if "others" in MODULE_KEYWORDS[model_args.model_family_id]:
        rank0_print(f"Other multimodal component is freezed... including:")
        # 遍历其他多模态组件的关键字
        for other_key in MODULE_KEYWORDS[model_args.model_family_id]["others"]:
            rank0_print(f"\t{other_key}")
            # 冻结对应的模块参数
            eval(f"model.{other_key}").requires_grad_(False)

    # LoRA配置
    # 获取大语言模型的模块关键字
    llm_keys = MODULE_KEYWORDS[model_args.model_family_id]["llm"]
    # 如果没有启用LoRA且不训练视觉编码器或不启用视觉LoRA
    if not (lora_args.use_lora or (training_args.train_vision_encoder and lora_args.use_vision_lora)):
        rank0_print("No LoRA enabled...")        
    else:
        # 获取模型的所有命名模块
        named_modules = {n: m for n, m in model.named_modules()}
        # 初始化需要应用LoRA的模块列表
        lora_modules = []
        # 初始化需要完全训练的模块列表
        full_modules = []

        # 如果训练视觉编码器且启用了视觉LoRA
        if training_args.train_vision_encoder and lora_args.use_vision_lora:
            rank0_print("LoRA for vision encoder enabled...")
            # 找到视觉编码器中所有线性层的名称并添加到lora_modules中
            lora_modules.extend(find_all_linear_names(named_modules, vision_encoder_keys))
        # 如果只训练视觉编码器
        elif training_args.train_vision_encoder:
            rank0_print("Vision encoder will be fully trained...")
            # 将视觉编码器的模块关键字添加到full_modules中
            full_modules.extend(vision_encoder_keys)
        
        # 如果启用了LoRA
        if lora_args.use_lora:
            rank0_print("LoRA for LLM enabled...")
            # 找到大语言模型中所有线性层的名称并添加到lora_modules中
            lora_modules.extend(find_all_linear_names(named_modules, llm_keys))
        else:
            rank0_print("LLM will be fully trained...")
            # 将大语言模型的模块关键字添加到full_modules中
            full_modules.extend(llm_keys)
        
        # 如果训练视觉投影器
        if training_args.train_vision_projector:
            rank0_print("Vision projector will be fully trained...")
            # 将视觉投影器的模块关键字添加到full_modules中
            full_modules.extend(vision_projector_keys)
        
        # 配置LoRA
        lora_config = LoraConfig(
            r=lora_args.lora_r,
            lora_alpha=lora_args.lora_alpha,
            target_modules=lora_modules,
            modules_to_save=full_modules,
            lora_dropout=lora_args.lora_dropout,
            bias=lora_args.lora_bias,
            task_type="CAUSAL_LM",
        )

        # 如果启用了QLoRA
        if lora_args.q_lora:
            # 为模型准备4位训练
            model = prepare_model_for_kbit_training(
                model, use_gradient_checkpointing=training_args.gradient_checkpointing
            )
            
        # 将LoRA配置应用到模型上
        model = get_peft_model(model, lora_config)
        
    # 打印可训练的参数，方便检查
    rank0_print("Trainable parameters:")
    # 遍历模型的所有命名参数
    for name, param in model.named_parameters():
        # 如果参数需要梯度更新
        if param.requires_grad:
            rank0_print(f"\t{name}")

    # 加载数据
    rank0_print("Loading data...")
    # 创建训练数据集
    train_dataset = LazySupervisedDataset(
        query_data_path=data_args.query_data_path,
        cand_pool_path=data_args.cand_pool_path,
        instructions_path=data_args.instructions_path,
        image_path_prefix=data_args.image_path_prefix,
        tokenizer=tokenizer 
    )
    
    # 初始化评估数据集为None
    eval_dataset = None
    # 设置评估策略为不进行评估
    training_args.eval_strategy = "no"

    # 创建数据整理器
    data_collator = COLLATORS[model_args.model_family_id](
        tokenizer=tokenizer,
        processor=processor,
    )

    # 设置梯度检查点的参数
    training_args.gradient_checkpointing_kwargs = {"use_reentrant": False} 
    # 创建训练器
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset, 
    )
    
    # 开始训练
    trainer.train()
    # 保存训练状态
    trainer.save_state()

    # 安全地保存模型
    safe_save_model_for_hf_trainer(trainer=trainer, output_dir=output_dir)
    

if __name__ == "__main__":
    # 调用train函数开始训练
    train()