from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Union, Literal
from functools import partial
import numpy as np
import torch
import torch.nn.functional as F
import torch.distributed as dist
from PIL import Image
import random
from trl.trainer.dpo_trainer import *



def read_image(img):
    if isinstance(img, str):
        img = Image.open(img)
    else:
        assert isinstance(img, Image.Image)
    if img.mode != 'RGB':
        img = img.convert('RGB')
    return img

def make_multi_image_prompts(prompt, processor):
    instruction = " Answer based on the {INDEX} image."
    prompt1_text = prompt + instruction.format(INDEX="first")
    prompt2_text = prompt + instruction.format(INDEX="second")
    
    conv_prompt1 = [{"role": "user", "content": [{"type": "image"}, {"type": "image"}, {"type": "text", "text": prompt1_text}]}]
    conv_prompt2 = [{"role": "user", "content": [{"type": "image"}, {"type": "image"}, {"type": "text", "text": prompt2_text}]}]
    
    # 2. 转换为字符串 (apply_chat_template)
    # 注意: 这里假设 processor 支持处理 list of dicts 并识别 "type": "image"
    prompt1 = processor.apply_chat_template(conv_prompt1, tokenize=False, add_generation_prompt=True)
    prompt2 = processor.apply_chat_template(conv_prompt2, tokenize=False, add_generation_prompt=True)
    
    return prompt1, prompt2
        
def make_vco_data(
    examples,
    prompt_col,
    image1_col,
    image2_col,
    resp1_col,
    resp2_col,
    training_args,
    processor,
    resp1_target_span_col: Optional=None,
    resp2_target_span_col: Optional=None,  
):
    vco_examples = {
        "prompt": [],
        "resp_1": [],
        "resp_2": [],
        "image_1": [],
        "image_2": [],
    }
    if training_args.use_no_image:
        vco_examples['prompt_no_image'] = []
    if training_args.use_multi_image:
        vco_examples['prompt_multi_image_1'] = []
        vco_examples['prompt_multi_image_2'] = []
    if training_args.use_resp_token_mask:
        vco_examples['resp_1_target_spans'], vco_examples['resp_2_target_spans'] = [], []
        
    for i in range(len(examples[prompt_col])):
        prompt = examples[prompt_col][i]
        image_1 = read_image(examples[image1_col][i]) 
        image_2 = read_image(examples[image2_col][i])
        resp_1 = examples[resp1_col][i]
        resp_2 = examples[resp2_col][i]
        if training_args.use_resp_token_mask:
            resp_1_spans, resp_2_spans = examples[resp1_target_span_col][i], examples[resp2_target_span_col][i]
        else:
            resp_1_spans, resp_2_spans = None, None

        # randomly shuffle the image order
        if random.random() > 0.5:
            image_1, image_2 = image_2, image_1
            resp_1, resp_2 = resp_2, resp_1
            resp_1_spans, resp_2_spans = resp_2_spans, resp_1_spans

        vco_examples['resp_1'].append(resp_1)
        vco_examples['resp_2'].append(resp_2)
        if resp_1_spans and resp_2_spans:
            vco_examples['resp_1_target_spans'].append(resp_1_spans)
            vco_examples['resp_2_target_spans'].append(resp_2_spans)
        
        vco_examples['image_1'].append(image_1)
        vco_examples['image_2'].append(image_2)

        conv_prompt = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
        conv_prompt_str = processor.apply_chat_template(conv_prompt, tokenize=False, add_generation_prompt=True)
        vco_examples['prompt'].append(conv_prompt_str) 

        if training_args.use_no_image:
            conv_prompt_no_image = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
            conv_prompt_no_image_str = processor.apply_chat_template(conv_prompt_no_image, tokenize=False, add_generation_prompt=True)
            vco_examples['prompt_no_image'].append(conv_prompt_no_image_str)

        if training_args.use_multi_image:
            conv_prompt_multi_str_1, conv_prompt_multi_str_2 = make_multi_image_prompts(
                prompt, 
                processor=processor
            )
            vco_examples['prompt_multi_image_1'].append(conv_prompt_multi_str_1)
            vco_examples['prompt_multi_image_2'].append(conv_prompt_multi_str_2)
            
    return vco_examples

@dataclass
class VCOConfig(DPOConfig):
    vco_type: str = field(
        default='dpo', metadata={"help": "Type of VCO objective. Currently support: dpo, s-vco, v-dpo, symmpo, mdpo"}
    )
    use_multi_image: bool = field(
        default=False, metadata={"help": "Whether to use multi-image sample for training."}
    )
    skip_single_image: bool = field(
        default=False, metadata={"help": "Whether to skip single image."}
    )
    use_no_image: bool =field(
         default=False, metadata={"help": "Whether to use no image sample as additional rejected sample."}
    )
    use_resp_token_mask: bool=field(
        default=False, metadata={"help": "Whether to use response token mask to compute token-level DPO loss"}
    )
    add_sft_loss: bool = field(
        default=False, metadata={"help": "Whether to use SFT loss."}
    )
    sft_loss_weight: float = field(
        default=0.1, metadata={"help": "Weight for SFT loss."}
    )
    add_hinge_sft_loss: bool = field(
        default=False, metadata={"help": "Whether to use hinge-SFT loss."}
    )
    hinge_sft_loss_weight: float = field(
        default=1.0, metadata={"help": "Weight for hinge-SFT loss."}
    )
    add_anchor_loss: bool = field(
        default=False, metadata={"help": "Whether to use anchor loss in mDPO paper."}
    )
    anchor_delta: float = field(
        default=0.0, metadata={"help": "Delta for anchor loss."}
    )
    anchor_weight: float = field(
        default=1, metadata={"help": "Weight for anchor loss."}
    )
    add_margin_consistency_loss: bool = field(
        default=False, metadata={"help": "Whether to use margin consistency loss in SymMPO paper."}
    )
    margin_consistency_weight: float = field(
        default=1e-4, metadata={"help": "Weight for margin consistency loss."}
    )
    add_margin_distillation_loss: bool = field(
        default=False, metadata={"help": "Whether to use margin distillation loss."}
    )
    margin_distillation_weight: float = field(
        default=1e-2, metadata={"help": "Weight for margin consistency loss."}
    )
    add_vcdist: bool = field(
        default=False, metadata={"help": "Whether to use distillation loss."}
    )
    vcdist_weight: float = field(
        default=0.5, metadata={"help": "Weight for margin consistency loss."}
    )
    vcdist_filter: bool = field(
        default=True, metadata={"help": "Weight to filter teacher."}\
    )
    vcdist_stopgrad: bool = field(
        default=True, metadata={"help": "Weight to stop teacher gradient."}\
    )
    remove_unused_columns: bool = field(
        default=False, metadata={"help": "Must be set to False so the custom dataloader can work."}
    )   
    max_prompt_length: int = field(
        default=None, metadata={"help": "Maximum length of the prompt."}
    )
    max_completion_length: int = field(
        default=None, metadata={"help": "Maximum length of the completion."}
    )
    max_length: int = field(
        default=None, metadata={"help": "Maximum length of the sequence. If None, will use the maximum length of the model."}
    )

    def __post_init__(self):
        # self.add_anchor_loss = True
        if self.vco_type.startswith('ic-vco'):
            self.use_multi_image = True
            if 'noanchor' not in self.vco_type:
                self.add_anchor_loss = True
            if 'novcdist' not in self.vco_type:
                self.add_vcdist = True
            if 'skipsingle' in self.vco_type:
                self.skip_single_image = True
                self.add_vcdist = False

            # if 'nomargin' not in self.vco_type:
            #     self.add_margin_consistency_loss = True
            #     if 'nodist' not in self.vco_type:
            #         self.add_margin_distillation_loss = True
            # if self.vco_type == 'ic-vco-hingesft':
            #     self.add_hinge_sft_loss = True
            # elif self.vco_type == 'ic-vco-anchor':
            #     self.add_anchor_loss = True
            # self.add_sft_loss = True
            # self.add_anchor_loss = True
            # self.add_hinge_sft_loss = True
        elif self.vco_type == 's-vco':
            self.use_no_image = True
        elif self.vco_type ==  'symmpo':
            self.add_anchor_loss = True
            self.add_margin_consistency_loss = True
        elif  self.vco_type == 'mdpo':
            self.add_anchor_loss = True
        elif self.vco_type == 'v-dpo':
            self.use_no_image = True
        super().__post_init__()

@dataclass
class DataCollatorForVisualContrastivePreference(DataCollatorForPreference):
    def torch_call(self, examples: list[list[int] | Any | dict[str, Any]]) -> dict[str, Any]:
        output = {}

        # =================================================================
        # 1. 辅助函数：处理文本序列 (Input IDs + Attention Mask)
        # =================================================================
        def collate_text(key_name: str, padding_side: str = "right", ):
            """
            检查 batch 中是否存在该 key，如果存在则进行 Tensor 转换、生成 Mask 并 Padding。
            """
            if key_name not in examples[0]:
                return
            
            # 转换为 Tensor
            input_ids_list = [torch.tensor(example[key_name], dtype=torch.long) for example in examples]
            # 生成 Attention Mask (1 for real tokens, 0 for padding)
            attention_mask_list = [torch.ones_like(input_ids) for input_ids in input_ids_list] 
            
            # 执行 Padding
            output[key_name] = pad(input_ids_list, padding_value=self.pad_token_id, padding_side=padding_side)
            
            # 自动生成对应的 mask key (例如 image_1_prompt_input_ids -> image_1_prompt_attention_mask)
            attn_mask_key = key_name.replace("input_ids", "attention_mask")
            output[attn_mask_key] = pad(attention_mask_list, padding_value=0, padding_side=padding_side)

            label_mask_key = key_name.replace("input_ids", "label_mask")
            if label_mask_key in examples[0]:
                label_mask_ids = [torch.tensor(example[label_mask_key], dtype=torch.long) for example in examples]
                output[label_mask_key] = pad(label_mask_ids, padding_value=0, padding_side=padding_side)


            # 处理对应的 Token Type IDs (如果有)
            token_type_key = key_name.replace("prompt_input_ids", "token_type_ids").replace("input_ids", "token_type_ids")
            # 注意：上面的 replace 逻辑是为了覆盖 'image_1_prompt_input_ids' -> 'image_1_token_type_ids' 
            # 以及可能存在的 'resp_1_input_ids' -> 'resp_1_token_type_ids'
            
            # 修正：通常 key 命名比较特定，这里尝试直接查找
            token_type_key = key_name.replace("input_ids", "token_type_ids")
            if token_type_key in examples[0]:
                tt_ids = [torch.tensor(example[token_type_key], dtype=torch.long) for example in examples]
                output[token_type_key] = pad(tt_ids, padding_value=0, padding_side=padding_side)


        # =================================================================
        # 2. 处理 Prompts (通常使用 Left Padding)
        # =================================================================
        # 单图模式 Keys
        collate_text("image_1_prompt_input_ids", padding_side="left")
        collate_text("image_2_prompt_input_ids", padding_side="left")
        collate_text("no_image_prompt_input_ids", padding_side="left")
        
        # 多图模式 Keys (基于你上一轮的 multi_image 函数)
        collate_text("multi_image_1_prompt_input_ids", padding_side="left")
        collate_text("multi_image_2_prompt_input_ids", padding_side="left")

        # =================================================================
        # 3. 处理 Responses (通常使用 Right Padding)
        # =================================================================
        collate_text("resp_1_input_ids", padding_side="right")
        collate_text("resp_2_input_ids", padding_side="right")


        # =================================================================
        # 4. 处理视觉特征 (Pixel Values & Masks) - 适配 (Num_Images, Num_Tiles, ...)
        # =================================================================
        
        vision_keys = [
            "image_1_pixel_values", 
            "image_2_pixel_values", 
            "multi_image_pixel_values"
        ]
        
        mask_keys = [
            "image_1_pixel_attention_mask", 
            "image_2_pixel_attention_mask", 
            "multi_image_pixel_attention_mask"
        ]
        
        # --- 核心修复：更鲁棒的模式检测 ---
        def check_is_anyres(examples, vision_keys):
            # 判据 1: 如果存在 image_sizes 字段，几乎肯定是 AnyRes (OneVision/NeXT)
            size_keys = ["image_1_image_sizes", "image_2_image_sizes", "multi_image_image_sizes"]
            for ex in examples:
                for sk in size_keys:
                    if sk in ex and ex[sk] is not None:
                        return True
            return False

        is_anyres_mode = check_is_anyres(examples, vision_keys)
        
        # 增加 Debug 信息 (只在 Rank 7 报错时有用，防止再次静默失败)
        # if not is_anyres_mode:
        #     print(f"[DEBUG] Collator detected Fixed-Res Mode. Sample Key Shape: {len(examples[0].get('image_1_pixel_values', []))}")

        # =================================================================
        # 分支 A: AnyRes 模式 (LLaVA-OneVision) -> 需要 Padding Tiles
        # =================================================================            
        if is_anyres_mode:
            # 1. 寻找 Global Max Tiles (在 dim=1 上)
            global_max_tiles = 0
            for example in examples:
                for key in vision_keys:
                    if key in example and example[key] is not None:
                        val = example[key]
                        if isinstance(val, list): val = torch.tensor(val)
                        
                        # [FIX 1] 确保升维逻辑存在，防止误判 Channels
                        if isinstance(val, torch.Tensor) and val.dim() == 4:
                            val = val.unsqueeze(0) 
                        
                        curr_tiles = val.shape[1]
                        if curr_tiles > global_max_tiles:
                            global_max_tiles = curr_tiles

            # 2. 定义 Padding 函数 (Dim=1)
            def pad_vision_tensor(tensor, target_n, is_mask=False):
                # Input: (num_images, num_tiles, ...)
                curr_n = tensor.shape[1]
                if curr_n == target_n: return tensor
                
                diff = target_n - curr_n
                num_imgs = tensor.shape[0]
                rest_shape = tensor.shape[2:] 
                
                padding_shape = (num_imgs, diff, *rest_shape)
                padding = torch.zeros(padding_shape, dtype=tensor.dtype)
                return torch.cat([tensor, padding], dim=1)



            # # 3. 处理 Pixel Values (Pad + Stack)
            # for key in vision_keys:
            #     if any(key in ex for ex in examples):
            #         batch_tensors = []
            #         for example in examples:
            #             # 确保转 Tensor
            #             val = torch.tensor(example[key], dtype=torch.float32)
            #             padded_val = pad_vision_tensor(val, global_max_tiles, is_mask=False)
            #             batch_tensors.append(padded_val)
            #         output[key] = torch.stack(batch_tensors, dim=0)

            # 3. 处理 Pixel Values (Pad + Stack)
            for key in vision_keys:
                if any(key in ex for ex in examples):
                    batch_tensors = []
                    for example in examples:
                        val = torch.tensor(example[key], dtype=torch.float32)
                        
                        # [FIX 2] 升维：(Tiles, C, H, W) -> (1, Tiles, C, H, W)
                        if val.dim() == 4:
                            val = val.unsqueeze(0) 
                        
                        padded_val = pad_vision_tensor(val, global_max_tiles, is_mask=False)
                        
                        # [FIX 3] 必须恢复维度！(1, Max_T, 3, H, W) -> (Max_T, 3, H, W)
                        # 如果不 Squeeze，Collator 输出就是 (Batch, 1, Tiles...), 导致下游 5D 检查失败
                        if padded_val.shape[0] == 1:
                            padded_val = padded_val.squeeze(0)
                            
                        batch_tensors.append(padded_val)
                    output[key] = torch.stack(batch_tensors, dim=0)

        # 4. 处理 Attention Mask (Pad + Stack)
            for key in mask_keys:
                if any(key in ex for ex in examples):
                    batch_masks = []
                    for example in examples:
                        val = torch.tensor(example[key], dtype=torch.long)
                        
                        # [FIX 4] 致命遗漏修复：Mask 也必须升维！
                        # Mask 通常是 3D: (Tiles, H, W)
                        # 如果不升维，shape[1] 是 H，Padding 会导致严重错误
                        if val.dim() == 3: 
                            val = val.unsqueeze(0) # -> (1, Tiles, H, W)
                        elif val.dim() == 4: # 如果本来就是 (Num_Images, Tiles, H, W)
                            pass 

                        padded_val = pad_vision_tensor(val, global_max_tiles, is_mask=True)
                        
                        # [FIX 5] 同样需要恢复维度
                        if padded_val.shape[0] == 1:
                            padded_val = padded_val.squeeze(0)

                        batch_masks.append(padded_val)
                    output[key] = torch.stack(batch_masks, dim=0)

        # =================================================================
        # 分支 B: 固定分辨率模式 (LLaVA-Interleave) -> 直接 Stack
        # =================================================================
        else:
            # 直接 Stack，无需计算 Tiles，因为 H,W 是固定的
            # Input Shape: (Num_Images, C, H, W)
            # Output Shape: (Batch, Num_Images, C, H, W)
            
            # 处理 Pixel Values
            for key in vision_keys:
                if any(key in ex for ex in examples):
                    # 列表推导式 + torch.stack
                    # 注意：这里假设同一 key 下所有样本的 num_images 是一致的
                    # 如果 num_images 不一致（变长图文混排），则需要对 dim=0 进行 pad，但通常 DPO 数据是对齐的
                    tensors = [torch.tensor(ex[key], dtype=torch.float32) for ex in examples]
                    output[key] = torch.stack(tensors, dim=0)

            # 处理 Masks (如果有)
            for key in mask_keys:
                if any(key in ex for ex in examples):
                    masks = [torch.tensor(ex[key], dtype=torch.long) for ex in examples]
                    output[key] = torch.stack(masks, dim=0)

        # =================================================================
        # 5. 处理 Image Sizes (通用)
        # =================================================================
        # image_sizes 对于 AnyRes 是必须的，对于 Fixed Res 有时也需要用来做 Aspect Ratio 处理
        # 统一处理方式：Stack
        size_keys = [
            "image_1_image_sizes", "image_2_image_sizes", "multi_image_image_sizes"
        ]
        for k in size_keys:
            if k in examples[0]:
                val_list = [torch.tensor(ex[k], dtype=torch.long) for ex in examples]
                output[k] = torch.stack(val_list, dim=0)


        
        return output



class VCOTrainer(DPOTrainer):

    def _prepare_dataset(
        self,
        dataset,
        processing_class,
        args,
        dataset_name: str,
    ):
        # Build the kwargs for the `map` function
        map_kwargs = {}
        if isinstance(dataset, Dataset):  # IterableDataset does not support num_proc nor writer_batch_size
            map_kwargs["num_proc"] = args.dataset_num_proc
            map_kwargs["writer_batch_size"] = 10

        with PartialState().main_process_first():
            # Extract prompt if needed
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Extracting prompt in {dataset_name} dataset"
            dataset = dataset.map(maybe_extract_prompt, **map_kwargs)

            # Apply the chat template if needed
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Applying chat template to {dataset_name} dataset"
            dataset = dataset.map(
                maybe_apply_chat_template, fn_kwargs={"tokenizer": processing_class, "tools": args.tools}, **map_kwargs
            )

            # Tokenize the dataset
            if isinstance(dataset, Dataset):  # `IterableDataset.map` does not support `desc`
                map_kwargs["desc"] = f"Tokenizing {dataset_name} dataset"

            # 使用 set_transform 替代 map，实现即时转换（lazy evaluation）
            # 避免预处理缓存，节省磁盘空间
            transform_fn = partial(
                self.process_row,
                processing_class=processing_class,
                max_prompt_length=args.max_prompt_length,
                max_completion_length=args.max_completion_length,
                add_special_tokens=False,
                use_multi_image=args.use_multi_image,
            )
            dataset.set_transform(transform_fn)
        return dataset

    @staticmethod
    def prepare_token_label_mask(
        input_ids: List[int],
        offset_mapping: List[Tuple[int, int]],
        target_spans: List[Tuple[int, int]]
    ) -> List[int]:
        """
        针对单个样本生成 token mask。纯 Python 实现，无 Tensor 依赖。
        
        Args:
            input_ids: List[int], 单个序列的 token ids
            offset_mapping: List[(start, end)], 每个 token 对应的字符起止位置
            target_spans: List[(start, end)], 需要 mask 为 1 的目标字符区间列表
        
        Returns:
            mask: List[int], 与 input_ids 等长，0 或 1
        """
        seq_len = len(input_ids)
        # 初始化全 0 mask
        mask = [0] * seq_len
        
        # 如果没有目标 span，直接返回全 0
        if not target_spans:
            return mask

        # 遍历每个 Token
        for idx, (token_start, token_end) in enumerate(offset_mapping):
            # 1. 跳过特殊 Token
            # 大多数 Tokenizer 将 special tokens (BOS, EOS, PAD) 的 offset 设为 (0, 0)
            # 也有部分 tokenizer (如 Llama) 可能设为 (0, 0) 或其他空区间
            if token_start == token_end:
                continue
                
            # 2. 检查当前 Token 是否与 *任意一个* target span 重叠
            for span_start, span_end in target_spans:
                # 判定重叠 (Loose Overlap): max(A, C) < min(B, D)
                # 只要 Token 包含 span 的一部分，或者是 span 的一部分，都算
                overlap_start = max(token_start, span_start)
                overlap_end = min(token_end, span_end)
                
                if overlap_start < overlap_end:
                    mask[idx] = 1
                    break # 只要命中一个 span，该 token 就标记为 1，跳出内层循环
        return mask

    @staticmethod
    def process_row_single_image(
        features,
        processing_class,
        max_prompt_length = None,
        max_completion_length = None,
        add_special_tokens = True,
    ) -> dict[str, list[int]]:
        """
        Same as `tokenize_row` but for vision models. Please refer to `tokenize_row` for more information.
        """
        processor, tokenizer = processing_class, processing_class.tokenizer  # the processing class is a processor
        image_1_processed_features = processor(images=features['image_1'], text=features['prompt'], add_special_tokens=False)
        image_1_prompt_input_ids = image_1_processed_features["input_ids"][0]
        image_1_pixel_values = image_1_processed_features["pixel_values"][0]

        image_2_processed_features = processor(images=features['image_2'], text=features['prompt'], add_special_tokens=False)
        image_2_prompt_input_ids = image_2_processed_features["input_ids"][0]
        image_2_pixel_values = image_2_processed_features["pixel_values"][0]

        use_no_image = 'prompt_no_image' in features
        no_image_prompt_processed_features = processor(images=None, text=features['prompt_no_image'], add_special_tokens=False) if use_no_image else None
        no_image_prompt_input_ids = no_image_prompt_processed_features["input_ids"][0] if no_image_prompt_processed_features is not None else None

        use_resp_token_mask = 'resp_1_target_spans' in features and 'resp_2_target_spans' in features
        resp_1_features = tokenizer(features['resp_1'], add_special_tokens=False, return_offsets_mapping=use_resp_token_mask)
        resp_2_features = tokenizer(features['resp_2'], add_special_tokens=False, return_offsets_mapping=use_resp_token_mask)
        resp_1_input_ids = resp_1_features['input_ids']
        resp_2_input_ids = resp_2_features['input_ids']

        if use_resp_token_mask:
            resp_1_label_mask = VCOTrainer.prepare_token_label_mask(resp_1_input_ids, resp_1_features['offset_mapping'], features['resp_1_target_spans'])
            resp_2_label_mask = VCOTrainer.prepare_token_label_mask(resp_2_input_ids, resp_2_features['offset_mapping'], features['resp_2_target_spans'])
        else:
            resp_1_label_mask, resp_2_label_mask = None, None


        # Add special tokens (typically for encoder-decoder models)
        if add_special_tokens:
            if tokenizer.bos_token_id is not None:
                image_1_prompt_input_ids = [tokenizer.bos_token_id] + image_1_prompt_input_ids
                image_2_prompt_input_ids = [tokenizer.bos_token_id] + image_2_prompt_input_ids
                if use_no_image:
                    no_image_prompt_input_ids = [tokenizer.bos_token_id] + no_image_prompt_input_ids
            if tokenizer.eos_token_id is not None:
                image_1_prompt_input_ids = image_1_prompt_input_ids + [tokenizer.eos_token_id]
                image_2_prompt_input_ids = image_2_prompt_input_ids + [tokenizer.eos_token_id]
                if use_no_image:
                    no_image_prompt_input_ids = no_image_prompt_input_ids + [tokenizer.eos_token_id]

        resp_1_input_ids = resp_1_input_ids + [tokenizer.eos_token_id]
        resp_2_input_ids = resp_2_input_ids + [tokenizer.eos_token_id]
        if use_resp_token_mask:
            resp_1_label_mask += [0]
            resp_2_label_mask += [0]


        # Truncate prompt and completion sequences
        # if max_prompt_length is not None:
        #     image_1_prompt_input_ids = image_1_prompt_input_ids[-max_prompt_length:]
        #     image_2_prompt_input_ids = image_2_prompt_input_ids[-max_prompt_length:]
        #     if no_image_prompt_input_ids is not None:
        #         no_image_prompt_input_ids = no_image_prompt_input_ids[-max_prompt_length:]
                
        if max_completion_length is not None:
            resp_1_input_ids = resp_1_input_ids[:max_completion_length]
            resp_2_input_ids = resp_2_input_ids[:max_completion_length]
            if use_resp_token_mask:
                resp_1_label_mask = resp_1_label_mask[:max_completion_length]
                resp_2_label_mask = resp_2_label_mask[:max_completion_length]

        output = {
            "image_1_prompt_input_ids": image_1_prompt_input_ids,
            "image_2_prompt_input_ids": image_2_prompt_input_ids,
            "image_1_pixel_values": image_1_pixel_values,
            "image_2_pixel_values": image_2_pixel_values,
            "resp_1_input_ids": resp_1_input_ids,
            "resp_2_input_ids": resp_2_input_ids,
        }
        if use_no_image:
            output["no_image_prompt_input_ids"] = no_image_prompt_input_ids
        
        if use_resp_token_mask:
            output['resp_1_label_mask'] = resp_1_label_mask
            output['resp_2_label_mask'] = resp_2_label_mask

        # if "pixel_attention_mask" in image_1_processed_features:
        #     output["image_1_pixel_attention_mask"] = image_1_processed_features["pixel_attention_mask"][0]
        #     output["image_2_pixel_attention_mask"] = image_2_processed_features["pixel_attention_mask"][0]




        if "image_sizes" in image_1_processed_features: # AnyRes Mode
            output["image_1_image_sizes"] = image_1_processed_features["image_sizes"][0]
            output["image_2_image_sizes"] = image_2_processed_features["image_sizes"][0]

            # =======================================================
            # [FIX] 强制提取或生成 Pixel Attention Mask
            # =======================================================
            def ensure_pixel_mask(processed_features, key_prefix):
                # 1. 尝试直接从 processor 输出中获取
                if "pixel_attention_mask" in processed_features:
                    return processed_features["pixel_attention_mask"][0]
                
                # 2. 如果没有 (Processor 没返回)，则手动生成全 1 Mask
                # Pixel Values Shape: (Tiles, C, H, W)
                # Target Mask Shape:  (Tiles, H, W)
                pv = processed_features["pixel_values"][0]
                if not isinstance(pv, torch.Tensor):
                    pv = torch.tensor(pv)
                
                # 生成全 1 mask (表示所有像素都有效)
                # shape 取 pv 的 (Tiles, H, W) -> 也就是 (dim 0, dim 2, dim 3)
                mask_shape = (pv.shape[0], pv.shape[2], pv.shape[3])
                return torch.ones(mask_shape, dtype=torch.long)

            # 注入到 output 字典中
            output["image_1_pixel_attention_mask"] = ensure_pixel_mask(image_1_processed_features, "image_1")
            output["image_2_pixel_attention_mask"] = ensure_pixel_mask(image_2_processed_features, "image_2")

        if "token_type_ids" in image_1_processed_features:
            output["image_1_token_type_ids"] = image_1_processed_features["token_type_ids"][0]
            output["image_2_token_type_ids"] = image_2_processed_features["token_type_ids"][0]
        if no_image_prompt_processed_features is not None and 'token_type_ids' in no_image_prompt_processed_features:
            output['no_image_token_type_ids'] = no_image_prompt_processed_features["token_type_ids"][0]

        if "image_grid_thw" in image_1_processed_features:
            output["image_1_image_grid_thw"] = image_1_processed_features["image_grid_thw"].squeeze(0)
            output["image_2_image_grid_thw"] = image_2_processed_features["image_grid_thw"].squeeze(0)

        return output
    
    @staticmethod
    def process_row_multi_image(
        features,
        processing_class,
        max_prompt_length = None,
        add_special_tokens = True,
    ) -> dict[str, list[int]]:
        processor, tokenizer = processing_class, processing_class.tokenizer

        # 1. 准备图像数据
        # 假设 features['image_X'] 是列表格式，这里将两组图片合并作为上下文
        # 如果是单张图片对象，可能需要调整为 [features['image_1'], features['image_2']]
        combined_images = [features["image_1"], features["image_2"]]
        num_images_in_context = 2

        # 2. 调用 Processor 处理两个不同的 Prompt
        processed_features_1 = processor(
            images=combined_images, 
            text=features["prompt_multi_image_1"], 
            add_special_tokens=False
        )
        processed_features_2 = processor(
            images=combined_images, 
            text=features["prompt_multi_image_2"], 
            add_special_tokens=False
        )


        # 3. 提取基本特征 (Input IDs 和 Pixel Values)
        multi_image_1_prompt_input_ids = processed_features_1["input_ids"][0]
        multi_image_pixel_values = processed_features_1["pixel_values"]
        
        multi_image_2_prompt_input_ids = processed_features_2["input_ids"][0]


        # # Truncate prompt and completion sequences
        # if max_prompt_length is not None:
        #     multi_image_1_prompt_input_ids = multi_image_1_prompt_input_ids[-max_prompt_length:]
        #     multi_image_2_prompt_input_ids = multi_image_2_prompt_input_ids[-max_prompt_length:]

        # 5. 构建输出字典
        output = {
            "multi_image_1_prompt_input_ids": multi_image_1_prompt_input_ids,
            "multi_image_2_prompt_input_ids": multi_image_2_prompt_input_ids,
            "multi_image_pixel_values": multi_image_pixel_values
        }

        # 6. 处理可选的视觉特征 (Masks, Sizes, Token Types)
        # 许多新模型(如Qwen-VL, Idefics2)需要这些额外参数
        
        # 处理 Prompt 1 的额外特征
        # if "pixel_attention_mask" in processed_features_1:
        #     output["multi_image_pixel_attention_mask"] = processed_features_1["pixel_attention_mask"][0]
        if "image_sizes" in processed_features_1: # AnyRes Mode
            output["multi_image_image_sizes"] = processed_features_1["image_sizes"][0]

            # =================================================================
            # [FIX] 强制提取或生成 Pixel Attention Mask (多图版)
            # =================================================================
            # 逻辑：如果有 mask 就拿，没有就根据 pixel_values 形状造一个全 1 的
    

            def ensure_multi_image_pixel_mask(proc_feats, pv):
                # 1. 尝试直接获取
                if "pixel_attention_mask" in proc_feats:
                    return proc_feats["pixel_attention_mask"][0]
                
                # 2. 手动生成全 1 Mask
                
                # Case A: pv 是 Tensor -> 直接用
                if isinstance(pv, torch.Tensor):
                    pv_tensor = pv
                    
                # Case B: pv 是 List -> 需要转换
                elif isinstance(pv, list):
                    if len(pv) > 0:
                        # 检查 list 里的元素类型
                        first_elem = pv[0]
                        if isinstance(first_elem, torch.Tensor):
                            pv_tensor = torch.stack(pv)
                        elif isinstance(first_elem, np.ndarray):
                            # [FIX] 先用 numpy stack，再转 tensor，速度快且无 warning
                            pv_tensor = torch.from_numpy(np.stack(pv))
                        else:
                            # Fallback: 可能是 list of list，或者其他 weird format
                            pv_tensor = torch.tensor(pv)
                    else:
                        # 空 list
                        pv_tensor = torch.tensor([])
                
                # Case C: pv 是 ndarray
                elif isinstance(pv, np.ndarray):
                    pv_tensor = torch.from_numpy(pv)
                
                else:
                    # Fallback
                    pv_tensor = torch.tensor(pv)

                mask_shape = list(pv_tensor.shape)
                c_dim_index = -3 # C 在倒数第3维
                del mask_shape[c_dim_index] # 删除 C 维度
                
                return torch.ones(mask_shape, dtype=torch.long)

            # 生成 mask
            pixel_mask = ensure_multi_image_pixel_mask(processed_features_1, multi_image_pixel_values)
            output["multi_image_pixel_attention_mask"] = pixel_mask
            # =================================================================

            # 4. 添加特殊 Token (BOS / EOS)
            # 逻辑与 process_row_single_image 保持完全一致
            if add_special_tokens:
                if tokenizer.bos_token_id is not None:
                    multi_image_1_prompt_input_ids = [tokenizer.bos_token_id] + multi_image_1_prompt_input_ids
                    multi_image_2_prompt_input_ids = [tokenizer.bos_token_id] + multi_image_2_prompt_input_ids
                
                if tokenizer.eos_token_id is not None:
                    multi_image_1_prompt_input_ids = multi_image_1_prompt_input_ids + [tokenizer.eos_token_id]
                    multi_image_2_prompt_input_ids = multi_image_2_prompt_input_ids + [tokenizer.eos_token_id]


        if "token_type_ids" in processed_features_1:
            output["multi_image_token_type_ids"] = processed_features_1["token_type_ids"][0]

        if "image_grid_thw" in processed_features_1:
            output["multi_image_grid_thw"] = processed_features_1["image_grid_thw"]


        # 处理 Prompt 2 的额外特征
        # if "pixel_attention_mask" in processed_features_2:
        #     output["multi_image_2_pixel_attention_mask"] = processed_features_2["pixel_attention_mask"][0]
        # if "image_sizes" in processed_features_2:
        #     output["multi_image_2_image_sizes"] = processed_features_2["image_sizes"][0]
        if "token_type_ids" in processed_features_2:
            output["multi_image_2_token_type_ids"] = processed_features_2["token_type_ids"][0]

        return output


    @staticmethod
    def process_row(
        features,
        processing_class,
        max_prompt_length = None,
        max_completion_length = None,
        add_special_tokens: bool = True,
        use_multi_image: bool = False,

    ) -> dict[str, list[int]]:
        # 检测是否为 batch 格式（set_transform 在 DataLoader 批量获取时传入 batch）
        # 通过检查 'prompt' 字段是否为列表来判断
        is_batched = isinstance(features.get('prompt'), list)
        
        if not is_batched:
            # 单样本处理（原逻辑）
            output = VCOTrainer.process_row_single_image(
                features=features, 
                processing_class=processing_class,
                max_prompt_length=max_prompt_length, 
                max_completion_length=max_completion_length, 
                add_special_tokens=add_special_tokens
            )
            if use_multi_image:
                output.update(VCOTrainer.process_row_multi_image(
                    features=features, 
                    processing_class=processing_class,
                    max_prompt_length=max_prompt_length, 
                    add_special_tokens=add_special_tokens
                ))
            return output
        
        # batch 处理：逐样本处理后合并
        batch_size = len(features['prompt'])
        all_outputs = []
        
        for i in range(batch_size):
            # 提取第 i 个样本的 features
            single_features = {k: v[i] for k, v in features.items()}
            
            # 单样本处理
            single_output = VCOTrainer.process_row_single_image(
                features=single_features, 
                processing_class=processing_class,
                max_prompt_length=max_prompt_length, 
                max_completion_length=max_completion_length, 
                add_special_tokens=add_special_tokens
            )
            if use_multi_image:
                single_output.update(VCOTrainer.process_row_multi_image(
                    features=single_features, 
                    processing_class=processing_class,
                    max_prompt_length=max_prompt_length, 
                    add_special_tokens=add_special_tokens
                ))
            all_outputs.append(single_output)
        
        # 合并为 batch 格式：{key: [sample1_value, sample2_value, ...]}
        if not all_outputs:
            return {}
        batch_output = {k: [out[k] for out in all_outputs] for k in all_outputs[0].keys()}
        return batch_output

    
    
    def get_batch_loss_metrics(
        self,
        model: PreTrainedModel | nn.Module,
        batch: dict[str, list | torch.LongTensor],
        train_eval: Literal["train", "eval"] = "train",
    ) -> tuple[torch.Tensor, dict[str, float]]:
        """
        Refactored: Modular execution for Single and Multi-Image tasks.
        """
        metrics = {}
        batch_losses = []
        single_loss, multi_loss = None, None
        # --- Module 1: Single Image DPO (Parallel) ---
        # 处理 Image 1 和 Image 2 的任务

        if not self.args.skip_single_image:
            if self.args.vco_type == 'v-dpo':
                single_loss = self.run_single_image_vdpo(model, batch, metrics, train_eval)
            elif self.args.vco_type == 's-vco':
                single_loss = self.run_single_image_mdpo(model, batch, metrics, train_eval)
            else:
                single_loss, single_logps = self.run_single_image_dpo(model, batch, metrics, train_eval)

            if self.args.vco_type == 'mdpo':
                single_loss += self.run_single_image_mdpo(model, batch, metrics, train_eval)
            
        # --- Module 2: Multi Image DPO (Serial) ---
        # 处理 Multi 1 和 Multi 2 的任务 (如果配置开启)
        if self.args.use_multi_image:
            if self.args.vco_type == 'mdpo':
                multi_loss = self.run_multi_image_mdpo(model, batch, metrics, train_eval)
            else:
                multi_loss, multi_logps = self.run_multi_image_dpo(model, batch, metrics, train_eval)

        # === 1. 处理 Single Task 分支 ===
        if single_loss is not None:
            # 基础 Hard Label Loss
            single_total_loss = single_loss
            
            # [INSERT] 计算并加入 vcdist Loss
            if self.args.add_vcdist and multi_logps is not None:
                vcdist_loss, vcdist_metrics = self.soft_label_dpo_loss(
                    student_chosen_logps=single_logps["chosen"],
                    student_rejected_logps=single_logps["rejected"],
                    student_ref_chosen_logps=single_logps["ref_chosen"],
                    student_ref_rejected_logps=single_logps["ref_rejected"],
                    
                    teacher_chosen_logps=multi_logps["chosen"],
                    teacher_rejected_logps=multi_logps["rejected"],
                    teacher_ref_chosen_logps=multi_logps["ref_chosen"],
                    teacher_ref_rejected_logps=multi_logps["ref_rejected"],
                    beta=self.args.beta
                )
                
                # 记录 metrics
                # metrics[f"{train_eval}_vcdist_loss"] = vcdist_loss.item()
                # metrics.update({f"{train_eval}_{k}": v for k, v in vcdist_metrics.items()})
                keys = ["vcdist_valid_kl_sum", "vcdist_valid_count", "vcdist_global_kl_sum", "vcdist_global_count"]
                # 2. Stack & All Reduce
                # 确保 vcdist_metrics 里所有 value 都是 Tensor 且在同一个 device
                local_stats = torch.stack([vcdist_metrics[k] for k in keys])
                
                if dist.is_initialized():
                    dist.all_reduce(local_stats, op=dist.ReduceOp.SUM)
                
                # 3. 解包 & 计算 Mean
                global_stats = {k: v for k, v in zip(keys, local_stats)}
                
                valid_count = global_stats["vcdist_valid_count"].item()
                global_count = global_stats["vcdist_global_count"].item() # 这就是全局的 Total Batch Size
                
                # 计算 Valid Metrics
                if valid_count > 0:
                    metrics[f"{train_eval}_vcdist_valid_kl"] = (global_stats["vcdist_valid_kl_sum"] / valid_count).item()
                else:
                    metrics[f"{train_eval}_vcdist_valid_kl"] = 0.0
                
                # [新增] 计算 Global Metrics
                if global_count > 0:
                    metrics[f"{train_eval}_vcdist_global_kl"] = (global_stats["vcdist_global_kl_sum"] / global_count).item()
                else:
                    metrics[f"{train_eval}_vcdist_global_kl"] = 0.0

                # 记录样本数供参考
                metrics[f"{train_eval}_vcdist_valid_count"] = valid_count
                metrics[f"{train_eval}_vcdist_global_count"] = global_count
                
                # === 关键点：加在这里 ===
                single_total_loss += self.args.vcdist_weight * vcdist_loss
            
            # 将处理完的 Single Total Loss 加入列表
            batch_losses.append(single_total_loss)

        # === 2. 处理 Multi Task 分支 ===
        if multi_loss is not None:
            batch_losses.append(multi_loss)

        # === 3. 最终聚合 ===
        if len(batch_losses) > 0:
            # 这里取 Mean，会将 Single Total (含vcdist) 和 Multi Total 平均
            final_loss = torch.stack(batch_losses).mean()
        else:
            final_loss = torch.tensor(0.0, device=model.device, requires_grad=True)

            # if self.args.add_margin_distillation_loss:
            #     margin_dist_loss = self.margin_distillation_loss(single_margin, multi_margin)
            #     metrics[f"{train_eval}_margin_distillation_loss"] = margin_dist_loss.item()
            #     final_loss += self.args.margin_distillation_weight*margin_dist_loss

        metrics[f"{train_eval}_loss_combined"] = final_loss.item()
        
        return final_loss, metrics

    def forward_with_pre_concatenated_batch(
        self, model: nn.Module, concatenated_batch: dict[str, torch.Tensor]
    ) -> dict[str, torch.Tensor]:
        """
        A modified version of `concatenated_forward` that skips `concatenated_inputs`.
        It expects `concatenated_batch` to be already constructed with shape (2*Batch, ...).
        Now supports 'completion_label_mask' for token-level DPO.
        """
        # 1. 准备模型参数
        model_kwargs = {"use_cache": False}
        if self.aux_loss_enabled:
            model_kwargs["output_router_logits"] = True

        # 透传视觉参数
        for k in ["pixel_values", "pixel_attention_mask", "image_sizes", "image_grid_thw"]:
            if k in concatenated_batch:
                model_kwargs[k] = concatenated_batch[k]
        
        if "token_type_ids" in concatenated_batch:
            model_kwargs["token_type_ids"] = concatenated_batch["token_type_ids"]

        # 2. 提取并拼接 Input IDs (Prompt + Completion)
        prompt_input_ids = concatenated_batch["prompt_input_ids"]
        prompt_attention_mask = concatenated_batch["prompt_attention_mask"]
        completion_input_ids = concatenated_batch["completion_input_ids"]
        completion_attention_mask = concatenated_batch["completion_attention_mask"]

        # 拼接 Prompt 和 Completion (Left + Right)
        input_ids = torch.cat((prompt_input_ids, completion_input_ids), dim=1)
        attention_mask = torch.cat((prompt_attention_mask, completion_attention_mask), dim=1)
        
        # 3. 构建 Loss Mask (基础: 只计算 Completion 部分的 Loss)
        # Prompt 部分全为 0，Completion 部分为 1
        loss_mask = torch.cat(
            (torch.zeros_like(prompt_attention_mask), completion_attention_mask),
            dim=1,
        )

        # 如果 batch 中提供了 token 级别的 label mask，我们需要将其对齐到 full sequence
        token_level_mask = None
        if "completion_label_mask" in concatenated_batch:
            # completion_label_mask 形状应与 completion_input_ids 一致
            comp_lbl_mask = concatenated_batch["completion_label_mask"]
            # 同样拼接：Prompt 部分不计算 loss (全0) + Completion 部分使用自定义 mask
            token_level_mask = torch.cat(
                (torch.zeros_like(prompt_attention_mask), comp_lbl_mask),
                dim=1
            )

        if self.max_length is not None and self.max_length < attention_mask.size(1):
            input_ids = input_ids[:, -self.max_length :]
            attention_mask = attention_mask[:, -self.max_length :]
            loss_mask = loss_mask[:, -self.max_length :]
            
            if token_level_mask is not None:
                token_level_mask = token_level_mask[:, -self.max_length :]
            
            if "token_type_ids" in model_kwargs:
                model_kwargs["token_type_ids"] = model_kwargs["token_type_ids"][:, -self.max_length :]

        # 5. 模型前向传播
        model_kwargs["attention_mask"] = attention_mask
        outputs = model(input_ids, **model_kwargs)
        logits = outputs.logits

        # 6. 计算 Log Probability
        # Shift logits and labels
        labels = input_ids.clone()
        
        # Shift: logits[t] predicts labels[t+1]
        logits = logits[:, :-1, :]
        labels = labels[:, 1:]
        loss_mask = loss_mask[:, 1:]
        
        # 如果有 token_level_mask，也要同步 Shift 并合并到 loss_mask
        if token_level_mask is not None:
            token_level_mask = token_level_mask[:, 1:]
            # 合并 Mask：
            # loss_mask 原本只区分 Prompt/Padding vs Completion。
            # 现在我们要求：(是 Completion) AND (是 Valid Token) AND (Label Mask 为 1)
            # 假设 mask 都是 0/1 Tensor，直接相乘即可实现 AND 逻辑
            loss_mask = loss_mask * token_level_mask

        # -------------------------------------------------
        # 防止 gather 越界：先把不计算 loss 的位置设为 0 (或其他有效 index)
        labels[~loss_mask.bool()] = 0 
        
        per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2)
        
        # 再次 Mask：把无效位置的 logps 抹零
        # 此时 loss_mask 已经结合了 attention_mask 和 completion_label_mask
        per_token_logps[~loss_mask.bool()] = 0
        # -------------------------------------------------

        # Sum per sequence
        all_logps = per_token_logps.sum(-1)

        # 7. 构建输出
        batch_size = input_ids.shape[0] // 2
        return {
            "chosen_logps": all_logps[:batch_size],
            "rejected_logps": all_logps[batch_size:],
            "logits": logits, 
        }
    
    

    def run_single_image_dpo(
        self, 
        model, 
        batch, 
        metrics, 
        train_eval
    ) -> torch.Tensor | None:
        """
        Executes Image 1 and Image 2 tasks in parallel using `forward_with_pre_concatenated_batch`.
        Supports:
        1. Unified DataCollator (Auto-Padding).
        2. Token-Level Label Masking (completion_label_mask).
        """
        # 基础校验
        if "image_1_prompt_input_ids" not in batch:
            return None

        pad_token_id = self.processing_class.tokenizer.pad_token_id or 0
        device = model.device

        # =================================================================
        # 1. 准备基础文本数据 (Lists)
        # =================================================================
        # 提取 Prompt
        prompts = [batch["image_1_prompt_input_ids"], batch["image_2_prompt_input_ids"]]
        prompt_masks = [batch["image_1_prompt_attention_mask"], batch["image_2_prompt_attention_mask"]]
        
        # 提取 Chosen (Response 1)
        chosens = [batch["resp_1_input_ids"], batch["resp_2_input_ids"]]
        chosen_masks = [batch["resp_1_attention_mask"], batch["resp_2_attention_mask"]]
        
        # 提取 Rejected (Response 2)
        rejecteds = [batch["resp_2_input_ids"], batch["resp_1_input_ids"]]
        rejected_masks = [batch["resp_2_attention_mask"], batch["resp_1_attention_mask"]]

        # Token Type IDs
        token_type_ids_list = []
        if "image_1_prompt_token_type_ids" in batch:
            tt1 = batch["image_1_prompt_token_type_ids"]
            tt2 = batch["image_2_prompt_token_type_ids"]
            token_type_ids_list = [tt1, tt2]

        # =================================================================
        # 2. 处理 Label Mask (新增逻辑)
        # =================================================================
        # 我们需要构造一个 mask 列表，长度与 batch 一致
        # Chosen 部分使用数据集提供的 mask，Rejected 部分通常全为 1 (不额外 mask)
        chosen_label_masks, rejected_label_masks = [], []
        has_label_mask = False
        
        # 尝试获取 Image 1 和 Image 2 的 Label Mask
        if "resp_1_label_mask" in batch:
            chosen_label_masks.append(batch["resp_1_label_mask"])
            rejected_label_masks.append(batch["resp_2_label_mask"])
            has_label_mask = True
        else:
            # 如果没有，用 None 占位 (后续处理) 或者全 1
            # 为了 pad_and_cat 方便，我们生成全 1
            chosen_label_masks.append(torch.ones_like(batch["resp_1_input_ids"]))
            rejected_label_masks.append(torch.ones_like(batch["resp_2_input_ids"]))

        if "resp_2_label_mask" in batch:
            chosen_label_masks.append(batch["resp_2_label_mask"])
            rejected_label_masks.append(batch["resp_1_label_mask"])
            has_label_mask = True
        else:
            chosen_label_masks.append(torch.ones_like(batch["resp_2_input_ids"]))
            rejected_label_masks.append(torch.ones_like(batch["resp_1_input_ids"]))

        # 如果完全没有 Label Mask，则不仅不需要构造，连 key 都不需要传
        final_label_masks = None
        if has_label_mask:
            # 拼接: [Chosen1, Chosen2, Rejected1, Rejected2]
            all_label_masks_list = chosen_label_masks + rejected_label_masks
            final_label_masks = self.pad_and_cat(all_label_masks_list, padding_side="right", padding_value=0, device=device)

        # =================================================================
        # 3. 处理视觉特征 (Pixel Values)
        # =================================================================
        p1 = batch["image_1_pixel_values"]
        p2 = batch["image_2_pixel_values"]

        # 1. 展平 Pixel Values
        def flatten_vision_input(tensor):
            if tensor.dim() == 6: return tensor.flatten(0, 2)
            elif tensor.dim() == 5: return tensor.flatten(0, 1)
            return tensor
        
        p1 = flatten_vision_input(p1)
        p2 = flatten_vision_input(p2)

        combined_pv = torch.cat([p1, p2], dim=0)
        final_pixel_values = torch.cat([combined_pv, combined_pv], dim=0)

        if 'image_1_pixel_attention_mask' in batch.keys():
            # 2. 准备 Mask
            m1 = batch["image_1_pixel_attention_mask"]
            m2 = batch["image_2_pixel_attention_mask"]

            def flatten_mask_input(tensor):
                if tensor.dim() == 5: return tensor.flatten(0, 2)
                elif tensor.dim() == 4: return tensor.flatten(0, 1)
                return tensor

            m1 = flatten_mask_input(m1)
            m2 = flatten_mask_input(m2)

            combined_pm = torch.cat([m1, m2], dim=0)
            final_pixel_mask = torch.cat([combined_pm, combined_pm], dim=0)

            # 3. 执行 Unpadding (过滤)
            # 计算有效性: 只要 HxW 平面上有一个点是 1，该 Tile 就有效
            valid_tile_mask = final_pixel_mask.sum(dim=(1, 2)) > 0
            
            # 维度对齐检查
            if final_pixel_values.shape[0] != valid_tile_mask.shape[0]:
                raise RuntimeError(f"Shape Mismatch! PV={final_pixel_values.shape[0]}, Mask={valid_tile_mask.shape[0]}")

            # 应用过滤
            # [Visual Guide]
            # Before: [Tile1(Valid), Tile2(Valid), Tile3(Pad), Tile4(Pad)] -> Shape 4
            # Mask:   [True,         True,         False,      False      ]
            # After:  [Tile1,        Tile2]                                 -> Shape 2
            
            final_pixel_values = final_pixel_values[valid_tile_mask]




        # E. 处理 Image Sizes
        final_image_sizes = None
        if "image_1_image_sizes" in batch:
            s1 = batch["image_1_image_sizes"]
            s2 = batch["image_2_image_sizes"]
            combined_sizes = torch.cat([s1, s2], dim=0)
            if combined_sizes.shape[1] == 1: combined_sizes = combined_sizes.squeeze(1)
            final_image_sizes = torch.cat([combined_sizes, combined_sizes], dim=0)

        # =================================================================
        # 4. 构造 Concatenated Batch (4*B)
        # =================================================================
        # 拼接顺序: [Chosen, Rejected]
        all_prompts = prompts + prompts
        all_prompt_masks = prompt_masks + prompt_masks
        
        all_resps = chosens + rejecteds
        all_resp_masks = chosen_masks + rejected_masks

        concatenated_batch = {
            "prompt_input_ids": self.pad_and_cat(all_prompts, padding_side="left", padding_value=pad_token_id, device=device),
            "prompt_attention_mask": self.pad_and_cat(all_prompt_masks, padding_side="left", padding_value=0, device=device),
            
            "completion_input_ids": self.pad_and_cat(all_resps, padding_side="right", padding_value=pad_token_id, device=device),
            "completion_attention_mask": self.pad_and_cat(all_resp_masks, padding_side="right", padding_value=0, device=device),
            
            "pixel_values": final_pixel_values,
        }

        # 注入可选参数
        if final_label_masks is not None:
            concatenated_batch["completion_label_mask"] = final_label_masks
        # if final_pixel_mask is not None:
        #     concatenated_batch["pixel_attention_mask"] = final_pixel_mask
        if final_image_sizes is not None:
            concatenated_batch["image_sizes"] = final_image_sizes
        
        if token_type_ids_list:
            # [TT1, TT2, TT1, TT2]
            all_tt = token_type_ids_list + token_type_ids_list
            concatenated_batch["token_type_ids"] = self.pad_and_cat(all_tt, padding_side="left", padding_value=pad_token_id, device=device)


        # =================================================================
        # 5. 前向传播 (调用新写的 forward)
        # =================================================================
        # 注意：这里不需要再传 model_kwargs，因为 forward_with_pre_concatenated_batch 会自己处理
        policy_output = self.forward_with_pre_concatenated_batch(model, concatenated_batch)

        with torch.no_grad():
            if self.ref_model is None:
                with self.null_ref_context():
                    ref_output = self.forward_with_pre_concatenated_batch(model, concatenated_batch)
            else:
                ref_output = self.forward_with_pre_concatenated_batch(self.ref_model, concatenated_batch)

        # =================================================================
        # 6. Loss 计算
        # =================================================================
        all_losses, all_chosen_rewards, all_rejected_rewards = self.dpo_loss(
            policy_output["chosen_logps"], policy_output["rejected_logps"],
            ref_output["chosen_logps"], ref_output["rejected_logps"]
        )


        self.log_metrics(
            metrics, train_eval, "single", all_losses, all_chosen_rewards, all_rejected_rewards
        )

        total_loss = all_losses.mean()

        # SFT Loss
        if self.args.add_sft_loss:
            # 如果有 label mask，SFT Loss 的分母应该只计算 mask 为 1 的部分
            # 我们从 concatenated_batch 中取前一半 (Chosen) 的 Label Mask 或 Attention Mask
            
            if "completion_label_mask" in concatenated_batch:
                # 使用 label mask (Intersection of Attn Mask and Label Mask)
                # 注意：pad_and_cat 后的 mask 维度是 (4B, Seq)，我们需要前 2B
                chosen_len = concatenated_batch["completion_label_mask"].shape[0] // 2
                valid_mask = concatenated_batch["completion_label_mask"][:chosen_len] * concatenated_batch["completion_attention_mask"][:chosen_len]
            else:
                chosen_len = concatenated_batch["completion_attention_mask"].shape[0] // 2
                valid_mask = concatenated_batch["completion_attention_mask"][:chosen_len]

            sft_losses = self.sft_loss(
                policy_output["chosen_logps"], valid_mask
            )
            sft_loss = sft_losses.mean()
            metrics[f"{train_eval}_single_sft_loss"] = sft_loss.item()
            total_loss += self.args.sft_loss_weight * sft_loss


        if self.args.add_hinge_sft_loss:
            # 同样使用 valid_mask
            if "completion_label_mask" in concatenated_batch:
                chosen_len = concatenated_batch["completion_label_mask"].shape[0] // 2
                valid_mask = concatenated_batch["completion_label_mask"][:chosen_len] * concatenated_batch["completion_attention_mask"][:chosen_len]
            else:
                chosen_len = concatenated_batch["completion_attention_mask"].shape[0] // 2
                valid_mask = concatenated_batch["completion_attention_mask"][:chosen_len]

            sft_losses = self.hinge_sft_loss(
                policy_output["chosen_logps"], ref_output["chosen_logps"], valid_mask
            )
            sft_loss = sft_losses.mean()
            metrics[f"{train_eval}_single_hinge_loss"] = sft_loss.item()
            total_loss += self.args.hinge_sft_loss_weight * sft_loss

        if self.args.add_anchor_loss:
            anchor_losses = self.anchor_loss(
                policy_output["chosen_logps"], ref_output["chosen_logps"], self.args.anchor_delta
            )
            anchor_loss = anchor_losses.mean()
            metrics[f"{train_eval}_single_anchor_loss"] = anchor_loss.item()
            total_loss += self.args.anchor_weight * anchor_loss

        margin = None
        if self.args.add_margin_consistency_loss:
            mc_losses, margin = self.margin_consistency_loss(
                policy_output["chosen_logps"], 
                policy_output["rejected_logps"],
                ref_output["chosen_logps"], 
                ref_output["rejected_logps"]
            )
            mc_loss = mc_losses.mean()
            metrics[f"{train_eval}_single_margin_loss"] = mc_loss.item()
            total_loss += self.args.margin_consistency_weight * mc_loss

        logps_pack = {
            "chosen": policy_output["chosen_logps"],
            "rejected": policy_output["rejected_logps"],
            "ref_chosen": ref_output["chosen_logps"],
            "ref_rejected": ref_output["rejected_logps"]
        }


        return total_loss, logps_pack



    def run_multi_image_dpo(
        self, 
        model, 
        batch, 
        metrics, 
        train_eval
    ) -> torch.Tensor | None:
        """
        Executes Multi 1 and Multi 2 tasks in PARALLEL.
        FIXED: 
        1. Correctly flattens 6D pixel_values (OneVision).
        2. Manually concatenates batch to prevent TRL from dropping image_sizes.
        3. Passes correct masks to SFT loss.
        4. [NEW] handles 'token_mask' for phrase-level alignment.
        """
        if "multi_image_1_prompt_input_ids" not in batch:
            return None

        assert "multi_image_2_prompt_input_ids" in batch, "Multi Image Prompt 2 missing"
        
        pad_token_id = self.processing_class.tokenizer.pad_token_id or 0

        # =================================================================
        # 1. 准备基础数据 (List)
        # =================================================================
        prompts = [batch["multi_image_1_prompt_input_ids"], batch["multi_image_2_prompt_input_ids"]]
        prompt_masks = [batch["multi_image_1_prompt_attention_mask"], batch["multi_image_2_prompt_attention_mask"]]
        
        chosens = [batch["resp_1_input_ids"], batch["resp_2_input_ids"]]
        chosen_masks = [batch["resp_1_attention_mask"], batch["resp_2_attention_mask"]]
        
        rejecteds = [batch["resp_2_input_ids"], batch["resp_1_input_ids"]]
        rejected_masks = [batch["resp_2_attention_mask"], batch["resp_1_attention_mask"]]

        # [NEW] 准备 Token Masks (用于 Phrase-level DPO)
        # 对应关系与 chosens/rejecteds 保持严格一致
        chosen_token_masks = None
        rejected_token_masks = None
        
        if "resp_1_token_mask" in batch and "resp_2_token_mask" in batch:
            # Task 1 Chosen = Resp 1, Task 2 Chosen = Resp 2
            chosen_token_masks = [batch["resp_1_token_mask"], batch["resp_2_token_mask"]]
            # Task 1 Rejected = Resp 2, Task 2 Rejected = Resp 1
            rejected_token_masks = [batch["resp_2_token_mask"], batch["resp_1_token_mask"]]

        # =================================================================
        # 2. 准备 & 修正视觉特征 (Pixel Values)
        # =================================================================
        # Task1 和 Task2 共享图片，复制一份
        pixel_values_list = [batch["multi_image_pixel_values"], batch["multi_image_pixel_values"]]
        raw_pv = torch.cat(pixel_values_list, dim=0) # Shape: (2B, Num_Images, ...)
        
        # 同时支持 6D (OneVision) 和 5D (Standard) 的 Flatten
        if raw_pv.dim() == 6:
            # (2B, N, Patches, C, H, W) -> (2B*N, Patches, C, H, W)
            final_pv = raw_pv.flatten(0, 2)
        elif raw_pv.dim() == 5:
            # (2B, N, C, H, W) -> (2B*N, C, H, W)
            final_pv = raw_pv.flatten(0, 1)
        else:
            final_pv = raw_pv


        # =================================================================
        # 3. 准备视觉元数据 (Sizes, Grids)
        # =================================================================
        # Image Sizes
        final_sizes = None
        if "multi_image_image_sizes" in batch:
            sizes_list = [batch["multi_image_image_sizes"], batch["multi_image_image_sizes"]]
            raw_sizes = torch.cat(sizes_list, dim=0) # (2B, N, 2)
            
            # 如果 pv 被展平了，sizes 也要展平
            if raw_pv.dim() >= 5 and raw_sizes.dim() == 3:
                 final_sizes = raw_sizes.flatten(0, 1)
            else:
                 final_sizes = raw_sizes

        # Image Grids (Qwen)
        # final_grids = None
        # if "multi_image_grid_thw" in batch:
        #     final_grids = torch.cat([batch["multi_image_grid_thw"], batch["multi_image_grid_thw"]], dim=0)
        final_grids = None
        if "multi_image_grid_thw" in batch:
            # Grids 也是对应 Image 级别
            raw_grids = torch.cat([batch["multi_image_grid_thw"], batch["multi_image_grid_thw"]], dim=0)
            if raw_grids.dim() == 3:
                # [Fix] 之前缺少 Flatten，导致是 (2B, N, 3)，模型期望 (Total_Images, 3)
                final_grids = raw_grids.flatten(0, 1) 
            else:
                final_grids = raw_grids

        # Pixel Masks (Optional)
        # final_pixel_mask = None
        # if "multi_image_pixel_attention_mask" in batch:
        #      mask_list = [batch["multi_image_pixel_attention_mask"], batch["multi_image_pixel_attention_mask"]]
        #      raw_pm = torch.cat(mask_list, dim=0)
        #      if raw_pv.dim() >= 5 and raw_pm.dim() > 1:
        #          final_pixel_mask = raw_pm.flatten(0, 1)
        #      else:
        #          final_pixel_mask = raw_pm
    
        final_pixel_mask = None
        if "multi_image_pixel_attention_mask" in batch:
             mask_list = [batch["multi_image_pixel_attention_mask"], batch["multi_image_pixel_attention_mask"]]
             raw_pm = torch.cat(mask_list, dim=0)
             
             # [Fix] Mask 的展平逻辑必须严格跟随 Pixel Values
             if raw_pv.dim() == 6:
                 # PV 是 6D (B, N, T, C, H, W) -> flatten(0, 2)
                 # Mask 通常是 5D (B, N, T, H, W) -> 必须也是 flatten(0, 2)
                 if raw_pm.dim() == 5:
                     final_pixel_mask = raw_pm.flatten(0, 2)
                 else:
                     # 防御性：如果 Mask 维度定义不一致，尝试回退
                     final_pixel_mask = raw_pm.flatten(0, 1)
             elif raw_pv.dim() == 5 and raw_pm.dim() > 1:
                 # PV 是 5D -> flatten(0, 1)
                 final_pixel_mask = raw_pm.flatten(0, 1)
             else:
                 final_pixel_mask = raw_pm

        # =================================================================
        # 4. 手动构造 Concatenated Batch
        # =================================================================
        
        # A. 拼接文本: [Chosen_Task1, Chosen_Task2, Rejected_Task1, Rejected_Task2]
        all_prompts = prompts + prompts 
        all_prompt_masks = prompt_masks + prompt_masks
        all_resps = chosens + rejecteds
        all_resp_masks = chosen_masks + rejected_masks
        
        prompt_ids = self.pad_and_cat(all_prompts, padding_side="left", padding_value=pad_token_id, device=model.device)
        prompt_mask = self.pad_and_cat(all_prompt_masks, padding_side="left", padding_value=0, device=model.device)
        resp_ids = self.pad_and_cat(all_resps, padding_side="right", padding_value=pad_token_id, device=model.device)
        resp_mask = self.pad_and_cat(all_resp_masks, padding_side="right", padding_value=0, device=model.device)
        
        # B. 拼接视觉: [Chosen_PV, Rejected_PV]
        # final_pv 已经是 2B (Task1+Task2)，再拼一次变成 4B
        concatenated_pv = torch.cat([final_pv, final_pv], dim=0)
        
        concatenated_sizes = None
        if final_sizes is not None:
            concatenated_sizes = torch.cat([final_sizes, final_sizes], dim=0)
            
        concatenated_grids = None
        if final_grids is not None:
            concatenated_grids = torch.cat([final_grids, final_grids], dim=0)

        concatenated_pm = None
        if final_pixel_mask is not None:
            concatenated_pm = torch.cat([final_pixel_mask, final_pixel_mask], dim=0)

        # C. [NEW] 拼接 Token Masks (如果存在)
        concatenated_token_mask = None
        if chosen_token_masks is not None:
            all_token_masks = chosen_token_masks + rejected_token_masks
            # Padding side MUST be right for response
            concatenated_token_mask = self.pad_and_cat(
                all_token_masks, 
                padding_side="right", 
                padding_value=0, 
                device=model.device
            )

        # D. 组装 Batch
        concatenated_batch = {
            "prompt_input_ids": prompt_ids,
            "prompt_attention_mask": prompt_mask,
            "completion_input_ids": resp_ids,
            "completion_attention_mask": resp_mask,
            "pixel_values": concatenated_pv,
        }
        
        # [NEW] 放入 token_mask，供 forward 函数计算 logps 时使用 (masked_sum)
        if concatenated_token_mask is not None:
            concatenated_batch["token_mask"] = concatenated_token_mask
        
        # 显式填入 key
        if concatenated_sizes is not None:
            concatenated_batch["image_sizes"] = concatenated_sizes
        if concatenated_grids is not None:
            concatenated_batch["image_grid_thw"] = concatenated_grids
        if concatenated_pm is not None:
            concatenated_batch["pixel_attention_mask"] = concatenated_pm
            
        # Token Types
        if "multi_image_token_type_ids" in batch:
            tt = batch["multi_image_token_type_ids"]
            tt2 = batch["multi_image_2_token_type_ids"]
            all_tt = [tt, tt2, tt, tt2]
            concatenated_batch["token_type_ids"] = self.pad_and_cat(all_tt, padding_side="left", padding_value=pad_token_id, device=model.device)

        policy_output = self.forward_with_pre_concatenated_batch(model, concatenated_batch)
        with torch.no_grad():
            if self.ref_model is None:
                with self.null_ref_context():
                    ref_output = self.forward_with_pre_concatenated_batch(model, concatenated_batch)
            else:
                ref_output = self.forward_with_pre_concatenated_batch(self.ref_model, concatenated_batch)

        # =================================================================
        # 6. Loss 计算 & 日志
        # =================================================================
        all_losses, all_c_rewards, all_r_rewards = self.dpo_loss(
            policy_output["chosen_logps"], policy_output["rejected_logps"],
            ref_output["chosen_logps"], ref_output["rejected_logps"]
        )


        self.log_metrics(metrics, train_eval, "multi", all_losses, all_c_rewards, all_r_rewards)
        
        total_loss = all_losses.mean()

        # [Modified] SFT Loss 使用 Token Mask (如果存在)
        if self.args.add_sft_loss:
            total_batch_len = concatenated_batch["completion_attention_mask"].shape[0]
            split_index = total_batch_len // 2
            
            # 优先使用 token_mask (Phrase Level)，否则使用 attention_mask (Sentence Level)
            if "token_mask" in concatenated_batch:
                real_chosen_mask = concatenated_batch["token_mask"][:split_index]
            else:
                real_chosen_mask = concatenated_batch["completion_attention_mask"][:split_index]
            
            sft_losses = self.sft_loss(
                policy_output["chosen_logps"], real_chosen_mask
            )
            sft_loss = sft_losses.mean()
            metrics[f"{train_eval}_multi_sft_loss"] = sft_loss.item()
            total_loss += self.args.sft_loss_weight * sft_loss

        # [Modified] Hinge SFT Loss 同样适配 Token Mask
        if self.args.add_hinge_sft_loss:
            total_batch_len = concatenated_batch["completion_attention_mask"].shape[0]
            split_index = total_batch_len // 2
            
            if "token_mask" in concatenated_batch:
                real_chosen_mask = concatenated_batch["token_mask"][:split_index]
            else:
                real_chosen_mask = concatenated_batch["completion_attention_mask"][:split_index]
            
            sft_losses = self.hinge_sft_loss(
                policy_output["chosen_logps"], ref_output["chosen_logps"], real_chosen_mask
            )
            sft_loss = sft_losses.mean()
            metrics[f"{train_eval}_multi_hinge_loss"] = sft_loss.item()
            total_loss += self.args.hinge_sft_loss_weight * sft_loss
        
        if self.args.add_anchor_loss:
            anchor_losses = self.anchor_loss(
                policy_output["chosen_logps"], ref_output["chosen_logps"], self.args.anchor_delta
            )
            anchor_loss = anchor_losses.mean()
            metrics[f"{train_eval}_multi_anchor_loss"] = anchor_loss.item()
            total_loss += self.args.anchor_weight * anchor_loss

        margin = None
        if self.args.add_margin_consistency_loss:
            mc_losses, margin = self.margin_consistency_loss(
                policy_output["chosen_logps"], 
                policy_output["rejected_logps"],
                ref_output["chosen_logps"], 
                ref_output["rejected_logps"]
            )
            mc_loss = mc_losses.mean()
            metrics[f"{train_eval}_multi_margin_loss"] = mc_loss.item()
            total_loss += self.args.margin_consistency_weight * mc_loss

        logps_pack = {
            "chosen": policy_output["chosen_logps"],
            "rejected": policy_output["rejected_logps"],
            "ref_chosen": ref_output["chosen_logps"],
            "ref_rejected": ref_output["rejected_logps"]
        }

        return total_loss, logps_pack




    # def run_multi_image_dpo(
    #     self, 
    #     model, 
    #     batch, 
    #     metrics, 
    #     train_eval
    # ) -> torch.Tensor | None:
    #     """
    #     Executes Multi 1 and Multi 2 tasks in PARALLEL.
    #     FIXED: 
    #     1. Correctly flattens 6D pixel_values (OneVision).
    #     2. Unpads pixel_values using masks to fix 'split_with_sizes' errors.
    #     3. Passes correct masks to SFT loss.
    #     """
    #     if "multi_image_1_prompt_input_ids" not in batch:
    #         return None

    #     assert "multi_image_2_prompt_input_ids" in batch, "Multi Image Prompt 2 missing"
        
    #     pad_token_id = self.processing_class.tokenizer.pad_token_id or 0

    #     # =================================================================
    #     # 1. 准备基础数据 (List)
    #     # =================================================================
    #     prompts = [batch["multi_image_1_prompt_input_ids"], batch["multi_image_2_prompt_input_ids"]]
    #     prompt_masks = [batch["multi_image_1_prompt_attention_mask"], batch["multi_image_2_prompt_attention_mask"]]
        
    #     chosens = [batch["resp_1_input_ids"], batch["resp_2_input_ids"]]
    #     chosen_masks = [batch["resp_1_attention_mask"], batch["resp_2_attention_mask"]]
        
    #     rejecteds = [batch["resp_2_input_ids"], batch["resp_1_input_ids"]]
    #     rejected_masks = [batch["resp_2_attention_mask"], batch["resp_1_attention_mask"]]

    #     # Token Masks
    #     chosen_token_masks = None
    #     rejected_token_masks = None
    #     if "resp_1_token_mask" in batch and "resp_2_token_mask" in batch:
    #         chosen_token_masks = [batch["resp_1_token_mask"], batch["resp_2_token_mask"]]
    #         rejected_token_masks = [batch["resp_2_token_mask"], batch["resp_1_token_mask"]]

    #     # =================================================================
    #     # [DEBUG ZONE 1] 检查 Mask 是否存在
    #     # =================================================================
    #     is_main_process = self.args.process_index == 0
    #     if "multi_image_pixel_attention_mask" not in batch:
    #         if is_main_process:
    #             print(f"\n[DEBUG] ❌ CRITICAL: 'multi_image_pixel_attention_mask' NOT FOUND in batch keys: {list(batch.keys())}")
    #         raise ValueError("multi_image_pixel_attention_mask missing!")

    #     # =================================================================
    #     # 2. 准备 & 修正视觉特征
    #     # =================================================================
    #     pixel_values_list = [batch["multi_image_pixel_values"], batch["multi_image_pixel_values"]]
    #     raw_pv = torch.cat(pixel_values_list, dim=0) 
        
    #     # 展平
    #     if raw_pv.dim() == 6: final_pv = raw_pv.flatten(0, 2)
    #     elif raw_pv.dim() == 5: final_pv = raw_pv.flatten(0, 1)
    #     else: final_pv = raw_pv

    #     # =================================================================
    #     # 3. 准备 Mask & Unpadding (带 Debug)
    #     # =================================================================
    #     mask_list = [batch["multi_image_pixel_attention_mask"], batch["multi_image_pixel_attention_mask"]]
    #     raw_pm = torch.cat(mask_list, dim=0)
        
    #     # Mask 展平
    #     if raw_pv.dim() == 6: 
    #          final_pixel_mask = raw_pm.flatten(0, 2) if raw_pm.dim() == 5 else raw_pm.flatten(0, 1)
    #     elif raw_pv.dim() == 5 and raw_pm.dim() > 1:
    #          final_pixel_mask = raw_pm.flatten(0, 1)
    #     else:
    #          final_pixel_mask = raw_pm

    #     # 构造 Concatenated Batch 的视觉部分 (Chosen + Rejected)
    #     # 注意：这里我们模拟后面 cat([final_pv, final_pv]) 的行为，因为 Unpad 必须在最终 cat 之后做或者对齐
    #     concatenated_pv = torch.cat([final_pv, final_pv], dim=0)
    #     concatenated_pm = torch.cat([final_pixel_mask, final_pixel_mask], dim=0)

    #     # --- [DEBUG ZONE 2] 打印 Unpadding 前的状态 ---
    #     if is_main_process:
    #         print(f"\n[DEBUG] === Vision Unpadding Check ===")
    #         print(f"[DEBUG] Raw PV Shape (Pre-Unpad): {concatenated_pv.shape}")
    #         print(f"[DEBUG] Mask Shape: {concatenated_pm.shape}")
            
    #         # 计算有效 Tile 数量
    #         valid_count = (concatenated_pm.sum(dim=(1, 2)) > 0).sum().item()
    #         total_count = concatenated_pm.shape[0]
    #         print(f"[DEBUG] Valid Tiles in Mask: {valid_count} / {total_count} (Padding: {total_count - valid_count})")

    #     # 执行 Unpadding
    #     valid_tile_indices = concatenated_pm.sum(dim=(1, 2)) > 0
        
    #     if concatenated_pv.shape[0] != valid_tile_indices.shape[0]:
    #         raise RuntimeError(f"Shape Mismatch! PV: {concatenated_pv.shape}, Mask: {valid_tile_indices.shape}")
            
    #     concatenated_pv = concatenated_pv[valid_tile_indices]

    #     # --- [DEBUG ZONE 3] 打印 Unpadding 后的状态 ---
    #     if is_main_process:
    #         print(f"[DEBUG] Final PV Shape (Post-Unpad): {concatenated_pv.shape}")
    #         print(f"[DEBUG] ⚠️ This number ({concatenated_pv.shape[0]}) represents the ACTUAL Input tensor size.")
    #         print(f"[DEBUG] ===============================\n")

    #     # =================================================================
    #     # 4. 其他元数据处理
    #     # =================================================================

    #     # C. [NEW] 拼接 Token Masks (如果存在)
    #     concatenated_token_mask = None
    #     if chosen_token_masks is not None:
    #         all_token_masks = chosen_token_masks + rejected_token_masks
    #         # Padding side MUST be right for response
    #         concatenated_token_mask = self.pad_and_cat(
    #             all_token_masks, 
    #             padding_side="right", 
    #             padding_value=0, 
    #             device=model.device
    #         )
    
    #     concatenated_sizes = None
    #     if "multi_image_image_sizes" in batch:
    #         sizes_list = [batch["multi_image_image_sizes"], batch["multi_image_image_sizes"]]
    #         raw_sizes = torch.cat(sizes_list, dim=0) 
    #         if raw_pv.dim() >= 5 and raw_sizes.dim() == 3:
    #              final_sizes = raw_sizes.flatten(0, 1)
    #         else:
    #              final_sizes = raw_sizes
    #         concatenated_sizes = torch.cat([final_sizes, final_sizes], dim=0)

    #     concatenated_grids = None
    #     if "multi_image_grid_thw" in batch:
    #         raw_grids = torch.cat([batch["multi_image_grid_thw"], batch["multi_image_grid_thw"]], dim=0)
    #         if raw_grids.dim() == 3:
    #             final_grids = raw_grids.flatten(0, 1)
    #         else:
    #             final_grids = raw_grids
    #         concatenated_grids = torch.cat([final_grids, final_grids], dim=0)

    #     # =================================================================
    #     # 5. 构造 Concatenated Batch & 检查 Prompt Token
    #     # =================================================================
    #     prompt_ids = self.pad_and_cat(prompts + prompts, padding_side="left", padding_value=pad_token_id, device=model.device)
    #     prompt_mask = self.pad_and_cat(prompt_masks + prompt_masks, padding_side="left", padding_value=0, device=model.device)
    #     resp_ids = self.pad_and_cat(chosens + rejecteds, padding_side="right", padding_value=pad_token_id, device=model.device)
    #     resp_mask = self.pad_and_cat(chosen_masks + rejected_masks, padding_side="right", padding_value=0, device=model.device)
        
    #     # --- [DEBUG ZONE 4] 检查 Prompt 中的 <image> Token 数量 ---
    #     if is_main_process:
    #         # 获取第一条 Prompt
    #         sample_ids = prompt_ids[0]
    #         # 尝试解码 (如果没有 tokenizer，只能尝试打印 ID)
    #         if hasattr(self.processing_class, "tokenizer"):
    #             decoded_text = self.processing_class.tokenizer.decode(sample_ids, skip_special_tokens=False)
    #             # 统计 <image> 出现的次数 (OneVision 使用 <image> 作为占位符)
    #             image_token_count = decoded_text.count("<image>")
                
    #             print(f"[DEBUG] === Prompt Token Check ===")
    #             print(f"[DEBUG] Decoded Prompt[0] snippet: {decoded_text[-100:]}") # 打印最后100字符
    #             print(f"[DEBUG] Count of '<image>' in ONE prompt: {image_token_count}")
    #             print(f"[DEBUG] Total prompts in batch: {prompt_ids.shape[0]}")
    #             print(f"[DEBUG] EXPECTED Total Images = {image_token_count} * {prompt_ids.shape[0]} = {image_token_count * prompt_ids.shape[0]}")
                
    #             # 这里的 EXPECTED Total Images 对应的 Tile 总数，应该等于上面的 Final PV Shape
    #             # 如果 Image Count 是 1，但你有 2 张图的数据，这里就会对不上！
    #             print(f"[DEBUG] ==========================\n")

    #     # E. 组装 Batch
    #     concatenated_batch = {
    #         "prompt_input_ids": prompt_ids,
    #         "prompt_attention_mask": prompt_mask,
    #         "completion_input_ids": resp_ids,
    #         "completion_attention_mask": resp_mask,
    #         "pixel_values": concatenated_pv, # 已经是 Unpadded 的干净数据
    #     }
        
    #     if concatenated_token_mask is not None:
    #         concatenated_batch["token_mask"] = concatenated_token_mask
        
    #     if concatenated_sizes is not None:
    #         concatenated_batch["image_sizes"] = concatenated_sizes
    #     if concatenated_grids is not None:
    #         concatenated_batch["image_grid_thw"] = concatenated_grids
            
    #     # Token Types
    #     if "multi_image_token_type_ids" in batch:
    #         tt = batch["multi_image_token_type_ids"]
    #         tt2 = batch["multi_image_2_token_type_ids"]
    #         all_tt = [tt, tt2, tt, tt2]
    #         concatenated_batch["token_type_ids"] = self.pad_and_cat(all_tt, padding_side="left", padding_value=pad_token_id, device=model.device)

    #     # =================================================================
    #     # 5. Forward & Loss
    #     # =================================================================
    #     policy_output = self.forward_with_pre_concatenated_batch(model, concatenated_batch)
        
    #     with torch.no_grad():
    #         if self.ref_model is None:
    #             with self.null_ref_context():
    #                 ref_output = self.forward_with_pre_concatenated_batch(model, concatenated_batch)
    #         else:
    #             ref_output = self.forward_with_pre_concatenated_batch(self.ref_model, concatenated_batch)

    #     # =================================================================
    #     # 6. Loss 计算 & 日志
    #     # =================================================================
    #     all_losses, all_c_rewards, all_r_rewards = self.dpo_loss(
    #         policy_output["chosen_logps"], policy_output["rejected_logps"],
    #         ref_output["chosen_logps"], ref_output["rejected_logps"]
    #     )

    #     self.log_metrics(metrics, train_eval, "multi", all_losses, all_c_rewards, all_r_rewards)
        
    #     total_loss = all_losses.mean()

    #     # [Modified] SFT Loss
    #     if self.args.add_sft_loss:
    #         total_batch_len = concatenated_batch["completion_attention_mask"].shape[0]
    #         split_index = total_batch_len // 2
            
    #         if "token_mask" in concatenated_batch:
    #             real_chosen_mask = concatenated_batch["token_mask"][:split_index]
    #         else:
    #             real_chosen_mask = concatenated_batch["completion_attention_mask"][:split_index]
            
    #         sft_losses = self.sft_loss(
    #             policy_output["chosen_logps"], real_chosen_mask
    #         )
    #         sft_loss = sft_losses.mean()
    #         metrics[f"{train_eval}_multi_sft_loss"] = sft_loss.item()
    #         total_loss += self.args.sft_loss_weight * sft_loss

    #     # [Modified] Hinge SFT Loss
    #     if self.args.add_hinge_sft_loss:
    #         total_batch_len = concatenated_batch["completion_attention_mask"].shape[0]
    #         split_index = total_batch_len // 2
            
    #         if "token_mask" in concatenated_batch:
    #             real_chosen_mask = concatenated_batch["token_mask"][:split_index]
    #         else:
    #             real_chosen_mask = concatenated_batch["completion_attention_mask"][:split_index]
            
    #         sft_losses = self.hinge_sft_loss(
    #             policy_output["chosen_logps"], ref_output["chosen_logps"], real_chosen_mask
    #         )
    #         sft_loss = sft_losses.mean()
    #         metrics[f"{train_eval}_multi_hinge_loss"] = sft_loss.item()
    #         total_loss += self.args.hinge_sft_loss_weight * sft_loss
        
    #     if self.args.add_anchor_loss:
    #         anchor_losses = self.anchor_loss(
    #             policy_output["chosen_logps"], ref_output["chosen_logps"], self.args.anchor_delta
    #         )
    #         anchor_loss = anchor_losses.mean()
    #         metrics[f"{train_eval}_multi_anchor_loss"] = anchor_loss.item()
    #         total_loss += self.args.anchor_weight * anchor_loss

    #     if self.args.add_margin_consistency_loss:
    #         mc_losses, margin = self.margin_consistency_loss(
    #             policy_output["chosen_logps"], 
    #             policy_output["rejected_logps"],
    #             ref_output["chosen_logps"], 
    #             ref_output["rejected_logps"]
    #         )
    #         mc_loss = mc_losses.mean()
    #         metrics[f"{train_eval}_multi_margin_loss"] = mc_loss.item()
    #         total_loss += self.args.margin_consistency_weight * mc_loss

    #     logps_pack = {
    #         "chosen": policy_output["chosen_logps"],
    #         "rejected": policy_output["rejected_logps"],
    #         "ref_chosen": ref_output["chosen_logps"],
    #         "ref_rejected": ref_output["rejected_logps"]
    #     }

    #     return total_loss, logps_pack



    def run_single_image_mdpo(
        self, 
        model, 
        batch, 
        metrics, 
        train_eval
    ) -> torch.Tensor | None:
        """
        Executes mDPO (Multimodal DPO) OR S-VCO (Stepwise VCO) with Phrase-level alignment support.
        
        Mode 1: mDPO (Standard)
            (Img1, P1, R1) > (Img2, P2, R1)
            
        Mode 2: S-VCO (Stepwise, with No-Image Anchor)
            (Img1, P1, R1) > (NoImg, P_No, R1) > (Img2, P2, R1)
            Total 4 Pairs per sample group.
        """

        # --- 1. 基础数据准备 ---
        pad_token_id = self.processing_class.tokenizer.pad_token_id or 0
        
        p1 = batch["image_1_prompt_input_ids"]
        p2 = batch["image_2_prompt_input_ids"]
        p1_mask = batch["image_1_prompt_attention_mask"]
        p2_mask = batch["image_2_prompt_attention_mask"]

        r1 = batch["resp_1_input_ids"]
        r2 = batch["resp_2_input_ids"]
        r1_mask = batch["resp_1_attention_mask"]
        r2_mask = batch["resp_2_attention_mask"]
        
        # [NEW] Token Masks (Phrase Level)
        tm1 = batch.get("resp_1_token_mask")
        tm2 = batch.get("resp_2_token_mask")

        img1 = batch["image_1_pixel_values"]
        img2 = batch["image_2_pixel_values"]
        
        # 视觉元数据 (用于 LLaVA-Next/Qwen)
        img1_sizes = batch.get("image_1_image_sizes")
        img2_sizes = batch.get("image_2_image_sizes")
        img1_grid = batch.get("image_1_image_grid_thw")
        img2_grid = batch.get("image_2_image_grid_thw")

        # --- 2. 判断模式: S-VCO vs mDPO ---
        use_s_vco = "no_image_prompt_input_ids" in batch
        
        if use_s_vco:
            # ================= S-VCO 逻辑 (4 Pairs) =================
            # Chain 1 (Focus R1): (I1, P1) > (No, P_No) > (I2, P2)
            #   Pair A: Chosen=(I1, P1, R1), Rejected=(No, P_No, R1)
            #   Pair B: Chosen=(No, P_No, R1), Rejected=(I2, P2, R1)
            
            # Chain 2 (Focus R2): (I2, P2) > (No, P_No) > (I1, P1)
            #   Pair C: Chosen=(I2, P2, R2), Rejected=(No, P_No, R2)
            #   Pair D: Chosen=(No, P_No, R2), Rejected=(I1, P1, R2)
            
            p_no = batch["no_image_prompt_input_ids"]
            p_no_mask = batch["no_image_prompt_attention_mask"] 

            # A. Prompts: [Ch_A, Ch_B, Ch_C, Ch_D] + [Rej_A, Rej_B, Rej_C, Rej_D]
            chosen_prompts = [p1, p_no, p2, p_no]
            chosen_p_masks = [p1_mask, p_no_mask, p2_mask, p_no_mask]
            
            rejected_prompts = [p_no, p2, p_no, p1]
            rejected_p_masks = [p_no_mask, p2_mask, p_no_mask, p1_mask]
            
            prompts_list = chosen_prompts + rejected_prompts
            prompt_masks_list = chosen_p_masks + rejected_p_masks

            # B. Responses & Masks (Token Mask 跟随 Response)
            # Chosen: [R1, R1, R2, R2]
            chosen_resps = [r1, r1, r2, r2]
            chosen_r_masks = [r1_mask, r1_mask, r2_mask, r2_mask]
            
            # Rejected: [R1, R1, R2, R2]
            rejected_resps = [r1, r1, r2, r2]
            rejected_r_masks = [r1_mask, r1_mask, r2_mask, r2_mask]
            
            resps_list = chosen_resps + rejected_resps
            resp_masks_list = chosen_r_masks + rejected_r_masks
            
            # [NEW] Token Masks List for S-VCO
            token_masks_list = None
            if tm1 is not None and tm2 is not None:
                chosen_tms = [tm1, tm1, tm2, tm2]
                rejected_tms = [tm1, tm1, tm2, tm2]
                token_masks_list = chosen_tms + rejected_tms

            # C. Images (只包含 Prompt 中有 <image> 的样本)
            # Order Analysis derived from prompts list:
            # 0:P1(I1), 1:No, 2:P2(I2), 3:No, 4:No, 5:P2(I2), 6:No, 7:P1(I1)
            images_to_cat = [img1, img2, img2, img1]
            
            sizes_to_cat = []
            if img1_sizes is not None:
                sizes_to_cat = [img1_sizes, img2_sizes, img2_sizes, img1_sizes]
                
            grids_to_cat = []
            if img1_grid is not None:
                grids_to_cat = [img1_grid, img2_grid, img2_grid, img1_grid]

            log_prefix = "s_vco"

        else:
            # ================= mDPO 逻辑 (2 Pairs) =================
            # Pair 1: Chosen=(I1, P1, R1), Rejected=(I2, P2, R1)
            # Pair 2: Chosen=(I2, P2, R2), Rejected=(I1, P1, R2)
            
            # Prompts: [Ch1, Ch2] + [Rej1, Rej2]
            # Chosen: [P1, P2]
            # Rejected: [P2, P1]
            prompts_list = [p1, p2, p2, p1]
            prompt_masks_list = [p1_mask, p2_mask, p2_mask, p1_mask]
            
            # Responses: [R1, R2, R1, R2]
            resps_list = [r1, r2, r1, r2]
            resp_masks_list = [r1_mask, r2_mask, r1_mask, r2_mask]
            
            # [NEW] Token Masks List for mDPO
            # Token Mask 跟随 Response 走: R1对应TM1, R2对应TM2
            token_masks_list = None
            if tm1 is not None and tm2 is not None:
                # Chosen: [TM1, TM2], Rejected: [TM1, TM2]
                token_masks_list = [tm1, tm2, tm1, tm2]
            
            # Images: [I1, I2, I2, I1] (全部对应)
            images_to_cat = [img1, img2, img2, img1]
            
            sizes_to_cat = []
            if img1_sizes is not None:
                sizes_to_cat = [img1_sizes, img2_sizes, img2_sizes, img1_sizes]
            
            grids_to_cat = []
            if img1_grid is not None:
                grids_to_cat = [img1_grid, img2_grid, img2_grid, img1_grid]

            log_prefix = "mdpo"

        # --- 3. 构造 Concatenated Batch ---
        concatenated_batch = {
            "prompt_input_ids": self.pad_and_cat(prompts_list, padding_side="left", padding_value=pad_token_id, device=model.device),
            "prompt_attention_mask": self.pad_and_cat(prompt_masks_list, padding_side="left", padding_value=0, device=model.device),
            "completion_input_ids": self.pad_and_cat(resps_list, padding_side="right", padding_value=pad_token_id, device=model.device),
            "completion_attention_mask": self.pad_and_cat(resp_masks_list, padding_side="right", padding_value=0, device=model.device),
        }
        
        # [NEW] Add Token Mask
        if token_masks_list is not None:
            concatenated_batch["token_mask"] = self.pad_and_cat(
                token_masks_list, 
                padding_side="right", 
                padding_value=0, 
                device=model.device
            )

        # --- 4. 视觉特征拼接 (Flatten & 4D Check) ---
        # A. 拼接原始 Pixel Values (含 Padding)
        # images_to_cat 顺序是 [Img1, Img2, Img2, Img1] (对应 Chosen1, Chosen2, Rej1, Rej2)
        raw_pv = torch.cat(images_to_cat, dim=0) 
        
        # B. 展平 Pixel Values (适配 6D/5D)
        if raw_pv.dim() == 6:
            # (B, N, Tiles, C, H, W) -> (Total_Tiles, C, H, W)
            final_pv = raw_pv.flatten(0, 2)
        elif raw_pv.dim() == 5:
            # (B, Tiles, C, H, W) -> (Total_Tiles, C, H, W)
            final_pv = raw_pv.flatten(0, 1)
        else:
            final_pv = raw_pv

        # C. 准备 Mask 并执行 Unpadding (强制执行!)
        # LLaVA-OneVision 必须依赖 Mask 来过滤 Padding，否则模型切分会错位
        # if "image_1_pixel_attention_mask" not in batch:
        #      # 如果报错这里，请检查 process_row_single_image 是否正确注入了 mask，并清除 dataset 缓存
        #      raise ValueError("CRITICAL: 'image_1_pixel_attention_mask' missing in batch! Unpadding cannot proceed.")
        if "image_1_pixel_attention_mask" in batch:

            m1 = batch["image_1_pixel_attention_mask"]
            m2 = batch["image_2_pixel_attention_mask"]

            # 构造 Mask 列表，顺序必须严格对齐 images_to_cat
            # S-VCO/mDPO 逻辑中，Image 顺序固定为: [I1, I2, I2, I1]
            masks_to_cat = [m1, m2, m2, m1]
            
            raw_pm = torch.cat(masks_to_cat, dim=0)

            # Mask 展平逻辑 (需与 PV 保持一致)
            if raw_pm.dim() == 5: # 对应 6D PV
                final_pixel_mask = raw_pm.flatten(0, 2)
            elif raw_pm.dim() == 4: # 对应 5D PV
                final_pixel_mask = raw_pm.flatten(0, 1)
            else:
                final_pixel_mask = raw_pm

            # D. 执行过滤 (Unpadding)
            # 只要 Tile 的 Mask (H, W) 上有任意非 0 值，即视为有效 Tile
            valid_tile_indices = final_pixel_mask.sum(dim=(1, 2)) > 0
            
            # [安全检查] 确保 Mask 和 PV 维度对齐
            if final_pv.shape[0] != valid_tile_indices.shape[0]:
                raise RuntimeError(f"Dimension Mismatch! PV: {final_pv.shape[0]}, Mask: {valid_tile_indices.shape[0]}")

            # 只保留有效 Tile
            # [Visual Guide] [Pad, Valid, Valid, Pad] -> [Valid, Valid]
            final_pv = final_pv[valid_tile_indices]

        # E. 更新 Batch
        concatenated_batch["pixel_values"] = final_pv


        # 其他视觉元数据处理
        if sizes_to_cat:
            concatenated_batch["image_sizes"] = torch.cat(sizes_to_cat, dim=0)
            if concatenated_batch["image_sizes"].dim() == 3:
                 concatenated_batch["image_sizes"] = concatenated_batch["image_sizes"].flatten(0, 1)

        if grids_to_cat:
            concatenated_batch["image_grid_thw"] = torch.cat(grids_to_cat, dim=0)

        # Token Types (S-VCO 可能需要特殊处理，此处略，假设不需要)

        # --- 5. Forward & Loss ---
        policy_output = self.forward_with_pre_concatenated_batch(model, concatenated_batch)
        
        with torch.no_grad():
            if self.ref_model is None:
                with self.null_ref_context():
                    ref_output = self.forward_with_pre_concatenated_batch(model, concatenated_batch)
            else:
                ref_output = self.forward_with_pre_concatenated_batch(self.ref_model, concatenated_batch)

        all_losses, all_chosen_rewards, all_rejected_rewards = self.dpo_loss(
            policy_output["chosen_logps"], policy_output["rejected_logps"],
            ref_output["chosen_logps"], ref_output["rejected_logps"]
        )

        # --- 6. Log & Aux Losses ---
        self.log_metrics(metrics, train_eval, log_prefix, 
                         all_losses, all_chosen_rewards, all_rejected_rewards)
        
        total_loss = all_losses.mean()
        
        # [NEW] SFT Loss (适配 Token Mask)
        if self.args.add_sft_loss:
            total_batch_len = concatenated_batch["completion_attention_mask"].shape[0]
            split_index = total_batch_len // 2
            
            if "token_mask" in concatenated_batch:
                real_chosen_mask = concatenated_batch["token_mask"][:split_index]
            else:
                real_chosen_mask = concatenated_batch["completion_attention_mask"][:split_index]
            
            sft_losses = self.sft_loss(
                policy_output["chosen_logps"], real_chosen_mask
            )
            sft_loss = sft_losses.mean()
            metrics[f"{train_eval}_{log_prefix}_sft_loss"] = sft_loss.item()
            total_loss += self.args.sft_loss_weight * sft_loss

        # [NEW] Hinge Loss
        if self.args.add_hinge_sft_loss:
            total_batch_len = concatenated_batch["completion_attention_mask"].shape[0]
            split_index = total_batch_len // 2
            
            if "token_mask" in concatenated_batch:
                real_chosen_mask = concatenated_batch["token_mask"][:split_index]
            else:
                real_chosen_mask = concatenated_batch["completion_attention_mask"][:split_index]
            
            sft_losses = self.hinge_sft_loss(
                policy_output["chosen_logps"], ref_output["chosen_logps"], real_chosen_mask
            )
            sft_loss = sft_losses.mean()
            metrics[f"{train_eval}_{log_prefix}_hinge_loss"] = sft_loss.item()
            total_loss += self.args.hinge_sft_loss_weight * sft_loss
            
        # [NEW] Anchor Loss
        if self.args.add_anchor_loss:
            anchor_losses = self.anchor_loss(
                policy_output["chosen_logps"], ref_output["chosen_logps"], self.args.anchor_delta
            )
            anchor_loss = anchor_losses.mean()
            metrics[f"{train_eval}_{log_prefix}_anchor_loss"] = anchor_loss.item()
            total_loss += self.args.anchor_weight * anchor_loss

        # [NEW] Margin Consistency Loss
        if self.args.add_margin_consistency_loss:
            mc_losses = self.margin_consistency_loss(
                policy_output["chosen_logps"], 
                policy_output["rejected_logps"],
                ref_output["chosen_logps"], 
                ref_output["rejected_logps"]
            )
            mc_loss = mc_losses.mean()
            metrics[f"{train_eval}_{log_prefix}_margin_loss"] = mc_loss.item()
            total_loss += self.args.margin_consistency_weight * mc_loss

        return total_loss

    # def run_single_image_vdpo(
    #     self, 
    #     model, 
    #     batch, 
    #     metrics, 
    #     train_eval,
    #     gamma=0.75
    # ) -> torch.Tensor | None:
    #     """
    #     Executes V-DPO using a SINGLE Mixed Batch strategy.
    #     UPDATED: Supports Phrase-level Token Masks & 6D Pixel Values.
        
    #     Why Mixed Batch?
    #     - We mix Visual Rows (with images) and Blind Rows (without images) in ONE batch.
    #     - As long as the batch contains SOME images, the Vision Encoder runs and generates gradients.
    #     - This prevents DeepSpeed from crashing due to "unused parameters".
    #     """

    #     pad_token_id = self.processing_class.tokenizer.pad_token_id or 0
    #     vdpo_alpha = (1 - gamma) * self.args.beta

    #     # =================================================================
    #     # 1. 提取基础数据
    #     # =================================================================
    #     # Visual Prompts
    #     p1 = batch["image_1_prompt_input_ids"]
    #     p1_mask = batch["image_1_prompt_attention_mask"]
    #     p2 = batch["image_2_prompt_input_ids"]
    #     p2_mask = batch["image_2_prompt_attention_mask"]
        
    #     # Blind Prompts
    #     p_no = batch["no_image_prompt_input_ids"]
    #     p_no_mask = batch["no_image_prompt_attention_mask"]

    #     # Responses
    #     r1 = batch["resp_1_input_ids"]
    #     r1_mask = batch["resp_1_attention_mask"]
    #     r2 = batch["resp_2_input_ids"]
    #     r2_mask = batch["resp_2_attention_mask"]
        
    #     # [NEW] Token Masks (Phrase Level)
    #     tm1 = batch.get("resp_1_token_mask")
    #     tm2 = batch.get("resp_2_token_mask")

    #     # Images
    #     img1 = batch["image_1_pixel_values"]
    #     img2 = batch["image_2_pixel_values"]
        
    #     # Metadata
    #     img1_sizes = batch.get("image_1_image_sizes")
    #     img2_sizes = batch.get("image_2_image_sizes")
    #     img1_grid = batch.get("image_1_image_grid_thw")
    #     img2_grid = batch.get("image_2_image_grid_thw")
    #     img1_pm = batch.get("image_1_pixel_attention_mask")
    #     img2_pm = batch.get("image_2_pixel_attention_mask")

    #     # =================================================================
    #     # 2. 构造 Mixed Batch (8 Rows)
    #     # =================================================================
    #     # Structure Recap:
    #     # Index 0: Vis_1_Chosen (I1, P1, R1)
    #     # Index 1: Vis_2_Chosen (I2, P2, R2)
    #     # Index 2: Vis_1_Reject (I1, P1, R2)
    #     # Index 3: Vis_2_Reject (I2, P2, R1)
    #     # Index 4: Bld_1_Chosen (No, Pno, R1)
    #     # Index 5: Bld_2_Chosen (No, Pno, R2)
    #     # Index 6: Bld_1_Reject (No, Pno, R2)
    #     # Index 7: Bld_2_Reject (No, Pno, R1)

    #     # A. Prompts
    #     # [P1, P2, P1, P2] + [Pno, Pno, Pno, Pno]
    #     all_prompts = [p1, p2, p1, p2, p_no, p_no, p_no, p_no]
    #     all_p_masks = [p1_mask, p2_mask, p1_mask, p2_mask, p_no_mask, p_no_mask, p_no_mask, p_no_mask]

    #     # B. Responses
    #     # [R1, R2, R2, R1] + [R1, R2, R2, R1]
    #     all_resps = [r1, r2, r2, r1, r1, r2, r2, r1]
    #     all_r_masks = [r1_mask, r2_mask, r2_mask, r1_mask, r1_mask, r2_mask, r2_mask, r1_mask]

    #     # [NEW] C. Token Masks
    #     # 对应 Response 的顺序传递 Token Mask
    #     # TM1 对应 R1, TM2 对应 R2
    #     all_token_masks = None
    #     if tm1 is not None and tm2 is not None:
    #         # [TM1, TM2, TM2, TM1] + [TM1, TM2, TM2, TM1]
    #         all_token_masks = [tm1, tm2, tm2, tm1, tm1, tm2, tm2, tm1]

    #     # D. Images (Only for Visual Rows 0-3)
    #     # List: [Img1, Img2, Img1, Img2]
    #     images_to_cat = [img1, img2, img1, img2]
        
    #     # Metadata
    #     sizes_to_cat = [img1_sizes, img2_sizes, img1_sizes, img2_sizes] if img1_sizes is not None else []
    #     grids_to_cat = [img1_grid, img2_grid, img1_grid, img2_grid] if img1_grid is not None else []
    #     pm_to_cat = [img1_pm, img2_pm, img1_pm, img2_pm] if img1_pm is not None else []

    #     # =================================================================
    #     # 3. 物理构建 Batch Tensor
    #     # =================================================================
    #     mixed_batch = {
    #         "prompt_input_ids": self.pad_and_cat(all_prompts, padding_side="left", padding_value=pad_token_id, device=model.device),
    #         "prompt_attention_mask": self.pad_and_cat(all_p_masks, padding_side="left", padding_value=0, device=model.device),
    #         "completion_input_ids": self.pad_and_cat(all_resps, padding_side="right", padding_value=pad_token_id, device=model.device),
    #         "completion_attention_mask": self.pad_and_cat(all_r_masks, padding_side="right", padding_value=0, device=model.device),
    #     }
        
    #     # [NEW] Add Token Mask to Batch
    #     if all_token_masks is not None:
    #         mixed_batch["token_mask"] = self.pad_and_cat(
    #             all_token_masks, 
    #             padding_side="right", 
    #             padding_value=0, 
    #             device=model.device
    #         )

    #     # E. Pixel Values 拼接
    #     raw_pv = torch.cat(images_to_cat, dim=0)
        
    #     # [Fixed] 6D Flatten Logic for OneVision
    #     if raw_pv.dim() == 6:
    #         # (B, N, Patches, C, H, W) -> (B*N, Patches, C, H, W)
    #         mixed_batch["pixel_values"] = raw_pv.flatten(0, 2)
    #     elif raw_pv.dim() == 5:
    #         # 如果是 (B, N, C, H, W) 且 N=1, 可能需要 flatten，也可能不需要，视模型而定
    #         # 为了安全，通常 VLM 期望 (Total_Images, C, H, W)
    #         mixed_batch["pixel_values"] = raw_pv.flatten(0, 1)
    #     else:
    #         mixed_batch["pixel_values"] = raw_pv

    #     # Metadata Splice
    #     if sizes_to_cat:
    #         rs = torch.cat(sizes_to_cat, dim=0)
    #         mixed_batch["image_sizes"] = rs.flatten(0, 1) if rs.dim() == 3 else rs
    #     if grids_to_cat: 
    #         mixed_batch["image_grid_thw"] = torch.cat(grids_to_cat, dim=0)
    #     if pm_to_cat:
    #         rpm = torch.cat(pm_to_cat, dim=0)
    #         mixed_batch["pixel_attention_mask"] = rpm.flatten(0, 1) if (raw_pv.dim()>=5 and rpm.dim()>1) else rpm

    #     # =================================================================
    #     # 4. Single Forward Pass
    #     # =================================================================
        
    #     # Policy Forward
    #     policy_output = self.forward_with_pre_concatenated_batch(model, mixed_batch)
        
    #     # Reference Forward
    #     with torch.no_grad():
    #         if self.ref_model is None:
    #             with self.null_ref_context():
    #                 ref_output = self.forward_with_pre_concatenated_batch(model, mixed_batch)
    #         else:
    #             ref_output = self.forward_with_pre_concatenated_batch(self.ref_model, mixed_batch)

    #     # =================================================================
    #     # 5. 解包 Logps 并计算 Loss
    #     # =================================================================
        
    #     def _extract_all_logps(out_dict):
    #         # [Correction] TRL splits output into chosen/rejected based on batch size / 2.
    #         # Our batch size is 8.
    #         # out_dict["chosen_logps"] contains rows 0, 1, 2, 3
    #         # out_dict["rejected_logps"] contains rows 4, 5, 6, 7
    #         # We need to concat them to get the flat list of 0..7
    #         c = out_dict["chosen_logps"] 
    #         r = out_dict["rejected_logps"] 
    #         return torch.cat([c, r], dim=0)

    #     # 获取扁平化的 8 个 logps
    #     all_pi = _extract_all_logps(policy_output)
    #     all_ref = _extract_all_logps(ref_output)

    #     # Mapping Indices:
    #     # Visual Part (0-3)
    #     pi_vis_chosen = torch.stack([all_pi[0], all_pi[1]])   # (I1,R1), (I2,R2)
    #     pi_vis_reject = torch.stack([all_pi[2], all_pi[3]])   # (I1,R2), (I2,R1)
        
    #     ref_vis_chosen = torch.stack([all_ref[0], all_ref[1]])
    #     ref_vis_reject = torch.stack([all_ref[2], all_ref[3]])
        
    #     # Blind Part (4-7)
    #     pi_bld_chosen = torch.stack([all_pi[4], all_pi[5]])   # (No,R1), (No,R2)
    #     pi_bld_reject = torch.stack([all_pi[6], all_pi[7]])   # (No,R2), (No,R1)
        
    #     ref_bld_chosen = torch.stack([all_ref[4], all_ref[5]])
    #     ref_bld_reject = torch.stack([all_ref[6], all_ref[7]])

    #     # 6. 计算 Rewards & V-DPO Loss
    #     beta = self.args.beta
        
    #     # Visual Rewards
    #     r_vis_w = beta * (pi_vis_chosen - ref_vis_chosen)
    #     r_vis_l = beta * (pi_vis_reject - ref_vis_reject)
        
    #     # Blind Rewards
    #     r_bld_w = beta * (pi_bld_chosen - ref_bld_chosen)
    #     r_bld_l = beta * (pi_bld_reject - ref_bld_reject)
        
    #     # V-DPO Aggregation
    #     diff_visual = r_vis_w - r_vis_l
    #     diff_blind = r_bld_w - r_bld_l
        
    #     # The Core Formula:
    #     # Increase Visual Margin, Decrease Blind Margin (Don't let model distinguish blind pairs)
    #     logits = diff_visual - vdpo_alpha * diff_blind
        
    #     losses = -F.logsigmoid(logits)

    #     # =================================================================
    #     # 7. Metrics
    #     # =================================================================
    #     self.log_metrics(metrics, train_eval, "single_vdpo", losses, r_vis_w.detach(), r_vis_l.detach())
        
    #     if train_eval == "train":
    #         metrics["train_blind_margin"] = diff_blind.mean().item()
    #         metrics["train_visual_margin"] = diff_visual.mean().item()

    #     return losses.mean()
        


    def run_single_image_vdpo(
        self, 
        model, 
        batch, 
        metrics, 
        train_eval,
        gamma=0.75
    ) -> torch.Tensor | None:
        """
        Executes V-DPO using a SINGLE Mixed Batch strategy.
        FIXED: Added Unpadding logic for LLaVA-OneVision.
        """
        pad_token_id = self.processing_class.tokenizer.pad_token_id or 0
        vdpo_alpha = (1 - gamma) * self.args.beta

        # =================================================================
        # 1. 提取基础数据
        # =================================================================
        p1 = batch["image_1_prompt_input_ids"]
        p1_mask = batch["image_1_prompt_attention_mask"]
        p2 = batch["image_2_prompt_input_ids"]
        p2_mask = batch["image_2_prompt_attention_mask"]
        
        p_no = batch["no_image_prompt_input_ids"]
        p_no_mask = batch["no_image_prompt_attention_mask"]

        r1 = batch["resp_1_input_ids"]
        r1_mask = batch["resp_1_attention_mask"]
        r2 = batch["resp_2_input_ids"]
        r2_mask = batch["resp_2_attention_mask"]
        
        tm1 = batch.get("resp_1_token_mask")
        tm2 = batch.get("resp_2_token_mask")

        img1 = batch["image_1_pixel_values"]
        img2 = batch["image_2_pixel_values"]
        
        # # 强制检查 Mask，否则无法 Unpad
        # if "image_1_pixel_attention_mask" not in batch:
        #      raise ValueError("CRITICAL: 'image_1_pixel_attention_mask' missing! Check process_row_single_image.")
        # img1_pm = batch["image_1_pixel_attention_mask"]
        # img2_pm = batch["image_2_pixel_attention_mask"]

        # Metadata
        img1_sizes = batch.get("image_1_image_sizes")
        img2_sizes = batch.get("image_2_image_sizes")
        img1_pm = batch.get("image_1_pixel_attention_mask")
        img2_pm = batch.get("image_2_pixel_attention_mask")
        img1_grid = batch.get("image_1_image_grid_thw")
        img2_grid = batch.get("image_2_image_grid_thw")

        # =================================================================
        # 2. 构造 Mixed Batch (8 Rows)
        # =================================================================
        # A. Prompts: [Vis_1, Vis_2, Vis_1, Vis_2] + [Bld, Bld, Bld, Bld]
        all_prompts = [p1, p2, p1, p2, p_no, p_no, p_no, p_no]
        all_p_masks = [p1_mask, p2_mask, p1_mask, p2_mask, p_no_mask, p_no_mask, p_no_mask, p_no_mask]

        # B. Responses: [R1, R2, R2, R1] + [R1, R2, R2, R1]
        all_resps = [r1, r2, r2, r1, r1, r2, r2, r1]
        all_r_masks = [r1_mask, r2_mask, r2_mask, r1_mask, r1_mask, r2_mask, r2_mask, r1_mask]

        # C. Token Masks
        all_token_masks = None
        if tm1 is not None and tm2 is not None:
            all_token_masks = [tm1, tm2, tm2, tm1, tm1, tm2, tm2, tm1]

        # =================================================================
        # 3. 处理视觉特征 (Pixel Values) & Unpadding [FIXED]
        # =================================================================
        # D. Images (Only for Visual Rows 0-3: [Img1, Img2, Img1, Img2])
        images_to_cat = [img1, img2, img1, img2]
        raw_pv = torch.cat(images_to_cat, dim=0)
        
        # 1. 展平 Pixel Values (适配 6D/5D)
        if raw_pv.dim() == 6:
            final_pv = raw_pv.flatten(0, 2)
        elif raw_pv.dim() == 5:
            final_pv = raw_pv.flatten(0, 1)
        else:
            final_pv = raw_pv

        # 2. 构造对应的 Mask 列表 (顺序必须严格一致)
        if img1_pm and img2_pm:
            pm_to_cat = [img1_pm, img2_pm, img1_pm, img2_pm]
            raw_pm = torch.cat(pm_to_cat, dim=0)

            # 3. 展平 Mask (与 PV 同步)
            if raw_pv.dim() == 6: 
                final_pixel_mask = raw_pm.flatten(0, 2) if raw_pm.dim() == 5 else raw_pm.flatten(0, 1)
            elif raw_pv.dim() == 5 and raw_pm.dim() > 1:
                final_pixel_mask = raw_pm.flatten(0, 1)
            else:
                final_pixel_mask = raw_pm

            # 4. 执行 Unpadding (过滤无效 Tile)
            valid_tile_indices = final_pixel_mask.sum(dim=(1, 2)) > 0
            
            if final_pv.shape[0] != valid_tile_indices.shape[0]:
                raise RuntimeError(f"Shape Mismatch! PV: {final_pv.shape[0]}, Mask: {valid_tile_indices.shape[0]}")
                
            final_pv = final_pv[valid_tile_indices] # [Unpadded]

        # =================================================================
        # 4. 物理构建 Batch Tensor
        # =================================================================
        mixed_batch = {
            "prompt_input_ids": self.pad_and_cat(all_prompts, padding_side="left", padding_value=pad_token_id, device=model.device),
            "prompt_attention_mask": self.pad_and_cat(all_p_masks, padding_side="left", padding_value=0, device=model.device),
            "completion_input_ids": self.pad_and_cat(all_resps, padding_side="right", padding_value=pad_token_id, device=model.device),
            "completion_attention_mask": self.pad_and_cat(all_r_masks, padding_side="right", padding_value=0, device=model.device),
            "pixel_values": final_pv, # 已去除 Padding
        }
        
        if all_token_masks is not None:
            mixed_batch["token_mask"] = self.pad_and_cat(all_token_masks, padding_side="right", padding_value=0, device=model.device)

        # Metadata Splice
        if img1_sizes is not None:
            sizes_list = [img1_sizes, img2_sizes, img1_sizes, img2_sizes]
            rs = torch.cat(sizes_list, dim=0)
            # Sizes 只要对应 Image 数量即可，不需要 Unpad
            mixed_batch["image_sizes"] = rs.flatten(0, 1) if rs.dim() == 3 else rs
            
        if img1_grid is not None: 
            grids_list = [img1_grid, img2_grid, img1_grid, img2_grid]
            mixed_batch["image_grid_thw"] = torch.cat(grids_list, dim=0)

        # =================================================================
        # 5. Forward & Loss (保持不变)
        # =================================================================
        policy_output = self.forward_with_pre_concatenated_batch(model, mixed_batch)
        
        with torch.no_grad():
            if self.ref_model is None:
                with self.null_ref_context():
                    ref_output = self.forward_with_pre_concatenated_batch(model, mixed_batch)
            else:
                ref_output = self.forward_with_pre_concatenated_batch(self.ref_model, mixed_batch)

        # Helper to extract logps
        def _extract_all_logps(out_dict):
            # mixed_batch size is 8
            # policy_output split into chosen(0-3)/rejected(4-7) by TRL logic?
            # Wait, standard DPO splits by batch/2.
            # Batch=8 -> Chosen=0,1,2,3; Rejected=4,5,6,7? NO.
            # TRL DPO Trainer expects concatenated [Chosen, Rejected].
            # Our batch construction:
            # 0,1,2,3: Visual Pairs
            # 4,5,6,7: Blind Pairs
            # This is NOT standard [Chosen, Rejected] structure for TRL's compute_loss.
            # But here we are calculating loss manually, so we just need access to raw logps.
            
            # forward_with_pre_concatenated_batch returns:
            # { "chosen_logps": logps[:BS/2], "rejected_logps": logps[BS/2:] }
            # Since BS=8, chosen=0-3, rejected=4-7.
            c = out_dict["chosen_logps"] 
            r = out_dict["rejected_logps"] 
            return torch.cat([c, r], dim=0) # [0, 1, 2, 3, 4, 5, 6, 7]

        all_pi = _extract_all_logps(policy_output)
        all_ref = _extract_all_logps(ref_output)

        # Mapping:
        # 0: Vis_1_Chosen (I1, R1)
        # 1: Vis_2_Chosen (I2, R2)
        # 2: Vis_1_Reject (I1, R2)
        # 3: Vis_2_Reject (I2, R1)
        pi_vis_chosen = torch.stack([all_pi[0], all_pi[1]]) 
        pi_vis_reject = torch.stack([all_pi[2], all_pi[3]]) 
        ref_vis_chosen = torch.stack([all_ref[0], all_ref[1]])
        ref_vis_reject = torch.stack([all_ref[2], all_ref[3]])
        
        # 4: Bld_1_Chosen (No, R1)
        # 5: Bld_2_Chosen (No, R2)
        # 6: Bld_1_Reject (No, R2)  <-- Correction: Index 6 is Bld_1_Reject
        # 7: Bld_2_Reject (No, R1)  <-- Correction: Index 7 is Bld_2_Reject
        pi_bld_chosen = torch.stack([all_pi[4], all_pi[5]]) 
        pi_bld_reject = torch.stack([all_pi[6], all_pi[7]]) 
        ref_bld_chosen = torch.stack([all_ref[4], all_ref[5]])
        ref_bld_reject = torch.stack([all_ref[6], all_ref[7]])

        # 6. 计算 Rewards & V-DPO Loss
        beta = self.args.beta
        
        r_vis_w = beta * (pi_vis_chosen - ref_vis_chosen)
        r_vis_l = beta * (pi_vis_reject - ref_vis_reject)
        
        r_bld_w = beta * (pi_bld_chosen - ref_bld_chosen)
        r_bld_l = beta * (pi_bld_reject - ref_bld_reject)
        
        diff_visual = r_vis_w - r_vis_l
        diff_blind = r_bld_w - r_bld_l
        
        logits = diff_visual - vdpo_alpha * diff_blind
        
        losses = -F.logsigmoid(logits)

        self.log_metrics(metrics, train_eval, "single_vdpo", losses, r_vis_w.detach(), r_vis_l.detach())
        
        if train_eval == "train":
            metrics["train_blind_margin"] = diff_blind.mean().item()
            metrics["train_visual_margin"] = diff_visual.mean().item()

        return losses.mean()

    def soft_label_dpo_loss(
        self,
        student_chosen_logps: torch.Tensor,
        student_rejected_logps: torch.Tensor,
        student_ref_chosen_logps: torch.Tensor,
        student_ref_rejected_logps: torch.Tensor,
        teacher_chosen_logps: torch.Tensor,
        teacher_rejected_logps: torch.Tensor,
        teacher_ref_chosen_logps: torch.Tensor,
        teacher_ref_rejected_logps: torch.Tensor,
        beta: float = 0.1,
        temperature: float = 1.0,
        validation_threshold: float = 0.5, # 新增参数：过滤阈值
    ) -> tuple[torch.Tensor, dict]:
        """
        Visual Preference Distillation with Teacher Correctness Filtering.
        Only distills when Teacher agrees with the Ground Truth label (prob > threshold).
        """
        # 1. 计算 Student Margin
        student_logits = (student_chosen_logps - student_rejected_logps) - (student_ref_chosen_logps - student_ref_rejected_logps)
        student_margin = beta * student_logits

        # 2. 计算 Teacher Margin & Soft Label
        if self.args.vcdist_stopgrad:
            with torch.no_grad():
                teacher_logits = (teacher_chosen_logps - teacher_rejected_logps) - (teacher_ref_chosen_logps - teacher_ref_rejected_logps)
                scaled_teacher_logits = teacher_logits / temperature
                teacher_margin = beta * scaled_teacher_logits
                
                target_probs = torch.sigmoid(teacher_margin)
        else:
            teacher_logits = (teacher_chosen_logps - teacher_rejected_logps) - (teacher_ref_chosen_logps - teacher_ref_rejected_logps)
            scaled_teacher_logits = teacher_logits / temperature
            teacher_margin = beta * scaled_teacher_logits
            
            target_probs = torch.sigmoid(teacher_margin)

        # 3. 计算 Student 概率
        student_probs = torch.sigmoid(student_margin)

        # 4. 计算原始 BCE Loss (不进行 reduction，保持每个样本独立)
        raw_loss = F.binary_cross_entropy(student_probs, target_probs, reduction='none')

        # === 核心优化：Correctness Filtering ===
        # # 创建 Mask：只有当 Teacher 认为 Chosen > Rejected (即 P > 0.5) 时，Mask 为 1
        # # 如果你希望更严格，可以设 threshold = 0.6，只学 Teacher 比较确信的
        # valid_mask = (target_probs > validation_threshold).float()
        if self.args.vcdist_filter:
            # Margin Gap: Teacher 希望 Student 至少达到 Teacher 的水平
            valid_mask = (student_probs < target_probs).float() 
            # 结合之前的 correctness filter (p_teacher > 0.5)
            valid_mask = valid_mask * (target_probs > validation_threshold).float()
        else:
            valid_mask = (target_probs>0).float() 
            # 应用 Mask
        filtered_loss = raw_loss * valid_mask

        # 5. Loss Reduction (计算平均值)
        # 注意：分母应该是 valid_mask.sum() 还是 batch_size？
        # 推荐使用 valid_mask.sum()，这样 loss 的量级不会因为有效样本变少而由于被稀释变小
        num_valid = valid_mask.sum()
        
        # 防止分母为 0
        if num_valid > 0:
            final_loss = filtered_loss.sum() / num_valid
        else:
            final_loss = torch.tensor(0.0, device=raw_loss.device, requires_grad=True)

        # # 6. Metrics 监控
        # kl_div = (target_probs * (torch.log(target_probs + 1e-6) - torch.log(student_probs + 1e-6)) + 
        #           (1 - target_probs) * (torch.log(1 - target_probs + 1e-6) - torch.log(1 - student_probs + 1e-6)))
        
        # # 只统计有效样本的 KL
        # valid_kl = (kl_div * valid_mask).sum() / (num_valid + 1e-8)

        # return final_loss, {
        #     "vcdist_kl": valid_kl.item(),
        #     "teacher_confidence": (target_probs - 0.5).abs().mean().item(),
        #     "teacher_acc": (target_probs > 0.5).float().mean().item(), # 监控 Teacher 的准确率
        #     "vcdist_valid_count": num_valid.item() # 监控有多少样本参与了蒸馏
        # }

        # 公式: P * log(P/Q) + (1-P) * log((1-P)/(1-Q))
        kl_div_element_wise = (target_probs * (torch.log(target_probs + 1e-6) - torch.log(student_probs + 1e-6)) + 
                               (1 - target_probs) * (torch.log(1 - target_probs + 1e-6) - torch.log(1 - student_probs + 1e-6)))
        
        # === A. 计算 Valid KL (原有逻辑) ===
        valid_mask = (target_probs > validation_threshold).float()
        valid_kl_sum = (kl_div_element_wise * valid_mask).sum()
        valid_count = valid_mask.sum()
        
        # === B. [新增] 计算 Global KL (所有样本) ===
        # 直接求和，不乘 mask
        global_kl_sum = kl_div_element_wise.sum()
        global_count = torch.tensor(kl_div_element_wise.numel(), device=kl_div_element_wise.device, dtype=torch.float)

        # ... (Loss 计算逻辑不变，只对 valid 样本算梯度) ...
        # 计算 Loss Sum 用于分布式聚合
        # raw_loss = F.binary_cross_entropy(student_probs, target_probs, reduction='none')
        # loss_sum = (raw_loss * valid_mask).sum()

        # 返回 Raw Tensors (全部 detach 以节省显存)
        return final_loss, {            # Valid Metrics
            "vcdist_valid_kl_sum": valid_kl_sum.detach(),
            "vcdist_valid_count": valid_count.detach(),
            
            # Global Metrics [新增]
            "vcdist_global_kl_sum": global_kl_sum.detach(),
            "vcdist_global_count": global_count.detach(), # 其实就是 batch_size，但转为tensor方便reduce
        }

    def sft_loss(
        self,
        chosen_logps: torch.FloatTensor,
        chosen_attention_mask: torch.LongTensor
    ) -> torch.FloatTensor:
        """
        Compute the SFT loss (Token-averaged Negative Log Likelihood on Chosen responses).
        
        Formula: L = - (Sum of logps) / (Number of valid tokens)
        """
        # 1. 基础 NLL (Summed)
        # chosen_logps 是一个 shape 为 (Batch,) 的 tensor，包含每个样本 Response 的 Log 概率之和
        losses = -chosen_logps

        # 2. 执行平均 (Token-averaged)
        # 计算每个样本 Response 的有效长度 (Sum over sequence length dim)
        # 这里的 mask 应该是只包含 response 的 mask (不含 prompt)
        valid_token_counts = chosen_attention_mask.sum(dim=-1).float()
        
        # 防止除以 0 (虽然理论上 response 不为空)
        valid_token_counts = chosen_attention_mask.sum(dim=-1).float()
        
        # 归一化
        losses = losses / valid_token_counts

        return losses

    def hinge_sft_loss(
        self,
        chosen_logps: torch.FloatTensor,
        ref_chosen_logps: torch.FloatTensor,
        chosen_attention_mask: torch.LongTensor
    ) -> torch.FloatTensor:
        sum_logp_diff = ref_chosen_logps - chosen_logps
        # 计算有效长度 (Length)
        # 注意：TRL等库通常会对chosen_logps做处理，这里假设mask与logps是对齐的
        # 如果是 Vision-Language Model，要注意 mask 是否包含 image token，通常只统计 text token
        valid_token_counts = chosen_attention_mask.sum(dim=-1).float()
        
        # 防止除以0 (虽然不太可能)
        valid_token_counts = chosen_attention_mask.sum(dim=-1).float()
        
        # 转成 Average level
        avg_logp_diff = sum_logp_diff / valid_token_counts
        
        # 应用 ReLU
        # 只有当平均每个token的概率都显著低于ref时，才产生loss
        hinge_sft_losses = F.relu(avg_logp_diff)
        return hinge_sft_losses.mean()
        
    def anchor_loss(
        self,
        chosen_logps: torch.FloatTensor,
        ref_chosen_logps: torch.FloatTensor,
        anchor_delta: float = 0.0,
    ) -> tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        """
        Compute the Anchor DPO loss (AncPO) based on the mDPO paper formula.
        
        Formula: L = -log(sigmoid(beta * log(pi/pi_ref) - delta))
        
        Args:
            The arguments are identical to the standard dpo_loss function.
        """
        device = self.accelerator.device

        # 1. 计算 Chosen 的 Log Ratios (log(pi) - log(pi_ref))
        # Anchor Loss 只关注 chosen response，因此我们主要计算这一部分
        chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device)

        # 3. 计算 Logits
        # 公式核心部分: beta * log(pi/pi_ref) - delta
        # 对应图片中的: beta * log(pi_theta / pi_ref) - delta
        logits = self.beta * chosen_logratios - anchor_delta

        # 4. 计算 Loss
        # 公式: -log(sigmoid(logits))
        # 使用 F.logsigmoid 计算更数值稳定: -F.logsigmoid(logits)
        losses = -F.logsigmoid(logits)

        return losses


    def margin_consistency_loss(
        self,
        chosen_logps: torch.FloatTensor,
        rejected_logps: torch.FloatTensor,
        ref_chosen_logps: torch.FloatTensor,
        ref_rejected_logps: torch.FloatTensor,
    ) -> torch.FloatTensor:
        """
        Compute the Margin Consistency Loss (SymMPO).
        
        Formula: L = (Delta(m) - Delta(m'))^2
        where Delta = log(pi/ref)_chosen - log(pi/ref)_rejected
        
        Note: This loss expects the batch to be structured as [Image1_Batch, Image2_Batch]
        perfectly aligned (symmetric).
        """
        device = self.accelerator.device

        # 1. 计算 Chosen 和 Rejected 的 Log Ratios (不乘 beta)
        # Delta = log(pi / pi_ref)
        chosen_logratios = chosen_logps.to(device) - (not self.reference_free) * ref_chosen_logps.to(device)
        rejected_logratios = rejected_logps.to(device) - (not self.reference_free) * ref_rejected_logps.to(device)

        # 2. 计算 Margin Delta
        # Delta(m, x, yw, y'w)
        margins = chosen_logratios - rejected_logratios

        # 3. 切分 Batch
        # 假设 input 是 [Image1_Data, Image2_Data] 拼接的
        batch_size = chosen_logps.shape[0] // 2
        
        # Margin under Image 1: r(y1|I1) - r(y2|I1)
        delta_pos = margins[:batch_size]
        
        # Margin under Image 2: r(y2|I2) - r(y1|I2) 
        # (注意：run_single_image_dpo 中 Image 2 部分的 chosen 是 resp 2，rejected 是 resp 1)
        # 这完美对应公式中的 Delta(m', x, y'w, yw)
        delta_neg = margins[batch_size:]
        margin = delta_pos - delta_neg
        # 4. 计算 MSE Loss
        # L = (Delta_m - Delta_m')^2
        losses = margin ** 2

        return losses, margin
    
    def margin_distillation_loss(
        self,
        single_margin: torch.FloatTensor,
        multi_margin: torch.FloatTensor
    ) -> torch.FloatTensor:
        """
        Computes the distillation loss to encourage the single-image model to match
        the confidence (margin) of the multi-image model, but only if the multi-image
        margin is higher.

        Args:
            single_margin (torch.FloatTensor): Margins from the single-image context.
            multi_margin (torch.FloatTensor): Margins from the multi-image context.

        Returns:
            torch.FloatTensor: The calculated distillation loss.
        """
        # Correct variable names
        gap = multi_margin.detach() - single_margin
        
        # Use ReLU:
        # If gap > 0 (Multi > Single): generate Loss to force Single to increase.
        # If gap <= 0 (Single >= Multi): Loss = 0, Single is performing well enough.
        loss = torch.mean(torch.nn.functional.relu(gap) ** 2)
        
        return loss

    # -------------------------------------------------------------------------
    # 4. 辅助函数
    # -------------------------------------------------------------------------
    @staticmethod
    def pad_and_cat(tensors, padding_side="right", padding_value=0, device=None):
        """
        Helper: Pad list of tensors to max length and concatenate along dim 0.
        """
        if not tensors:
            return torch.tensor([], device=device)
        
        max_len = max(t.size(1) for t in tensors)
        padded_tensors = []
        for t in tensors:
            pad_len = max_len - t.size(1)
            if pad_len > 0:
                if padding_side == "right":
                    t = torch.nn.functional.pad(t, (0, pad_len), value=padding_value)
                elif padding_side == "left":
                    t = torch.nn.functional.pad(t, (pad_len, 0), value=padding_value)
            padded_tensors.append(t)
        
        return torch.cat(padded_tensors, dim=0)

    def log_metrics(self, metrics, prefix, task, losses, chosen_rewards, rejected_rewards):
        """
        Helper to log metrics with DDP support.
        Gathers tensors from all GPUs before computing the mean.
        """
        # 3. 定义一个内部辅助函数来处理 gather + mean
        def gather_and_mean(tensor):
            # (A) Detach: 切断梯度，防止显存泄漏 (关键!)
            tensor = tensor.detach()
            
            # (B) Gather: 从所有 GPU 收集数据
            # gather_for_metrics 会自动处理 DDP 中的 padding 问题
            all_tensors = self.accelerator.gather_for_metrics(tensor)
            
            # (C) Mean: 计算全局均值
            return all_tensors.mean().item()

         # 2. 构造全名前缀
        full_prefix = f"{prefix}_{task}_"
        # 1. 计算 Accuracy 和 Margin (Local)
        # 先在本地计算，保持维度一致，方便 gather
        reward_accuracies = (chosen_rewards > rejected_rewards).float()
        margins = chosen_rewards - rejected_rewards
        # 4. 记录日志
        metrics[f"{full_prefix}rewards/chosen"] = gather_and_mean(chosen_rewards)
        metrics[f"{full_prefix}rewards/rejected"] = gather_and_mean(rejected_rewards)
        metrics[f"{full_prefix}rewards/accuracies"] = gather_and_mean(reward_accuracies)
        metrics[f"{full_prefix}rewards/margins"] = gather_and_mean(margins)
        metrics[f"{full_prefix}loss"] = gather_and_mean(losses)



