# -*- coding: utf-8 -*-
"""
改进的COCO Localized Narratives处理代码
支持基于短语的更精确的轨迹分割，并集成Qwen2.5-VL-72B模型进行定位矫正
"""
import json
import collections
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
from rdp import rdp
import json
import re
from difflib import SequenceMatcher
import datetime
import torch
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
# from transformers import AutoModelForObjectDetection, AutoConfig
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection
from torchvision import transforms

# 尝试导入正确的Qwen模型类
try:
    from transformers import Qwen2_5_VLForConditionalGeneration
    QWEN_MODEL_CLASS = Qwen2_5_VLForConditionalGeneration
except ImportError:
    try:
        from transformers import Qwen2VLForConditionalGeneration
        QWEN_MODEL_CLASS = Qwen2VLForConditionalGeneration
    except ImportError:
        print("警告：无法导入Qwen2.5-VL模型类，请检查transformers版本")
        QWEN_MODEL_CLASS = None


class QwenVLLocationCorrector:
    """
    使用Qwen2.5-VL-72B模型进行定位矫正的类
    """
    
    def __init__(self, model_path="/storage-root/9950backfile/yangfan/coyo/Qwen/Qwen2.5-VL-72B-Instruct", dino_path_name="/storage-root/9950backfile/yangfan/coyo/grounding-dino/", num_gpus=4):
        """
        初始化Qwen2.5-VL模型
        """
        print(f"正在加载Qwen2.5-VL模型: {model_path}")
        print(f"使用 {num_gpus} 个GPU")
        
        # 设置多GPU环境变量
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in range(num_gpus)])
        self.transform = transforms.Compose([
            transforms.Resize((800, 800)),  # Grounding DINO 常用输入尺寸
            transforms.ToTensor(),  # 转换为张量
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # ImageNet 归一化
        ])
        
        try:
            # 配置多GPU设备映射
            if num_gpus > 1:
                device_map = {}
                # 将模型层分布到多个GPU上
                device_map = "auto"  # 让transformers自动分配
                print(f"使用自动设备映射分布到 {num_gpus} 个GPU")
            else:
                device_map = "cuda:0"
                print("使用单GPU模式")

            self.device=torch.device("cuda:3")
            self.dino_processor = AutoProcessor.from_pretrained(dino_path_name)
            self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
                dino_path_name,
            ).to(self.device)
            
            # 参考qwen2_5_vl.py的加载方式
            self.model = QWEN_MODEL_CLASS.from_pretrained(
                model_path, 
                torch_dtype=torch.bfloat16,
                device_map=device_map,
                attn_implementation="flash_attention_2",
            ).eval()
            
            self.processor = AutoProcessor.from_pretrained(
                model_path, 
                max_pixels=12845056,  # 参考qwen2_5_vl.py中的设置
                min_pixels=3136
            )
            
            print("Qwen2.5-VL模型加载完成")
            
        except Exception as e:
            print(f"模型加载失败，尝试备用方案: {e}")
            self.device=torch.device("cuda")
            self.dino_processor = AutoProcessor.from_pretrained(dino_path_name)
            self.dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
                dino_path_name,
            ).to(self.device)
            try:
                # 备用方案：不使用flash_attention_2
                self.model = QWEN_MODEL_CLASS.from_pretrained(
                    model_path, 
                    torch_dtype="auto",
                    device_map="auto",
                ).eval()
                
                self.processor = AutoProcessor.from_pretrained(model_path)
                print("Qwen2.5-VL模型加载完成（备用方案）")
                
            except Exception as e2:
                print(f"备用方案也失败，使用在线模型: {e2}")
                # 最后尝试在线模型
                self.model = QWEN_MODEL_CLASS.from_pretrained(
                    "Qwen/Qwen2.5-VL-72B-Instruct", 
                    torch_dtype=torch.bfloat16,
                    device_map="auto",
                ).eval()
                
                self.processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-72B-Instruct")
                print("Qwen2.5-VL模型加载完成（在线模型）")
    
    def calculate_bbox_quality_score(self, bbox, phrase, image_path):
        """
        使用Qwen2.5-VL评估bbox定位质量
        
        Args:
            bbox: (xmin, ymin, xmax, ymax) 归一化坐标
            phrase: 对应的描述短语
            image_path: 图像路径
            
        Returns:
            quality_score: 0-1之间的质量分数，越高表示定位越准确
        """
        try:
            # 加载图像
            image = Image.open(image_path).convert("RGB")
            
            # 构建提示词 - 参考qwen2_5_vl.py的消息格式
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image_path,
                        },
                        {
                            "type": "text", 
                            "text": f"""请评估给定边界框是否准确定位了描述的内容。

描述短语: "{phrase}"
边界框坐标 (归一化): xmin={bbox[0]:.3f}, ymin={bbox[1]:.3f}, xmax={bbox[2]:.3f}, ymax={bbox[3]:.3f}

请给出一个0-1之间的质量分数，其中：
- 1.0: 边界框完美包含描述的内容
- 0.8-0.9: 边界框很好地包含描述的内容，略有偏差
- 0.6-0.7: 边界框包含描述的内容，但有明显偏差
- 0.4-0.5: 边界框部分包含描述的内容
- 0.0-0.3: 边界框没有包含描述的内容或严重偏差

只需要返回数字分数，不需要其他解释。"""
                        }
                    ]
                }
            ]
            
            # 参考qwen2_5_vl.py的处理方式
            text = self.processor.apply_chat_template(
                [messages], tokenize=False, add_generation_prompt=True
            )[0]
            
            image_inputs, video_inputs = process_vision_info([messages])
            
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            
            # 移动到正确的设备
            inputs = inputs.to(self.model.device)
            
            # 生成回复
            outputs = self.model.generate(
                **inputs, 
                max_new_tokens=50,
                temperature=0.0,  # 确定性输出
                do_sample=False,
            )
            
            # 解码新生成的部分
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, outputs)
            ]
            output_text = self.processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )[0]
            
            # 提取分数
            try:
                # 使用正则表达式提取数字
                numbers = re.findall(r'0\.\d+|1\.0|1|0', output_text)
                if numbers:
                    score = float(numbers[0])
                    return max(0.0, min(1.0, score))  # 确保在0-1范围内
                else:
                    print(f"无法从回复中提取分数: {output_text}")
                    return 0.5  # 默认分数
            except:
                print(f"分数解析失败: {output_text}")
                return 0.5
                
        except Exception as e:
            print(f"质量评估失败: {str(e)}")
            return 0.5  # 默认分数

    def get_corrected_bbox_with_dino(self, phrase, image_path, original_bbox=None):
        """
                使用 Grounding DINO 获取矫正后的边界框

                Args:
                    phrase: 描述短语
                    image_path: 图像路径（支持本地路径或 URL）
                    original_bbox: 原始边界框 (可选)

                Returns:
                    corrected_bbox: (xmin, ymin, xmax, ymax) 归一化坐标
                """
        try:
            # 加载图像
            if image_path.startswith("http"):
                image = Image.open(requests.get(image_path, stream=True).raw).convert("RGB")
            else:
                image = Image.open(image_path).convert("RGB")

            # 格式化文本（小写并以点号结尾）
            text = f"{phrase.lower()}."

            # 处理输入
            inputs = self.dino_processor(images=image, text=text, return_tensors="pt").to(self.device)

            # 模型推理
            with torch.no_grad():
                outputs = self.dino_model(**inputs)

            # 后处理输出
            results = self.dino_processor.post_process_grounded_object_detection(
                outputs,
                inputs.input_ids,
                box_threshold=0.35,  # 与原代码保持一致
                text_threshold=0.25,
                target_sizes=[image.size[::-1]]  # [width, height]
            )

            # 处理结果
            if results and len(results[0]["boxes"]) > 0:
                # 选择置信度最高的边界框
                scores = results[0]["scores"]
                boxes = results[0]["boxes"]
                max_idx = torch.argmax(scores).item()
                best_box = boxes[max_idx].tolist()  # [xmin, ymin, xmax, ymax]

                # 归一化坐标
                img_width, img_height = image.size
                xmin = max(0.0, min(1.0, best_box[0] / img_width))
                ymin = max(0.0, min(1.0, best_box[1] / img_height))
                xmax = max(0.0, min(1.0, best_box[2] / img_width))
                ymax = max(0.0, min(1.0, best_box[3] / img_height))

                # 确保最小值小于最大值
                if xmin >= xmax or ymin >= ymax:
                    print("无效的边界框，坐标范围错误")
                    return original_bbox if original_bbox else (0.0, 0.0, 1.0, 1.0)

                return (xmin, ymin, xmax, ymax)
            else:
                print(f"未检测到目标: {phrase}")
                return original_bbox if original_bbox else (0.0, 0.0, 1.0, 1.0)

        except Exception as e:
            print(f"边界框矫正失败: {str(e)}")
            return original_bbox if original_bbox else (0.0, 0.0, 1.0, 1.0)
    
    def get_corrected_bbox(self, phrase, image_path, original_bbox=None):
        """
        使用Qwen2.5-VL获取矫正后的边界框
        
        Args:
            phrase: 描述短语
            image_path: 图像路径
            original_bbox: 原始边界框 (可选)
            
        Returns:
            corrected_bbox: (xmin, ymin, xmax, ymax) 归一化坐标
        """
        try:
            # 加载图像
            image = Image.open(image_path).convert("RGB")
            
            # 构建提示词
            original_info = ""
            if original_bbox:
                original_info = f"\n原始定位边界框: xmin={original_bbox[0]:.3f}, ymin={original_bbox[1]:.3f}, xmax={original_bbox[2]:.3f}, ymax={original_bbox[3]:.3f}"
            
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image_path,
                        },
                        {
                            "type": "text",
                            "text": f"""请在图像中准确定位以下描述的内容，并给出归一化的边界框坐标。

描述内容: "{phrase}"{original_info}

请返回最准确的边界框坐标，格式为: xmin,ymin,xmax,ymax
坐标应该是0-1之间的归一化值。
只需要返回坐标数字，用逗号分隔，不需要其他文字。"""
                        }
                    ]
                }
            ]
            
            # 准备输入
            text = self.processor.apply_chat_template(
                [messages], tokenize=False, add_generation_prompt=True
            )[0]
            
            image_inputs, video_inputs = process_vision_info([messages])
            
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(self.model.device)
            
            # 生成回复
            outputs = self.model.generate(
                **inputs, 
                max_new_tokens=100,
                temperature=0.0,
                do_sample=False,
            )
            
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, outputs)
            ]
            
            output_text = self.processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )[0]
            
            # 解析坐标
            try:
                # 使用正则表达式提取坐标
                numbers = re.findall(r'0\.\d+|1\.0|1|0', output_text)
                if len(numbers) >= 4:
                    xmin = float(numbers[0])
                    ymin = float(numbers[1])
                    xmax = float(numbers[2])
                    ymax = float(numbers[3])
                    
                    # 确保坐标在合理范围内
                    xmin = max(0.0, min(1.0, xmin))
                    ymin = max(0.0, min(1.0, ymin))
                    xmax = max(0.0, min(1.0, xmax))
                    ymax = max(0.0, min(1.0, ymax))
                    
                    # 确保最小值小于最大值
                    if xmin >= xmax:
                        xmin, xmax = 0.0, 1.0
                    if ymin >= ymax:
                        ymin, ymax = 0.0, 1.0
                        
                    return (xmin, ymin, xmax, ymax)
                else:
                    print(f"无法从回复中提取足够的坐标: {output_text}")
                    return original_bbox if original_bbox else (0.0, 0.0, 1.0, 1.0)
            except:
                print(f"坐标解析失败: {output_text}")
                return original_bbox if original_bbox else (0.0, 0.0, 1.0, 1.0)
                
        except Exception as e:
            print(f"边界框矫正失败: {str(e)}")
            return original_bbox if original_bbox else (0.0, 0.0, 1.0, 1.0)


def load_jsonl_by_index(jsonl_path, selected_index=1):
    """
    读取jsonl文件，返回第selected_index条（从1开始计数）数据
    """
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f, 1):
            if i == selected_index:
                return json.loads(line)
    raise IndexError(f"JSONL文件中没有第{selected_index}条数据")


def visualize_all_boxes(
    image_id, segmented_xs, segmented_ys, xmins, xmaxs, ymins, ymaxs,
    tokens, segmentation_method, image_base_path, dataset_id, full_caption, output_dir="."):
    """
    生成三种可视化图：原图、轨迹线图、坐标框图
    """
    # Construct the path to the actual image file.
    split_name = dataset_id.split('_')[1]
    image_filename = f"{int(image_id):012d}.jpg"
    full_image_path = os.path.join(image_base_path, split_name, image_filename)
    
    try:
        im = Image.open(full_image_path)
        imw, imh = im.size
        # Convert PIL image to numpy array for matplotlib
        im_array = np.array(im)
        # 确保输出目录存在
        os.makedirs(output_dir, exist_ok=True)
        # Save the original image to the output directory
        original_image_output_path = os.path.join(output_dir, f"original_{image_id}.jpg")
        im.save(original_image_output_path)
        print(f"Saved original image to: {original_image_output_path}")
    except FileNotFoundError:
        print(f"Image not found at {full_image_path}, using a blank canvas instead.")
        # Fallback to a blank canvas if image is not found.
        imw, imh = 640, 480
        im_array = np.ones((imh, imw, 3), dtype=np.uint8) * 255  # White background

    # Use a colormap to get distinct colors for each segment
    colors = plt.get_cmap('gist_rainbow')(np.linspace(0, 1, len(xmins)))

    # 1. 生成带轨迹线的图
    fig1, ax1 = plt.subplots(figsize=(12, 9))
    ax1.imshow(im_array)
    ax1.set_xlim(0, imw)
    ax1.set_ylim(imh, 0)  # Flip y-axis to match image coordinates

    for i in range(len(xmins)):
        segment_color = colors[i]
        
        # Plot the trace segment with the corresponding color
        if segmented_xs[i]:  # Check if segment is not empty
            seg_xs_scaled = [x * imw for x in segmented_xs[i]]
            seg_ys_scaled = [y * imh for y in segmented_ys[i]]
            ax1.plot(seg_xs_scaled, seg_ys_scaled, linewidth=3, color=segment_color, alpha=0.8)

    ax1.set_title(f"Trace Lines for Image {image_id}\nMethod: {segmentation_method}", fontsize=16)
    ax1.axis('off')
    
    # Save trace lines image
    trace_filename = f"traces_{image_id}_{segmentation_method.replace(' ', '_').replace('/', '_')}.png"
    plt.savefig(os.path.join(output_dir, trace_filename), bbox_inches='tight', dpi=300)
    plt.close(fig1)
    print(f"Saved trace lines image to: {os.path.join(output_dir, trace_filename)}")

    # 2. 生成坐标框图
    fig2, ax2 = plt.subplots(figsize=(12, 9))
    ax2.imshow(im_array)
    ax2.set_xlim(0, imw)
    ax2.set_ylim(imh, 0)  # Flip y-axis to match image coordinates

    for i in range(len(xmins)):
        segment_color = colors[i]
        
        # 为每个短语的轨迹点创建边界框
        if segmented_xs[i] and segmented_ys[i]:  # Check if segment is not empty
            # 计算轨迹点的边界框
            min_x = min(segmented_xs[i]) * imw
            max_x = max(segmented_xs[i]) * imw
            min_y = min(segmented_ys[i]) * imh
            max_y = max(segmented_ys[i]) * imh
            
            # 创建矩形框
            rect = patches.Rectangle(
                (min_x, min_y),
                max_x - min_x,
                max_y - min_y,
                linewidth=3,
                edgecolor=segment_color,
                facecolor='none',
                alpha=0.8
            )
            ax2.add_patch(rect)
            
            # 添加标签
            if tokens and i < len(tokens):
                # 在框的左上角添加短语标签
                ax2.text(min_x, min_y - 10, f"{i+1}: {tokens[i][:30]}...", 
                        fontsize=10, color=segment_color, weight='bold',
                        bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=2))

    ax2.set_title(f"Bounding Boxes for Image {image_id}\nMethod: {segmentation_method}", fontsize=16)
    ax2.axis('off')
    
    # Save bounding boxes image
    bbox_filename = f"bboxes_{image_id}_{segmentation_method.replace(' ', '_').replace('/', '_')}.png"
    plt.savefig(os.path.join(output_dir, bbox_filename), bbox_inches='tight', dpi=300)
    plt.close(fig2)
    print(f"Saved bounding boxes image to: {os.path.join(output_dir, bbox_filename)}")

    # 3. 生成带标注的综合图（可选）
    fig3, ax3 = plt.subplots(figsize=(12, 9))
    ax3.imshow(im_array)
    ax3.set_xlim(0, imw)
    ax3.set_ylim(imh, 0)

    # 同时显示轨迹线和边界框
    for i in range(len(xmins)):
        segment_color = colors[i]
        
        if segmented_xs[i] and segmented_ys[i]:
            # 绘制轨迹线
            seg_xs_scaled = [x * imw for x in segmented_xs[i]]
            seg_ys_scaled = [y * imh for y in segmented_ys[i]]
            ax3.plot(seg_xs_scaled, seg_ys_scaled, linewidth=2, color=segment_color, alpha=0.6)
            
            # 绘制边界框
            min_x = min(segmented_xs[i]) * imw
            max_x = max(segmented_xs[i]) * imw
            min_y = min(segmented_ys[i]) * imh
            max_y = max(segmented_ys[i]) * imh
            
            rect = patches.Rectangle(
                (min_x, min_y),
                max_x - min_x,
                max_y - min_y,
                linewidth=2,
                edgecolor=segment_color,
                facecolor='none',
                alpha=0.8,
                linestyle='--'
            )
            ax3.add_patch(rect)

    ax3.set_title(f"Combined Visualization for Image {image_id}\nMethod: {segmentation_method}", fontsize=16)
    
    # Add caption with colors
    if tokens:
        fig_width = fig3.get_window_extent().width
        y_pos = 0.03
        
        renderer = fig3.canvas.get_renderer()
        text_objs = [plt.text(0, 0, f"{t} ", fontsize=10, ha='left', va='bottom') for t in tokens]
        total_width_pixels = sum(t.get_window_extent(renderer).width for t in text_objs)
        for t in text_objs:
            t.remove()

        start_x = 0.5 - (total_width_pixels / fig_width) / 2
        current_x = start_x
        
        for i, token in enumerate(tokens):
            if i < len(colors):
                color = colors[i]
                txt_obj = plt.text(current_x, y_pos, f"{token} ", transform=fig3.transFigure,
                                 ha='left', va='bottom', fontsize=10, color=color,
                                 bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1))
                current_x += txt_obj.get_window_extent(renderer).width / fig_width

    plt.subplots_adjust(bottom=0.15)
    ax3.axis('off')
    
    # Save combined image
    combined_filename = f"combined_{image_id}_{segmentation_method.replace(' ', '_').replace('/', '_')}.png"
    plt.savefig(os.path.join(output_dir, combined_filename), bbox_inches='tight', dpi=300)
    plt.close(fig3)
    print(f"Saved combined image to: {os.path.join(output_dir, combined_filename)}")
    
    print(f"All visualizations saved to: {os.path.abspath(output_dir)}")


def visualize_qwen_optimization_comparison(
    image_id, original_segmented_xs, original_segmented_ys, original_xmins, original_xmaxs, 
    original_ymins, original_ymaxs, optimized_bboxes, tokens, segmentation_method, 
    image_base_path, dataset_id, full_caption, output_dir=".", quality_scores=None):
    """
    生成包含原始定位和Qwen优化定位的对比可视化图
    """
    # Construct the path to the actual image file.
    split_name = dataset_id.split('_')[1]
    image_filename = f"{int(image_id):012d}.jpg"
    full_image_path = os.path.join(image_base_path, split_name, image_filename)
    
    try:
        im = Image.open(full_image_path)
        imw, imh = im.size
        im_array = np.array(im)
        os.makedirs(output_dir, exist_ok=True)
        
        # Save the original image
        original_image_output_path = os.path.join(output_dir, f"original_{image_id}.jpg")
        im.save(original_image_output_path)
        print(f"Saved original image to: {original_image_output_path}")
    except FileNotFoundError:
        print(f"Image not found at {full_image_path}, using a blank canvas instead.")
        imw, imh = 640, 480
        im_array = np.ones((imh, imw, 3), dtype=np.uint8) * 255

    colors = plt.get_cmap('gist_rainbow')(np.linspace(0, 1, len(tokens)))

    # 生成对比图：原始定位 vs Qwen优化定位
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12))
    
    # 左图：原始定位
    ax1.imshow(im_array)
    ax1.set_xlim(0, imw)
    ax1.set_ylim(imh, 0)
    
    for i in range(len(original_xmins)):
        if i < len(colors):
            segment_color = colors[i]
            
            # 绘制原始边界框
            min_x = original_xmins[i] * imw
            max_x = original_xmaxs[i] * imw
            min_y = original_ymins[i] * imh
            max_y = original_ymaxs[i] * imh
            
            rect = patches.Rectangle(
                (min_x, min_y),
                max_x - min_x,
                max_y - min_y,
                linewidth=3,
                edgecolor=segment_color,
                facecolor='none',
                alpha=0.8
            )
            ax1.add_patch(rect)
            
            # 添加标签
            if tokens and i < len(tokens):
                label_text = f"{i+1}: {tokens[i][:30]}..."
                if quality_scores and i < len(quality_scores):
                    label_text += f" (Q:{quality_scores[i]:.2f})"
                
                ax1.text(min_x, min_y - 10, label_text, 
                        fontsize=10, color=segment_color, weight='bold',
                        bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=2))

    ax1.set_title(f"Original {segmentation_method}\nImage {image_id}", fontsize=16)
    ax1.axis('off')
    
    # 右图：Qwen优化后定位
    ax2.imshow(im_array)
    ax2.set_xlim(0, imw)
    ax2.set_ylim(imh, 0)
    
    for i in range(len(optimized_bboxes)):
        if i < len(colors) and optimized_bboxes[i] is not None:
            segment_color = colors[i]
            
            # 绘制优化后的边界框
            bbox = optimized_bboxes[i]
            min_x = bbox[0] * imw
            min_y = bbox[1] * imh
            max_x = bbox[2] * imw
            max_y = bbox[3] * imh
            
            rect = patches.Rectangle(
                (min_x, min_y),
                max_x - min_x,
                max_y - min_y,
                linewidth=3,
                edgecolor=segment_color,
                facecolor='none',
                alpha=0.8
            )
            ax2.add_patch(rect)
            
            # 添加标签
            if tokens and i < len(tokens):
                ax2.text(min_x, min_y - 10, f"{i+1}: {tokens[i][:30]}...", 
                        fontsize=10, color=segment_color, weight='bold',
                        bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=2))

    ax2.set_title(f"Qwen2.5-VL Optimized {segmentation_method}\nImage {image_id}", fontsize=16)
    ax2.axis('off')
    
    # 保存对比图
    comparison_filename = f"qwen_optimization_{image_id}_{segmentation_method.replace(' ', '_').replace('/', '_')}.png"
    plt.savefig(os.path.join(output_dir, comparison_filename), bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Saved Qwen optimization comparison to: {os.path.join(output_dir, comparison_filename)}")


def get_json_anno_external(json_anno):
    """Get transcription and trace for each of its element from a json dict."""
    traces = json_anno['traces']
    all_traces = [point for trace in traces for point in trace]
    
    if not all_traces:
        return None, None, None, None, None, None

    xs = [max(0.0, min(1.0, p['x'])) for p in all_traces]
    ys = [max(0.0, min(1.0, p['y'])) for p in all_traces]
    ts = [p['t'] for p in all_traces]

    timed_caption = json_anno['timed_caption']
    toks = [tcap['utterance'] for tcap in timed_caption]
    time_begins = [tcap['start_time'] for tcap in timed_caption]
    time_ends = [tcap['end_time'] for tcap in timed_caption]
    
    return xs, ys, ts, toks, time_begins, time_ends


def trace_segment_uniform_time_interval(
    trace_xs, trace_ys, trace_ts, time_interval, use_douglas_peucker=False, dp_epsilon=0.01):
    """Segment the traces uniformly given a specificed time interval."""
    if not trace_ts:
        return [], [], []
    
    tbins = (np.array(trace_ts) - trace_ts[0]) / time_interval
    tbins = tbins.astype(int)
    
    segmented_xs, segmented_ys, segmented_ts = [], [], []
    
    num_segments = tbins[-1] + 1 if len(tbins) > 0 else 0
    for i in range(num_segments):
        indices = np.where(tbins == i)[0]
        if len(indices) > 0:
            xs = [trace_xs[j] for j in indices]
            ys = [trace_ys[j] for j in indices]
            ts = [trace_ts[j] for j in indices]

            # Apply Douglas-Peucker simplification if requested
            if use_douglas_peucker and len(xs) > 2:
                points = np.array(list(zip(xs, ys)))
                simplified_points = rdp(points, epsilon=dp_epsilon)
                segmented_xs.append(simplified_points[:, 0].tolist())
                segmented_ys.append(simplified_points[:, 1].tolist())
            else:
                segmented_xs.append(xs)
                segmented_ys.append(ys)
            
            segmented_ts.append(ts)
        else:
            segmented_xs.append([])
            segmented_ys.append([])
            segmented_ts.append([])

    return segmented_xs, segmented_ys, segmented_ts


def trace_segment_timestamp(trace_xs, trace_ys, trace_ts,
                            token_time_begins, token_time_ends, use_douglas_peucker=False, dp_epsilon=0.01):
    """Segment the traces based on 'ground-truth' timestamped tokens."""
    segmented_xs, segmented_ys, segmented_ts = [], [], []
    trace_arr = np.array(trace_ts)

    for time_begin, time_end in zip(token_time_begins, token_time_ends):
        # Find indices of trace points within the token's time window
        indices = np.where((trace_arr >= time_begin) & (trace_arr <= time_end))[0]
        if len(indices) > 0:
            xs = [trace_xs[j] for j in indices]
            ys = [trace_ys[j] for j in indices]
            ts = [trace_ts[j] for j in indices]

            # Apply Douglas-Peucker simplification if requested
            if use_douglas_peucker and len(xs) > 2:
                points = np.array(list(zip(xs, ys)))
                simplified_points = rdp(points, epsilon=dp_epsilon)
                segmented_xs.append(simplified_points[:, 0].tolist())
                segmented_ys.append(simplified_points[:, 1].tolist())
            else:
                segmented_xs.append(xs)
                segmented_ys.append(ys)

            segmented_ts.append(ts)
        else:
            segmented_xs.append([])
            segmented_ys.append([])
            segmented_ts.append([])
            
    return segmented_xs, segmented_ys, segmented_ts


def split_caption_into_phrases(caption):
    """
    将完整的caption拆分成短语
    
    Args:
        caption (str): 完整的图像描述文本
        
    Returns:
        list: 拆分后的短语列表
    """
    import re
    
    # 预处理：清理多余的空格
    caption = re.sub(r'\s+', ' ', caption.strip())
    
    # 定义分割标志
    # 1. 句号、感叹号、问号后面的分割
    # 2. 逗号后面的分割（但要考虑上下文）
    # 3. 连接词前的分割
    
    # 先按句号、感叹号、问号分割成句子
    sentences = re.split(r'[.!?]+', caption)
    
    phrases = []
    
    for sentence in sentences:
        if not sentence.strip():
            continue
            
        # 对每个句子进一步分割
        # 按逗号分割，但保留语义完整性
        comma_parts = sentence.split(',')
        
        current_phrase = ""
        
        for i, part in enumerate(comma_parts):
            part = part.strip()
            if not part:
                continue
                
            # 如果当前部分太短（少于3个词），与前面的合并
            words_in_part = len(part.split())
            
            if current_phrase and words_in_part < 3:
                current_phrase += ", " + part
            else:
                # 如果有积累的短语，先添加到结果中
                if current_phrase:
                    phrases.append(current_phrase.strip())
                current_phrase = part
        
        # 添加最后一个短语
        if current_phrase:
            phrases.append(current_phrase.strip())
    
    # 进一步优化：处理连接词分割
    refined_phrases = []
    
    for phrase in phrases:
        # 在特定连接词处分割
        connectors = [
            r'\bThere is\b', r'\bThere are\b', r'\bWe can observe\b', 
            r'\bIn the\b', r'\bOn the\b', r'\bAt the\b',
            r'\bWhich is\b', r'\bThat is\b'
        ]
        
        # 尝试在连接词处分割
        parts = [phrase]
        for connector in connectors:
            new_parts = []
            for part in parts:
                # 在连接词前分割，但保留连接词在新短语的开头
                split_parts = re.split(f'({connector})', part, flags=re.IGNORECASE)
                
                current = ""
                for j, split_part in enumerate(split_parts):
                    if re.match(connector, split_part, re.IGNORECASE):
                        # 这是连接词，开始新的短语
                        if current.strip():
                            new_parts.append(current.strip())
                        current = split_part
                    else:
                        current += split_part
                
                if current.strip():
                    new_parts.append(current.strip())
            
            parts = new_parts
        
        refined_phrases.extend(parts)
    
    # 处理"and"的分割
    final_refined_phrases = []
    for phrase in refined_phrases:
        # 在"and"处分割，但要考虑上下文
        # 使用正则表达式匹配 " and " （前后有空格的and）
        and_parts = re.split(r'\s+and\s+', phrase, flags=re.IGNORECASE)
        
        if len(and_parts) > 1:
            # 如果有多个部分，需要智能处理
            for i, part in enumerate(and_parts):
                part = part.strip()
                if not part:
                    continue
                
                # 检查部分是否太短，如果太短则与前一个合并
                words_in_part = len(part.split())
                
                # 如果是第一个部分或者词数足够多，直接添加
                if i == 0 or words_in_part >= 3:
                    if part:
                        final_refined_phrases.append(part)
                else:
                    # 如果太短，与前一个合并
                    if final_refined_phrases:
                        final_refined_phrases[-1] += " and " + part
                    else:
                        final_refined_phrases.append(part)
        else:
            # 没有"and"，直接添加
            if phrase.strip():
                final_refined_phrases.append(phrase.strip())
    
    # 最终清理：移除过短或空的短语
    final_phrases = []
    for phrase in final_refined_phrases:
        phrase = phrase.strip()
        if phrase and len(phrase.split()) >= 2:  # 至少包含2个词
            final_phrases.append(phrase)
    
    return final_phrases


def split_caption_into_phrases_improved(caption):
    """
    改进的caption短语分割函数
    更好地处理各种语言结构
    """
    # 预处理：清理多余的空格
    caption = re.sub(r'\s+', ' ', caption.strip())
    
    phrases = []
    
    # 第一步：按强分割符分割（句号、感叹号、问号）
    sentences = re.split(r'[.!?]+', caption)
    
    for sentence in sentences:
        if not sentence.strip():
            continue
            
        sentence = sentence.strip()
        
        # 第二步：识别并分割介词短语和从句
        split_patterns = [
            r'(\s+and\s+)',           # "and" 连接
            r'(\s*,\s*and\s+)',       # ", and" 连接  
            r'(\s*,\s*)',             # 逗号分隔
            r'(\s+which\s+)',         # "which" 从句
            r'(\s+that\s+)',          # "that" 从句
            r'(\s+where\s+)',         # "where" 从句
            r'(\s+while\s+)',         # "while" 从句
            r'(\s+in\s+the\s+)',      # "in the" 介词短语
            r'(\s+on\s+the\s+)',      # "on the" 介词短语
            r'(\s+at\s+the\s+)',      # "at the" 介词短语
            r'(\s+behind\s+)',        # "behind" 介词
            r'(\s+there\s+is\s+)',    # "there is" 存在句
            r'(\s+there\s+are\s+)',   # "there are" 存在句
        ]
        
        # 逐步应用分割规则
        current_parts = [sentence]
        
        for pattern in split_patterns:
            new_parts = []
            for part in current_parts:
                # 使用分组捕获来保留分割符
                segments = re.split(pattern, part, flags=re.IGNORECASE)
                
                i = 0
                while i < len(segments):
                    if i == 0:
                        # 第一个片段
                        if segments[i].strip():
                            new_parts.append(segments[i].strip())
                    else:
                        # 处理分割符和后续内容
                        if i + 1 < len(segments):
                            separator = segments[i].strip()
                            next_part = segments[i + 1].strip()
                            
                            if next_part:
                                # 将分割符与后续内容合并
                                combined = separator + " " + next_part if separator else next_part
                                new_parts.append(combined.strip())
                            i += 1  # 跳过下一个片段，因为已经处理了
                    i += 1
            
            current_parts = new_parts
        
        # 第三步：后处理 - 合并过短的片段
        final_parts = []
        for part in current_parts:
            part = part.strip()
            if not part:
                continue
                
            # 如果片段太短且前面有片段，尝试合并
            words_in_part = len(part.split())
            if words_in_part < 3 and final_parts:
                # 检查是否可以合并
                last_part = final_parts[-1]
                combined_length = len(last_part.split()) + words_in_part
                
                if combined_length <= 8:  # 避免合并后过长
                    final_parts[-1] = last_part + ", " + part
                else:
                    final_parts.append(part)
            else:
                final_parts.append(part)
        
        phrases.extend(final_parts)
    
    # 最终清理：确保每个短语都有意义
    cleaned_phrases = []
    for phrase in phrases:
        phrase = phrase.strip()
        # 移除只包含连接词的短语
        if phrase and len(phrase.split()) >= 2:
            # 清理开头的连接词
            phrase = re.sub(r'^(and|or|but)\s+', '', phrase, flags=re.IGNORECASE)
            if phrase:
                cleaned_phrases.append(phrase)
    
    return cleaned_phrases


def trace_segment_by_phrases_word_level(trace_xs, trace_ys, trace_ts, phrases, original_tokens, time_begins, time_ends, use_douglas_peucker=False, dp_epsilon=0.01):
    """
    基于词级别匹配的短语轨迹分割函数
    更精确地处理跨短语的token分配
    """
    if not trace_ts or not phrases:
        return [], [], []
    
    # 预处理函数
    def clean_and_tokenize(text):
        # 移除标点符号，转为小写，然后分词
        cleaned = re.sub(r'[^\w\s]', ' ', text.lower())
        return [word.strip() for word in cleaned.split() if word.strip()]
    
    segmented_xs, segmented_ys, segmented_ts = [], [], []
    
    # 预处理所有短语和tokens
    phrase_words = [clean_and_tokenize(phrase) for phrase in phrases]
    token_words = [clean_and_tokenize(token) for token in original_tokens]
    
    print(f"预处理后的短语词列表:")
    for i, words in enumerate(phrase_words):
        print(f"  短语 {i+1}: {words}")
    
    print(f"预处理后的token词列表:")
    for i, words in enumerate(token_words):
        print(f"  Token {i+1}: {words}")
    
    # 为每个短语找到匹配的tokens
    for phrase_idx, phrase_word_list in enumerate(phrase_words):
        print(f"\n处理短语 {phrase_idx + 1}: '{phrases[phrase_idx]}'")
        print(f"短语词列表: {phrase_word_list}")
        
        matched_token_indices = []
        
        # 对每个token，计算与当前短语的重叠度
        for token_idx, token_word_list in enumerate(token_words):
            if not token_word_list:  # 跳过空token
                continue
                
            # 计算词重叠
            phrase_word_set = set(phrase_word_list)
            token_word_set = set(token_word_list)
            
            overlap = phrase_word_set & token_word_set
            
            if overlap:
                # 计算重叠比例
                overlap_ratio_in_phrase = len(overlap) / len(phrase_word_set) if phrase_word_list else 0
                overlap_ratio_in_token = len(overlap) / len(token_word_set) if token_word_list else 0
                
                # 如果有足够的重叠，认为匹配
                if overlap_ratio_in_token >= 0.3 or overlap_ratio_in_phrase >= 0.2:
                    matched_token_indices.append(token_idx)
                    print(f"  匹配token {token_idx}: '{original_tokens[token_idx]}'")
                    print(f"    重叠词: {overlap}")
                    print(f"    短语重叠比例: {overlap_ratio_in_phrase:.2f}, Token重叠比例: {overlap_ratio_in_token:.2f}")
        
        # 如果没有找到匹配，尝试子串匹配
        if not matched_token_indices:
            print(f"  未找到词重叠匹配，尝试子串匹配...")
            phrase_text = ' '.join(phrase_word_list)
            
            for token_idx, token_word_list in enumerate(token_words):
                token_text = ' '.join(token_word_list)
                
                # 检查是否有子串关系
                if (phrase_text in token_text and len(phrase_text) >= 3) or \
                   (token_text in phrase_text and len(token_text) >= 3):
                    matched_token_indices.append(token_idx)
                    print(f"  子串匹配token {token_idx}: '{original_tokens[token_idx]}'")
        
        # 收集匹配tokens对应的轨迹点
        phrase_xs, phrase_ys, phrase_ts = [], [], []
        
        for token_idx in matched_token_indices:
            if token_idx < len(time_begins) and token_idx < len(time_ends):
                token_start_time = time_begins[token_idx]
                token_end_time = time_ends[token_idx]
                
                # 找到这个token时间范围内的所有轨迹点
                trace_arr = np.array(trace_ts)
                indices = np.where((trace_arr >= token_start_time) & (trace_arr <= token_end_time))[0]
                
                for idx in indices:
                    phrase_xs.append(trace_xs[idx])
                    phrase_ys.append(trace_ys[idx])
                    phrase_ts.append(trace_ts[idx])
        
        print(f"  总共收集到 {len(phrase_xs)} 个轨迹点")
        
        if phrase_xs:
            # 去重并按时间排序
            time_point_pairs = list(zip(phrase_ts, phrase_xs, phrase_ys))
            time_point_pairs = list(set(time_point_pairs))  # 去重
            time_point_pairs.sort(key=lambda x: x[0])  # 按时间排序
            
            phrase_ts_sorted = [t for t, x, y in time_point_pairs]
            phrase_xs_sorted = [x for t, x, y in time_point_pairs]
            phrase_ys_sorted = [y for t, x, y in time_point_pairs]
            
            # 应用Douglas-Peucker简化（如果需要）
            if use_douglas_peucker and len(phrase_xs_sorted) > 2:
                points = np.array(list(zip(phrase_xs_sorted, phrase_ys_sorted)))
                simplified_points = rdp(points, epsilon=dp_epsilon)
                segmented_xs.append(simplified_points[:, 0].tolist())
                segmented_ys.append(simplified_points[:, 1].tolist())
            else:
                segmented_xs.append(phrase_xs_sorted)
                segmented_ys.append(phrase_ys_sorted)
            
            segmented_ts.append(phrase_ts_sorted)
        else:
            # 如果没有轨迹点，添加空列表
            print(f"  未找到轨迹点")
            segmented_xs.append([])
            segmented_ys.append([])
            segmented_ts.append([])
    
    print(f"\n将轨迹分割成 {len(segmented_xs)} 段，对应 {len(phrases)} 个短语")
    
    return segmented_xs, segmented_ys, segmented_ts


def trace_segment_by_phrases_position_based(trace_xs, trace_ys, trace_ts, phrases, original_tokens, time_begins, time_ends, use_douglas_peucker=False, dp_epsilon=0.01):
    """
    基于文本位置的短语轨迹分割函数
    通过文本位置匹配将original_tokens映射到phrases
    """
    if not trace_ts or not phrases:
        return [], [], []
    
    # 预处理函数：清理文本，移除标点符号并转为小写
    def clean_text(text):
        return re.sub(r'[^\w\s]', '', text.lower()).strip()
    # 将所有原始tokens连接成完整文本
    full_original_text = ' '.join(original_tokens)
    clean_full_text = clean_text(full_original_text)
    
    print(f"完整原始文本: {full_original_text}")
    print(f"清理后文本: {clean_full_text}")
    
    segmented_xs, segmented_ys, segmented_ts = [], [], []
    
    # 为每个短语找到对应的token
    for phrase_idx, phrase in enumerate(phrases):
        print(f"\n处理短语 {phrase_idx + 1}: '{phrase}'")
        
        clean_phrase = clean_text(phrase)
        print(f"清理后短语: '{clean_phrase}'")
        
        # 在完整文本中找到这个短语的位置
        phrase_start_pos = clean_full_text.find(clean_phrase)
        
        if phrase_start_pos == -1:
            # 如果找不到精确匹配，使用模糊匹配
            print(f"未找到精确匹配，尝试模糊匹配...")
            best_ratio = 0
            best_start = -1
            best_end = -1
            
            # 在文本中滑动窗口寻找最佳匹配
            for i in range(len(clean_full_text) - len(clean_phrase) + 1):
                substring = clean_full_text[i:i + len(clean_phrase)]
                ratio = SequenceMatcher(None, clean_phrase, substring).ratio()
                if ratio > best_ratio and ratio > 0.6:  # 相似度阈值
                    best_ratio = ratio
                    best_start = i
                    best_end = i + len(clean_phrase)
            
            if best_start != -1:
                phrase_start_pos = best_start
                phrase_end_pos = best_end
                print(f"模糊匹配成功，相似度: {best_ratio:.2f}")
            else:
                print(f"未找到匹配，跳过该短语")
                segmented_xs.append([])
                segmented_ys.append([])
                segmented_ts.append([])
                continue
        else:
            phrase_end_pos = phrase_start_pos + len(clean_phrase)
            print(f"精确匹配成功")
        
        print(f"短语在文本中的位置: {phrase_start_pos}-{phrase_end_pos}")
        
        # 找到这个位置范围内包含的所有tokens
        current_pos = 0
        matched_token_indices = []
        
        for token_idx, token in enumerate(original_tokens):
            clean_token = clean_text(token)
            token_start_pos = current_pos
            token_end_pos = current_pos + len(clean_token)
            
            # 检查token是否与短语有重叠
            if not (token_end_pos <= phrase_start_pos or token_start_pos >= phrase_end_pos):
                matched_token_indices.append(token_idx)
                print(f"  匹配token {token_idx}: '{token}' (位置: {token_start_pos}-{token_end_pos})")
            
            # 更新位置（加上空格）
            current_pos = token_end_pos + 1  # +1 for space
        
        # 收集匹配tokens对应的所有轨迹点
        phrase_xs, phrase_ys, phrase_ts = [], [], []
        
        for token_idx in matched_token_indices:
            if token_idx < len(time_begins) and token_idx < len(time_ends):
                token_start_time = time_begins[token_idx]
                token_end_time = time_ends[token_idx]
                
                # 找到这个token时间范围内的所有轨迹点
                trace_arr = np.array(trace_ts)
                indices = np.where((trace_arr >= token_start_time) & (trace_arr <= token_end_time))[0]
                
                for idx in indices:
                    phrase_xs.append(trace_xs[idx])
                    phrase_ys.append(trace_ys[idx])
                    phrase_ts.append(trace_ts[idx])
        
        print(f"  总共收集到 {len(phrase_xs)} 个轨迹点")
        
        if phrase_xs:
            # 按时间排序
            sorted_indices = sorted(range(len(phrase_ts)), key=lambda i: phrase_ts[i])
            phrase_xs = [phrase_xs[i] for i in sorted_indices]
            phrase_ys = [phrase_ys[i] for i in sorted_indices]
            phrase_ts = [phrase_ts[i] for i in sorted_indices]
            
            # 应用Douglas-Peucker简化（如果需要）
            if use_douglas_peucker and len(phrase_xs) > 2:
                points = np.array(list(zip(phrase_xs, phrase_ys)))
                simplified_points = rdp(points, epsilon=dp_epsilon)
                segmented_xs.append(simplified_points[:, 0].tolist())
                segmented_ys.append(simplified_points[:, 1].tolist())
            else:
                segmented_xs.append(phrase_xs)
                segmented_ys.append(phrase_ys)
            
            segmented_ts.append(phrase_ts)
        else:
            # 如果没有轨迹点，添加空列表
            print(f"  未找到轨迹点")
            segmented_xs.append([])
            segmented_ys.append([])
            segmented_ts.append([])
    
    print(f"\n将轨迹分割成 {len(segmented_xs)} 段，对应 {len(phrases)} 个短语")
    
    return segmented_xs, segmented_ys, segmented_ts


def traces_to_bboxs(xs_list, ys_list):
    """Convert each trace of x and y coordinates to bounding boxes."""
    xmins, xmaxs, ymins, ymaxs = [], [], [], []
    for xs, ys in zip(xs_list, ys_list):
        if not xs:  # Handle empty segments
            xmin, xmax, ymin, ymax = 0.0, 1.0, 0.0, 1.0 # Default to full box for empty
        else:
            xmin = min(xs)
            xmax = max(xs)
            ymin = min(ys)
            ymax = max(ys)
        xmins.append(xmin)
        xmaxs.append(xmax)
        ymins.append(ymin)
        ymaxs.append(ymax)
    return xmins, xmaxs, ymins, ymaxs


def process_json_improved(json_anno, trace_segmentation_method, image_base_path, visualize=False, use_douglas_peucker=False, dp_epsilon=0.01, output_dir="."):
    """
    改进的处理函数，使用新的短语分割和token匹配方法
    """
    image_id = json_anno['image_id']
    dataset_id = json_anno['dataset_id']
    full_caption = json_anno['caption']
    
    method_title = trace_segmentation_method
    if use_douglas_peucker:
        method_title += " + Douglas-Peucker"
    print(f"\n--- Processing Image: {image_id} using '{method_title}' method ---")
    
    xs, ys, ts, transcription, time_begins, time_ends = get_json_anno_external(json_anno)

    if not xs:
        print("Error: Empty or invalid trace data.")
        return

    # Segment traces into chunks based on the chosen method
    tokens_for_vis = None
    
    if trace_segmentation_method == 'uni_len_global':
        uniform_trace_segmentation_time_interval = 0.4
        segmented_xs, segmented_ys, _ = \
            trace_segment_uniform_time_interval(
                xs, ys, ts, uniform_trace_segmentation_time_interval,
                use_douglas_peucker=use_douglas_peucker, dp_epsilon=dp_epsilon)
        tokens_for_vis = transcription
    elif trace_segmentation_method == 'timestamp':
        segmented_xs, segmented_ys, _ = \
            trace_segment_timestamp(xs, ys, ts, time_begins, time_ends,
                                    use_douglas_peucker=use_douglas_peucker, dp_epsilon=dp_epsilon)
        tokens_for_vis = transcription
    elif trace_segmentation_method == 'phrases_word_level':
        # 使用基于词级别匹配的短语分割功能
        phrases = split_caption_into_phrases(full_caption)
        print(f"拆分出 {len(phrases)} 个短语:")
        for i, phrase in enumerate(phrases, 1):
            print(f"  {i:2d}. {phrase}")
        
        # 根据短语和原始token的时间戳来分割轨迹
        segmented_xs, segmented_ys, _ = \
            trace_segment_by_phrases_word_level(xs, ys, ts, phrases, transcription, time_begins, time_ends,
                                   use_douglas_peucker=use_douglas_peucker, dp_epsilon=dp_epsilon)
        tokens_for_vis = phrases
    elif trace_segmentation_method == 'phrases_position':
        # 使用基于位置的短语分割功能
        phrases = split_caption_into_phrases(full_caption)
        print(f"拆分出 {len(phrases)} 个短语:")
        for i, phrase in enumerate(phrases, 1):
            print(f"  {i:2d}. {phrase}")
        
        # 根据短语和原始token的时间戳来分割轨迹
        segmented_xs, segmented_ys, _ = \
            trace_segment_by_phrases_position_based(xs, ys, ts, phrases, transcription, time_begins, time_ends,
                                   use_douglas_peucker=use_douglas_peucker, dp_epsilon=dp_epsilon)
        tokens_for_vis = phrases
    else:
        raise ValueError("Unknown trace_segmentation_method")

    # Convert segmented trace chunks into bounding boxes
    xmins, xmaxs, ymins, ymaxs = traces_to_bboxs(segmented_xs, segmented_ys)
    
    print(f"Generated {len(xmins)} bounding boxes.")
    
    # For visualization
    viz_segmented_xs = segmented_xs
    viz_segmented_ys = segmented_ys
    output_method_name = trace_segmentation_method
    if use_douglas_peucker:
        print(f"Applying Douglas-Peucker simplification with epsilon={dp_epsilon}.")
        output_method_name += "_DP"
    
    # Trigger visualization if requested
    if visualize:
        visualize_all_boxes(
            image_id, viz_segmented_xs, viz_segmented_ys, xmins, xmaxs, ymins, ymaxs, 
            tokens_for_vis, output_method_name,
            image_base_path, dataset_id, full_caption, output_dir=output_dir)

    return phrases if 'phrases' in trace_segmentation_method else tokens_for_vis


def process_json_with_qwen_optimization(json_anno, trace_segmentation_method, image_base_path, qwen_optimizer, 
                                       quality_threshold=0.6, visualize=True, use_douglas_peucker=False, 
                                       dp_epsilon=0.01, output_dir="."):
    """
    使用Qwen2.5-VL优化指定方法得到的定位框
    当优化后质量未提升时，使用Qwen2.5-VL重新生成定位框
    
    Args:
        json_anno: JSON标注数据
        trace_segmentation_method: 基础分割方法
        image_base_path: 图像基础路径
        qwen_optimizer: Qwen2.5-VL优化器实例
        quality_threshold: 质量阈值，低于此值将进行优化
        visualize: 是否生成可视化
        use_douglas_peucker: 是否使用Douglas-Peucker简化
        dp_epsilon: Douglas-Peucker简化参数
        output_dir: 输出目录
        
    Returns:
        optimization_results: 优化结果字典
    """
    image_id = json_anno['image_id']
    dataset_id = json_anno['dataset_id']
    full_caption = json_anno['caption']
    
    print(f"\n--- Processing Image: {image_id} with Qwen2.5-VL Optimization ---")
    print(f"Base method: {trace_segmentation_method}")
    
    xs, ys, ts, transcription, time_begins, time_ends = get_json_anno_external(json_anno)

    if not xs:
        print("Error: Empty or invalid trace data.")
        return None

    # 构建图像路径
    split_name = dataset_id.split('_')[1]
    image_filename = f"{int(image_id):012d}.jpg"
    full_image_path = os.path.join(image_base_path, split_name, image_filename)
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 使用指定方法获取原始分割结果
    tokens_for_vis = None
    
    if trace_segmentation_method == 'uni_len_global':
        uniform_trace_segmentation_time_interval = 0.4
        segmented_xs, segmented_ys, _ = trace_segment_uniform_time_interval(
            xs, ys, ts, uniform_trace_segmentation_time_interval,
            use_douglas_peucker=use_douglas_peucker, dp_epsilon=dp_epsilon)
        tokens_for_vis = transcription
    elif trace_segmentation_method == 'timestamp':
        segmented_xs, segmented_ys, _ = trace_segment_timestamp(
            xs, ys, ts, time_begins, time_ends,
            use_douglas_peucker=use_douglas_peucker, dp_epsilon=dp_epsilon)
        tokens_for_vis = transcription
    elif trace_segmentation_method == 'phrases_word_level':
        phrases = split_caption_into_phrases(full_caption)
        print(f"拆分出 {len(phrases)} 个短语:")
        for i, phrase in enumerate(phrases, 1):
            print(f"  {i:2d}. {phrase}")
        
        segmented_xs, segmented_ys, _ = trace_segment_by_phrases_word_level(
            xs, ys, ts, phrases, transcription, time_begins, time_ends,
            use_douglas_peucker=use_douglas_peucker, dp_epsilon=dp_epsilon)
        tokens_for_vis = phrases
    elif trace_segmentation_method == 'phrases_position':
        phrases = split_caption_into_phrases(full_caption)
        print(f"拆分出 {len(phrases)} 个短语:")
        for i, phrase in enumerate(phrases, 1):
            print(f"  {i:2d}. {phrase}")
        
        segmented_xs, segmented_ys, _ = trace_segment_by_phrases_position_based(
            xs, ys, ts, phrases, transcription, time_begins, time_ends,
            use_douglas_peucker=use_douglas_peucker, dp_epsilon=dp_epsilon)
        tokens_for_vis = phrases
    else:
        raise ValueError("Unknown trace_segmentation_method")

    # 转换为边界框
    xmins, xmaxs, ymins, ymaxs = traces_to_bboxs(segmented_xs, segmented_ys)
    
    print(f"Generated {len(xmins)} bounding boxes from {trace_segmentation_method}")
    
    # 使用Qwen2.5-VL评估和优化
    quality_scores = []
    optimized_bboxes = []
    optimization_stats = {"total": 0, "optimized": 0, "quality_improved": 0, "regenerated": 0}
    
    for i, token in enumerate(tokens_for_vis):
        if i >= len(xmins):
            break
            
        original_bbox = (xmins[i], ymins[i], xmaxs[i], ymaxs[i])
        optimization_stats["total"] += 1
        
        print(f"\n处理 {i+1}: '{token}'")
        print(f"原始边界框: {original_bbox}")
        
        # 评估原始定位质量
        quality_score = qwen_optimizer.calculate_bbox_quality_score(
            original_bbox, token, full_image_path)
        quality_scores.append(quality_score)
        
        print(f"质量分数: {quality_score:.3f}")
        
        if quality_score < quality_threshold:
            print(f"质量分数 {quality_score:.3f} 低于阈值 {quality_threshold}，进行Qwen2.5-VL优化...")
            
            # 第一步：尝试基于原始边界框的优化
            optimized_bbox = qwen_optimizer.get_corrected_bbox(
                token, full_image_path, original_bbox)
            
            # 评估优化后的质量
            optimized_quality = qwen_optimizer.calculate_bbox_quality_score(
                optimized_bbox, token, full_image_path)
            
            print(f"优化后边界框: {optimized_bbox}")
            print(f"优化后质量分数: {optimized_quality:.3f}")
            
            # 如果优化后质量有提升，使用优化结果
            if optimized_quality > quality_score:
                optimized_bboxes.append(optimized_bbox)
                optimization_stats["optimized"] += 1
                optimization_stats["quality_improved"] += 1
                print(f"✓ 优化成功，质量提升: {quality_score:.3f} -> {optimized_quality:.3f}")
            else:
                print(f"✗ 优化后质量未提升，使用Qwen2.5-VL重新生成定位框...")
                
                # 第二步：让Qwen2.5-VL重新根据描述生成定位框
                regenerated_bbox = qwen_optimizer.get_corrected_bbox(
                    token, full_image_path, original_bbox=None)  # 不提供原始边界框参考
                
                # 评估重新生成的质量
                regenerated_quality = qwen_optimizer.calculate_bbox_quality_score(
                    regenerated_bbox, token, full_image_path)
                
                print(f"重新生成的边界框: {regenerated_bbox}")
                print(f"重新生成的质量分数: {regenerated_quality:.3f}")
                
                # 比较三个结果，选择质量最好的
                if regenerated_quality > max(quality_score, optimized_quality):
                    optimized_bboxes.append(regenerated_bbox)
                    optimization_stats["optimized"] += 1
                    optimization_stats["regenerated"] += 1
                    optimization_stats["quality_improved"] += 1
                    print(f"✓ 重新生成成功，质量提升: {quality_score:.3f} -> {regenerated_quality:.3f}")
                elif optimized_quality > quality_score:
                    optimized_bboxes.append(optimized_bbox)
                    optimization_stats["optimized"] += 1
                    optimization_stats["quality_improved"] += 1
                    print(f"✓ 使用初步优化结果，质量提升: {quality_score:.3f} -> {optimized_quality:.3f}")
                else:
                    optimized_bboxes.append(original_bbox)
                    optimization_stats["optimized"] += 1
                    print(f"✗ 所有优化尝试均未改善质量，保持原始定位")
        else:
            optimized_bboxes.append(original_bbox)
            print(f"✓ 质量分数达标，无需优化")
    
    # 输出优化统计
    print(f"\n=== Qwen2.5-VL优化统计 ===")
    print(f"总数: {optimization_stats['total']}")
    print(f"尝试优化数: {optimization_stats['optimized']}")
    print(f"成功改善数: {optimization_stats['quality_improved']}")
    print(f"重新生成数: {optimization_stats['regenerated']}")
    print(f"优化成功率: {optimization_stats['quality_improved']/optimization_stats['optimized']*100:.1f}%" if optimization_stats['optimized'] > 0 else "优化成功率: 0%")
    print(f"重新生成率: {optimization_stats['regenerated']/optimization_stats['optimized']*100:.1f}%" if optimization_stats['optimized'] > 0 else "重新生成率: 0%")
    
    # 准备返回结果
    results = {
        "image_id": image_id,
        "base_method": trace_segmentation_method,
        "tokens": tokens_for_vis,
        "original_bboxes": [(xmins[i], ymins[i], xmaxs[i], ymaxs[i]) for i in range(len(tokens_for_vis))],
        "optimized_bboxes": optimized_bboxes,
        "quality_scores": quality_scores,
        "optimization_stats": optimization_stats,
        "segmented_xs": segmented_xs,
        "segmented_ys": segmented_ys
    }
    
    # 生成可视化
    if visualize:
        visualize_qwen_optimization_comparison(
            image_id, segmented_xs, segmented_ys, xmins, xmaxs, ymins, ymaxs,
            optimized_bboxes, tokens_for_vis, f"{trace_segmentation_method}_qwen_optimized",
            image_base_path, dataset_id, full_caption, output_dir, quality_scores)
        
        # 保存详细结果到JSON
        results_file = os.path.join(output_dir, f"qwen_optimization_results_{image_id}.json")
        with open(results_file, 'w', encoding='utf-8') as f:
            # 转换numpy类型为Python原生类型以便JSON序列化
            json_results = {
                "image_id": image_id,
                "base_method": trace_segmentation_method,
                "tokens": tokens_for_vis,
                "original_bboxes": results["original_bboxes"],
                "optimized_bboxes": results["optimized_bboxes"], 
                "quality_scores": [float(score) for score in quality_scores],
                "optimization_stats": optimization_stats
            }
            json.dump(json_results, f, ensure_ascii=False, indent=2)
        print(f"优化结果已保存到: {results_file}")
    
    return results


def test_improved_phrase_segmentation():
    """
    测试改进的短语分割功能
    """
    # 测试用例
    test_captions = [
        "In this image there are two persons sitting on the vehicle, and there are cardboard boxes, tire, bicycle in the vehicle, and at the background there is sky.",
        "In this picture there is a room in which we can observe a sofa set and a table in front of the sofa set. There is a TV fixed to the wall which is in cream color.",
        "A woman is sitting on a chair and holding a mobile phone. Behind her there is a fencing."
    ]
    
    print("=== 测试改进的短语分割功能 ===\n")
    
    for i, caption in enumerate(test_captions, 1):
        print(f"测试用例 {i}:")
        print(f"原始caption: {caption}")
        
        phrases = split_caption_into_phrases_improved(caption)
        
        print(f"拆分后的短语 ({len(phrases)} 个):")
        for j, phrase in enumerate(phrases, 1):
            print(f"  {j:2d}. {phrase}")
        print()


if __name__ == '__main__':
    # 定义路径
    COCO_IMAGE_PATH = '/storage-root/datasets/yangfan/coco2017'
    JSONL_PATH = '/storage-root/datasets/yangfan/Seg_LLaVA_v2/datasets/Localized_Narratives/coco_train_localized_narratives-00000-of-00004.jsonl'
    MODEL_PATH = "/storage-root/9950backfile/yangfan/coyo/Qwen/Qwen2.5-VL-72B-Instruct"

    # 初始化Qwen2.5-VL优化器，使用4个GPU
    qwen_optimizer = QwenVLLocationCorrector(MODEL_PATH, num_gpus=4)

    # 创建输出目录
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_base_dir = f"output_visualizations_with_qwen_{timestamp}"
    os.makedirs(output_base_dir, exist_ok=True)
    print(f"Saving visualizations to: {os.path.abspath(output_base_dir)}")

    # 处理指定索引
    for selected_index in range(2, 21):  # 处理前20个样本
        try:
            # 加载数据
            COCO_VAL_JSON_DICT = load_jsonl_by_index(JSONL_PATH, selected_index)
            
            # 创建子目录
            current_output_dir = os.path.join(output_base_dir, f"index_{selected_index}")
            os.makedirs(current_output_dir, exist_ok=True)
            
            print(f"\n=== 测试词级别匹配方法 ===")
            process_json_improved(COCO_VAL_JSON_DICT, 
                         trace_segmentation_method='phrases_word_level',
                         image_base_path=COCO_IMAGE_PATH,
                         visualize=True,
                         use_douglas_peucker=False,
                         dp_epsilon=0.01,
                         output_dir=os.path.join(current_output_dir, "word_level"))
            
            print(f"\n=== 测试位置匹配方法 ===")
            process_json_improved(COCO_VAL_JSON_DICT, 
                         trace_segmentation_method='phrases_position',
                         image_base_path=COCO_IMAGE_PATH,
                         visualize=True,
                         use_douglas_peucker=False,
                         dp_epsilon=0.01,
                         output_dir=os.path.join(current_output_dir, "position_based"))
            
            print(f"\n=== 对比：原始时间戳方法 ===")
            process_json_improved(COCO_VAL_JSON_DICT, 
                         trace_segmentation_method='timestamp',
                         image_base_path=COCO_IMAGE_PATH,
                         visualize=True,
                         use_douglas_peucker=False,
                         dp_epsilon=0.01,
                         output_dir=os.path.join(current_output_dir, "timestamp"))

            print(f"\n=== 使用Qwen2.5-VL优化短语级定位框 ===")
            process_json_with_qwen_optimization(
                COCO_VAL_JSON_DICT, 
                trace_segmentation_method='phrases_position',
                image_base_path=COCO_IMAGE_PATH,
                qwen_optimizer=qwen_optimizer,
                quality_threshold=0.6,  # 质量阈值
                visualize=True,
                use_douglas_peucker=False,
                dp_epsilon=0.01,
                output_dir=os.path.join(current_output_dir, "qwen_optimized")
            )
                         
        except IndexError as e:
            print(f"Skipping index {selected_index}: {e}")
        except Exception as e:
            print(f"An error occurred while processing index {selected_index}: {e}")
            import traceback
            traceback.print_exc()
        
    # 运行短语分割测试
    print("\n" + "="*50)
    test_improved_phrase_segmentation()
    
    print(f"\n程序执行完成！所有结果保存在: {os.path.abspath(output_base_dir)}")