# -*- coding: utf-8 -*-
"""
基于小目标的轨迹构建脚本（带可视化保存）
基于dp_visual_chaifenjuzi_qwenyanzheng.py，专门用于构建小目标的轨迹数据
包含可视化功能，保存原图、轨迹线图、边界框图
输出：包含phrases和对应坐标点的JSON文件 + 可视化图片
"""

import json
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.font_manager as fm
import os
import re
from difflib import SequenceMatcher
import datetime
from tqdm import tqdm
from dp_visual_chaifenjuzi_qwenyanzheng import (
    process_json_with_qwen_optimization, 
    QwenVLLocationCorrector,
    get_json_anno_external,
    traces_to_bboxs
)
import matplotlib.patches as patches


def load_small_targets_data(small_targets_json_path):
    """
    加载小目标数据，返回包含小目标的图像ID集合
    """
    print(f"加载小目标数据: {small_targets_json_path}")
    
    with open(small_targets_json_path, 'r', encoding='utf-8') as f:
        small_targets_data = json.load(f)
    
    # 提取所有包含小目标的图像ID
    small_target_image_ids = set()
    for target in small_targets_data['small_targets']:
        small_target_image_ids.add(target['image_id'])
    
    print(f"共找到 {len(small_target_image_ids)} 个包含小目标的图像")
    print(f"小目标总数: {small_targets_data['metadata']['total_small_targets']}")
    
    return small_target_image_ids, small_targets_data


def load_ln_jsonl_files(ln_base_path, small_target_image_ids):
    """
    加载LN的jsonl文件，只保留包含小目标的图像
    """
    ln_files = [
        'coco_train_localized_narratives-00000-of-00004.jsonl',
        'coco_train_localized_narratives-00001-of-00004.jsonl', 
        'coco_train_localized_narratives-00002-of-00004.jsonl',
        'coco_train_localized_narratives-00003-of-00004.jsonl'
    ]
    
    filtered_ln_data = []
    total_processed = 0
    
    for file_name in ln_files:
        file_path = os.path.join(ln_base_path, file_name)
        if os.path.exists(file_path):
            print(f"处理LN文件: {file_path}")
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f, 1):
                    try:
                        data = json.loads(line.strip())
                        total_processed += 1
                        
                        # 检查图像ID是否在小目标列表中
                        image_id = int(data['image_id'])
                        if image_id in small_target_image_ids:
                            filtered_ln_data.append(data)
                            
                    except json.JSONDecodeError as e:
                        print(f"Error parsing line {line_num} in {file_name}: {e}")
        else:
            print(f"Warning: File not found: {file_path}")
    
    print(f"总共处理 {total_processed} 条LN数据")
    print(f"筛选出包含小目标的LN数据: {len(filtered_ln_data)} 条")
    
    return filtered_ln_data


def get_json_anno_external(json_anno):
    """从json dict中获取转录和轨迹信息"""
    traces = json_anno['traces']
    all_traces = [point for trace in traces for point in trace]
    
    if not all_traces:
        return None, None, None, None, None, None

    xs = [max(0.0, min(1.0, p['x'])) for p in all_traces]
    ys = [max(0.0, min(1.0, p['y'])) for p in all_traces]
    ts = [p['t'] for p in all_traces]

    timed_caption = json_anno['timed_caption']
    toks = [tcap['utterance'] for tcap in timed_caption]
    time_begins = [tcap['start_time'] for tcap in timed_caption]
    time_ends = [tcap['end_time'] for tcap in timed_caption]
    
    return xs, ys, ts, toks, time_begins, time_ends


def split_caption_into_phrases(caption):
    """
    将完整的caption拆分成短语
    """
    # 预处理：清理多余的空格
    caption = re.sub(r'\s+', ' ', caption.strip())
    
    phrases = []
    
    # 第一步：按强分割符分割（句号、感叹号、问号）
    sentences = re.split(r'[.!?]+', caption)
    
    for sentence in sentences:
        if not sentence.strip():
            continue
            
        sentence = sentence.strip()
        
        # 第二步：识别并分割介词短语和从句
        split_patterns = [
            r'(\s+and\s+)',           # "and" 连接
            r'(\s*,\s*and\s+)',       # ", and" 连接  
            r'(\s*,\s*)',             # 逗号分隔
            r'(\s+which\s+)',         # "which" 从句
            r'(\s+that\s+)',          # "that" 从句
            r'(\s+where\s+)',         # "where" 从句
            r'(\s+while\s+)',         # "while" 从句
            r'(\s+in\s+the\s+)',      # "in the" 介词短语
            r'(\s+on\s+the\s+)',      # "on the" 介词短语
            r'(\s+at\s+the\s+)',      # "at the" 介词短语
            r'(\s+behind\s+)',        # "behind" 介词
            r'(\s+there\s+is\s+)',    # "there is" 存在句
            r'(\s+there\s+are\s+)',   # "there are" 存在句
        ]
        
        # 逐步应用分割规则
        current_parts = [sentence]
        
        for pattern in split_patterns:
            new_parts = []
            for part in current_parts:
                segments = re.split(pattern, part, flags=re.IGNORECASE)
                
                i = 0
                while i < len(segments):
                    if i == 0:
                        if segments[i].strip():
                            new_parts.append(segments[i].strip())
                    else:
                        if i + 1 < len(segments):
                            separator = segments[i].strip()
                            next_part = segments[i + 1].strip()
                            
                            if next_part:
                                combined = separator + " " + next_part if separator else next_part
                                new_parts.append(combined.strip())
                            i += 1
                    i += 1
            
            current_parts = new_parts
        
        # 第三步：后处理 - 合并过短的片段
        final_parts = []
        for part in current_parts:
            part = part.strip()
            if not part:
                continue
                
            words_in_part = len(part.split())
            if words_in_part < 3 and final_parts:
                last_part = final_parts[-1]
                combined_length = len(last_part.split()) + words_in_part
                
                if combined_length <= 8:
                    final_parts[-1] = last_part + ", " + part
                else:
                    final_parts.append(part)
            else:
                final_parts.append(part)
        
        phrases.extend(final_parts)
    
    # 最终清理
    cleaned_phrases = []
    for phrase in phrases:
        phrase = phrase.strip()
        if phrase and len(phrase.split()) >= 2:
            phrase = re.sub(r'^(and|or|but)\s+', '', phrase, flags=re.IGNORECASE)
            if phrase:
                cleaned_phrases.append(phrase)
    
    return cleaned_phrases


def trace_segment_by_phrases_position_based(trace_xs, trace_ys, trace_ts, phrases, original_tokens, time_begins, time_ends):
    """
    基于文本位置的短语轨迹分割函数
    """
    if not trace_ts or not phrases:
        return [], [], []
    
    def clean_text(text):
        return re.sub(r'[^\w\s]', '', text.lower()).strip()
    
    full_original_text = ' '.join(original_tokens)
    clean_full_text = clean_text(full_original_text)
    
    segmented_xs, segmented_ys, segmented_ts = [], [], []
    
    for phrase_idx, phrase in enumerate(phrases):
        clean_phrase = clean_text(phrase)
        
        phrase_start_pos = clean_full_text.find(clean_phrase)
        
        if phrase_start_pos == -1:
            best_ratio = 0
            best_start = -1
            best_end = -1
            
            for i in range(len(clean_full_text) - len(clean_phrase) + 1):
                substring = clean_full_text[i:i + len(clean_phrase)]
                ratio = SequenceMatcher(None, clean_phrase, substring).ratio()
                if ratio > best_ratio and ratio > 0.6:
                    best_ratio = ratio
                    best_start = i
                    best_end = i + len(clean_phrase)
            
            if best_start != -1:
                phrase_start_pos = best_start
                phrase_end_pos = best_end
            else:
                segmented_xs.append([])
                segmented_ys.append([])
                segmented_ts.append([])
                continue
        else:
            phrase_end_pos = phrase_start_pos + len(clean_phrase)
        
        current_pos = 0
        matched_token_indices = []
        
        for token_idx, token in enumerate(original_tokens):
            clean_token = clean_text(token)
            token_start_pos = current_pos
            token_end_pos = current_pos + len(clean_token)
            
            if not (token_end_pos <= phrase_start_pos or token_start_pos >= phrase_end_pos):
                matched_token_indices.append(token_idx)
            
            current_pos = token_end_pos + 1
        
        phrase_xs, phrase_ys, phrase_ts = [], [], []
        
        for token_idx in matched_token_indices:
            if token_idx < len(time_begins) and token_idx < len(time_ends):
                token_start_time = time_begins[token_idx]
                token_end_time = time_ends[token_idx]
                
                trace_arr = np.array(trace_ts)
                indices = np.where((trace_arr >= token_start_time) & (trace_arr <= token_end_time))[0]
                
                for idx in indices:
                    phrase_xs.append(trace_xs[idx])
                    phrase_ys.append(trace_ys[idx])
                    phrase_ts.append(trace_ts[idx])
        
        if phrase_xs:
            time_point_pairs = list(zip(phrase_ts, phrase_xs, phrase_ys))
            time_point_pairs = list(set(time_point_pairs))
            time_point_pairs.sort(key=lambda x: x[0])
            
            phrase_ts_sorted = [t for t, x, y in time_point_pairs]
            phrase_xs_sorted = [x for t, x, y in time_point_pairs]
            phrase_ys_sorted = [y for t, x, y in time_point_pairs]
            
            segmented_xs.append(phrase_xs_sorted)
            segmented_ys.append(phrase_ys_sorted)
            segmented_ts.append(phrase_ts_sorted)
        else:
            segmented_xs.append([])
            segmented_ys.append([])
            segmented_ts.append([])
    
    return segmented_xs, segmented_ys, segmented_ts


def traces_to_bboxs(xs_list, ys_list):
    """将轨迹坐标转换为边界框"""
    xmins, xmaxs, ymins, ymaxs = [], [], [], []
    for xs, ys in zip(xs_list, ys_list):
        if not xs:
            xmin, xmax, ymin, ymax = 0.0, 1.0, 0.0, 1.0
        else:
            xmin = min(xs)
            xmax = max(xs)
            ymin = min(ys)
            ymax = max(ys)
        xmins.append(xmin)
        xmaxs.append(xmax)
        ymins.append(ymin)
        ymaxs.append(ymax)
    return xmins, xmaxs, ymins, ymaxs


def visualize_trajectory_data(image_id, image_path, phrases, segmented_xs, segmented_ys, 
                            xmins, xmaxs, ymins, ymaxs, full_caption, output_dir):
    """
    可视化轨迹数据，生成三种图：原图、轨迹线图、边界框图
    """
    try:
        # 读取图像
        if os.path.exists(image_path):
            im = Image.open(image_path).convert('RGB')
            imw, imh = im.size
            im_array = np.array(im)
        else:
            print(f"图像不存在: {image_path}")
            imw, imh = 640, 480
            im_array = np.ones((imh, imw, 3), dtype=np.uint8) * 255
        
        # 创建输出目录
        image_output_dir = os.path.join(output_dir, f"image_{image_id}")
        os.makedirs(image_output_dir, exist_ok=True)
        
        # 1. 保存原图
        if os.path.exists(image_path):
            original_save_path = os.path.join(image_output_dir, f"original_{image_id}.jpg")
            im.save(original_save_path, quality=95)
        
        # 颜色映射
        colors = plt.get_cmap('gist_rainbow')(np.linspace(0, 1, max(len(phrases), 1)))
        
        # 2. 生成轨迹线图
        fig1, ax1 = plt.subplots(figsize=(12, 9))
        ax1.imshow(im_array)
        ax1.set_xlim(0, imw)
        ax1.set_ylim(imh, 0)
        
        for i, phrase in enumerate(phrases):
            if i < len(segmented_xs) and segmented_xs[i]:
                color = colors[i % len(colors)]
                seg_xs_scaled = [x * imw for x in segmented_xs[i]]
                seg_ys_scaled = [y * imh for y in segmented_ys[i]]
                ax1.plot(seg_xs_scaled, seg_ys_scaled, linewidth=3, color=color, alpha=0.8, label=f"{i+1}: {phrase[:30]}...")
        
        # 设置中文字体
        try:
            # 尝试使用系统中文字体
            chinese_fonts = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Liberation Sans']
            font_prop = None
            for font_name in chinese_fonts:
                try:
                    font_prop = fm.FontProperties(fname=fm.findfont(fm.FontProperties(family=font_name)))
                    break
                except:
                    continue
            
            if font_prop is None:
                # 如果没有找到中文字体，使用默认字体
                ax1.set_title(f"Trajectory Lines - Image {image_id}", fontsize=14)
            else:
                ax1.set_title(f"轨迹线图 - Image {image_id}", fontsize=14, fontproperties=font_prop)
        except:
            ax1.set_title(f"Trajectory Lines - Image {image_id}", fontsize=14)
        
        ax1.axis('off')
        
        # 添加图例（如果短语不太多）
        if len(phrases) <= 8:
            ax1.legend(loc='upper right', fontsize=8, bbox_to_anchor=(1, 1))
        
        trace_filename = os.path.join(image_output_dir, f"traces_{image_id}.png")
        plt.savefig(trace_filename, bbox_inches='tight', dpi=200)
        plt.close(fig1)
        
        # 3. 生成边界框图
        fig2, ax2 = plt.subplots(figsize=(12, 9))
        ax2.imshow(im_array)
        ax2.set_xlim(0, imw)
        ax2.set_ylim(imh, 0)
        
        for i, phrase in enumerate(phrases):
            if i < len(xmins):
                color = colors[i % len(colors)]
                
                min_x = xmins[i] * imw
                max_x = xmaxs[i] * imw
                min_y = ymins[i] * imh
                max_y = ymaxs[i] * imh
                
                # 创建矩形框
                rect = patches.Rectangle(
                    (min_x, min_y),
                    max_x - min_x,
                    max_y - min_y,
                    linewidth=3,
                    edgecolor=color,
                    facecolor='none',
                    alpha=0.8
                )
                ax2.add_patch(rect)
                
                # 添加标签
                ax2.text(min_x, min_y - 10, f"{i+1}: {phrase[:25]}...", 
                        fontsize=10, color=color, weight='bold',
                        bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=2))
        
        # 设置中文字体
        try:
            chinese_fonts = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Liberation Sans']
            font_prop = None
            for font_name in chinese_fonts:
                try:
                    font_prop = fm.FontProperties(fname=fm.findfont(fm.FontProperties(family=font_name)))
                    break
                except:
                    continue
            
            if font_prop is None:
                ax2.set_title(f"Bounding Boxes - Image {image_id}", fontsize=14)
            else:
                ax2.set_title(f"边界框图 - Image {image_id}", fontsize=14, fontproperties=font_prop)
        except:
            ax2.set_title(f"Bounding Boxes - Image {image_id}", fontsize=14)
        
        ax2.axis('off')
        
        bbox_filename = os.path.join(image_output_dir, f"bboxes_{image_id}.png")
        plt.savefig(bbox_filename, bbox_inches='tight', dpi=200)
        plt.close(fig2)
        
        # 4. 生成综合图
        fig3, ax3 = plt.subplots(figsize=(12, 9))
        ax3.imshow(im_array)
        ax3.set_xlim(0, imw)
        ax3.set_ylim(imh, 0)
        
        # 同时显示轨迹线和边界框
        for i, phrase in enumerate(phrases):
            if i < len(segmented_xs):
                color = colors[i % len(colors)]
                
                # 绘制轨迹线
                if segmented_xs[i]:
                    seg_xs_scaled = [x * imw for x in segmented_xs[i]]
                    seg_ys_scaled = [y * imh for y in segmented_ys[i]]
                    ax3.plot(seg_xs_scaled, seg_ys_scaled, linewidth=2, color=color, alpha=0.7)
                
                # 绘制边界框
                if i < len(xmins):
                    min_x = xmins[i] * imw
                    max_x = xmaxs[i] * imw
                    min_y = ymins[i] * imh
                    max_y = ymaxs[i] * imh
                    
                    rect = patches.Rectangle(
                        (min_x, min_y),
                        max_x - min_x,
                        max_y - min_y,
                        linewidth=2,
                        edgecolor=color,
                        facecolor='none',
                        alpha=0.8,
                        linestyle='--'
                    )
                    ax3.add_patch(rect)
        
        # 设置中文字体
        try:
            chinese_fonts = ['SimHei', 'Microsoft YaHei', 'DejaVu Sans', 'Liberation Sans']
            font_prop = None
            for font_name in chinese_fonts:
                try:
                    font_prop = fm.FontProperties(fname=fm.findfont(fm.FontProperties(family=font_name)))
                    break
                except:
                    continue
            
            if font_prop is None:
                ax3.set_title(f"Combined View - Image {image_id}", fontsize=14)
            else:
                ax3.set_title(f"综合图 - Image {image_id}", fontsize=14, fontproperties=font_prop)
        except:
            ax3.set_title(f"Combined View - Image {image_id}", fontsize=14)
        
        ax3.axis('off')
        
        # 添加说明文本
        if phrases:
            caption_text = f"Caption: {full_caption}"
            if len(caption_text) > 100:
                caption_text = caption_text[:97] + "..."
            ax3.text(0.02, 0.98, caption_text, transform=ax3.transAxes, 
                    fontsize=10, verticalalignment='top', 
                    bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=3))
            
            # 添加短语列表
            phrases_text = "Phrases:\n" + "\n".join([f"{i+1}. {phrase}" for i, phrase in enumerate(phrases[:5])])
            if len(phrases) > 5:
                phrases_text += f"\n... and {len(phrases) - 5} more"
            
            ax3.text(0.02, 0.02, phrases_text, transform=ax3.transAxes, 
                    fontsize=8, verticalalignment='bottom',
                    bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=3))
        
        combined_filename = os.path.join(image_output_dir, f"combined_{image_id}.png")
        plt.savefig(combined_filename, bbox_inches='tight', dpi=200)
        plt.close(fig3)
        
        print(f"可视化文件已保存到: {image_output_dir}")
        
        return {
            "original_image": original_save_path if os.path.exists(image_path) else None,
            "traces_image": trace_filename,
            "bboxes_image": bbox_filename,
            "combined_image": combined_filename
        }
        
    except Exception as e:
        print(f"可视化失败 (image_id={image_id}): {e}")
        return None


def process_json_with_qwen_optimization_improved(json_anno, trace_segmentation_method, image_base_path, qwen_optimizer, 
                                       quality_threshold=0.6, max_retries=3, visualize=True, use_douglas_peucker=False, 
                                       dp_epsilon=0.01, output_dir="."):
    """
    改进版Qwen2.5-VL优化函数，增加重试机制
    当优化后质量仍低于阈值时，会重复尝试生成新的边界框
    
    Args:
        json_anno: JSON标注数据
        trace_segmentation_method: 基础分割方法
        image_base_path: 图像基础路径
        qwen_optimizer: Qwen2.5-VL优化器实例
        quality_threshold: 质量阈值，低于此值将进行优化
        max_retries: 最大重试次数
        visualize: 是否生成可视化
        use_douglas_peucker: 是否使用Douglas-Peucker简化
        dp_epsilon: Douglas-Peucker简化参数
        output_dir: 输出目录
        
    Returns:
        optimization_results: 优化结果字典
    """
    image_id = json_anno['image_id']
    dataset_id = json_anno['dataset_id']
    full_caption = json_anno['caption']
    
    print(f"\n--- Processing Image: {image_id} with Improved Qwen2.5-VL Optimization ---")
    print(f"Base method: {trace_segmentation_method}")
    print(f"Quality threshold: {quality_threshold}, Max retries: {max_retries}")
    
    xs, ys, ts, transcription, time_begins, time_ends = get_json_anno_external(json_anno)

    if not xs:
        print("Error: Empty or invalid trace data.")
        return None

    # 构建图像路径
    split_name = dataset_id.split('_')[1]
    image_filename = f"{int(image_id):012d}.jpg"
    full_image_path = os.path.join(image_base_path, split_name, image_filename)
    
    # 创建输出目录
    os.makedirs(output_dir, exist_ok=True)
    
    # 使用指定方法获取原始分割结果
    tokens_for_vis = None
    
    if trace_segmentation_method == 'phrases_word_level':
        phrases = split_caption_into_phrases(full_caption)
        print(f"拆分出 {len(phrases)} 个短语:")
        for i, phrase in enumerate(phrases, 1):
            print(f"  {i:2d}. {phrase}")
        
        segmented_xs, segmented_ys, _ = trace_segment_by_phrases_position_based(
            xs, ys, ts, phrases, transcription, time_begins, time_ends)
        tokens_for_vis = phrases
    else:
        raise ValueError(f"Unsupported trace_segmentation_method: {trace_segmentation_method}")

    # 转换为边界框
    xmins, xmaxs, ymins, ymaxs = traces_to_bboxs(segmented_xs, segmented_ys)
    
    print(f"Generated {len(xmins)} bounding boxes from {trace_segmentation_method}")
    
    # 使用Qwen2.5-VL评估和优化（带重试机制）
    quality_scores = []
    optimized_bboxes = []
    optimization_stats = {
        "total": 0, 
        "optimized": 0, 
        "quality_improved": 0, 
        "regenerated": 0,
        "retries_used": 0,
        "max_retries_reached": 0
    }
    
    for i, token in enumerate(tokens_for_vis):
        if i >= len(xmins):
            break
            
        original_bbox = (xmins[i], ymins[i], xmaxs[i], ymaxs[i])
        optimization_stats["total"] += 1
        
        print(f"\n处理 {i+1}: '{token}'")
        print(f"原始边界框: {original_bbox}")
        
        # 评估原始定位质量
        original_quality_score = qwen_optimizer.calculate_bbox_quality_score(
            original_bbox, token, full_image_path)
        
        print(f"原始质量分数: {original_quality_score:.3f}")
        
        # 如果原始质量已经达标，直接使用
        if original_quality_score >= quality_threshold:
            optimized_bboxes.append(original_bbox)
            quality_scores.append(original_quality_score)
            print(f"✓ 质量分数达标，无需优化")
            continue
        
        # 开始优化过程（带重试机制）
        print(f"质量分数 {original_quality_score:.3f} 低于阈值 {quality_threshold}，开始优化...")
        optimization_stats["optimized"] += 1
        
        best_bbox = original_bbox
        best_quality = original_quality_score
        retry_count = 0
        
        while retry_count <= max_retries and best_quality < quality_threshold:
            if retry_count == 0:
                print(f"第1次优化尝试：基于原始边界框进行优化...")
                # 第一次：基于原始边界框的优化
                candidate_bbox = qwen_optimizer.get_corrected_bbox(
                    token, full_image_path, original_bbox)
            else:
                print(f"第{retry_count + 1}次优化尝试：重新生成定位框...")
                optimization_stats["retries_used"] += 1
                # 后续尝试：完全重新生成
                candidate_bbox = qwen_optimizer.get_corrected_bbox(
                    token, full_image_path, original_bbox=None)
                if retry_count > 1:
                    optimization_stats["regenerated"] += 1
            
            # 评估候选边界框的质量
            candidate_quality = qwen_optimizer.calculate_bbox_quality_score(
                candidate_bbox, token, full_image_path)
            
            print(f"  候选边界框: {candidate_bbox}")
            print(f"  候选质量分数: {candidate_quality:.3f}")
            
            # 如果质量有改善，更新最佳结果
            if candidate_quality > best_quality:
                best_bbox = candidate_bbox
                best_quality = candidate_quality
                print(f"  ✓ 质量改善: {best_quality:.3f}")
                
                # 如果达到阈值，可以提前退出
                if best_quality >= quality_threshold:
                    print(f"  ✓ 达到质量阈值，停止重试")
                    break
            else:
                print(f"  ✗ 质量未改善")
            
            retry_count += 1
        
        # 使用最佳结果
        optimized_bboxes.append(best_bbox)
        quality_scores.append(best_quality)
        
        # 统计结果
        if best_quality > original_quality_score:
            optimization_stats["quality_improved"] += 1
            print(f"✓ 优化成功，最终质量: {original_quality_score:.3f} -> {best_quality:.3f}")
        else:
            print(f"✗ 优化失败，保持原始定位")
        
        if retry_count > max_retries:
            optimization_stats["max_retries_reached"] += 1
            print(f"⚠ 达到最大重试次数 {max_retries}")
    
    # 输出优化统计
    print(f"\n=== 改进版Qwen2.5-VL优化统计 ===")
    print(f"总数: {optimization_stats['total']}")
    print(f"尝试优化数: {optimization_stats['optimized']}")
    print(f"成功改善数: {optimization_stats['quality_improved']}")
    print(f"使用重试数: {optimization_stats['retries_used']}")
    print(f"重新生成数: {optimization_stats['regenerated']}")
    print(f"达到最大重试数: {optimization_stats['max_retries_reached']}")
    if optimization_stats['optimized'] > 0:
        print(f"优化成功率: {optimization_stats['quality_improved']/optimization_stats['optimized']*100:.1f}%")
        print(f"重试使用率: {optimization_stats['retries_used']/optimization_stats['optimized']*100:.1f}%")
    else:
        print(f"优化成功率: 0%")
        print(f"重试使用率: 0%")
    
    # 准备返回结果
    results = {
        "image_id": image_id,
        "base_method": trace_segmentation_method,
        "tokens": tokens_for_vis,
        "original_bboxes": [(xmins[i], ymins[i], xmaxs[i], ymaxs[i]) for i in range(len(tokens_for_vis))],
        "optimized_bboxes": optimized_bboxes,
        "quality_scores": quality_scores,
        "optimization_stats": optimization_stats,
        "segmented_xs": segmented_xs,
        "segmented_ys": segmented_ys
    }
    
    # 生成可视化
    if visualize:
        visualize_qwen_optimization_comparison_improved(
            image_id, segmented_xs, segmented_ys, xmins, xmaxs, ymins, ymaxs,
            optimized_bboxes, tokens_for_vis, f"{trace_segmentation_method}_qwen_optimized_improved",
            image_base_path, dataset_id, full_caption, output_dir, quality_scores, optimization_stats)
        
        # 保存详细结果到JSON
        results_file = os.path.join(output_dir, f"qwen_optimization_results_improved_{image_id}.json")
        with open(results_file, 'w', encoding='utf-8') as f:
            # 转换numpy类型为Python原生类型以便JSON序列化
            json_results = {
                "image_id": image_id,
                "base_method": trace_segmentation_method,
                "tokens": tokens_for_vis,
                "original_bboxes": results["original_bboxes"],
                "optimized_bboxes": results["optimized_bboxes"], 
                "quality_scores": [float(score) for score in quality_scores],
                "optimization_stats": optimization_stats
            }
            json.dump(json_results, f, ensure_ascii=False, indent=2)
        print(f"优化结果已保存到: {results_file}")
    
    return results


def visualize_qwen_optimization_comparison_improved(
    image_id, original_segmented_xs, original_segmented_ys, original_xmins, original_xmaxs, 
    original_ymins, original_ymaxs, optimized_bboxes, tokens, segmentation_method, 
    image_base_path, dataset_id, full_caption, output_dir=".", quality_scores=None, optimization_stats=None):
    """
    改进版可视化函数，增加优化统计信息显示
    """
    # Construct the path to the actual image file.
    split_name = dataset_id.split('_')[1]
    image_filename = f"{int(image_id):012d}.jpg"
    full_image_path = os.path.join(image_base_path, split_name, image_filename)
    
    try:
        im = Image.open(full_image_path)
        imw, imh = im.size
        im_array = np.array(im)
        os.makedirs(output_dir, exist_ok=True)
        
        # Save the original image
        original_image_output_path = os.path.join(output_dir, f"original_{image_id}.jpg")
        im.save(original_image_output_path)
        print(f"Saved original image to: {original_image_output_path}")
    except FileNotFoundError:
        print(f"Image not found at {full_image_path}, using a blank canvas instead.")
        imw, imh = 640, 480
        im_array = np.ones((imh, imw, 3), dtype=np.uint8) * 255

    colors = plt.get_cmap('gist_rainbow')(np.linspace(0, 1, len(tokens)))

    # 生成对比图：原始定位 vs Qwen优化定位
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(24, 12))
    
    # 左图：原始定位
    ax1.imshow(im_array)
    ax1.set_xlim(0, imw)
    ax1.set_ylim(imh, 0)
    
    for i in range(len(original_xmins)):
        if i < len(colors):
            segment_color = colors[i]
            
            # 绘制原始边界框
            min_x = original_xmins[i] * imw
            max_x = original_xmaxs[i] * imw
            min_y = original_ymins[i] * imh
            max_y = original_ymaxs[i] * imh
            
            rect = patches.Rectangle(
                (min_x, min_y),
                max_x - min_x,
                max_y - min_y,
                linewidth=3,
                edgecolor=segment_color,
                facecolor='none',
                alpha=0.8
            )
            ax1.add_patch(rect)
            
            # 添加标签
            if tokens and i < len(tokens):
                label_text = f"{i+1}: {tokens[i][:25]}..."
                
                ax1.text(min_x, min_y - 10, label_text, 
                        fontsize=10, color=segment_color, weight='bold',
                        bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=2))

    ax1.set_title(f"Original {segmentation_method}\nImage {image_id}", fontsize=16)
    ax1.axis('off')
    
    # 右图：Qwen优化后定位
    ax2.imshow(im_array)
    ax2.set_xlim(0, imw)
    ax2.set_ylim(imh, 0)
    
    for i in range(len(optimized_bboxes)):
        if i < len(colors) and optimized_bboxes[i] is not None:
            segment_color = colors[i]
            
            # 绘制优化后的边界框
            bbox = optimized_bboxes[i]
            min_x = bbox[0] * imw
            min_y = bbox[1] * imh
            max_x = bbox[2] * imw
            max_y = bbox[3] * imh
            
            rect = patches.Rectangle(
                (min_x, min_y),
                max_x - min_x,
                max_y - min_y,
                linewidth=3,
                edgecolor=segment_color,
                facecolor='none',
                alpha=0.8
            )
            ax2.add_patch(rect)
            
            # 添加标签（包含质量分数）
            if tokens and i < len(tokens):
                label_text = f"{i+1}: {tokens[i][:20]}..."
                if quality_scores and i < len(quality_scores):
                    label_text += f" (Q:{quality_scores[i]:.2f})"
                
                ax2.text(min_x, min_y - 10, label_text, 
                        fontsize=10, color=segment_color, weight='bold',
                        bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=2))

    # 添加优化统计信息到标题
    title = f"Improved Qwen2.5-VL Optimized {segmentation_method}\nImage {image_id}"
    if optimization_stats:
        title += f"\nOptimized: {optimization_stats.get('quality_improved', 0)}/{optimization_stats.get('optimized', 0)}"
        title += f" | Retries: {optimization_stats.get('retries_used', 0)}"
    ax2.set_title(title, fontsize=16)
    ax2.axis('off')
    
    # 保存对比图
    comparison_filename = f"qwen_optimization_improved_{image_id}_{segmentation_method.replace(' ', '_').replace('/', '_')}.png"
    plt.savefig(os.path.join(output_dir, comparison_filename), bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Saved improved Qwen optimization comparison to: {os.path.join(output_dir, comparison_filename)}")


def process_single_ln_data(json_anno, coco_images_path, output_dir, qwen_optimizer=None, visualize=True, use_qwen_optimization=False):
    """
    处理单个LN数据，提取短语和对应的坐标，并可视化
    支持使用Qwen2.5-VL优化定位精度
    """
    image_id = json_anno['image_id']
    dataset_id = json_anno['dataset_id']
    full_caption = json_anno['caption']
    
    # 提取轨迹数据
    xs, ys, ts, transcription, time_begins, time_ends = get_json_anno_external(json_anno)
    
    if not xs:
        return None
    
    # 构建图像路径
    split_name = dataset_id.split('_')[1]
    image_filename = f"{int(image_id):012d}.jpg"
    full_image_path = os.path.join(coco_images_path, split_name, image_filename)
    
    if use_qwen_optimization and qwen_optimizer is not None:
        # 使用Qwen2.5-VL优化的方法
        print(f"使用Qwen2.5-VL优化处理图像 {image_id}...")
        
        optimization_results = process_json_with_qwen_optimization_improved(
            json_anno, 
            trace_segmentation_method='phrases_word_level',  # 使用正确的参数名
            image_base_path=coco_images_path,
            qwen_optimizer=qwen_optimizer,
            quality_threshold=0.6,
            max_retries=3,  # 新增：最大重试次数
            visualize=False,  # 稍后单独处理可视化
            use_douglas_peucker=False,
            dp_epsilon=0.01,
            output_dir=output_dir
        )
        
        if optimization_results is None:
            return None
            
        phrases = optimization_results['tokens']
        optimized_bboxes = optimization_results['optimized_bboxes']
        segmented_xs = optimization_results['segmented_xs']
        segmented_ys = optimization_results['segmented_ys']
        
        # 从优化后的边界框中提取坐标
        xmins = [bbox[0] for bbox in optimized_bboxes]
        ymins = [bbox[1] for bbox in optimized_bboxes]
        xmaxs = [bbox[2] for bbox in optimized_bboxes]
        ymaxs = [bbox[3] for bbox in optimized_bboxes]
        
    else:
        # 使用原始方法
        # 分割成短语
        phrases = split_caption_into_phrases(full_caption)
        
        if not phrases:
            return None
        
        # 基于短语分割轨迹
        segmented_xs, segmented_ys, segmented_ts = trace_segment_by_phrases_position_based(
            xs, ys, ts, phrases, transcription, time_begins, time_ends
        )
        
        # 转换为边界框
        xmins, xmaxs, ymins, ymaxs = traces_to_bboxs(segmented_xs, segmented_ys)
    
    # 可视化（如果需要）
    visualization_files = None
    if visualize:
        if use_qwen_optimization and 'optimization_results' in locals() and optimization_results:
            # 如果使用了Qwen优化，可视化已经在优化过程中处理了
            # 这里可以添加额外的可视化，如轨迹线图等
            visualization_files = visualize_trajectory_data(
                image_id, full_image_path, phrases, segmented_xs, segmented_ys,
                xmins, xmaxs, ymins, ymaxs, full_caption, output_dir
            )
        else:
            visualization_files = visualize_trajectory_data(
                image_id, full_image_path, phrases, segmented_xs, segmented_ys,
                xmins, xmaxs, ymins, ymaxs, full_caption, output_dir
            )
    
    # 准备结果数据
    result = {
        "image_id": image_id,
        "image_filename": image_filename,
        "image_path": full_image_path,
        "full_caption": full_caption,
        "phrases_data": [],
        "visualization_files": visualization_files,
        "qwen_optimized": use_qwen_optimization
    }
    
    # 处理每个短语
    for i, phrase in enumerate(phrases):
        if i < len(xmins):
            bbox = (xmins[i], ymins[i], xmaxs[i], ymaxs[i])
            
            phrase_data = {
                "phrase": phrase,
                "bbox": bbox,  # (xmin, ymin, xmax, ymax) 归一化坐标
                "trace_points": list(zip(segmented_xs[i], segmented_ys[i])) if segmented_xs[i] else []
            }
            
            result["phrases_data"].append(phrase_data)
    
    return result


def build_small_targets_trajectory_dataset(ln_base_path, small_targets_json_path, 
                                         coco_images_path, output_path, 
                                         max_samples=None, visualize=True, 
                                         use_qwen_optimization=False, qwen_model_path=None):
    """
    构建小目标轨迹数据集（带可视化）
    支持使用Qwen2.5-VL优化定位精度
    """
    
    print("="*60)
    print("开始构建小目标轨迹数据集（带可视化）")
    if use_qwen_optimization:
        print("启用Qwen2.5-VL优化")
    print("="*60)
    
    # 初始化Qwen优化器（如果需要）
    qwen_optimizer = None
    if use_qwen_optimization and qwen_model_path:
        print("正在初始化Qwen2.5-VL优化器...")
        try:
            qwen_optimizer = QwenVLLocationCorrector(qwen_model_path, num_gpus=1)
            print("Qwen2.5-VL优化器初始化成功")
        except Exception as e:
            print(f"Qwen2.5-VL优化器初始化失败: {e}")
            print("将使用原始方法处理")
            use_qwen_optimization = False
    
    # 1. 加载小目标数据
    small_target_image_ids, small_targets_data = load_small_targets_data(small_targets_json_path)
    
    # 2. 加载并筛选LN数据
    filtered_ln_data = load_ln_jsonl_files(ln_base_path, small_target_image_ids)
    
    if max_samples:
        filtered_ln_data = filtered_ln_data[:max_samples]
        print(f"限制处理样本数: {max_samples}")
    
    # 3. 创建输出目录
    base_output_dir = os.path.splitext(output_path)[0] + "_visualizations"
    if visualize:
        os.makedirs(base_output_dir, exist_ok=True)
        print(f"可视化文件将保存到: {base_output_dir}")
    
    # 4. 处理每个LN数据
    trajectory_dataset = {
        "metadata": {
            "source_files": [
                "coco_train_localized_narratives-00000-of-00004.jsonl",
                "coco_train_localized_narratives-00001-of-00004.jsonl",
                "coco_train_localized_narratives-00002-of-00004.jsonl",
                "coco_train_localized_narratives-00003-of-00004.jsonl"
            ],
            "small_targets_source": small_targets_json_path,
            "total_samples": len(filtered_ln_data),
            "visualization_enabled": visualize,
            "visualization_directory": base_output_dir if visualize else None,
            "creation_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        },
        "data": []
    }
    
    print(f"开始处理 {len(filtered_ln_data)} 个样本...")
    
    successful_count = 0
    failed_count = 0
    
    for i, ln_data in enumerate(tqdm(filtered_ln_data, desc="处理轨迹数据")):
        try:
            result = process_single_ln_data(
                ln_data, coco_images_path, base_output_dir, 
                qwen_optimizer=qwen_optimizer, 
                visualize=visualize, 
                use_qwen_optimization=use_qwen_optimization
            )
            
            if result is not None:
                trajectory_dataset["data"].append(result)
                successful_count += 1
            else:
                failed_count += 1
                
        except Exception as e:
            print(f"处理第 {i+1} 个样本时出错: {e}")
            failed_count += 1
            continue
    
    # 5. 保存结果
    print(f"\n处理完成:")
    print(f"  成功: {successful_count}")
    print(f"  失败: {failed_count}")
    print(f"  总计: {len(filtered_ln_data)}")
    
    # 更新元数据
    trajectory_dataset["metadata"]["successful_samples"] = successful_count
    trajectory_dataset["metadata"]["failed_samples"] = failed_count
    
    # 计算一些统计信息
    total_phrases = sum(len(item["phrases_data"]) for item in trajectory_dataset["data"])
    trajectory_dataset["metadata"]["total_phrases"] = total_phrases
    
    print(f"  总短语数: {total_phrases}")
    
    # 保存到JSON文件
    print(f"\n保存结果到: {output_path}")
    os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(trajectory_dataset, f, ensure_ascii=False, indent=2)
    
    print("数据集构建完成！")
    
    return trajectory_dataset


if __name__ == '__main__':
    # 配置路径
    LN_BASE_PATH = '/storage-root/datasets/yangfan/Seg_LLaVA_v2/datasets/Localized_Narratives'
    SMALL_TARGETS_JSON_PATH = '/storage-root/datasets/yangfan/ICLR_AAAI/Trajectory-VLM/coco_bbox_analysis_20250824_232913/small_targets_details.json'
    COCO_IMAGES_PATH = '/storage-root/datasets/yangfan/coco2017'
    
    # 创建输出路径
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_path = f"small_targets_trajectory_with_vis_{timestamp}.json"
    
    try:
        # 构建数据集 - 支持Qwen2.5-VL优化
        QWEN_MODEL_PATH = "/storage-root/9950backfile/yangfan/coyo/Qwen/Qwen2.5-VL-72B-Instruct"  # 可选
        
        trajectory_dataset = build_small_targets_trajectory_dataset(
            ln_base_path=LN_BASE_PATH,
            small_targets_json_path=SMALL_TARGETS_JSON_PATH,
            coco_images_path=COCO_IMAGES_PATH,
            output_path=output_path,
            max_samples=20,  # 限制处理20个样本进行测试
            visualize=True,   # 启用可视化
            use_qwen_optimization=True,  # 设置为True启用Qwen优化
            qwen_model_path=QWEN_MODEL_PATH if os.path.exists(QWEN_MODEL_PATH) else None
        )
        
        print(f"\n数据集构建成功！")
        print(f"JSON文件: {os.path.abspath(output_path)}")
        print(f"可视化目录: {os.path.abspath(os.path.splitext(output_path)[0] + '_visualizations')}")
        print(f"总样本数: {len(trajectory_dataset['data'])}")
        print(f"总短语数: {trajectory_dataset['metadata']['total_phrases']}")
        
        # 打印一些示例数据
        if trajectory_dataset['data']:
            print(f"\n示例数据:")
            sample = trajectory_dataset['data'][0]
            print(f"图像ID: {sample['image_id']}")
            print(f"完整描述: {sample['full_caption']}")
            print(f"短语数量: {len(sample['phrases_data'])}")
            for i, phrase_data in enumerate(sample['phrases_data'][:3]):  # 只显示前3个
                print(f"  短语{i+1}: {phrase_data['phrase']}")
                print(f"    边界框: {phrase_data['bbox']}")
                print(f"    轨迹点数: {len(phrase_data['trace_points'])}")
            
            if sample['visualization_files']:
                print(f"可视化文件:")
                for key, file_path in sample['visualization_files'].items():
                    if file_path:
                        print(f"  {key}: {os.path.basename(file_path)}")
        
    except Exception as e:
        print(f"构建数据集时出错: {e}")
        import traceback
        traceback.print_exc()