import os
import logging
import pathlib
import torch
import transformers
import json
from typing import Dict
import shutil
import sys
from pathlib import Path

# 导入 LoRA 相关库
from peft import get_peft_model, LoraConfig, TaskType,PeftModel, PeftConfig

project_root = Path(__file__).parent.parent.parent
sys.path.append(str(project_root))

import qwenvl.train.trainer
from trainer import replace_qwen2_vl_attention_class

from transformers import (
    Qwen2VLForConditionalGeneration,
    Qwen2_5_VLForConditionalGeneration,
)
from qwenvl.data.data_qwen import make_supervised_data_module

from qwenvl.train.argument import (
    ModelArguments,
    DataArguments,
    TrainingArguments,
)
from transformers import AutoTokenizer, AutoProcessor, Qwen2VLImageProcessor, Trainer

local_rank = None


def rank0_print(*args):
    if local_rank == 0:
        print(*args)


def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
    """Collects the state dict and dump to disk."""
    if trainer.deepspeed:
        torch.cuda.synchronize()
        trainer.save_model(output_dir)
        return

    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa


def set_model(model_args, model):
    # 仅针对 LoRA 微调，配置模型的参数
    if model_args.tune_mm_vision:
        for n, p in model.visual.named_parameters():
            p.requires_grad = True
    else:
        for n, p in model.visual.named_parameters():
            p.requires_grad = False

    if model_args.tune_mm_mlp:
        for n, p in model.visual.merger.named_parameters():
            p.requires_grad = True
    else:
        for n, p in model.visual.merger.named_parameters():
            p.requires_grad = False

    if model_args.tune_mm_llm:
        for n, p in model.model.named_parameters():
            p.requires_grad = True
        model.lm_head.requires_grad = True
    else:
        for n, p in model.model.named_parameters():
            p.requires_grad = False
        model.lm_head.requires_grad = False

def set_model_lora(model_args, model):
    """
    根据三个 flag：
        tune_mm_vision – ViT 全微调
        tune_mm_mlp    – 只微调 Patch‑Merger MLP
        tune_mm_llm    – 微调 LLM 解码器
    返回一个 target_modules 列表，供 LoRA 使用
    """
    target_modules = []

    # ---------- 视觉 ViT ----------
    for n, p in model.visual.named_parameters():
        p.requires_grad = bool(model_args.tune_mm_vision)

    # ---------- Patch Merger ----------
    for n, p in model.visual.merger.named_parameters():
        p.requires_grad = bool(model_args.tune_mm_mlp)
    if model_args.tune_mm_mlp:
        # 收集 Merger 里的 Linear 层做 LoRA
        for name, module in model.visual.merger.named_modules():
            if isinstance(module, torch.nn.Linear):
                target_modules.append(name)         # 全路径唯一，不会误命中

    # ---------- LLM ----------
    for n, p in model.model.named_parameters():
        p.requires_grad = bool(model_args.tune_mm_llm)
    model.lm_head.requires_grad = bool(model_args.tune_mm_llm)
    if model_args.tune_mm_llm:
        # 举例：只对 decoder MLP 挂 LoRA
        for l in range(len(model.model.layers)):
            for proj in ("gate_proj", "up_proj", "down_proj"):
                target_modules.append(f"model.layers.{l}.mlp.{proj}")
    # 去重并返回
    return list(set(target_modules))
def load_model(model_args,training_args,data_args,attn_implementation="flash_attention_2"):
    # 加载模型并应用 LoRA 配置
    if "qwen2.5" in model_args.model_name_or_path.lower().split("/")[-1]:
        print("Loading Qwen2.5 model...")
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            attn_implementation=attn_implementation,
            torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
        )
        data_args.image_processor = AutoProcessor.from_pretrained(
            model_args.model_name_or_path,
        ).image_processor
        data_args.model_type = "qwen2.5vl"
    else:
        print("Loading Qwen2 model...")
        model = Qwen2VLForConditionalGeneration.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            attn_implementation=attn_implementation,
            torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
        )
        data_args.image_processor = Qwen2VLImageProcessor.from_pretrained(
            model_args.model_name_or_path,
        )
        data_args.model_type = "qwen2vl"
    return model,data_args

def train(attn_implementation="flash_attention_2"):
    global local_rank

    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    local_rank = training_args.local_rank
    os.makedirs(training_args.output_dir, exist_ok=True)
    if model_args.lora_model:
        peft_config = PeftConfig.from_pretrained(model_args.model_name_or_path)
        # === 加载 tokenizer（必须从 base 模型路径）===
        tokenizer = AutoTokenizer.from_pretrained(
            peft_config.base_model_name_or_path,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=False,
            truncation=True,
        )
        # === 加载 base 模型 + LoRA adapter ===
        base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
            peft_config.base_model_name_or_path,
            cache_dir=training_args.cache_dir,
            attn_implementation=attn_implementation,
            torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
        )
        model = PeftModel.from_pretrained(base_model, model_args.model_name_or_path)
        # === 设置可训练参数（可选）===
        model.print_trainable_parameters()
    else:
        model, data_args = load_model(model_args, training_args, data_args, attn_implementation=attn_implementation)
        print("get_base_model:\n", model)
        # ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        target_modules = set_model_lora(model_args, model)
        # 配置 LoRA 微调
        lora_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            target_modules=target_modules,
            r=8,
            lora_alpha=16,
            lora_dropout=0.05,
            bias="none",
        )
        # set_model(model_args, model)
        model = get_peft_model(model, lora_config)
        model.print_trainable_parameters()
        print("get_peft_model:\n", model)
        tokenizer = transformers.AutoTokenizer.from_pretrained(
            model_args.model_name_or_path,
            cache_dir=training_args.cache_dir,
            model_max_length=training_args.model_max_length,
            padding_side="right",
            use_fast=False,
            truncation=True,
        )

    if data_args.data_flatten:
        replace_qwen2_vl_attention_class()
    model.config.use_cache = False

    if training_args.gradient_checkpointing:
        model.gradient_checkpointing_enable()
        if hasattr(model, "enable_input_require_grads"):
            model.enable_input_require_grads()
        else:
            try:
                def make_inputs_require_grad(module, input, output):
                    output.requires_grad_(True)
                model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
            except:
                raise RuntimeError("Gradient checkpointing is enabled but model does not support input grads setup.")

    # if torch.distributed.get_rank() == 0:
        # model.visual.print_trainable_parameters()
        # model.model.print_trainable_parameters()
    rank0_print(f"正在加载数据集：{data_args}")
    data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)

    trainer = Trainer(
        model=model, processing_class=tokenizer, args=training_args, **data_module
    )

    if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
        logging.info("checkpoint found, resume training")
        trainer.train(resume_from_checkpoint=True)
    else:
        trainer.train()

    # # ✅ 显式保存 LoRA adapter（会生成 adapter_config.json 和 adapter_model.bin）
    # for name, module in model.named_modules():
    #     if hasattr(module, "lora_A") and "default" in module.lora_A:
    #         if module.lora_A["default"].weight.numel() == 0 or module.lora_B["default"].weight.numel() == 0:
    #             module.lora_A.pop("default", None)
    #             module.lora_B.pop("default", None)
    print(model)
    trainer.save_state()
    trainer.save_model(training_args.output_dir)  # 保存模型权重
    data_args.image_processor.save_pretrained(training_args.output_dir)

    source_path = os.path.join(model_args.model_name_or_path, "chat_template.json")
    template_path = os.path.join(training_args.output_dir, "chat_template.json")
    shutil.copy2(source_path, template_path)

    model.config.use_cache = True
    safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
    # 保存 LoRA adapter（默认）
    # trainer.save_model(training_args.output_dir)
    tokenizer.save_pretrained(training_args.output_dir)

if __name__ == "__main__":
    train(attn_implementation="flash_attention_2")
