import os
import sys
current_file_path = os.path.dirname(os.path.abspath(__file__))
module_path = os.path.join(current_file_path, "../")
sys.path.append(module_path)
from dataclasses import asdict
import math
from pathlib import Path
from typing import List, Optional
import yaml
import random
from accelerate.utils import DistributedType
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

import torch
import torch_npu                              # 适配 npu
from torch_npu.contrib import transfer_to_npu # 适配 npu
import torch.distributed as dist
import transformers
from transformers import Trainer, deepspeed


from arguments import ModelArguments, DataArguments, TrainingArguments, LoraArguments
from collators import COLLATORS
from dataset.datasets_mbeir_with_hard_neg import LazySupervisedDataset
from loaders import LOADERS
from supported_models import MODULE_KEYWORDS
from utils import (
    rank0_print, find_all_linear_names, safe_save_model_for_hf_trainer,
    get_peft_state_maybe_zero_3, TrainerWithCustomSampler
)

def train():
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments, LoraArguments)
    )
    
    model_args, data_args, training_args, lora_args = parser.parse_args_into_dataclasses()

    # dumping arguments
    output_dir = getattr(training_args, 'output_dir', None)
    assert output_dir is not None, "output_dir is required"
    args_dir = Path(output_dir) / "arguments"
    args_dir.mkdir(parents=True, exist_ok=True)
    yaml.dump(asdict(model_args), open(args_dir / "model.yaml", "w"))
    yaml.dump(asdict(data_args), open(args_dir / "data.yaml", "w"))
    yaml.dump(asdict(training_args), open(args_dir / "training.yaml", "w"))
    yaml.dump(asdict(lora_args), open(args_dir / "lora.yaml", "w"))

    compute_dtype = (torch.float16 if training_args.fp16 else (torch.bfloat16 if training_args.bf16 else torch.float32))
    if getattr(training_args, 'deepspeed', None) and getattr(lora_args, 'q_lora', False):
        training_args.distributed_state.distributed_type = DistributedType.DEEPSPEED

    device_map = None
    if lora_args.q_lora:
        device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if int(os.environ.get("WORLD_SIZE", 1)) != 1 else None
        if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
            raise ValueError("FSDP or ZeRO3 are not incompatible with QLoRA.")

    # llm quantization config (for q-lora)
    bnb_config = None
    if lora_args.use_lora and lora_args.q_lora:
        from transformers import BitsAndBytesConfig
        rank0_print("Quantization for LLM enabled...")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_quant_type="nf4", 
        )
    
    # load model, tokenizer, processor
    rank0_print("Loading model, tokenizer, processor...")
    loader = LOADERS[model_args.model_family_id](
        model_hf_path=model_args.model_hf_path,
        model_local_path=model_args.model_local_path,
        compute_dtype=compute_dtype,
        bnb_config=bnb_config,
        use_flash_attn=training_args.use_flash_attn,
        device_map=device_map,
    )

    # 训练加载模型时，需要传入 pretrain=False 使用 qwen-vl-finetune 模型
    # 修改模型的配置，必须在这里进行修改，这个参数才会真的生效
    model, tokenizer, processor = loader.load(pretrain=False,hard_neg = True)
    tokenizer.model_max_length = training_args.model_max_length
    model.mean_pooling = model_args.mean_pooling                                    # 是否进行全局平均池化
    model.use_bi_atten = model_args.use_bi_atten                                    # 是否使用双向注意力
    model.use_latent_atten = model_args.use_latent_atten                            # 是否使用潜在注意力模块
    model.use_instruction_mask = model_args.use_instruction_mask                    # 是否使用指令mask
    model.use_bi_loss = model_args.use_bi_loss                                      # 是否使用双向损失
    model.use_isotropy_loss = model_args.use_isotropy_loss                          # 是否使用各向同性损失
    model.use_self_attent_pooling = model_args.use_self_attent_pooling              # 是否使用自注意力池化
    model.use_cross_entropy_loss = model_args.use_cross_entropy_loss                # 是否使用交叉熵损失
    model.use_focal_infonce_loss = model_args.use_focal_infonce_loss                # 是否使用 FocalInfoNCELoss
    model.use_focal_infonce_abs_loss = model_args.use_focal_infonce_abs_loss        # 是否使用 FocalInfoNCEABSLoss
    model.use_diht_loss = model_args.use_diht_loss                                  # 是否使用 DIHTLoss
    model.use_llave_loss = model_args.use_llave_loss                                # 是否使用 LLaVELoss
    model.use_softcse_weight_loss = model_args.use_softcse_weight_loss              # 是否使用 SoftCSEWeightLoss
    model.use_softcse_temperature_loss = model_args.use_softcse_temperature_loss    # 是否使用 SoftCSETemperatureLoss
    model.topk_hard_negative = model_args.topk_hard_negative                        # topk hard negative 数量
    model.topk_modality_hard_negative = model_args.topk_modality_hard_negative      # topk modality hard negative 数量
    model.ignore_batch_other_samples = model_args.ignore_batch_other_samples        # 是否忽略其他样本的损失
    model.use_feature_constraint = model_args.use_feature_constraint                # 是否使用特征约束
    model.use_rerank_scores = model_args.use_rerank_scores                          # 是否使用重排序分数

    # 验证有且仅有一个损失函数被启用
    assert model.use_cross_entropy_loss + model.use_focal_infonce_loss + model.use_diht_loss + \
        model.use_llave_loss + model.use_softcse_weight_loss + model.use_softcse_temperature_loss + model.use_focal_infonce_abs_loss == 1, \
        "Only one loss function can be set to True."
    model._initialize_loss_functions()  # 不使用交叉熵损失函数，必须执行重新初始化损失函数

    # 根据 use_self_attent_pooling 参数决定是否初始化 latent_attention 模块
    # 使用 latent_attention 模块时，需要在模型初始化时进行初始化 
    if model.use_latent_atten:
        model._initialize_latent_attention()  # 初始化 latent_attention 模块
        assert model.latent_attention is not None, "Latent attention module is not initialized"
    
    rank0_print("*"*50)
    rank0_print("模型的基础架构：",model)
    rank0_print("*"*50)
    rank0_print("model_args.mean_pooling: ",model_args.mean_pooling, type(model_args.mean_pooling))
    rank0_print("model_args.use_bi_atten: ",model_args.use_bi_atten, type(model_args.use_bi_atten))
    rank0_print("model_args.use_latent_atten: ",model_args.use_latent_atten, type(model_args.use_latent_atten))
    rank0_print("model_args.use_instruction_mask: ",model_args.use_instruction_mask, type(model_args.use_instruction_mask))
    rank0_print("model_args.use_bi_loss: ",model_args.use_bi_loss, type(model_args.use_bi_loss))
    rank0_print("model_args.use_isotropy_loss: ",model_args.use_isotropy_loss, type(model_args.use_isotropy_loss))
    rank0_print("model_args.use_self_attent_pooling: ",model_args.use_self_attent_pooling, type(model_args.use_self_attent_pooling))
    rank0_print("model_args.use_cross_entropy_loss: ",model_args.use_cross_entropy_loss, type(model_args.use_cross_entropy_loss))
    rank0_print("model_args.use_focal_infonce_loss: ",model_args.use_focal_infonce_loss, type(model_args.use_focal_infonce_loss))
    rank0_print("model_args.use_focal_infonce_abs_loss: ",model_args.use_focal_infonce_abs_loss, type(model_args.use_focal_infonce_abs_loss))
    rank0_print("model_args.use_diht_loss: ",model_args.use_diht_loss, type(model_args.use_diht_loss))
    rank0_print("model_args.use_llave_loss: ",model_args.use_llave_loss, type(model_args.use_llave_loss))
    rank0_print("model_args.use_softcse_weight_loss: ",model_args.use_softcse_weight_loss, type(model_args.use_softcse_weight_loss))
    rank0_print("model_args.use_softcse_temperature_loss: ",model_args.use_softcse_temperature_loss, type(model_args.use_softcse_temperature_loss)) 
    rank0_print("model_args.topk_hard_negative: ",model_args.topk_hard_negative, type(model_args.topk_hard_negative))
    rank0_print("model_args.topk_modality_hard_negative: ",model_args.topk_modality_hard_negative, type(model_args.topk_modality_hard_negative))  
    rank0_print("model_args.ignore_batch_other_samples: ",model_args.ignore_batch_other_samples, type(model_args.ignore_batch_other_samples))
    rank0_print("model_args.use_feature_constraint: ",model_args.use_feature_constraint, type(model_args.use_feature_constraint))
    rank0_print("model_args.use_rerank_scores: ",model_args.use_rerank_scores, type(model_args.use_rerank_scores))
    rank0_print("*"*50)
    # debug --------------------------------------------------------------------------------------------
    rank0_print("模型初始化完成————————————————————————————————————————————————————————————————————————")
    rank0_print("model 对象所属的类对象： ", model.__class__)
    rank0_print("model 对象的类型： ", type(model))
    rank0_print("mean_pooling: ",model.mean_pooling)
    rank0_print("use_bi_atten: ",model.use_bi_atten,)
    rank0_print("use_latent_atten: ",model.use_latent_atten,)
    rank0_print("use_instruction_mask: ",model.use_instruction_mask)
    rank0_print("use_bi_loss: ",model.use_bi_loss)
    rank0_print("use_isotropy_loss: ",model.use_isotropy_loss)
    rank0_print("use_self_attent_pooling: ",model.use_self_attent_pooling)
    rank0_print("use_cross_entropy_loss: ",model.use_cross_entropy_loss)
    rank0_print("use_focal_infonce_loss: ",model.use_focal_infonce_loss)
    rank0_print("use_focal_infonce_abs_loss: ",model.use_focal_infonce_abs_loss)
    rank0_print("use_diht_loss: ",model.use_diht_loss)
    rank0_print("use_llave_loss: ",model.use_llave_loss)
    rank0_print("use_softcse_weight_loss: ",model.use_softcse_weight_loss)
    rank0_print("use_softcse_temperature_loss: ",model.use_softcse_temperature_loss)
    rank0_print("topk_hard_negative: ",model.topk_hard_negative)
    rank0_print("topk_modality_hard_negative: ",model.topk_modality_hard_negative)
    rank0_print("ignore_batch_other_samples: ",model.ignore_batch_other_samples)
    rank0_print("use_feature_constraint: ",model.use_feature_constraint)
    rank0_print("use_rerank_scores: ",model.use_rerank_scores)
    rank0_print("model.loss_fct 对象的类的名称： ", model.loss_fct.__class__.__name__)
    rank0_print("是否训练温度参数 ", training_args.train_temperature)  # 温度参数
    rank0_print("-"*50)
    rank0_print("模型的配置信息： ", model.config)
    rank0_print("-"*50)
    # ---------------------------------------------------------------------------------------------------

    if training_args.gradient_checkpointing:
        model.enable_input_require_grads()

    # freeze certain params
    vision_encoder_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_encoder"]
    if not training_args.train_vision_encoder:
        rank0_print(f"Vision encoder is freezed... including:")
        for module in vision_encoder_keys:
            rank0_print(f"\t{module}")
            eval(f"model.{module}").requires_grad_(False)

    vision_projector_keys = MODULE_KEYWORDS[model_args.model_family_id]["vision_projector"]
    if not training_args.train_vision_projector:
        rank0_print(f"Vision projector is freezed... including:")
        for module in vision_projector_keys:
            rank0_print(f"\t{module}")
            eval(f"model.{module}").requires_grad_(False)
    
    temperature_keys = MODULE_KEYWORDS[model_args.model_family_id]["temperature"]
    if not training_args.train_temperature:
        rank0_print(f"temperature is freezed... including:")
        for module in temperature_keys:
            rank0_print(f"\t{module}")
            eval(f"model.{module}").requires_grad_(False)

    # other components preparation (e.g., image_newline, vision_resampler)
    # we will just freeze these
    if "others" in MODULE_KEYWORDS[model_args.model_family_id]:
        rank0_print(f"Other multimodal component is freezed... including:")
        for other_key in MODULE_KEYWORDS[model_args.model_family_id]["others"]:
            rank0_print(f"\t{other_key}")
            eval(f"model.{other_key}").requires_grad_(False)

    # lora preparation
    llm_keys = MODULE_KEYWORDS[model_args.model_family_id]["llm"]
    if not (lora_args.use_lora or (training_args.train_vision_encoder and lora_args.use_vision_lora)):
        rank0_print("No LoRA enabled...")        
    else:
        named_modules = {n: m for n, m in model.named_modules()}
        lora_modules = []
        full_modules = []

        if training_args.train_vision_encoder and lora_args.use_vision_lora:
            rank0_print("LoRA for vision encoder enabled...")
            lora_modules.extend(find_all_linear_names(named_modules, vision_encoder_keys))
        elif training_args.train_vision_encoder:
            rank0_print("Vision encoder will be fully trained...")
            full_modules.extend(vision_encoder_keys)
        
        if lora_args.use_lora:
            rank0_print("LoRA for LLM enabled...")
            lora_modules.extend(find_all_linear_names(named_modules, llm_keys))
        else:
            rank0_print("LLM will be fully trained...")
            full_modules.extend(llm_keys)
        
        if training_args.train_vision_projector:
            rank0_print("Vision projector will be fully trained...")
            full_modules.extend(vision_projector_keys)
        
        # 是否训练温度参数
        if training_args.train_temperature:
            rank0_print("temperature will be fully trained...")
            full_modules.extend(temperature_keys)
        
        lora_config = LoraConfig(
            r=lora_args.lora_r,
            lora_alpha=lora_args.lora_alpha,
            target_modules=lora_modules,
            modules_to_save=full_modules,
            lora_dropout=lora_args.lora_dropout,
            bias=lora_args.lora_bias,
            task_type=lora_args.task_type,
        )
        # 打印 LoRA 配置-------------------------------------
        rank0_print(f"LoRA config: {lora_config.task_type}")
        # --------------------------------------------------
        if lora_args.q_lora:
            model = prepare_model_for_kbit_training(
                model, use_gradient_checkpointing=training_args.gradient_checkpointing
            )
            
        model = get_peft_model(model, lora_config)
        
    # print trainable parameters for inspection
    rank0_print("Trainable parameters:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            rank0_print(f"\t{name}")

    # load data
    rank0_print("Loading data...")
    train_dataset = LazySupervisedDataset(
        query_data_path=data_args.query_data_path,
        cand_pool_path=data_args.cand_pool_path,
        instructions_path=data_args.instructions_path,
        image_path_prefix=data_args.image_path_prefix,
        tokenizer=tokenizer,
        has_instruction = data_args.has_instruction,                                # M-BEIR 默认存在指令,代表是否进行指令匹配的计算
        use_instruction_token = data_args.use_instruction_token,                    # M-BEIR 默认不使用指令特殊 token，代表是否使用指令特殊 token
        has_hard_negative = data_args.has_hard_negative,                            # 是否使用 hard negative
        has_modality_hard_negative = data_args.has_modality_hard_negative,          # 是否使用 modality hard negative
        topk_hard_negative = model_args.topk_hard_negative,                         # topk hard negative 数量
        topk_modality_hard_negative = model_args.topk_modality_hard_negative,       # topk modality hard negative 数量
        hard_negative_path=data_args.hard_negative_path,                            # hard negative 数据集路径
        modality_hard_negative_path=data_args.modality_hard_negative_path,          # modality hard negative 数据集路径
        rerank_scores_path=data_args.rerank_scores_path,                            # 重排序分数数据集路径
        query_feature_path=data_args.query_feature_path,                            # query 特征约束数据集路径
        cand_feature_path=data_args.cand_feature_path,                              # cand 特征约束数据集路径
        has_rerank_scores=model_args.use_rerank_scores,                             # 是否使用重排序分数
        has_feature_constraint=model_args.use_feature_constraint,                   # 是否使用特征约束
    )

    # 打印当前使用数据集使用的 prompt 和 是否存在指令-------------------------------------
    rank0_print("当前使用的 LazySupervisedDataset 来自 dataset/datasets_mbeir_with_hard_neg.py")
    rank0_print("LazySupervisedDataset 对象的名称： ", train_dataset.__class__)
    rank0_print("prompt:", train_dataset.prompt)
    rank0_print("has_instruction: ", train_dataset.has_instruction)
    rank0_print("use_instruction_token: ", train_dataset.use_instruction_token)
    rank0_print("has_hard_negative: ", train_dataset.has_hard_negative)
    rank0_print("has_modality_hard_negative: ", train_dataset.has_modality_hard_negative)
    rank0_print("query_data 的长度: ", len(train_dataset.query_data))
    rank0_print("cand_pool 的长度: ", len(train_dataset.cand_pool))
    rank0_print("topk_hard_negative: ", train_dataset.topk_hard_negative)
    rank0_print("topk_modality_hard_negative: ", train_dataset.topk_modality_hard_negative)
    rank0_print("has_rerank_scores: ", train_dataset.has_rerank_scores)
    rank0_print("has_feature_constraint: ", train_dataset.has_feature_constraint)
    # -----------------------------------------------------------

    eval_dataset = None
    training_args.eval_strategy = "no"

    # data collator
    data_collator = COLLATORS[model_args.model_family_id](
        tokenizer=tokenizer,
        processor=processor,
        has_instruction=train_dataset.has_instruction, # M-BEIR 数据集存在指令
        use_instruction_token=train_dataset.use_instruction_token, # M-BEIR 数据集使用指令特殊 token
        has_hard_negative=train_dataset.has_hard_negative, # M-BEIR 数据集使用 hard negative
        has_modality_hard_negative=train_dataset.has_modality_hard_negative, # M-BEIR 数据集使用 modality hard negative
        has_rerank_scores=train_dataset.has_rerank_scores, # M-BEIR 数据集使用重排序分数
        has_feature_constraint=train_dataset.has_feature_constraint, # M-BEIR 数据集使用特征约束
    )

    # 打印当前 data_collator 是否存在指令以及模型是否使用 指令mask -------------------------------------
    rank0_print("has_instruction: ", data_collator.has_instruction)
    rank0_print("PAD_TOKEN_ID: ",data_collator.PAD_TOKEN_ID)
    rank0_print("IGNORE_TOKEN_ID: ",data_collator.IGNORE_TOKEN_ID)
    # -----------------------------------------------------------
    training_args.gradient_checkpointing_kwargs = {"use_reentrant": False} # add this one
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
    )

    trainer.train()
    trainer.save_state()
    safe_save_model_for_hf_trainer(trainer=trainer, output_dir=output_dir)
    
    # 独立保存 latent_attention 模块
    if dist.is_initialized():
        if dist.get_rank() == 0 and model.use_latent_atten:
            latent_path = os.path.join(output_dir, "latent_attention.bin")
            torch.save(model.latent_attention.state_dict(), latent_path)

if __name__ == "__main__":
    train()