"""
智能描述细化引擎
封装了基于 Jitter 分析和视觉定位的迭代细化流程
"""
import warnings
warnings.filterwarnings("ignore")

import torch
import copy
import time
import os
import sys
import numpy as np
import logging
import gc
from PIL import Image
from typing import Optional, Tuple, Dict, Any

from crop import Config, GroundingAgent, JitterAnalyzer, TextMiner
from llava.model.builder import load_pretrained_model
from llava.mm_utils import process_images, tokenizer_image_token
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN
from llava.conversation import conv_templates
from llava.hooks.fast_dllm_hook import register_fast_dllm_hook


class RefinementEngine:
    """智能描述细化引擎"""
    
    def __init__(
        self,
        model_path: str,
        model_base: str,
        model_name: str = "llava_llada_lora",
        vision_tower_path: str = "google/siglip2-so400m-patch14-384",
        device: str = "cuda:0",
        max_steps: int = 10,
        jitter_threshold: float = 0.30,
        mask_expansion: int = 2,
        temp_dir: str = "./cropped_image",
        image_input_mode: str = "both",
        mask_mode: str = "span",
        token_selection_mode: str = "jitter",
        local_refinement_mode: str = "crop",
        logger: Optional[logging.Logger] = None
    ):
        """
        初始化细化引擎
        
        Args:
            model_path: 训练好的模型路径
            model_base: 基础模型路径
            model_name: 模型名称
            vision_tower_path: 视觉编码器路径
            device: 设备
            max_steps: 最大迭代次数
            jitter_threshold: Jitter 阈值
            mask_expansion: Mask 扩张（仅在 mask_mode="span" 时使用）
            temp_dir: 临时文件目录
            image_input_mode: 图像输入模式，可选 "original"（仅原图）、"crop"（仅局部图）、"both"（原图+局部图）
            mask_mode: Mask 模式，可选 "span"（使用TextMiner解析span）、"single"（只mask高波动token）、"expand"（扩展高波动token左右各4个token）
            token_selection_mode: Token 选择模式，可选 "jitter"（选择jitter最高的token）、"random"（随机选择一个token）、"confidence"（选择confidence最低的token）、"jitter_confidence"（在低置信度token中选择jitter最高的）
            local_refinement_mode: 局部修复模式，可选 "crop"（裁剪模式，通过裁剪获取局部证据）、"bbox"（边界框模式，在原图上绘制红色边界框）
            logger: 日志记录器
        """
        self.device = device
        self.max_steps = max_steps
        self.jitter_threshold = jitter_threshold
        self.mask_expansion = mask_expansion
        self.temp_dir = temp_dir
        self.mask_token_id = 126336
        self.image_input_mode = image_input_mode
        self.mask_mode = mask_mode
        self.token_selection_mode = token_selection_mode
        self.local_refinement_mode = local_refinement_mode
        
        # 验证 image_input_mode 参数
        if image_input_mode not in ["original", "crop", "both"]:
            raise ValueError(f"image_input_mode 必须是 'original', 'crop' 或 'both'，当前值: {image_input_mode}")
        
        # 验证 mask_mode 参数
        if mask_mode not in ["span", "single", "expand"]:
            raise ValueError(f"mask_mode 必须是 'span', 'single' 或 'expand'，当前值: {mask_mode}")
        
        # 验证 token_selection_mode 参数
        if token_selection_mode not in ["jitter", "random", "confidence", "jitter_confidence"]:
            raise ValueError(f"token_selection_mode 必须是 'jitter', 'random', 'confidence' 或 'jitter_confidence'，当前值: {token_selection_mode}")
        
        # 验证 local_refinement_mode 参数
        if local_refinement_mode not in ["crop", "bbox"]:
            raise ValueError(f"local_refinement_mode 必须是 'crop' 或 'bbox'，当前值: {local_refinement_mode}")
        
        # 设置日志
        if logger is None:
            self.logger = logging.getLogger(__name__)
            self.logger.setLevel(logging.INFO)
            if not self.logger.handlers:
                handler = logging.StreamHandler()
                handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
                self.logger.addHandler(handler)
        else:
            self.logger = logger
        
        # 创建临时目录
        os.makedirs(temp_dir, exist_ok=True)
        
        # 加载模型
        self._load_models(model_path, model_base, model_name, vision_tower_path)
        
    def _load_models(self, model_path, model_base, model_name, vision_tower_path):
        """加载模型"""
        self.logger.info(">>> Loading LLaDA Model...")
        self.tokenizer, self.model, self.image_processor, self.max_length = load_pretrained_model(
            model_path, model_base, model_name,
            attn_implementation="sdpa",
            device_map=None,
            overwrite_config={"mm_vision_tower": vision_tower_path}
        )
        self.tokenizer.padding_side = "left"
        self.model = self.model.to(self.device)
        self.model.eval()
        # register_fast_dllm_hook(self.model)
        
        # 加载 GroundingDINO
        self.logger.info(">>> Loading GroundingDINO (Standing by on CPU)...")
        dino_config = Config()
        self.grounder = GroundingAgent(dino_config, logger=self.logger)
        self.grounder.load()
        # if self.grounder.model:
        #     self.grounder.model.to("cpu")
    
    def _flush(self):
        """强制清理显存"""
        gc.collect()
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    
    def _move_model_to_cpu(self, model_obj):
        """将模型移动到 CPU"""
        if model_obj:
            model_obj.to("cpu")
            self._flush()
    
    def _move_model_to_gpu(self, model_obj):
        """将模型移动回 GPU"""
        if model_obj:
            model_obj.to(self.device)
    
    def _expand2square(self, pil_img, background_color):
        """将图片扩展为正方形"""
        width, height = pil_img.size
        if width == height:
            return pil_img
        elif width > height:
            result = Image.new(pil_img.mode, (width, width), background_color)
            result.paste(pil_img, (0, (width - height) // 2))
            return result
        else:
            result = Image.new(pil_img.mode, (height, height), background_color)
            result.paste(pil_img, ((height - width) // 2, 0))
            return result
    
    def _draw_bbox_on_image(self, image: Image.Image, bbox: Tuple[int, int, int, int]) -> Image.Image:
        """
        在图像上绘制红色边界框
        
        Args:
            image: PIL Image 对象
            bbox: 边界框坐标 (x1, y1, x2, y2)
            
        Returns:
            绘制了边界框的图像
        """
        import cv2
        import numpy as np
        
        # 将 PIL Image 转换为 numpy array (RGB)
        img_array = np.array(image)
        # 转换为 BGR 格式供 OpenCV 使用
        img_bgr = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
        
        x1, y1, x2, y2 = bbox
        # 绘制红色边界框，线宽为3
        cv2.rectangle(img_bgr, (x1, y1), (x2, y2), (0, 0, 255), 3)
        
        # 转换回 RGB 格式
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        # 转换回 PIL Image
        return Image.fromarray(img_rgb)
    
    def _map_char_span_to_token_span(self, token_ids, tokenizer, char_start, char_end):
        """将字符级范围映射回 Token 索引范围"""
        current_char_pos = 0
        token_start = None
        token_end = None

        for i, tid in enumerate(token_ids):
            token_text = tokenizer.decode([tid], skip_special_tokens=False)
            token_len = len(token_text)
            
            t_s = current_char_pos
            t_e = current_char_pos + token_len
            
            if t_e > char_start and t_s < char_end:
                if token_start is None:
                    token_start = i
                token_end = i + 1
            
            current_char_pos += token_len
            
            if current_char_pos >= char_end and token_start is not None:
                break
                
        if token_start is None:
            return None, None
            
        return token_start, token_end
    
    def _check_token_pos(self, token_idx, current_generated_ids, doc):
        """
        检查指定 token 的词性是否符合要求（NOUN, ADJ, VERB, ADP）
        
        Args:
            token_idx: token 在 current_generated_ids 中的索引
            current_generated_ids: 生成的 token ID 列表
            doc: spacy 解析后的文档对象
            
        Returns:
            bool: 如果词性符合要求（NOUN, ADJ, VERB, ADP）返回 True，否则返回 False
        """
        if token_idx >= len(current_generated_ids):
            return False
        
        # 获取 token 对应的文本
        prefix_ids = current_generated_ids[:token_idx]
        prefix_text = self.tokenizer.decode(prefix_ids, skip_special_tokens=True)
        char_offset = len(prefix_text)
        
        # 跳过空格
        while char_offset < len(doc.text) and doc.text[char_offset].isspace():
            char_offset += 1
        
        # 在 doc 中找到对应的 token
        target_token = None
        for token in doc:
            if token.idx <= char_offset < (token.idx + len(token.text)):
                target_token = token
                break
        
        if not target_token:
            return False
        
        # 检查词性：NOUN, ADJ, VERB, ADP
        valid_pos = {"NOUN", "PROPN", "ADJ", "VERB", "ADP"}
        return target_token.pos_ in valid_pos
    
    def refine(
        self,
        image_path: str,
        base_instruction: str = "Please describe the image in detail. Use less absolute directional descriptions.",
        **kwargs
    ) -> Tuple[str, Dict[str, Any]]:
        """
        对图像进行智能描述细化
        
        Args:
            image_path: 图像路径
            base_instruction: 基础指令
            **kwargs: 其他参数
            
        Returns:
            (final_response, metadata): 最终响应和元数据
        """
        # 加载图像
        original_image_obj = Image.open(image_path).convert("RGB")
        
        # 状态变量
        last_intermediate_confidences = None
        current_generated_ids = []
        global_covered_indices = set()
        initial_jitter = []
        initial_confidence = []
        target_span_range = None
        # jitter_confidence 模式：记录当前应该处理的类型（"jitter" 或 "confidence"）
        jitter_conf_mode_type = None  # 在 jitter_confidence 模式下使用
        
        self.logger.info(f"{'=' * 20} Start Intelligent Loop (Threshold: {self.jitter_threshold}) {'=' * 20}")
        
        # 主循环
        for step_idx in range(self.max_steps + 1):
            # 分析与决策阶段
            if step_idx > 0:
                self.logger.info(f"\n>>> [Phase {step_idx}] Analyzing Uncertainty...")
                
                if self.token_selection_mode != "random":
                    if last_intermediate_confidences:
                        # print(f"last_intermediate_confidences: {last_intermediate_confidences}")
                        history_data = [{'confidences': [float(c) for c in step.get('confidences', [])]}
                                        for step in last_intermediate_confidences]
                        
                        del last_intermediate_confidences
                        last_intermediate_confidences = None
                        self._flush()
                        
                        conf_matrix, _ = JitterAnalyzer.extract_confidence_data(history_data)
                        
                        # 获取最终 confidence（最后一次生成的 confidence，用于 confidence 和 jitter_confidence 模式）
                        # conf_matrix 的形状是 (num_tokens, num_steps)，所以应该用 conf_matrix[:, -1] 获取最后一列（所有 token 的最后一步 confidence）
                        if conf_matrix.size > 0 and conf_matrix.shape[1] > 0:
                            final_confidence_values = conf_matrix[:, -1].copy()  # 获取最后一列（所有 token 的最后一步 confidence）
                        else:
                            final_confidence_values = np.array([])
                            
                        # print(f"final_confidence_values: {final_confidence_values}")
                        
                        # 计算 jitter（jitter 模式和 jitter_confidence 模式都需要）
                        if self.token_selection_mode in ["jitter", "jitter_confidence"]:
                            jitter_values = JitterAnalyzer.calculate_jitter_numpy(conf_matrix)
                            # print(f"jitter_values: {jitter_values}")
                            if step_idx == 1:
                                initial_jitter = jitter_values.copy()
                            else:
                                if len(initial_jitter) > 0 and len(jitter_values) == len(initial_jitter) and target_span_range:
                                    s, e = target_span_range
                                    initial_jitter[s:e] = jitter_values[s:e].copy()
                                    jitter_values = initial_jitter.copy()
                                else:
                                    self.logger.warning("  -> Initial jitter and current jitter have different lengths. Skipping.")
                                    continue
                            
                            if len(jitter_values) == 0:
                                break
                        
                        # 处理 confidence 相关的模式
                        if self.token_selection_mode in ["confidence", "jitter_confidence"]:
                            # 但需要确保 final_confidence_values 的长度和 current_generated_ids 一致
                            if len(final_confidence_values) == 0:
                                # 如果没有 confidence 数据，创建一个与 current_generated_ids 长度相同的数组，填充为 1.0
                                final_confidence_values = np.ones(len(current_generated_ids), dtype=np.float32)
                            elif len(final_confidence_values) != len(current_generated_ids):
                                # 如果长度不一致，调整 final_confidence_values 的长度
                                if len(final_confidence_values) > len(current_generated_ids):
                                    final_confidence_values = final_confidence_values[:len(current_generated_ids)]
                                else:
                                    # 如果 final_confidence 长度小于 current_generated_ids，用 1.0 填充
                                    final_confidence_values = np.pad(final_confidence_values, (0, len(current_generated_ids) - len(final_confidence_values)), 
                                                                    mode='constant', constant_values=1.0)
                            else:
                                if len(initial_confidence) > len(current_generated_ids):
                                    initial_confidence = initial_confidence[:len(current_generated_ids)]
                                else:
                                    initial_confidence = np.pad(initial_confidence, (0, len(current_generated_ids) - len(initial_confidence)), 
                                                                    mode='constant', constant_values=1.0)
                            if step_idx == 1:
                                initial_confidence = final_confidence_values.copy()
                            else:
                                if len(initial_confidence) > 0 and len(final_confidence_values) == len(initial_confidence) and target_span_range:
                                    s, e = target_span_range
                                    initial_confidence[s:e] = final_confidence_values[s:e].copy()
                                    final_confidence_values = initial_confidence.copy()
                                else:
                                    self.logger.warning("  -> Initial confidence and current confidence have different lengths. Skipping.")
                                    continue
                        
                        del history_data, conf_matrix
                
                # 应用全局 Mask
                if global_covered_indices:
                    if self.token_selection_mode in ["jitter", "jitter_confidence"]:
                        suppressed_mask = np.zeros_like(jitter_values, dtype=bool)
                        for idx in global_covered_indices:
                            if idx < len(jitter_values):
                                suppressed_mask[idx] = True
                        jitter_values[suppressed_mask] = 0.0
                        self.logger.info(f"  -> Suppressed {sum(suppressed_mask)} tokens based on history.")
                    if self.token_selection_mode in ["confidence", "jitter_confidence"]:
                        suppressed_mask = np.zeros_like(final_confidence_values, dtype=bool)
                        for idx in global_covered_indices:
                            if idx < len(final_confidence_values):
                                suppressed_mask[idx] = True
                        final_confidence_values[suppressed_mask] = 1.0  # 将已覆盖的 token 的 confidence 设为最高，避免被选中
                        if self.token_selection_mode == "confidence":
                            self.logger.info(f"  -> Suppressed {sum(suppressed_mask)} tokens based on history.")
                        # jitter_confidence 模式的日志在 jitter 分支已经输出，这里不重复
                
                # 根据 token_selection_mode 选择候选 token 索引
                if self.token_selection_mode == "jitter":
                    # 模式1: 选择 jitter 最高的 token（原有方式）
                    sorted_indices = np.argsort(jitter_values)[::-1]
                    candidate_indices = sorted_indices
                    selection_metric = jitter_values
                    metric_name = "Jitter"
                elif self.token_selection_mode == "random":
                    # 模式2: 从所有 token 中随机选择一个（排除已覆盖的）
                    all_indices = np.arange(len(current_generated_ids))
                    # 排除已覆盖的 token
                    if global_covered_indices:
                        available_indices = all_indices[~np.isin(all_indices, list(global_covered_indices))]
                    else:
                        available_indices = all_indices
                    
                    if len(available_indices) == 0:
                        candidate_indices = []
                    else:
                        # 如果使用 span 模式，需要尝试多个随机 token，直到找到一个有效的
                        # 因为 analyze_token_for_refinement 可能返回 None（找不到有效的 chunk）
                        if self.mask_mode == "span":
                            import random as py_random
                            # 打乱顺序，随机尝试多个 token
                            shuffled_indices = available_indices.tolist()
                            py_random.shuffle(shuffled_indices)
                            # 最多尝试 min(10, len(available_indices)) 个 token
                            max_attempts = min(10, len(shuffled_indices))
                            candidate_indices = shuffled_indices[:max_attempts]
                            self.logger.info(f"  -> Random mode (span): Will try {len(candidate_indices)} random tokens to find a valid target.")
                        else:
                            # single 和 expand 模式不需要检查词性，直接随机选择一个
                            import random as py_random
                            selected_random_idx = py_random.choice(available_indices.tolist())
                            candidate_indices = [selected_random_idx]
                    selection_metric = np.zeros(len(current_generated_ids))  # random 模式不需要 metric
                    metric_name = "Random"
                elif self.token_selection_mode == "confidence":
                    # 模式3: 从所有 token 中选择最终 confidence 最低的 token（排除已覆盖的）
                    # if len(final_confidence_values) == 0 or len(final_confidence_values) != len(current_generated_ids):
                    #     self.logger.warning("  -> No confidence data available or length mismatch, using fallback")
                    #     # 如果 confidence 数据不可用，创建一个默认数组
                    #     final_confidence_values = np.ones(len(current_generated_ids), dtype=np.float32)
                    
                    # 从所有 token 中选择（排除已覆盖的）
                    all_indices = np.arange(len(current_generated_ids))
                    if global_covered_indices:
                        available_indices = all_indices[~np.isin(all_indices, list(global_covered_indices))]
                    else:
                        available_indices = all_indices
                    
                    if len(available_indices) == 0:
                        candidate_indices = []
                    else:
                        # 在可用 token 中选择最终 confidence 最低的
                        available_confidences = final_confidence_values[available_indices]
                        sorted_available_idx = np.argsort(available_confidences)  # 升序排列，最低的在前面
                        candidate_indices = available_indices[sorted_available_idx]
                    selection_metric = final_confidence_values
                    metric_name = "Confidence"
                elif self.token_selection_mode == "jitter_confidence":
                    # 模式4: 取并集 - 高 jitter 的 token 和低置信度的 token 都要处理
                    # 前 (max_steps+1)//2 步处理高 jitter 的 token
                    # 后 max_steps//2 步处理低置信度的 token
                    
                    # 从所有 token 中选择（排除已覆盖的）
                    all_indices = np.arange(len(current_generated_ids))
                    if global_covered_indices:
                        available_indices = all_indices[~np.isin(all_indices, list(global_covered_indices))]
                    else:
                        available_indices = all_indices
                    
                    if len(available_indices) == 0:
                        candidate_indices = []
                        jitter_conf_mode_type = None
                    else:
                        # 计算应该处理哪一类
                        jitter_steps = (self.max_steps + 1) // 2  # 前 jitter_steps 步处理高 jitter
                        confidence_steps = self.max_steps // 2  # 后 confidence_steps 步处理低置信度
                        
                        if step_idx <= jitter_steps:
                            # 处理高 jitter 的 token
                            jitter_conf_mode_type = "jitter"
                            available_jitters = jitter_values[available_indices]
                            
                            # 按 jitter 降序排列
                            sorted_by_jitter = np.argsort(available_jitters)[::-1]  # 降序排列，最高的在前面
                            candidate_indices = available_indices[sorted_by_jitter]
                            
                            self.logger.info(f"  -> Jitter-Confidence mode (Step {step_idx}/{self.max_steps}): Processing HIGH JITTER tokens. Selected {len(candidate_indices)} tokens sorted by jitter.")
                        else:
                            # 处理低置信度的 token
                            jitter_conf_mode_type = "confidence"
                            available_confidences = final_confidence_values[available_indices]
                            
                            # 筛选出低置信度的 token（使用中位数作为阈值，选择低于中位数的 token）
                            confidence_median = np.median(available_confidences)
                            low_conf_mask = available_confidences <= confidence_median
                            low_conf_indices = available_indices[low_conf_mask]
                            
                            if len(low_conf_indices) == 0:
                                # 如果没有低于中位数的，则选择最低的 50% 的 token
                                num_low_conf = max(1, len(available_indices) // 2)
                                sorted_by_conf = np.argsort(available_confidences)
                                low_conf_indices = available_indices[sorted_by_conf[:num_low_conf]]
                            
                            # 对低置信度 token 进行词性筛选：只保留 NOUN, ADJ, VERB, ADP
                            original_low_conf_count = len(low_conf_indices)
                            if original_low_conf_count > 0:
                                full_text = self.tokenizer.decode(current_generated_ids, skip_special_tokens=True)
                                nlp = TextMiner.get_nlp()
                                doc = nlp(full_text)
                                
                                pos_filtered_indices = []
                                for idx in low_conf_indices:
                                    if self._check_token_pos(idx, current_generated_ids, doc):
                                        pos_filtered_indices.append(idx)
                                
                                if len(pos_filtered_indices) > 0:
                                    low_conf_indices = np.array(pos_filtered_indices)
                                    self.logger.info(f"  -> Jitter-Confidence mode (Step {step_idx}/{self.max_steps}): POS filtering (NOUN/ADJ/VERB/ADP) applied. {len(pos_filtered_indices)}/{original_low_conf_count} tokens passed.")
                                else:
                                    self.logger.info(f"  -> Jitter-Confidence mode (Step {step_idx}/{self.max_steps}): No tokens passed POS filtering (NOUN/ADJ/VERB/ADP).")
                                    low_conf_indices = np.array([])
                            
                            # 按置信度升序排列（最低的在前面）
                            low_conf_confidences = final_confidence_values[low_conf_indices]
                            sorted_by_conf = np.argsort(low_conf_confidences)  # 升序排列，最低的在前面
                            candidate_indices = low_conf_indices[sorted_by_conf]
                            
                            self.logger.info(f"  -> Jitter-Confidence mode (Step {step_idx}/{self.max_steps}): Processing LOW CONFIDENCE tokens. Selected {len(candidate_indices)} tokens (median: {confidence_median:.4f}), sorted by confidence.")
                    
                    # 使用 jitter 或 confidence 作为 selection_metric（根据当前类型）
                    if jitter_conf_mode_type == "jitter":
                        selection_metric = jitter_values
                        metric_name = "Jitter-Confidence (Jitter)"
                    else:
                        selection_metric = final_confidence_values
                        metric_name = "Jitter-Confidence (Confidence)"
                
                # 智能选点逻辑
                selected_idx = None
                target_phrase = None
                target_span_range = None
                max_val = None
                
                if self.mask_mode == "span":
                    # 模式1: 使用TextMiner解析span（原有方式）
                    full_text = self.tokenizer.decode(current_generated_ids, skip_special_tokens=True)
                    nlp = TextMiner.get_nlp()
                    doc = nlp(full_text)
                    
                    self.logger.info(f"  -> Searching for valid refinement target (NOUN/ADJ) using TextMiner (Selection: {metric_name})...")
                    
                    for idx in candidate_indices:
                        # 对于 jitter 模式，需要检查 jitter_threshold
                        # 对于 jitter_confidence 模式，只在处理 jitter 类型时检查阈值
                        if self.token_selection_mode == "jitter" and jitter_values[idx] < self.jitter_threshold:
                            break
                        elif self.token_selection_mode == "jitter_confidence" and jitter_conf_mode_type == "jitter" and jitter_values[idx] < self.jitter_threshold:
                            break
                        
                        # 检查 idx 是否在有效范围内
                        if idx >= len(current_generated_ids):
                            self.logger.warning(f"     [Skip] Index {idx} out of range (length: {len(current_generated_ids)}). Skipping.")
                            continue
                        
                        prefix_ids = current_generated_ids[:idx]
                        prefix_text = self.tokenizer.decode(prefix_ids, skip_special_tokens=True)
                        char_offset = len(prefix_text)
                        
                        chunk_text, char_span = TextMiner.analyze_token_for_refinement(doc, char_offset)
                        
                        if chunk_text:
                            selected_idx = idx
                            target_phrase = chunk_text
                            if self.token_selection_mode == "jitter":
                                max_val = jitter_values[idx]
                            elif self.token_selection_mode == "confidence":
                                max_val = final_confidence_values[idx] if idx < len(final_confidence_values) else 1.0
                            elif self.token_selection_mode == "jitter_confidence":
                                if jitter_conf_mode_type == "jitter":
                                    max_val = jitter_values[idx] if idx < len(jitter_values) else 0.0
                                else:  # "confidence"
                                    max_val = final_confidence_values[idx] if idx < len(final_confidence_values) else 1.0
                            else:  # random
                                max_val = 0.0  # random 模式不需要 metric 值
                            
                            if self.token_selection_mode == "jitter_confidence":
                                jitter_val = jitter_values[idx] if idx < len(jitter_values) else 0.0
                                conf_val = final_confidence_values[idx] if idx < len(final_confidence_values) else 1.0
                                if jitter_conf_mode_type == "jitter":
                                    metric_str = f"Jitter: {jitter_val:.4f} (Confidence: {conf_val:.4f})"
                                else:
                                    metric_str = f"Confidence: {conf_val:.4f} (Jitter: {jitter_val:.4f})"
                            else:
                                metric_str = f"{metric_name}: {max_val:.4f}"
                            self.logger.info(f"  -> Valid Target Found: '{self.tokenizer.decode([current_generated_ids[idx]])}' (POS Valid). Expanding to chunk: '{target_phrase}' ({metric_str})")
                            
                            start_char, end_char = char_span
                            s_start, s_end = self._map_char_span_to_token_span(
                                current_generated_ids, self.tokenizer, start_char, end_char
                            )
                            
                            if s_start is None:
                                self.logger.warning(f"     [Error] Failed to map char span to tokens. Skipping.")
                                continue
                            
                            target_span_range = [s_start, s_end]
                            break
                        else:
                            # 再次检查 idx 是否在有效范围内（防止在循环中 current_generated_ids 被修改）
                            if idx < len(current_generated_ids):
                                token_text = self.tokenizer.decode([current_generated_ids[idx]], skip_special_tokens=True)
                                self.logger.info(f"     [Skip] Token '{token_text}' (index {idx}) is not a valid refinement target (no valid chunk found).")
                            else:
                                self.logger.warning(f"     [Skip] Index {idx} out of range (length: {len(current_generated_ids)}). Skipping.")
                    
                    if selected_idx is None:
                        self.logger.info(f"  -> No valid refinement target found. Stability reached.")
                        break
                    
                    if self.token_selection_mode == "jitter_confidence":
                        jitter_val = jitter_values[selected_idx] if selected_idx < len(jitter_values) else 0.0
                        conf_val = final_confidence_values[selected_idx] if selected_idx < len(final_confidence_values) else 1.0
                        if jitter_conf_mode_type == "jitter":
                            metric_str = f"Jitter: {jitter_val:.4f} (Confidence: {conf_val:.4f})"
                        else:
                            metric_str = f"Confidence: {conf_val:.4f} (Jitter: {jitter_val:.4f})"
                    else:
                        metric_str = f"{metric_name}: {max_val:.4f}"
                    self.logger.info(f"  -> Selected Focus: '{target_phrase}' ({metric_str})")
                
                elif self.mask_mode == "single":
                    # 模式2: 只mask单个token
                    self.logger.info(f"  -> Searching for token (single token mode, Selection: {metric_name})...")
                    
                    for idx in candidate_indices:
                        # 对于 jitter 模式，需要检查 jitter_threshold
                        # 对于 jitter_confidence 模式，只在处理 jitter 类型时检查阈值
                        if self.token_selection_mode == "jitter" and jitter_values[idx] < self.jitter_threshold:
                            break
                        elif self.token_selection_mode == "jitter_confidence" and jitter_conf_mode_type == "jitter" and jitter_values[idx] < self.jitter_threshold:
                            break
                        
                        # 检查 idx 是否在有效范围内
                        if idx >= len(current_generated_ids):
                            self.logger.warning(f"     [Skip] Index {idx} out of range (length: {len(current_generated_ids)}). Skipping.")
                            continue
                        
                        selected_idx = idx
                        if self.token_selection_mode == "jitter":
                            max_val = jitter_values[idx]
                        elif self.token_selection_mode == "confidence":
                            max_val = final_confidence_values[idx] if idx < len(final_confidence_values) else 1.0
                        elif self.token_selection_mode == "jitter_confidence":
                            if jitter_conf_mode_type == "jitter":
                                max_val = jitter_values[idx] if idx < len(jitter_values) else 0.0
                            else:  # "confidence"
                                max_val = final_confidence_values[idx] if idx < len(final_confidence_values) else 1.0
                        else:  # random
                            max_val = 0.0  # random 模式不需要 metric 值
                        target_span_range = [idx, idx + 1]  # 只mask单个token
                        target_phrase = self.tokenizer.decode([current_generated_ids[idx]], skip_special_tokens=True)
                        
                        if self.token_selection_mode == "jitter_confidence":
                            jitter_val = jitter_values[idx] if idx < len(jitter_values) else 0.0
                            conf_val = final_confidence_values[idx] if idx < len(final_confidence_values) else 1.0
                            if jitter_conf_mode_type == "jitter":
                                metric_str = f"Jitter: {jitter_val:.4f} (Confidence: {conf_val:.4f})"
                            else:
                                metric_str = f"Confidence: {conf_val:.4f} (Jitter: {jitter_val:.4f})"
                        else:
                            metric_str = f"{metric_name}: {max_val:.4f}"
                        self.logger.info(f"  -> Selected Token: '{target_phrase}' ({metric_str}, Index: {idx})")
                        break
                    
                    if selected_idx is None:
                        self.logger.info(f"  -> No valid refinement target found. Stability reached.")
                        break
                
                elif self.mask_mode == "expand":
                    # 模式3: 扩展token左右各4个token
                    self.logger.info(f"  -> Searching for token (expand mode, ±4 tokens, Selection: {metric_name})...")
                    
                    for idx in candidate_indices:
                        # 对于 jitter 模式，需要检查 jitter_threshold
                        # 对于 jitter_confidence 模式，只在处理 jitter 类型时检查阈值
                        if self.token_selection_mode == "jitter" and jitter_values[idx] < self.jitter_threshold:
                            break
                        elif self.token_selection_mode == "jitter_confidence" and jitter_conf_mode_type == "jitter" and jitter_values[idx] < self.jitter_threshold:
                            break
                        
                        # 检查 idx 是否在有效范围内
                        if idx >= len(current_generated_ids):
                            self.logger.warning(f"     [Skip] Index {idx} out of range (length: {len(current_generated_ids)}). Skipping.")
                            continue
                        
                        selected_idx = idx
                        if self.token_selection_mode == "jitter":
                            max_val = jitter_values[idx]
                        elif self.token_selection_mode == "confidence":
                            max_val = final_confidence_values[idx] if idx < len(final_confidence_values) else 1.0
                        elif self.token_selection_mode == "jitter_confidence":
                            if jitter_conf_mode_type == "jitter":
                                max_val = jitter_values[idx] if idx < len(jitter_values) else 0.0
                            else:  # "confidence"
                                max_val = final_confidence_values[idx] if idx < len(final_confidence_values) else 1.0
                        else:  # random
                            max_val = 0.0  # random 模式不需要 metric 值
                        
                        # 扩展左右各4个token
                        expand_size = 4
                        s = max(0, idx - expand_size)
                        e = min(len(current_generated_ids), idx + expand_size + 1)
                        target_span_range = [s, e]
                        target_phrase = self.tokenizer.decode(current_generated_ids[s:e], skip_special_tokens=True)
                        
                        if self.token_selection_mode == "jitter_confidence":
                            jitter_val = jitter_values[idx] if idx < len(jitter_values) else 0.0
                            conf_val = final_confidence_values[idx] if idx < len(final_confidence_values) else 1.0
                            if jitter_conf_mode_type == "jitter":
                                metric_str = f"Jitter: {jitter_val:.4f} (Confidence: {conf_val:.4f})"
                            else:
                                metric_str = f"Confidence: {conf_val:.4f} (Jitter: {jitter_val:.4f})"
                        else:
                            metric_str = f"{metric_name}: {max_val:.4f}"
                        self.logger.info(f"  -> Selected Token: '{self.tokenizer.decode([current_generated_ids[idx]])}' ({metric_str}, Index: {idx})")
                        self.logger.info(f"  -> Expanded Range: [{s}, {e}) ({e-s} tokens)")
                        break
                    
                    if selected_idx is None:
                        self.logger.info(f"  -> No valid refinement target found. Stability reached.")
                        break
                
                # 所有模式都支持 Grounding（如果 image_input_mode 不是 "original"）
                # 对于 jitter_confidence 模式，需要根据当前类型决定图像输入方式
                crop_path = None
                dino_conf = None
                
                # jitter_confidence 模式特殊处理：低置信度 token 只使用全局图像
                effective_image_mode = self.image_input_mode
                if self.token_selection_mode == "jitter_confidence" and jitter_conf_mode_type == "confidence":
                    # 处理低置信度 token 时，强制使用全局图像
                    effective_image_mode = "original"
                    self.logger.info(f"  -> Jitter-Confidence mode (confidence): Forcing original image only for low-confidence tokens.")
                elif self.token_selection_mode == "jitter_confidence" and jitter_conf_mode_type == "jitter":
                    # 处理高 jitter token 时，强制使用双视角输入
                    if self.image_input_mode != "original":
                        effective_image_mode = "both"
                        self.logger.info(f"  -> Jitter-Confidence mode (jitter): Forcing dual-image input for high-jitter tokens.")
                
                if effective_image_mode == "original":
                    active_images = [original_image_obj]
                    self.logger.info(f"  -> Using original image only ({original_image_obj.size})")
                else:
                    # Grounding (所有模式都支持)
                    # self._move_model_to_gpu(self.grounder.model)
                    
                    # 根据 local_refinement_mode 选择处理方式
                    if self.local_refinement_mode == "bbox":
                        # BBox 模式：获取边界框坐标并在原图上绘制
                        bbox_result, dino_conf = self.grounder.run_grounding(
                            target_phrase, self.temp_dir, image_path, return_bbox=True
                        )
                        
                        if bbox_result:
                            bbox = bbox_result
                            # 在原图上绘制红色边界框
                            annotated_image = self._draw_bbox_on_image(original_image_obj.copy(), bbox)
                            
                            # 根据 effective_image_mode 决定输入图像
                            if effective_image_mode == "crop":
                                active_images = [annotated_image]
                                self.logger.info(f"  -> Using annotated image with bbox ({annotated_image.size}, BBox: {bbox})")
                            else:  # "both"
                                # bbox 模式下，"both" 实际上只使用标注后的原图
                                active_images = [annotated_image]
                                self.logger.info(f"  -> Using annotated image with bbox ({annotated_image.size}, BBox: {bbox})")
                        else:
                            # 没有找到 bbox 的情况
                            if dino_conf:
                                self.logger.info(f"  -> [REJECT] Visual evidence weak (Conf: {dino_conf:.2f} < 0.40). Skipping.")
                            else:
                                self.logger.info("  -> No bbox found. SKIPPING refinement for this step.")
                            if target_span_range:
                                for i in range(target_span_range[0], target_span_range[1]):
                                    global_covered_indices.add(i)
                            continue
                    else:
                        # Crop 模式：原有的裁剪逻辑
                        crop_path, dino_conf = self.grounder.run_grounding(
                            target_phrase, self.temp_dir, image_path, return_bbox=False
                        )
                        # self._move_model_to_cpu(self.grounder.model)
                        
                        # 加载 Crop 并根据 effective_image_mode 决定输入图像
                        if crop_path:
                            crop_img = Image.open(crop_path).convert("RGB")
                            # if hasattr(self.image_processor, 'image_mean') and self.image_processor.image_mean is not None:
                            #     mean = self.image_processor.image_mean
                            #     bg_color = tuple(int(x * 255) for x in mean)
                            # else:
                            #     bg_color = (122, 116, 104)
                            # crop_img = self._expand2square(crop_img, bg_color)
                            crop_img = crop_img.resize((384, 384), Image.Resampling.LANCZOS)
                            
                            # 根据 effective_image_mode 决定输入图像
                            if effective_image_mode == "crop":
                                active_images = [crop_img]
                                self.logger.info(f"  -> Using crop image only ({crop_img.size})")
                            else:  # "both"
                                active_images = [original_image_obj, crop_img]
                                self.logger.info(f"  -> Using dual-image input: original image ({original_image_obj.size}) + crop image ({crop_img.size})")
                        else:
                            # 没有 crop_path 的情况
                            if effective_image_mode == "crop":
                                # 如果只需要局部图，但没有 crop，则跳过
                                if dino_conf:
                                    self.logger.info(f"  -> [REJECT] Visual evidence weak (Conf: {dino_conf:.2f} < 0.40). Skipping.")
                                else:
                                    self.logger.info("  -> No crop found. SKIPPING refinement for this step (crop mode requires crop).")
                                if target_span_range:
                                    for i in range(target_span_range[0], target_span_range[1]):
                                        global_covered_indices.add(i)
                                continue
                            else:  # "both"
                                # 如果需要两个图，但没有 crop，则跳过
                                # if dino_conf:
                                #     self.logger.info(f"  -> [REJECT] Visual evidence weak (Conf: {dino_conf:.2f} < 0.40). Skipping.")
                                # else:
                                #     self.logger.info("  -> No crop found. SKIPPING refinement for this step.")
                                active_images = [original_image_obj]
                                self.logger.info(f"  -> Using original image only ({original_image_obj.size})")
                                if target_span_range:
                                    for i in range(target_span_range[0], target_span_range[1]):
                                        global_covered_indices.add(i)
                                # continue
            else:
                self.logger.info(f"\n>>> [Phase {step_idx}] Initial Generation...")
                active_images = [original_image_obj]
            
            # 构建 Prompt
            self._flush()
            
            image_tensor = process_images(active_images, self.image_processor, self.model.config)
            if isinstance(image_tensor, list):
                image_tensor = [_img.to(dtype=torch.float16, device=self.device) for _img in image_tensor]
            else:
                image_tensor = image_tensor.to(dtype=torch.float16, device=self.device)
            
            img_tokens_str = " ".join([DEFAULT_IMAGE_TOKEN] * len(active_images))
            kwargs_text_ids = {}
            
            if step_idx == 0:
                prompt_text = img_tokens_str + "\n" + base_instruction
            else:
                masked_ids = current_generated_ids.copy()
                
                if target_span_range:
                    s = target_span_range[0]
                    e = target_span_range[1]
                    
                    # 只在 span 模式下应用 mask_expansion
                    if self.mask_mode == "span":
                        for _ in range(self.mask_expansion):
                            if s > 0:
                                prev_token_id = current_generated_ids[s - 1]
                                prev_token = self.tokenizer.decode([prev_token_id])
                                if '.' in prev_token or '!' in prev_token or '?' in prev_token or ',' in prev_token:
                                    break
                                s -= 1
                        for _ in range(self.mask_expansion):
                            if e < len(current_generated_ids):
                                next_token_id = current_generated_ids[e]
                                next_token = self.tokenizer.decode([next_token_id])
                                if '.' in next_token or '!' in next_token or '?' in next_token or ',' in next_token:
                                    break
                                e += 1
                    
                    original_text_segment = self.tokenizer.decode(masked_ids[s:e], skip_special_tokens=False)
                    self.logger.info(f"  -> Applied Masking on: '{original_text_segment}' (Indices: {s}-{e}, Mode: {self.mask_mode})")
                    
                    for i in range(s, e):
                        masked_ids[i] = self.mask_token_id
                        global_covered_indices.add(i)
                
                text_token_ids_tensor = torch.tensor([masked_ids], device=self.device)
                kwargs_text_ids = {"text_token_ids": text_token_ids_tensor}
                
                # 根据实际使用的图像输入模式调整 prompt（使用 effective_image_mode 和 local_refinement_mode）
                # 对于 jitter_confidence 模式，effective_image_mode 已经在前面设置好了
                if self.local_refinement_mode == "bbox":
                    # BBox 模式：使用标注了边界框的原图
                    refine_instruction = (
                        "\nYou are presented with the original full image with a red bounding box highlighting the region of interest, along with a text description with missing parts. "
                        "Task: Analyze the specific visual attributes (such as texture, pattern, color, and object shape) "
                        "within the red bounding box region to accurately restore the missing text. "
                        "Instruction: The masked part of the text describes the visual evidence within the red bounding box. "
                        "Pay close attention to the details inside the bounding box. "
                        "Do not guess based on common language patterns. Instead, look closely at the region marked by the red box to accurately "
                        "restore the missing text. What you see in the bounding box is the ground truth."
                    )
                elif self.token_selection_mode == "jitter_confidence" and jitter_conf_mode_type == "confidence":
                    # 低置信度 token：使用全局图像 prompt
                    refine_instruction = (
                        "\nYou are presented with the original full image along with a text description with missing parts. "
                        "Task: Analyze the specific visual attributes (such as texture, pattern, color, and object shape) "
                        "in the image to accurately restore the missing text. "
                        "Instruction: The masked part of the text describes specific visual evidence in the image. "
                        "Do not guess based on common language patterns. Instead, look closely at the image to accurately "
                        "restore the missing text. What you see in the image is the ground truth."
                    )
                elif self.token_selection_mode == "jitter_confidence" and jitter_conf_mode_type == "jitter":
                    # 高 jitter token：使用双视角 prompt
                    refine_instruction = (
                        "\nYou are presented with two images: the original full image and a zoomed-in visual detail, along with a text description with missing parts. "
                        "Task: Analyze the specific visual attributes (such as texture, pattern, color, and object shape) "
                        "in the provided zoomed-in image crop, while considering the context from the original full image. "
                        "Instruction: The masked part of the text describes this specific visual evidence shown in the zoomed-in crop. "
                        "Do not guess based on common language patterns. Instead, look closely at the zoomed-in crop to accurately "
                        "restore the missing text. What you see in the crop is the ground truth. "
                        "The first image is the original full image, and the second image is the zoomed-in detail."
                    )
                else:
                    # 其他模式：根据 image_input_mode 决定（crop 模式）
                    if self.image_input_mode == "original":
                        refine_instruction = (
                            "\nYou are presented with the original full image along with a text description with missing parts. "
                            "Task: Analyze the specific visual attributes (such as texture, pattern, color, and object shape) "
                            "in the image to accurately restore the missing text. "
                            "Instruction: The masked part of the text describes specific visual evidence in the image. "
                            "Do not guess based on common language patterns. Instead, look closely at the image to accurately "
                            "restore the missing text. What you see in the image is the ground truth."
                        )
                    elif self.image_input_mode == "crop":
                        refine_instruction = (
                            "\nYou are presented with a zoomed-in visual detail along with a text description with missing parts. "
                            "Task: Analyze the specific visual attributes (such as texture, pattern, color, and object shape) "
                            "in the provided zoomed-in image crop to accurately restore the missing text. "
                            "Instruction: The masked part of the text describes this specific visual evidence shown in the zoomed-in crop. "
                            "Do not guess based on common language patterns. Instead, look closely at the zoomed-in crop to accurately "
                            "restore the missing text. What you see in the crop is the ground truth."
                        )
                    else:  # "both"
                        refine_instruction = (
                            "\nYou are presented with two images: the original full image and a zoomed-in visual detail, along with a text description with missing parts. "
                            "Task: Analyze the specific visual attributes (such as texture, pattern, color, and object shape) "
                            "in the provided zoomed-in image crop, while considering the context from the original full image. "
                            "Instruction: The masked part of the text describes this specific visual evidence shown in the zoomed-in crop. "
                            "Do not guess based on common language patterns. Instead, look closely at the zoomed-in crop to accurately "
                            "restore the missing text. What you see in the crop is the ground truth. "
                            "The first image is the original full image, and the second image is the zoomed-in detail."
                        )
                prompt_text = img_tokens_str + refine_instruction
            
            # 对话模版处理
            conv = copy.deepcopy(conv_templates["llava_llada"])
            conv.append_message(conv.roles[0], prompt_text)
            conv.append_message(conv.roles[1], None)
            prompt_ids = tokenizer_image_token(conv.get_prompt(), self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).to(self.device)
            
            # 模型推理
            self._flush()
            
            try:
                with torch.inference_mode():
                    result = self.model.generate(
                        prompt_ids,
                        images=image_tensor,
                        image_sizes=[img.size for img in active_images],
                        steps=128,
                        gen_length=128,
                        block_length=128,
                        tokenizer=self.tokenizer,
                        stopping_criteria=['<|eot_id|>'],
                        return_confidences=True,
                        save_confidence_interval=1,
                        is_initial_generation=(step_idx == 0),
                        repetition_penalty=1.3,  # 增加重复惩罚，避免重复生成
                        no_repeat_ngram_size=3,
                        **kwargs_text_ids
                    )
            except RuntimeError as e:
                if "out of memory" in str(e):
                    self.logger.error(f"!!! CUDA OOM at Step {step_idx}. Exiting.")
                    if 'last_intermediate_confidences' in locals():
                        del last_intermediate_confidences
                    del image_tensor, prompt_ids
                    self._flush()
                    break
                else:
                    raise e
            
            # 结果处理
            if isinstance(result, tuple):
                if len(result) == 3:
                    cont, _, intermediate_confidences = result
                elif len(result) == 2:
                    cont, _ = result
                    intermediate_confidences = None
            else:
                cont = result
                intermediate_confidences = None
            
            last_intermediate_confidences = intermediate_confidences
            
            if hasattr(cont, 'cpu'):
                current_generated_ids = cont[0].cpu().numpy().tolist()
            elif isinstance(cont, list):
                current_generated_ids = cont[0] if len(cont) > 0 else []
            
            if current_generated_ids:
                text_res = self.tokenizer.decode(current_generated_ids, skip_special_tokens=True)
                self.logger.info(f"--> Result: {text_res.replace(chr(10), ' ')}")
            
            # 清理
            del result, cont, image_tensor, prompt_ids
            if 'text_token_ids_tensor' in locals():
                del text_token_ids_tensor
            if 'crop_img' in locals():
                del crop_img
            if 'annotated_image' in locals():
                del annotated_image
            self._flush()
        
        # 生成最终响应
        if current_generated_ids:
            final_response = self.tokenizer.decode(current_generated_ids, skip_special_tokens=True)
        else:
            final_response = ""
        
        metadata = {
            "total_steps": step_idx + 1,
            "final_step": step_idx
        }
        
        self.logger.info(f"\n{'=' * 20} Loop Completed {'=' * 20}")
        
        return final_response, metadata

