# -*- coding: utf-8 -*-
"""
改进的COCO Localized Narratives处理代码
支持基于短语的更精确的轨迹分割
"""
import json
import collections
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
from rdp import rdp
import json
import re
from difflib import SequenceMatcher
import datetime


def load_jsonl_by_index(jsonl_path, selected_index=1):
    """
    读取jsonl文件，返回第selected_index条（从1开始计数）数据
    """
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f, 1):
            if i == selected_index:
                return json.loads(line)
    raise IndexError(f"JSONL文件中没有第{selected_index}条数据")


def visualize_all_boxes(
    image_id, segmented_xs, segmented_ys, xmins, xmaxs, ymins, ymaxs,
    tokens, segmentation_method, image_base_path, dataset_id, full_caption, output_dir="."):
    """
    生成三种可视化图：原图、轨迹线图、坐标框图
    """
    # Construct the path to the actual image file.
    split_name = dataset_id.split('_')[1]
    image_filename = f"{int(image_id):012d}.jpg"
    full_image_path = os.path.join(image_base_path, split_name, image_filename)
    
    try:
        im = Image.open(full_image_path)
        imw, imh = im.size
        # Convert PIL image to numpy array for matplotlib
        im_array = np.array(im)
        # 确保输出目录存在
        os.makedirs(output_dir, exist_ok=True)
        # Save the original image to the output directory
        original_image_output_path = os.path.join(output_dir, f"original_{image_id}.jpg")
        im.save(original_image_output_path)
        print(f"Saved original image to: {original_image_output_path}")
    except FileNotFoundError:
        print(f"Image not found at {full_image_path}, using a blank canvas instead.")
        # Fallback to a blank canvas if image is not found.
        imw, imh = 640, 480
        im_array = np.ones((imh, imw, 3), dtype=np.uint8) * 255  # White background

    # Use a colormap to get distinct colors for each segment
    colors = plt.get_cmap('gist_rainbow')(np.linspace(0, 1, len(xmins)))

    # 1. 生成带轨迹线的图
    fig1, ax1 = plt.subplots(figsize=(12, 9))
    ax1.imshow(im_array)
    ax1.set_xlim(0, imw)
    ax1.set_ylim(imh, 0)  # Flip y-axis to match image coordinates

    for i in range(len(xmins)):
        segment_color = colors[i]
        
        # Plot the trace segment with the corresponding color
        if segmented_xs[i]:  # Check if segment is not empty
            seg_xs_scaled = [x * imw for x in segmented_xs[i]]
            seg_ys_scaled = [y * imh for y in segmented_ys[i]]
            ax1.plot(seg_xs_scaled, seg_ys_scaled, linewidth=3, color=segment_color, alpha=0.8)

    ax1.set_title(f"Trace Lines for Image {image_id}\nMethod: {segmentation_method}", fontsize=16)
    ax1.axis('off')
    
    # Save trace lines image
    trace_filename = f"traces_{image_id}_{segmentation_method.replace(' ', '_').replace('/', '_')}.png"
    plt.savefig(os.path.join(output_dir, trace_filename), bbox_inches='tight', dpi=300)
    plt.close(fig1)
    print(f"Saved trace lines image to: {os.path.join(output_dir, trace_filename)}")

    # 2. 生成坐标框图
    fig2, ax2 = plt.subplots(figsize=(12, 9))
    ax2.imshow(im_array)
    ax2.set_xlim(0, imw)
    ax2.set_ylim(imh, 0)  # Flip y-axis to match image coordinates

    for i in range(len(xmins)):
        segment_color = colors[i]
        
        # 为每个短语的轨迹点创建边界框
        if segmented_xs[i] and segmented_ys[i]:  # Check if segment is not empty
            # 计算轨迹点的边界框
            min_x = min(segmented_xs[i]) * imw
            max_x = max(segmented_xs[i]) * imw
            min_y = min(segmented_ys[i]) * imh
            max_y = max(segmented_ys[i]) * imh
            
            # 创建矩形框
            rect = patches.Rectangle(
                (min_x, min_y),
                max_x - min_x,
                max_y - min_y,
                linewidth=3,
                edgecolor=segment_color,
                facecolor='none',
                alpha=0.8
            )
            ax2.add_patch(rect)
            
            # 添加标签
            if tokens and i < len(tokens):
                # 在框的左上角添加短语标签
                ax2.text(min_x, min_y - 10, f"{i+1}: {tokens[i][:30]}...", 
                        fontsize=10, color=segment_color, weight='bold',
                        bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=2))

    ax2.set_title(f"Bounding Boxes for Image {image_id}\nMethod: {segmentation_method}", fontsize=16)
    ax2.axis('off')
    
    # Save bounding boxes image
    bbox_filename = f"bboxes_{image_id}_{segmentation_method.replace(' ', '_').replace('/', '_')}.png"
    plt.savefig(os.path.join(output_dir, bbox_filename), bbox_inches='tight', dpi=300)
    plt.close(fig2)
    print(f"Saved bounding boxes image to: {os.path.join(output_dir, bbox_filename)}")

    # 3. 生成带标注的综合图（可选）
    fig3, ax3 = plt.subplots(figsize=(12, 9))
    ax3.imshow(im_array)
    ax3.set_xlim(0, imw)
    ax3.set_ylim(imh, 0)

    # 同时显示轨迹线和边界框
    for i in range(len(xmins)):
        segment_color = colors[i]
        
        if segmented_xs[i] and segmented_ys[i]:
            # 绘制轨迹线
            seg_xs_scaled = [x * imw for x in segmented_xs[i]]
            seg_ys_scaled = [y * imh for y in segmented_ys[i]]
            ax3.plot(seg_xs_scaled, seg_ys_scaled, linewidth=2, color=segment_color, alpha=0.6)
            
            # 绘制边界框
            min_x = min(segmented_xs[i]) * imw
            max_x = max(segmented_xs[i]) * imw
            min_y = min(segmented_ys[i]) * imh
            max_y = max(segmented_ys[i]) * imh
            
            rect = patches.Rectangle(
                (min_x, min_y),
                max_x - min_x,
                max_y - min_y,
                linewidth=2,
                edgecolor=segment_color,
                facecolor='none',
                alpha=0.8,
                linestyle='--'
            )
            ax3.add_patch(rect)

    ax3.set_title(f"Combined Visualization for Image {image_id}\nMethod: {segmentation_method}", fontsize=16)
    
    # Add caption with colors
    if tokens:
        fig_width = fig3.get_window_extent().width
        y_pos = 0.03
        
        renderer = fig3.canvas.get_renderer()
        text_objs = [plt.text(0, 0, f"{t} ", fontsize=10, ha='left', va='bottom') for t in tokens]
        total_width_pixels = sum(t.get_window_extent(renderer).width for t in text_objs)
        for t in text_objs:
            t.remove()

        start_x = 0.5 - (total_width_pixels / fig_width) / 2
        current_x = start_x
        
        for i, token in enumerate(tokens):
            if i < len(colors):
                color = colors[i]
                txt_obj = plt.text(current_x, y_pos, f"{token} ", transform=fig3.transFigure,
                                 ha='left', va='bottom', fontsize=10, color=color,
                                 bbox=dict(facecolor='white', alpha=0.8, edgecolor='none', pad=1))
                current_x += txt_obj.get_window_extent(renderer).width / fig_width

    plt.subplots_adjust(bottom=0.15)
    ax3.axis('off')
    
    # Save combined image
    combined_filename = f"combined_{image_id}_{segmentation_method.replace(' ', '_').replace('/', '_')}.png"
    plt.savefig(os.path.join(output_dir, combined_filename), bbox_inches='tight', dpi=300)
    plt.close(fig3)
    print(f"Saved combined image to: {os.path.join(output_dir, combined_filename)}")
    
    print(f"All visualizations saved to: {os.path.abspath(output_dir)}")


def get_json_anno_external(json_anno):
    """Get transcription and trace for each of its element from a json dict."""
    traces = json_anno['traces']
    all_traces = [point for trace in traces for point in trace]
    
    if not all_traces:
        return None, None, None, None, None, None

    xs = [max(0.0, min(1.0, p['x'])) for p in all_traces]
    ys = [max(0.0, min(1.0, p['y'])) for p in all_traces]
    ts = [p['t'] for p in all_traces]

    timed_caption = json_anno['timed_caption']
    toks = [tcap['utterance'] for tcap in timed_caption]
    time_begins = [tcap['start_time'] for tcap in timed_caption]
    time_ends = [tcap['end_time'] for tcap in timed_caption]
    
    return xs, ys, ts, toks, time_begins, time_ends


def trace_segment_uniform_time_interval(
    trace_xs, trace_ys, trace_ts, time_interval, use_douglas_peucker=False, dp_epsilon=0.01):
    """Segment the traces uniformly given a specificed time interval."""
    if not trace_ts:
        return [], [], []
    
    tbins = (np.array(trace_ts) - trace_ts[0]) / time_interval
    tbins = tbins.astype(int)
    
    segmented_xs, segmented_ys, segmented_ts = [], [], []
    
    num_segments = tbins[-1] + 1 if len(tbins) > 0 else 0
    for i in range(num_segments):
        indices = np.where(tbins == i)[0]
        if len(indices) > 0:
            xs = [trace_xs[j] for j in indices]
            ys = [trace_ys[j] for j in indices]
            ts = [trace_ts[j] for j in indices]

            # Apply Douglas-Peucker simplification if requested
            if use_douglas_peucker and len(xs) > 2:
                points = np.array(list(zip(xs, ys)))
                simplified_points = rdp(points, epsilon=dp_epsilon)
                segmented_xs.append(simplified_points[:, 0].tolist())
                segmented_ys.append(simplified_points[:, 1].tolist())
            else:
                segmented_xs.append(xs)
                segmented_ys.append(ys)
            
            segmented_ts.append(ts)
        else:
            segmented_xs.append([])
            segmented_ys.append([])
            segmented_ts.append([])

    return segmented_xs, segmented_ys, segmented_ts


def trace_segment_timestamp(trace_xs, trace_ys, trace_ts,
                            token_time_begins, token_time_ends, use_douglas_peucker=False, dp_epsilon=0.01):
    """Segment the traces based on 'ground-truth' timestamped tokens."""
    segmented_xs, segmented_ys, segmented_ts = [], [], []
    trace_arr = np.array(trace_ts)

    for time_begin, time_end in zip(token_time_begins, token_time_ends):
        # Find indices of trace points within the token's time window
        indices = np.where((trace_arr >= time_begin) & (trace_arr <= time_end))[0]
        if len(indices) > 0:
            xs = [trace_xs[j] for j in indices]
            ys = [trace_ys[j] for j in indices]
            ts = [trace_ts[j] for j in indices]

            # Apply Douglas-Peucker simplification if requested
            if use_douglas_peucker and len(xs) > 2:
                points = np.array(list(zip(xs, ys)))
                simplified_points = rdp(points, epsilon=dp_epsilon)
                segmented_xs.append(simplified_points[:, 0].tolist())
                segmented_ys.append(simplified_points[:, 1].tolist())
            else:
                segmented_xs.append(xs)
                segmented_ys.append(ys)

            segmented_ts.append(ts)
        else:
            segmented_xs.append([])
            segmented_ys.append([])
            segmented_ts.append([])
            
    return segmented_xs, segmented_ys, segmented_ts

def split_caption_into_phrases(caption):
    """
    将完整的caption拆分成短语
    
    Args:
        caption (str): 完整的图像描述文本
        
    Returns:
        list: 拆分后的短语列表
    """
    import re
    
    # 预处理：清理多余的空格
    caption = re.sub(r'\s+', ' ', caption.strip())
    
    # 定义分割标志
    # 1. 句号、感叹号、问号后面的分割
    # 2. 逗号后面的分割（但要考虑上下文）
    # 3. 连接词前的分割
    
    # 先按句号、感叹号、问号分割成句子
    sentences = re.split(r'[.!?]+', caption)
    
    phrases = []
    
    for sentence in sentences:
        if not sentence.strip():
            continue
            
        # 对每个句子进一步分割
        # 按逗号分割，但保留语义完整性
        comma_parts = sentence.split(',')
        
        current_phrase = ""
        
        for i, part in enumerate(comma_parts):
            part = part.strip()
            if not part:
                continue
                
            # 如果当前部分太短（少于3个词），与前面的合并
            words_in_part = len(part.split())
            
            if current_phrase and words_in_part < 3:
                current_phrase += ", " + part
            else:
                # 如果有积累的短语，先添加到结果中
                if current_phrase:
                    phrases.append(current_phrase.strip())
                current_phrase = part
        
        # 添加最后一个短语
        if current_phrase:
            phrases.append(current_phrase.strip())
    
    # 进一步优化：处理连接词分割
    refined_phrases = []
    
    for phrase in phrases:
        # 在特定连接词处分割
        connectors = [
            r'\bThere is\b', r'\bThere are\b', r'\bWe can observe\b', 
            r'\bIn the\b', r'\bOn the\b', r'\bAt the\b',
            r'\bWhich is\b', r'\bThat is\b'
        ]
        
        # 尝试在连接词处分割
        parts = [phrase]
        for connector in connectors:
            new_parts = []
            for part in parts:
                # 在连接词前分割，但保留连接词在新短语的开头
                split_parts = re.split(f'({connector})', part, flags=re.IGNORECASE)
                
                current = ""
                for j, split_part in enumerate(split_parts):
                    if re.match(connector, split_part, re.IGNORECASE):
                        # 这是连接词，开始新的短语
                        if current.strip():
                            new_parts.append(current.strip())
                        current = split_part
                    else:
                        current += split_part
                
                if current.strip():
                    new_parts.append(current.strip())
            
            parts = new_parts
        
        refined_phrases.extend(parts)
    
    # 处理"and"的分割
    final_refined_phrases = []
    for phrase in refined_phrases:
        # 在"and"处分割，但要考虑上下文
        # 使用正则表达式匹配 " and " （前后有空格的and）
        and_parts = re.split(r'\s+and\s+', phrase, flags=re.IGNORECASE)
        
        if len(and_parts) > 1:
            # 如果有多个部分，需要智能处理
            for i, part in enumerate(and_parts):
                part = part.strip()
                if not part:
                    continue
                
                # 检查部分是否太短，如果太短则与前一个合并
                words_in_part = len(part.split())
                
                # 如果是第一个部分或者词数足够多，直接添加
                if i == 0 or words_in_part >= 3:
                    if part:
                        final_refined_phrases.append(part)
                else:
                    # 如果太短，与前一个合并
                    if final_refined_phrases:
                        final_refined_phrases[-1] += " and " + part
                    else:
                        final_refined_phrases.append(part)
        else:
            # 没有"and"，直接添加
            if phrase.strip():
                final_refined_phrases.append(phrase.strip())
    
    # 最终清理：移除过短或空的短语
    final_phrases = []
    for phrase in final_refined_phrases:
        phrase = phrase.strip()
        if phrase and len(phrase.split()) >= 2:  # 至少包含2个词
            final_phrases.append(phrase)
    
    return final_phrases

def split_caption_into_phrases_improved(caption):
    """
    改进的caption短语分割函数
    更好地处理各种语言结构
    """
    # 预处理：清理多余的空格
    caption = re.sub(r'\s+', ' ', caption.strip())
    
    phrases = []
    
    # 第一步：按强分割符分割（句号、感叹号、问号）
    sentences = re.split(r'[.!?]+', caption)
    
    for sentence in sentences:
        if not sentence.strip():
            continue
            
        sentence = sentence.strip()
        
        # 第二步：识别并分割介词短语和从句
        split_patterns = [
            r'(\s+and\s+)',           # "and" 连接
            r'(\s*,\s*and\s+)',       # ", and" 连接  
            r'(\s*,\s*)',             # 逗号分隔
            r'(\s+which\s+)',         # "which" 从句
            r'(\s+that\s+)',          # "that" 从句
            r'(\s+where\s+)',         # "where" 从句
            r'(\s+while\s+)',         # "while" 从句
            r'(\s+in\s+the\s+)',      # "in the" 介词短语
            r'(\s+on\s+the\s+)',      # "on the" 介词短语
            r'(\s+at\s+the\s+)',      # "at the" 介词短语
            r'(\s+behind\s+)',        # "behind" 介词
            r'(\s+there\s+is\s+)',    # "there is" 存在句
            r'(\s+there\s+are\s+)',   # "there are" 存在句
        ]
        
        # 逐步应用分割规则
        current_parts = [sentence]
        
        for pattern in split_patterns:
            new_parts = []
            for part in current_parts:
                # 使用分组捕获来保留分割符
                segments = re.split(pattern, part, flags=re.IGNORECASE)
                
                i = 0
                while i < len(segments):
                    if i == 0:
                        # 第一个片段
                        if segments[i].strip():
                            new_parts.append(segments[i].strip())
                    else:
                        # 处理分割符和后续内容
                        if i + 1 < len(segments):
                            separator = segments[i].strip()
                            next_part = segments[i + 1].strip()
                            
                            if next_part:
                                # 将分割符与后续内容合并
                                combined = separator + " " + next_part if separator else next_part
                                new_parts.append(combined.strip())
                            i += 1  # 跳过下一个片段，因为已经处理了
                    i += 1
            
            current_parts = new_parts
        
        # 第三步：后处理 - 合并过短的片段
        final_parts = []
        for part in current_parts:
            part = part.strip()
            if not part:
                continue
                
            # 如果片段太短且前面有片段，尝试合并
            words_in_part = len(part.split())
            if words_in_part < 3 and final_parts:
                # 检查是否可以合并
                last_part = final_parts[-1]
                combined_length = len(last_part.split()) + words_in_part
                
                if combined_length <= 8:  # 避免合并后过长
                    final_parts[-1] = last_part + ", " + part
                else:
                    final_parts.append(part)
            else:
                final_parts.append(part)
        
        phrases.extend(final_parts)
    
    # 最终清理：确保每个短语都有意义
    cleaned_phrases = []
    for phrase in phrases:
        phrase = phrase.strip()
        # 移除只包含连接词的短语
        if phrase and len(phrase.split()) >= 2:
            # 清理开头的连接词
            phrase = re.sub(r'^(and|or|but)\s+', '', phrase, flags=re.IGNORECASE)
            if phrase:
                cleaned_phrases.append(phrase)
    
    return cleaned_phrases


def trace_segment_by_phrases_word_level(trace_xs, trace_ys, trace_ts, phrases, original_tokens, time_begins, time_ends, use_douglas_peucker=False, dp_epsilon=0.01):
    """
    基于词级别匹配的短语轨迹分割函数
    更精确地处理跨短语的token分配
    """
    if not trace_ts or not phrases:
        return [], [], []
    
    # 预处理函数
    def clean_and_tokenize(text):
        # 移除标点符号，转为小写，然后分词
        cleaned = re.sub(r'[^\w\s]', ' ', text.lower())
        return [word.strip() for word in cleaned.split() if word.strip()]
    
    segmented_xs, segmented_ys, segmented_ts = [], [], []
    
    # 预处理所有短语和tokens
    phrase_words = [clean_and_tokenize(phrase) for phrase in phrases]
    token_words = [clean_and_tokenize(token) for token in original_tokens]
    
    print(f"预处理后的短语词列表:")
    for i, words in enumerate(phrase_words):
        print(f"  短语 {i+1}: {words}")
    
    print(f"预处理后的token词列表:")
    for i, words in enumerate(token_words):
        print(f"  Token {i+1}: {words}")
    
    # 为每个短语找到匹配的tokens
    for phrase_idx, phrase_word_list in enumerate(phrase_words):
        print(f"\n处理短语 {phrase_idx + 1}: '{phrases[phrase_idx]}'")
        print(f"短语词列表: {phrase_word_list}")
        
        matched_token_indices = []
        
        # 对每个token，计算与当前短语的重叠度
        for token_idx, token_word_list in enumerate(token_words):
            if not token_word_list:  # 跳过空token
                continue
                
            # 计算词重叠
            phrase_word_set = set(phrase_word_list)
            token_word_set = set(token_word_list)
            
            overlap = phrase_word_set & token_word_set
            
            if overlap:
                # 计算重叠比例
                overlap_ratio_in_phrase = len(overlap) / len(phrase_word_set) if phrase_word_list else 0
                overlap_ratio_in_token = len(overlap) / len(token_word_set) if token_word_list else 0
                
                # 如果有足够的重叠，认为匹配
                if overlap_ratio_in_token >= 0.3 or overlap_ratio_in_phrase >= 0.2:
                    matched_token_indices.append(token_idx)
                    print(f"  匹配token {token_idx}: '{original_tokens[token_idx]}'")
                    print(f"    重叠词: {overlap}")
                    print(f"    短语重叠比例: {overlap_ratio_in_phrase:.2f}, Token重叠比例: {overlap_ratio_in_token:.2f}")
        
        # 如果没有找到匹配，尝试子串匹配
        if not matched_token_indices:
            print(f"  未找到词重叠匹配，尝试子串匹配...")
            phrase_text = ' '.join(phrase_word_list)
            
            for token_idx, token_word_list in enumerate(token_words):
                token_text = ' '.join(token_word_list)
                
                # 检查是否有子串关系
                if (phrase_text in token_text and len(phrase_text) >= 3) or \
                   (token_text in phrase_text and len(token_text) >= 3):
                    matched_token_indices.append(token_idx)
                    print(f"  子串匹配token {token_idx}: '{original_tokens[token_idx]}'")
        
        # 收集匹配tokens对应的轨迹点
        phrase_xs, phrase_ys, phrase_ts = [], [], []
        
        for token_idx in matched_token_indices:
            if token_idx < len(time_begins) and token_idx < len(time_ends):
                token_start_time = time_begins[token_idx]
                token_end_time = time_ends[token_idx]
                
                # 找到这个token时间范围内的所有轨迹点
                trace_arr = np.array(trace_ts)
                indices = np.where((trace_arr >= token_start_time) & (trace_arr <= token_end_time))[0]
                
                for idx in indices:
                    phrase_xs.append(trace_xs[idx])
                    phrase_ys.append(trace_ys[idx])
                    phrase_ts.append(trace_ts[idx])
        
        print(f"  总共收集到 {len(phrase_xs)} 个轨迹点")
        
        if phrase_xs:
            # 去重并按时间排序
            time_point_pairs = list(zip(phrase_ts, phrase_xs, phrase_ys))
            time_point_pairs = list(set(time_point_pairs))  # 去重
            time_point_pairs.sort(key=lambda x: x[0])  # 按时间排序
            
            phrase_ts_sorted = [t for t, x, y in time_point_pairs]
            phrase_xs_sorted = [x for t, x, y in time_point_pairs]
            phrase_ys_sorted = [y for t, x, y in time_point_pairs]
            
            # 应用Douglas-Peucker简化（如果需要）
            if use_douglas_peucker and len(phrase_xs_sorted) > 2:
                points = np.array(list(zip(phrase_xs_sorted, phrase_ys_sorted)))
                simplified_points = rdp(points, epsilon=dp_epsilon)
                segmented_xs.append(simplified_points[:, 0].tolist())
                segmented_ys.append(simplified_points[:, 1].tolist())
            else:
                segmented_xs.append(phrase_xs_sorted)
                segmented_ys.append(phrase_ys_sorted)
            
            segmented_ts.append(phrase_ts_sorted)
        else:
            # 如果没有轨迹点，添加空列表
            print(f"  未找到轨迹点")
            segmented_xs.append([])
            segmented_ys.append([])
            segmented_ts.append([])
    
    print(f"\n将轨迹分割成 {len(segmented_xs)} 段，对应 {len(phrases)} 个短语")
    
    return segmented_xs, segmented_ys, segmented_ts


def trace_segment_by_phrases_position_based(trace_xs, trace_ys, trace_ts, phrases, original_tokens, time_begins, time_ends, use_douglas_peucker=False, dp_epsilon=0.01):
    """
    基于文本位置的短语轨迹分割函数
    通过文本位置匹配将original_tokens映射到phrases
    """
    if not trace_ts or not phrases:
        return [], [], []
    
    # 预处理函数：清理文本，移除标点符号并转为小写
    def clean_text(text):
        return re.sub(r'[^\w\s]', '', text.lower()).strip()
    # 将所有原始tokens连接成完整文本
    full_original_text = ' '.join(original_tokens)
    clean_full_text = clean_text(full_original_text)
    
    print(f"完整原始文本: {full_original_text}")
    print(f"清理后文本: {clean_full_text}")
    
    segmented_xs, segmented_ys, segmented_ts = [], [], []
    
    # 为每个短语找到对应的token
    for phrase_idx, phrase in enumerate(phrases):
        print(f"\n处理短语 {phrase_idx + 1}: '{phrase}'")
        
        clean_phrase = clean_text(phrase)
        print(f"清理后短语: '{clean_phrase}'")
        
        # 在完整文本中找到这个短语的位置
        phrase_start_pos = clean_full_text.find(clean_phrase)
        
        if phrase_start_pos == -1:
            # 如果找不到精确匹配，使用模糊匹配
            print(f"未找到精确匹配，尝试模糊匹配...")
            best_ratio = 0
            best_start = -1
            best_end = -1
            
            # 在文本中滑动窗口寻找最佳匹配
            for i in range(len(clean_full_text) - len(clean_phrase) + 1):
                substring = clean_full_text[i:i + len(clean_phrase)]
                ratio = SequenceMatcher(None, clean_phrase, substring).ratio()
                if ratio > best_ratio and ratio > 0.6:  # 相似度阈值
                    best_ratio = ratio
                    best_start = i
                    best_end = i + len(clean_phrase)
            
            if best_start != -1:
                phrase_start_pos = best_start
                phrase_end_pos = best_end
                print(f"模糊匹配成功，相似度: {best_ratio:.2f}")
            else:
                print(f"未找到匹配，跳过该短语")
                segmented_xs.append([])
                segmented_ys.append([])
                segmented_ts.append([])
                continue
        else:
            phrase_end_pos = phrase_start_pos + len(clean_phrase)
            print(f"精确匹配成功")
        
        print(f"短语在文本中的位置: {phrase_start_pos}-{phrase_end_pos}")
        
        # 找到这个位置范围内包含的所有tokens
        current_pos = 0
        matched_token_indices = []
        
        for token_idx, token in enumerate(original_tokens):
            clean_token = clean_text(token)
            token_start_pos = current_pos
            token_end_pos = current_pos + len(clean_token)
            
            # 检查token是否与短语有重叠
            if not (token_end_pos <= phrase_start_pos or token_start_pos >= phrase_end_pos):
                matched_token_indices.append(token_idx)
                print(f"  匹配token {token_idx}: '{token}' (位置: {token_start_pos}-{token_end_pos})")
            
            # 更新位置（加上空格）
            current_pos = token_end_pos + 1  # +1 for space
        
        # 收集匹配tokens对应的所有轨迹点
        phrase_xs, phrase_ys, phrase_ts = [], [], []
        
        for token_idx in matched_token_indices:
            if token_idx < len(time_begins) and token_idx < len(time_ends):
                token_start_time = time_begins[token_idx]
                token_end_time = time_ends[token_idx]
                
                # 找到这个token时间范围内的所有轨迹点
                trace_arr = np.array(trace_ts)
                indices = np.where((trace_arr >= token_start_time) & (trace_arr <= token_end_time))[0]
                
                for idx in indices:
                    phrase_xs.append(trace_xs[idx])
                    phrase_ys.append(trace_ys[idx])
                    phrase_ts.append(trace_ts[idx])
        
        print(f"  总共收集到 {len(phrase_xs)} 个轨迹点")
        
        if phrase_xs:
            # 按时间排序
            sorted_indices = sorted(range(len(phrase_ts)), key=lambda i: phrase_ts[i])
            phrase_xs = [phrase_xs[i] for i in sorted_indices]
            phrase_ys = [phrase_ys[i] for i in sorted_indices]
            phrase_ts = [phrase_ts[i] for i in sorted_indices]
            
            # 应用Douglas-Peucker简化（如果需要）
            if use_douglas_peucker and len(phrase_xs) > 2:
                points = np.array(list(zip(phrase_xs, phrase_ys)))
                simplified_points = rdp(points, epsilon=dp_epsilon)
                segmented_xs.append(simplified_points[:, 0].tolist())
                segmented_ys.append(simplified_points[:, 1].tolist())
            else:
                segmented_xs.append(phrase_xs)
                segmented_ys.append(phrase_ys)
            
            segmented_ts.append(phrase_ts)
        else:
            # 如果没有轨迹点，添加空列表
            print(f"  未找到轨迹点")
            segmented_xs.append([])
            segmented_ys.append([])
            segmented_ts.append([])
    
    print(f"\n将轨迹分割成 {len(segmented_xs)} 段，对应 {len(phrases)} 个短语")
    
    return segmented_xs, segmented_ys, segmented_ts


def traces_to_bboxs(xs_list, ys_list):
    """Convert each trace of x and y coordinates to bounding boxes."""
    xmins, xmaxs, ymins, ymaxs = [], [], [], []
    for xs, ys in zip(xs_list, ys_list):
        if not xs:  # Handle empty segments
            xmin, xmax, ymin, ymax = 0.0, 1.0, 0.0, 1.0 # Default to full box for empty
        else:
            xmin = min(xs)
            xmax = max(xs)
            ymin = min(ys)
            ymax = max(ys)
        xmins.append(xmin)
        xmaxs.append(xmax)
        ymins.append(ymin)
        ymaxs.append(ymax)
    return xmins, xmaxs, ymins, ymaxs


def process_json_improved(json_anno, trace_segmentation_method, image_base_path, visualize=False, use_douglas_peucker=False, dp_epsilon=0.01, output_dir="."):
    """
    改进的处理函数，使用新的短语分割和token匹配方法
    """
    image_id = json_anno['image_id']
    dataset_id = json_anno['dataset_id']
    full_caption = json_anno['caption']
    
    method_title = trace_segmentation_method
    if use_douglas_peucker:
        method_title += " + Douglas-Peucker"
    print(f"\n--- Processing Image: {image_id} using '{method_title}' method ---")
    
    xs, ys, ts, transcription, time_begins, time_ends = get_json_anno_external(json_anno)

    if not xs:
        print("Error: Empty or invalid trace data.")
        return

    # Segment traces into chunks based on the chosen method
    tokens_for_vis = None
    
    if trace_segmentation_method == 'uni_len_global':
        uniform_trace_segmentation_time_interval = 0.4
        segmented_xs, segmented_ys, _ = \
            trace_segment_uniform_time_interval(
                xs, ys, ts, uniform_trace_segmentation_time_interval,
                use_douglas_peucker=use_douglas_peucker, dp_epsilon=dp_epsilon)
        tokens_for_vis = transcription
    elif trace_segmentation_method == 'timestamp':
        segmented_xs, segmented_ys, _ = \
            trace_segment_timestamp(xs, ys, ts, time_begins, time_ends,
                                    use_douglas_peucker=use_douglas_peucker, dp_epsilon=dp_epsilon)
        tokens_for_vis = transcription
    elif trace_segmentation_method == 'phrases_word_level':
        # 使用基于词级别匹配的短语分割功能
        phrases = split_caption_into_phrases(full_caption)
        print(f"拆分出 {len(phrases)} 个短语:")
        for i, phrase in enumerate(phrases, 1):
            print(f"  {i:2d}. {phrase}")
        
        # 根据短语和原始token的时间戳来分割轨迹
        segmented_xs, segmented_ys, _ = \
            trace_segment_by_phrases_word_level(xs, ys, ts, phrases, transcription, time_begins, time_ends,
                                   use_douglas_peucker=use_douglas_peucker, dp_epsilon=dp_epsilon)
        tokens_for_vis = phrases
    elif trace_segmentation_method == 'phrases_position':
        # 使用基于位置的短语分割功能
        phrases = split_caption_into_phrases(full_caption)
        print(f"拆分出 {len(phrases)} 个短语:")
        for i, phrase in enumerate(phrases, 1):
            print(f"  {i:2d}. {phrase}")
        
        # 根据短语和原始token的时间戳来分割轨迹
        segmented_xs, segmented_ys, _ = \
            trace_segment_by_phrases_position_based(xs, ys, ts, phrases, transcription, time_begins, time_ends,
                                   use_douglas_peucker=use_douglas_peucker, dp_epsilon=dp_epsilon)
        tokens_for_vis = phrases
    else:
        raise ValueError("Unknown trace_segmentation_method")

    # Convert segmented trace chunks into bounding boxes
    xmins, xmaxs, ymins, ymaxs = traces_to_bboxs(segmented_xs, segmented_ys)
    
    print(f"Generated {len(xmins)} bounding boxes.")
    
    # For visualization
    viz_segmented_xs = segmented_xs
    viz_segmented_ys = segmented_ys
    output_method_name = trace_segmentation_method
    if use_douglas_peucker:
        print(f"Applying Douglas-Peucker simplification with epsilon={dp_epsilon}.")
        output_method_name += "_DP"
    
    # Trigger visualization if requested
    if visualize:
        visualize_all_boxes(
            image_id, viz_segmented_xs, viz_segmented_ys, xmins, xmaxs, ymins, ymaxs, 
            tokens_for_vis, output_method_name,
            image_base_path, dataset_id, full_caption, output_dir=output_dir)

    return phrases if 'phrases' in trace_segmentation_method else tokens_for_vis


def test_improved_phrase_segmentation():
    """
    测试改进的短语分割功能
    """
    # 测试用例
    test_captions = [
        "In this image there are two persons sitting on the vehicle, and there are cardboard boxes, tire, bicycle in the vehicle, and at the background there is sky.",
        "In this picture there is a room in which we can observe a sofa set and a table in front of the sofa set. There is a TV fixed to the wall which is in cream color.",
        "A woman is sitting on a chair and holding a mobile phone. Behind her there is a fencing."
    ]
    
    print("=== 测试改进的短语分割功能 ===\n")
    
    for i, caption in enumerate(test_captions, 1):
        print(f"测试用例 {i}:")
        print(f"原始caption: {caption}")
        
        phrases = split_caption_into_phrases_improved(caption)
        
        print(f"拆分后的短语 ({len(phrases)} 个):")
        for j, phrase in enumerate(phrases, 1):
            print(f"  {j:2d}. {phrase}")
        print()


if __name__ == '__main__':
    # 定义路径
    COCO_IMAGE_PATH = '/storage-root/datasets/yangfan/coco2017'
    JSONL_PATH = '/storage-root/datasets/yangfan/Seg_LLaVA_v2/datasets/Localized_Narratives/coco_train_localized_narratives-00000-of-00004.jsonl'

    # 创建输出目录
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_base_dir = f"output_visualizations_improved_{timestamp}"
    os.makedirs(output_base_dir, exist_ok=True)
    print(f"Saving visualizations to: {os.path.abspath(output_base_dir)}")

    # 处理指定索引
    for selected_index in range(1, 201):
        try:
            # 加载数据
            COCO_VAL_JSON_DICT = load_jsonl_by_index(JSONL_PATH, selected_index)
            
            # 创建子目录
            current_output_dir = os.path.join(output_base_dir, f"index_{selected_index}")
            os.makedirs(current_output_dir, exist_ok=True)
            
            print(f"\n=== 测试词级别匹配方法 ===")
            process_json_improved(COCO_VAL_JSON_DICT, 
                         trace_segmentation_method='phrases_word_level',
                         image_base_path=COCO_IMAGE_PATH,
                         visualize=True,
                         use_douglas_peucker=False,
                         dp_epsilon=0.01,
                         output_dir=os.path.join(current_output_dir, "word_level"))
            
            print(f"\n=== 测试位置匹配方法 ===")
            process_json_improved(COCO_VAL_JSON_DICT, 
                         trace_segmentation_method='phrases_position',
                         image_base_path=COCO_IMAGE_PATH,
                         visualize=True,
                         use_douglas_peucker=False,
                         dp_epsilon=0.01,
                         output_dir=os.path.join(current_output_dir, "position_based"))
            
            print(f"\n=== 对比：原始时间戳方法 ===")
            process_json_improved(COCO_VAL_JSON_DICT, 
                         trace_segmentation_method='timestamp',
                         image_base_path=COCO_IMAGE_PATH,
                         visualize=True,
                         use_douglas_peucker=False,
                         dp_epsilon=0.01,
                         output_dir=os.path.join(current_output_dir, "timestamp"))
                         
        except IndexError as e:
            print(f"Skipping index {selected_index}: {e}")
        except Exception as e:
            print(f"An error occurred while processing index {selected_index}: {e}")
            import traceback
            traceback.print_exc()
        
    # 运行短语分割测试
    print("\n" + "="*50)
    test_improved_phrase_segmentation()