from torch.utils.data import Dataset, DataLoader
import json
import torch
import torch.nn.functional as F
import torch.nn as nn
import math
import numpy as np
from PIL import Image
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
from einops import rearrange
from internvl3.tools import expand2square
import os
import transformers
from transformers import AutoTokenizer
from internvl3.modeling_internvl_chat import InternVLChatModel
from dataclasses import dataclass
from infer import QwenVLEncoder
import torch.distributed as dist
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from transformers import Trainer
from tqdm import tqdm
from peft import LoraConfig, get_peft_model
from transformers import AutoModel, BitsAndBytesConfig
# 保持辅助函数不变
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)


def gather_with_grad(tensor):
    if not dist.is_available() or not dist.is_initialized():
        return tensor
    world_size = dist.get_world_size()
    tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)]
    dist.all_gather(tensor_list, tensor)
    return torch.cat(tensor_list, dim=0)


def gather_with_replace(x):
    if not dist.is_initialized():
        return x
    world_size = dist.get_world_size()
    rank = dist.get_rank()
    x_list = [torch.zeros_like(x) for _ in range(world_size)]
    dist.all_gather(x_list, x.contiguous())
    x_list[rank] = x
    return torch.cat(x_list, dim=0)


def load_model_and_tokenizer(model_path):
    model = Qwen3VLForConditionalGeneration.from_pretrained(
        model_path,
        dtype=torch.bfloat16,
        attn_implementation="sdpa",
        low_cpu_mem_usage=True,
    )
    # Qwen2/3-VL 的 Processor 默认就会处理动态分辨率
    # 它会根据 min_pixels 和 max_pixels 自动调整图片，保持长宽比
    tokenizer = AutoProcessor.from_pretrained(model_path,
                                              trust_remote_code=True,
                                              padding_side="left"
                                              )
    return model, tokenizer


class ImageTextDataset(Dataset):
    def __init__(self, ann_path, image_root, processor, max_length=64):
        self.image_root = image_root
        self.processor = processor
        self.max_length = max_length
        with open(ann_path, "r") as f:
            self.data = json.load(f)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        try:
            # ==== 修改 1: 启用动态分辨率 ====
            # 不再进行 expand2square 和 resize((448, 448))
            # 直接读取原始图片，Processor 会处理缩放和 Patch 切分
            image = Image.open(os.path.join(self.image_root, item['image'])).convert("RGB")
        except Exception as e:
            print(f"Skipping invalid image: {item['image']}, {e}")
            return None

        # 构建 Prompt (保持不变)
        image_messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": "Summary above image in one word:"}
                ]
            }
        ]

        # Processor 处理
        # 此时 Processor 会根据配置(max_pixels)保留长宽比处理图片
        image_inputs = self.processor(
            text=[self.processor.apply_chat_template(image_messages, tokenize=False, add_generation_prompt=True)],
            images=[image],
            return_tensors="pt",
            padding=True
        )

        # ==== 修改 2: 直接使用 Processor 输出的 Grid ====
        # 动态分辨率下，每张图的 Grid (h, w) 都不一样，必须用 processor 计算好的
        # image_inputs["image_grid_thw"] 通常是 [1, 3] 的 tensor，我们取 [0] 变成 [3]
        image_grid_thw = image_inputs["image_grid_thw"][0]

        # 获取 Pixel Values [N_patches, Dim]
        pixel_values = image_inputs["pixel_values"]

        # 文本部分 (保持不变)
        text_messages = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": f"{item['caption']} Summary above sentence in one word:"}
                ]
            }
        ]

        text_inputs = self.processor(
            text=[self.processor.apply_chat_template(text_messages, tokenize=False, add_generation_prompt=True)],
            return_tensors="pt",
            padding=True,
            max_length=128,
            truncation=True
        )

        return {
            "image_pixel_values": pixel_values,
            "image_grid_thw": image_grid_thw,  # 传递 Processor 计算的动态 Grid
            "text_input_ids": text_inputs["input_ids"][0],
            "text_attention_mask": text_inputs["attention_mask"][0],
            "image_input_ids": image_inputs["input_ids"][0],
            "image_attention_mask": image_inputs["attention_mask"][0]
        }


def retrieval_collate_fn(batch):
    batch = [x for x in batch if x is not None]
    if len(batch) == 0:
        return None

    # ==== 动态分辨率处理核心 ====
    # 1. Pixel Values:
    # 由于图片大小不同，每个样本的 patch 数量不同。
    # 我们直接将它们在 dim=0 拼接起来，形成一个巨大的 1D 序列 [Total_Patches_In_Batch, Hidden_Dim]
    pixel_values = torch.cat([x["image_pixel_values"] for x in batch], dim=0)

    # 2. Grid THW:
    # 堆叠每个样本的 grid 信息，形状为 [Batch_Size, 3]
    image_grid_thw = torch.stack([x["image_grid_thw"] for x in batch], dim=0)

    # 3. Text/IDs:
    # 使用 left_pad 处理文本和 input_ids 的变长问题
    return {
        "pixel_values": pixel_values,
        "image_grid_thw": image_grid_thw,
        "text_input_ids": left_pad([x["text_input_ids"].unsqueeze(0) for x in batch], pad_value=151643),
        "text_attention_mask": left_pad([x["text_attention_mask"].unsqueeze(0) for x in batch], pad_value=0),
        "image_input_ids": left_pad([x["image_input_ids"].unsqueeze(0) for x in batch], pad_value=151643),
        "image_attention_mask": left_pad([x["image_attention_mask"].unsqueeze(0) for x in batch], pad_value=0),
    }


def left_pad(seqs, pad_value):
    """
    seqs: List[Tensor[1, L_i]]
    return: Tensor[B, L_max]
    """
    max_len = max(x.size(1) for x in seqs)
    out = []
    for x in seqs:
        pad_len = max_len - x.size(1)
        if pad_len > 0:
            pad = x.new_full((1, pad_len), pad_value)
            x = torch.cat([pad, x], dim=1)
        out.append(x.squeeze(0))
    return torch.stack(out, dim=0)


class ContrastiveTrainer(Trainer):
    def __init__(self, temperature=0.05, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.temperature = temperature

        def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
            # 1. 解包 DDP 模型
            model = model.module if hasattr(model, "module") else model
            # Qwen3VLForConditionalGeneration.model -> Qwen3VLModel
            qwen_model = model.model

            # 2. 获取 Embedding 层
            embed_tokens = qwen_model.language_model.get_input_embeddings()

            # 获取特殊 Token ID
            image_token_id = qwen_model.config.image_token_id  # 通常是 151655
            
            # ==========================
            # 1. Text Branch (文本分支) - 保持不变
            # ==========================
            text_input_ids = inputs["text_input_ids"]
            text_embeds = embed_tokens(text_input_ids)

            text_outputs = qwen_model(
                input_ids=None,
                inputs_embeds=text_embeds,
                attention_mask=inputs["text_attention_mask"],
                output_hidden_states=True,
                return_dict=True
            )

            # 取最后一个 token 作为文本表征 [B, D]
            text_hidden = text_outputs.hidden_states[-1]
            text_emb = text_hidden[:, -1, :] 
            text_emb = F.normalize(text_emb, dim=-1)

            # ==========================
            # 2. Image Branch (Removed Meta Queries)
            # ==========================
            image_input_ids = inputs["image_input_ids"]  # [B, L]
            pixel_values = inputs["pixel_values"]        # [∑patch, D]
            image_grid_thw = inputs["image_grid_thw"]    # [B, 3]
            
            # 注意：不再需要 image_input_ids_ext 或 meta padding
            
            dtype = embed_tokens.weight.dtype

            # ---- A. image token embedding ----
            image_embeds = embed_tokens(image_input_ids)  # [B, L, D]

            # ---- B. visual feature extraction (Qwen3-VL official) ----
            image_embeds_list, deepstack_visual_embeds = \
                qwen_model.get_image_features(pixel_values, image_grid_thw)

            visual_features = torch.cat(image_embeds_list, dim=0).to(dtype)

            visual_mask = (image_input_ids == image_token_id)  # [B, L]

            # inject visual embeds
            image_embeds[visual_mask] = visual_features

            # ---- C. Position IDs & Masks (Directly on original sequence) ----
            # Qwen3-VL 需要通过 get_rope_index 获取 3D 位置编码索引
            position_ids, _ = qwen_model.model.get_rope_index(
                input_ids=image_input_ids,
                image_grid_thw=image_grid_thw,
                video_grid_thw=None,
                attention_mask=None 
            )
            
            # DeepStack/Qwen3VL 需要 visual_pos_masks 知道哪些是图片 token
            # 之前是拼接了 meta 的 mask，现在直接用原始 visual_mask 即可
            visual_pos_masks = visual_mask

            # ---- D. forward ----
            image_outputs = qwen_model.language_model(
                input_ids=None,
                inputs_embeds=image_embeds, # 直接传入处理后的 embedding
                attention_mask=inputs["image_attention_mask"],
                position_ids=position_ids,
                visual_pos_masks=visual_pos_masks,
                deepstack_visual_embeds=deepstack_visual_embeds,
                output_hidden_states=True,
                return_dict=True
            )

            # ---- E. Image Embedding Extraction ----
            # 移除 Meta Query 后，通常取序列的最后一个 Token (EOS) 作为整体表征
            # 对应 Qwen 这里的 inputs["image_input_ids"] 应该包含 EOS
            image_hidden = image_outputs.hidden_states[-1]
            
            # 方式1: 取最后一个 Token (推荐用于 Causal LM 做表征)
            image_emb = image_hidden[:, -1, :]
            
            # 方式2 (备选): 如果没有明确 EOS，可以使用全局平均池化
            # image_emb = image_hidden.mean(dim=1) 

            image_emb = F.normalize(image_emb, dim=-1)

            # ==========================
            # 3. Loss 计算
            # ==========================
            # 确保 tensor 维度匹配 (都应该是 [B, D])
            
            text_emb_all = gather_with_replace(text_emb)
            image_emb_all = gather_with_replace(image_emb)

            # logit_scale = model.logit_scale.exp().clamp(max=100)
            # 这里的 transpose 需要注意，确保是 [N, D] @ [D, N]
            logits = image_emb_all @ text_emb_all.t() / model.logit_scale

            B = image_emb.size(0)
            rank = dist.get_rank() if dist.is_initialized() else 0
            labels = torch.arange(B, device=image_emb.device) + rank * B

            start = rank * B
            end = start + B

            loss_i2t = F.cross_entropy(logits[start:end], labels)
            loss_t2i = F.cross_entropy(logits.t()[start:end], labels)
            loss = (loss_i2t + loss_t2i) / 2

            return (loss, (image_emb, text_emb)) if return_outputs else loss


def apply_lora(model):
    # 冻结所有参数
    for p in model.parameters():
        p.requires_grad = False

    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
    )
    model.model.language_model = get_peft_model(
        model.model.language_model,
        lora_config,
    )
    model.model.language_model.is_peft_model = True
    model.model.language_model.print_trainable_parameters()
    return model

def calculate_manual_lora_params(model_to_inspect, target_modules, r=8):
    """
    手动统计 LoRA 训练参数量，无需依赖 peft 库
    :param model_to_inspect: 需要应用 LoRA 的模型部分 (如 model.model.language_model)
    :param target_modules: LoRA 目标模块列表 (如 ["q_proj", "v_proj"])
    :param r: LoRA 的秩 (Rank)
    """
    lora_params = 0
    trainable_modules_count = 0
    
    print(f"正在统计 LoRA (r={r}) 参数量...")
    
    for name, module in model_to_inspect.named_modules():
        # 检查模块名是否以目标后缀结尾 (例如 "layers.0.self_attn.q_proj" 以 "q_proj" 结尾)
        if isinstance(module, nn.Linear) and any(name.endswith(target) for target in target_modules):
            
            in_features = module.in_features
            out_features = module.out_features
            
            # 计算该层的 LoRA 参数: A矩阵 + B矩阵
            # A shape: [r, in_features] -> r * in_features
            # B shape: [out_features, r] -> out_features * r
            params = (in_features * r) + (out_features * r)
            
            lora_params += params
            trainable_modules_count += 1
            
            # (可选) 打印每一层的计算详情用于调试
            # print(f"层: {name} | In: {in_features}, Out: {out_features} | LoRA Params: {params}")

    print(f"统计完成: 共有 {trainable_modules_count} 个模块会被注入 LoRA。")
    return lora_params

if __name__ == '__main__':
    path = "coco2017_train.json"
    image_root = ''
    model_path = "Qwen3-VL-4B-Instruct"

    # 加载模型
    model, tokenizer = load_model_and_tokenizer(model_path)
    model = QwenVLEncoder(model, tokenizer)

    lora_r = 8
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj", 
        "gate_proj", "up_proj", "down_proj"
    ]
    target_model_part = model.model.language_model 
    lora_params = calculate_manual_lora_params(
        target_model_part, 
        target_modules, 
        r=lora_r
    )

    # 获取基座模型的总参数量（用于计算占比）
    total_params = sum(p.numel() for p in target_model_part.parameters())
    
    print(f"==================================================")
    print(f"手动统计结果:")
    print(f"LoRA Trainable Params: {lora_params:,}")
    print(f"Total Model Params:    {total_params:,}")
    print(f"Trainable %:           {100 * lora_params / total_params:.4f}%")
    print(f"==================================================")
    model = apply_lora(model)
    model.train()
    grad_checkpoint = False
    if grad_checkpoint:
        model.enable_input_require_grads()

    # 注意：动态分辨率下，如果遇到非常大的图片，可能会导致 OOM
    # 可以通过 processor 的 config 限制 max_pixels，或者减小 batch_size
    micro_batch_size = 32
    num_epochs = 1
    learning_rate = 1e-5
    output_dir = './res_qwen2b_lora'
    save_steps = 100
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    ddp = world_size != 1


    data = ImageTextDataset(path, image_root, tokenizer, None)

    trainer = ContrastiveTrainer(
        model=model,
        train_dataset=data,
        args=transformers.TrainingArguments(
            per_device_train_batch_size=micro_batch_size,
            gradient_accumulation_steps=1,
            warmup_steps=100,
            num_train_epochs=num_epochs,
            learning_rate=learning_rate,
            fp16=False,
            bf16=True,
            logging_steps=10,
            save_strategy="steps",
            save_steps=save_steps,
            output_dir=output_dir,
            save_safetensors=False,
            save_total_limit=2,
            ddp_find_unused_parameters=False if ddp else None,
            run_name=run_name,
            remove_unused_columns=False,
            gradient_checkpointing=grad_checkpoint,
        ),
        data_collator=retrieval_collate_fn
    )
    trainer.train()
    model.model.language_model.save_pretrained('./qwen2b_lora')