# -*- coding: utf-8 -*-
"""
基于小目标的轨迹构建脚本
基于dp_visual_chaifenjuzi_qwenyanzheng.py，专门用于构建小目标的轨迹数据
数据来源：Localized Narratives的四个文件
筛选条件：只保留在small_targets_details.json中存在的图像
输出：包含phrases和对应的qwen2.5-vl纠正坐标点的JSON文件
"""

import json
import collections
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import os
from rdp import rdp
import re
from difflib import SequenceMatcher
import datetime
import torch
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
from tqdm import tqdm

# 尝试导入正确的Qwen模型类
try:
    from transformers import Qwen2_5_VLForConditionalGeneration
    QWEN_MODEL_CLASS = Qwen2_5_VLForConditionalGeneration
except ImportError:
    try:
        from transformers import Qwen2VLForConditionalGeneration
        QWEN_MODEL_CLASS = Qwen2VLForConditionalGeneration
    except ImportError:
        print("警告：无法导入Qwen2.5-VL模型类，请检查transformers版本")
        QWEN_MODEL_CLASS = None


class QwenVLLocationCorrector:
    """
    使用Qwen2.5-VL-72B模型进行定位矫正的类
    """
    
    def __init__(self, model_path="/storage-root/9950backfile/yangfan/coyo/Qwen/Qwen2.5-VL-72B-Instruct", num_gpus=4):
        """
        初始化Qwen2.5-VL模型
        """
        print(f"正在加载Qwen2.5-VL模型: {model_path}")
        print(f"使用 {num_gpus} 个GPU")
        
        # 设置多GPU环境变量
        os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(i) for i in range(num_gpus)])
        
        try:
            # 配置多GPU设备映射
            if num_gpus > 1:
                device_map = "auto"  # 让transformers自动分配
                print(f"使用自动设备映射分布到 {num_gpus} 个GPU")
            else:
                device_map = "cuda:0"
                print("使用单GPU模式")
            
            self.model = QWEN_MODEL_CLASS.from_pretrained(
                model_path, 
                torch_dtype=torch.bfloat16,
                device_map=device_map,
                attn_implementation="flash_attention_2",
            ).eval()
            
            self.processor = AutoProcessor.from_pretrained(
                model_path, 
                max_pixels=12845056,
                min_pixels=3136
            )
            
            print("Qwen2.5-VL模型加载完成")
            
        except Exception as e:
            print(f"模型加载失败: {e}")
            # 设置为None，后续处理时会跳过Qwen优化
            self.model = None
            self.processor = None
    
    def get_corrected_bbox(self, phrase, image_path, original_bbox=None):
        """
        使用Qwen2.5-VL获取矫正后的边界框
        """
        if self.model is None:
            print("Qwen模型未加载，返回默认边界框")
            return original_bbox if original_bbox else (0.0, 0.0, 1.0, 1.0)
        
        try:
            # 构建提示词
            original_info = ""
            if original_bbox:
                original_info = f"\n原始定位边界框: xmin={original_bbox[0]:.3f}, ymin={original_bbox[1]:.3f}, xmax={original_bbox[2]:.3f}, ymax={original_bbox[3]:.3f}"
            
            messages = [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "image": image_path,
                        },
                        {
                            "type": "text",
                            "text": f"""请在图像中准确定位以下描述的内容，并给出归一化的边界框坐标。

描述内容: "{phrase}"{original_info}

请返回最准确的边界框坐标，格式为: xmin,ymin,xmax,ymax
坐标应该是0-1之间的归一化值。
只需要返回坐标数字，用逗号分隔，不需要其他文字。"""
                        }
                    ]
                }
            ]
            
            # 准备输入
            text = self.processor.apply_chat_template(
                [messages], tokenize=False, add_generation_prompt=True
            )[0]
            
            image_inputs, video_inputs = process_vision_info([messages])
            
            inputs = self.processor(
                text=[text],
                images=image_inputs,
                videos=video_inputs,
                padding=True,
                return_tensors="pt",
            )
            inputs = inputs.to(self.model.device)
            
            # 生成回复
            outputs = self.model.generate(
                **inputs, 
                max_new_tokens=100,
                temperature=0.0,
                do_sample=False,
            )
            
            generated_ids_trimmed = [
                out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, outputs)
            ]
            
            output_text = self.processor.batch_decode(
                generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
            )[0]
            
            # 解析坐标
            try:
                numbers = re.findall(r'0\.\d+|1\.0|1|0', output_text)
                if len(numbers) >= 4:
                    xmin = float(numbers[0])
                    ymin = float(numbers[1])
                    xmax = float(numbers[2])
                    ymax = float(numbers[3])
                    
                    # 确保坐标在合理范围内
                    xmin = max(0.0, min(1.0, xmin))
                    ymin = max(0.0, min(1.0, ymin))
                    xmax = max(0.0, min(1.0, xmax))
                    ymax = max(0.0, min(1.0, ymax))
                    
                    # 确保最小值小于最大值
                    if xmin >= xmax:
                        xmin, xmax = 0.0, 1.0
                    if ymin >= ymax:
                        ymin, ymax = 0.0, 1.0
                        
                    return (xmin, ymin, xmax, ymax)
                else:
                    print(f"无法从回复中提取足够的坐标: {output_text}")
                    return original_bbox if original_bbox else (0.0, 0.0, 1.0, 1.0)
            except:
                print(f"坐标解析失败: {output_text}")
                return original_bbox if original_bbox else (0.0, 0.0, 1.0, 1.0)
                
        except Exception as e:
            print(f"边界框矫正失败: {str(e)}")
            return original_bbox if original_bbox else (0.0, 0.0, 1.0, 1.0)


def load_small_targets_data(small_targets_json_path):
    """
    加载小目标数据，返回包含小目标的图像ID集合
    """
    print(f"加载小目标数据: {small_targets_json_path}")
    
    with open(small_targets_json_path, 'r', encoding='utf-8') as f:
        small_targets_data = json.load(f)
    
    # 提取所有包含小目标的图像ID
    small_target_image_ids = set()
    for target in small_targets_data['small_targets']:
        small_target_image_ids.add(target['image_id'])
    
    print(f"共找到 {len(small_target_image_ids)} 个包含小目标的图像")
    print(f"小目标总数: {small_targets_data['metadata']['total_small_targets']}")
    
    return small_target_image_ids, small_targets_data


def load_ln_jsonl_files(ln_base_path, small_target_image_ids):
    """
    加载LN的jsonl文件，只保留包含小目标的图像
    """
    ln_files = [
        'coco_train_localized_narratives-00000-of-00004.jsonl',
        'coco_train_localized_narratives-00001-of-00004.jsonl', 
        'coco_train_localized_narratives-00002-of-00004.jsonl',
        'coco_train_localized_narratives-00003-of-00004.jsonl'
    ]
    
    filtered_ln_data = []
    total_processed = 0
    
    for file_name in ln_files:
        file_path = os.path.join(ln_base_path, file_name)
        if os.path.exists(file_path):
            print(f"处理LN文件: {file_path}")
            with open(file_path, 'r', encoding='utf-8') as f:
                for line_num, line in enumerate(f, 1):
                    try:
                        data = json.loads(line.strip())
                        total_processed += 1
                        
                        # 检查图像ID是否在小目标列表中
                        image_id = int(data['image_id'])
                        if image_id in small_target_image_ids:
                            filtered_ln_data.append(data)
                            
                    except json.JSONDecodeError as e:
                        print(f"Error parsing line {line_num} in {file_name}: {e}")
        else:
            print(f"Warning: File not found: {file_path}")
    
    print(f"总共处理 {total_processed} 条LN数据")
    print(f"筛选出包含小目标的LN数据: {len(filtered_ln_data)} 条")
    
    return filtered_ln_data


def get_json_anno_external(json_anno):
    """从json dict中获取转录和轨迹信息"""
    traces = json_anno['traces']
    all_traces = [point for trace in traces for point in trace]
    
    if not all_traces:
        return None, None, None, None, None, None

    xs = [max(0.0, min(1.0, p['x'])) for p in all_traces]
    ys = [max(0.0, min(1.0, p['y'])) for p in all_traces]
    ts = [p['t'] for p in all_traces]

    timed_caption = json_anno['timed_caption']
    toks = [tcap['utterance'] for tcap in timed_caption]
    time_begins = [tcap['start_time'] for tcap in timed_caption]
    time_ends = [tcap['end_time'] for tcap in timed_caption]
    
    return xs, ys, ts, toks, time_begins, time_ends


def split_caption_into_phrases(caption):
    """
    将完整的caption拆分成短语
    """
    # 预处理：清理多余的空格
    caption = re.sub(r'\s+', ' ', caption.strip())
    
    phrases = []
    
    # 第一步：按强分割符分割（句号、感叹号、问号）
    sentences = re.split(r'[.!?]+', caption)
    
    for sentence in sentences:
        if not sentence.strip():
            continue
            
        sentence = sentence.strip()
        
        # 第二步：识别并分割介词短语和从句
        split_patterns = [
            r'(\s+and\s+)',           # "and" 连接
            r'(\s*,\s*and\s+)',       # ", and" 连接  
            r'(\s*,\s*)',             # 逗号分隔
            r'(\s+which\s+)',         # "which" 从句
            r'(\s+that\s+)',          # "that" 从句
            r'(\s+where\s+)',         # "where" 从句
            r'(\s+while\s+)',         # "while" 从句
            r'(\s+in\s+the\s+)',      # "in the" 介词短语
            r'(\s+on\s+the\s+)',      # "on the" 介词短语
            r'(\s+at\s+the\s+)',      # "at the" 介词短语
            r'(\s+behind\s+)',        # "behind" 介词
            r'(\s+there\s+is\s+)',    # "there is" 存在句
            r'(\s+there\s+are\s+)',   # "there are" 存在句
        ]
        
        # 逐步应用分割规则
        current_parts = [sentence]
        
        for pattern in split_patterns:
            new_parts = []
            for part in current_parts:
                # 使用分组捕获来保留分割符
                segments = re.split(pattern, part, flags=re.IGNORECASE)
                
                i = 0
                while i < len(segments):
                    if i == 0:
                        # 第一个片段
                        if segments[i].strip():
                            new_parts.append(segments[i].strip())
                    else:
                        # 处理分割符和后续内容
                        if i + 1 < len(segments):
                            separator = segments[i].strip()
                            next_part = segments[i + 1].strip()
                            
                            if next_part:
                                # 将分割符与后续内容合并
                                combined = separator + " " + next_part if separator else next_part
                                new_parts.append(combined.strip())
                            i += 1  # 跳过下一个片段，因为已经处理了
                    i += 1
            
            current_parts = new_parts
        
        # 第三步：后处理 - 合并过短的片段
        final_parts = []
        for part in current_parts:
            part = part.strip()
            if not part:
                continue
                
            # 如果片段太短且前面有片段，尝试合并
            words_in_part = len(part.split())
            if words_in_part < 3 and final_parts:
                # 检查是否可以合并
                last_part = final_parts[-1]
                combined_length = len(last_part.split()) + words_in_part
                
                if combined_length <= 8:  # 避免合并后过长
                    final_parts[-1] = last_part + ", " + part
                else:
                    final_parts.append(part)
            else:
                final_parts.append(part)
        
        phrases.extend(final_parts)
    
    # 最终清理：确保每个短语都有意义
    cleaned_phrases = []
    for phrase in phrases:
        phrase = phrase.strip()
        # 移除只包含连接词的短语
        if phrase and len(phrase.split()) >= 2:
            # 清理开头的连接词
            phrase = re.sub(r'^(and|or|but)\s+', '', phrase, flags=re.IGNORECASE)
            if phrase:
                cleaned_phrases.append(phrase)
    
    return cleaned_phrases


def trace_segment_by_phrases_position_based(trace_xs, trace_ys, trace_ts, phrases, original_tokens, time_begins, time_ends):
    """
    基于文本位置的短语轨迹分割函数
    """
    if not trace_ts or not phrases:
        return [], [], []
    
    # 预处理函数：清理文本，移除标点符号并转为小写
    def clean_text(text):
        return re.sub(r'[^\w\s]', '', text.lower()).strip()
    
    # 将所有原始tokens连接成完整文本
    full_original_text = ' '.join(original_tokens)
    clean_full_text = clean_text(full_original_text)
    
    segmented_xs, segmented_ys, segmented_ts = [], [], []
    
    # 为每个短语找到对应的token
    for phrase_idx, phrase in enumerate(phrases):
        clean_phrase = clean_text(phrase)
        
        # 在完整文本中找到这个短语的位置
        phrase_start_pos = clean_full_text.find(clean_phrase)
        
        if phrase_start_pos == -1:
            # 如果找不到精确匹配，使用模糊匹配
            best_ratio = 0
            best_start = -1
            best_end = -1
            
            # 在文本中滑动窗口寻找最佳匹配
            for i in range(len(clean_full_text) - len(clean_phrase) + 1):
                substring = clean_full_text[i:i + len(clean_phrase)]
                ratio = SequenceMatcher(None, clean_phrase, substring).ratio()
                if ratio > best_ratio and ratio > 0.6:  # 相似度阈值
                    best_ratio = ratio
                    best_start = i
                    best_end = i + len(clean_phrase)
            
            if best_start != -1:
                phrase_start_pos = best_start
                phrase_end_pos = best_end
            else:
                segmented_xs.append([])
                segmented_ys.append([])
                segmented_ts.append([])
                continue
        else:
            phrase_end_pos = phrase_start_pos + len(clean_phrase)
        
        # 找到这个位置范围内包含的所有tokens
        current_pos = 0
        matched_token_indices = []
        
        for token_idx, token in enumerate(original_tokens):
            clean_token = clean_text(token)
            token_start_pos = current_pos
            token_end_pos = current_pos + len(clean_token)
            
            # 检查token是否与短语有重叠
            if not (token_end_pos <= phrase_start_pos or token_start_pos >= phrase_end_pos):
                matched_token_indices.append(token_idx)
            
            # 更新位置（加上空格）
            current_pos = token_end_pos + 1  # +1 for space
        
        # 收集匹配tokens对应的所有轨迹点
        phrase_xs, phrase_ys, phrase_ts = [], [], []
        
        for token_idx in matched_token_indices:
            if token_idx < len(time_begins) and token_idx < len(time_ends):
                token_start_time = time_begins[token_idx]
                token_end_time = time_ends[token_idx]
                
                # 找到这个token时间范围内的所有轨迹点
                trace_arr = np.array(trace_ts)
                indices = np.where((trace_arr >= token_start_time) & (trace_arr <= token_end_time))[0]
                
                for idx in indices:
                    phrase_xs.append(trace_xs[idx])
                    phrase_ys.append(trace_ys[idx])
                    phrase_ts.append(trace_ts[idx])
        
        if phrase_xs:
            # 按时间排序
            sorted_indices = sorted(range(len(phrase_ts)), key=lambda i: phrase_ts[i])
            phrase_xs = [phrase_xs[i] for i in sorted_indices]
            phrase_ys = [phrase_ys[i] for i in sorted_indices]
            phrase_ts = [phrase_ts[i] for i in sorted_indices]
            
            segmented_xs.append(phrase_xs)
            segmented_ys.append(phrase_ys)
            segmented_ts.append(phrase_ts)
        else:
            # 如果没有轨迹点，添加空列表
            segmented_xs.append([])
            segmented_ys.append([])
            segmented_ts.append([])
    
    return segmented_xs, segmented_ys, segmented_ts


def traces_to_bboxs(xs_list, ys_list):
    """将轨迹坐标转换为边界框"""
    xmins, xmaxs, ymins, ymaxs = [], [], [], []
    for xs, ys in zip(xs_list, ys_list):
        if not xs:  # 处理空片段
            xmin, xmax, ymin, ymax = 0.0, 1.0, 0.0, 1.0 # 空片段默认为全图
        else:
            xmin = min(xs)
            xmax = max(xs)
            ymin = min(ys)
            ymax = max(ys)
        xmins.append(xmin)
        xmaxs.append(xmax)
        ymins.append(ymin)
        ymaxs.append(ymax)
    return xmins, xmaxs, ymins, ymaxs


def process_single_ln_data(json_anno, image_base_path, qwen_optimizer=None):
    """
    处理单个LN数据，提取短语和对应的坐标
    """
    image_id = json_anno['image_id']
    dataset_id = json_anno['dataset_id']
    full_caption = json_anno['caption']
    
    # 提取轨迹数据
    xs, ys, ts, transcription, time_begins, time_ends = get_json_anno_external(json_anno)
    
    if not xs:
        return None
    
    # 构建图像路径
    split_name = dataset_id.split('_')[1]
    image_filename = f"{int(image_id):012d}.jpg"
    full_image_path = os.path.join(image_base_path, split_name, image_filename)
    
    # 检查图像是否存在
    if not os.path.exists(full_image_path):
        print(f"图像不存在: {full_image_path}")
        return None
    
    # 分割成短语
    phrases = split_caption_into_phrases(full_caption)
    
    if not phrases:
        return None
    
    # 基于短语分割轨迹
    segmented_xs, segmented_ys, segmented_ts = trace_segment_by_phrases_position_based(
        xs, ys, ts, phrases, transcription, time_begins, time_ends
    )
    
    # 转换为边界框
    xmins, xmaxs, ymins, ymaxs = traces_to_bboxs(segmented_xs, segmented_ys)
    
    # 准备结果数据
    result = {
        "image_id": image_id,
        "image_filename": image_filename,
        "image_path": full_image_path,
        "full_caption": full_caption,
        "phrases_data": []
    }
    
    # 处理每个短语
    for i, phrase in enumerate(phrases):
        if i < len(xmins):
            original_bbox = (xmins[i], ymins[i], xmaxs[i], ymaxs[i])
            
            # 使用Qwen2.5-VL进行坐标矫正（如果可用）
            corrected_bbox = original_bbox
            if qwen_optimizer is not None:
                try:
                    corrected_bbox = qwen_optimizer.get_corrected_bbox(
                        phrase, full_image_path, original_bbox
                    )
                except Exception as e:
                    print(f"Qwen矫正失败 (image_id={image_id}, phrase='{phrase}'): {e}")
                    corrected_bbox = original_bbox
            
            phrase_data = {
                "phrase": phrase,
                "original_bbox": original_bbox,  # (xmin, ymin, xmax, ymax) 归一化坐标
                "corrected_bbox": corrected_bbox,  # Qwen矫正后的坐标
                "trace_points": list(zip(segmented_xs[i], segmented_ys[i])) if segmented_xs[i] else []
            }
            
            result["phrases_data"].append(phrase_data)
    
    return result


def build_small_targets_trajectory_dataset(ln_base_path, small_targets_json_path, 
                                         coco_images_path, output_path, 
                                         use_qwen_correction=True, qwen_model_path=None,
                                         max_samples=None):
    """
    构建小目标轨迹数据集
    
    Args:
        ln_base_path: Localized Narratives数据路径
        small_targets_json_path: 小目标JSON文件路径
        coco_images_path: COCO图像路径
        output_path: 输出JSON文件路径
        use_qwen_correction: 是否使用Qwen2.5-VL进行坐标矫正
        qwen_model_path: Qwen模型路径
        max_samples: 最大处理样本数（None表示处理所有）
    """
    
    print("="*60)
    print("开始构建小目标轨迹数据集")
    print("="*60)
    
    # 1. 加载小目标数据
    small_target_image_ids, small_targets_data = load_small_targets_data(small_targets_json_path)
    
    # 2. 加载并筛选LN数据
    filtered_ln_data = load_ln_jsonl_files(ln_base_path, small_target_image_ids)
    
    if max_samples:
        filtered_ln_data = filtered_ln_data[:max_samples]
        print(f"限制处理样本数: {max_samples}")
    
    # 3. 初始化Qwen优化器（如果需要）
    qwen_optimizer = None
    if use_qwen_correction and qwen_model_path:
        print("初始化Qwen2.5-VL优化器...")
        try:
            qwen_optimizer = QwenVLLocationCorrector(qwen_model_path, num_gpus=4)
        except Exception as e:
            print(f"Qwen优化器初始化失败: {e}")
            print("将跳过Qwen坐标矫正")
    
    # 4. 处理每个LN数据
    trajectory_dataset = {
        "metadata": {
            "source_files": [
                "coco_train_localized_narratives-00000-of-00004.jsonl",
                "coco_train_localized_narratives-00001-of-00004.jsonl",
                "coco_train_localized_narratives-00002-of-00004.jsonl",
                "coco_train_localized_narratives-00003-of-00004.jsonl"
            ],
            "small_targets_source": small_targets_json_path,
            "total_samples": len(filtered_ln_data),
            "use_qwen_correction": use_qwen_correction and (qwen_optimizer is not None),
            "creation_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        },
        "data": []
    }
    
    print(f"开始处理 {len(filtered_ln_data)} 个样本...")
    
    successful_count = 0
    failed_count = 0
    
    for i, ln_data in enumerate(tqdm(filtered_ln_data, desc="处理轨迹数据")):
        try:
            result = process_single_ln_data(ln_data, coco_images_path, qwen_optimizer)
            
            if result is not None:
                trajectory_dataset["data"].append(result)
                successful_count += 1
            else:
                failed_count += 1
                
        except Exception as e:
            print(f"处理第 {i+1} 个样本时出错: {e}")
            failed_count += 1
            continue
    
    # 5. 保存结果
    print(f"\n处理完成:")
    print(f"  成功: {successful_count}")
    print(f"  失败: {failed_count}")
    print(f"  总计: {len(filtered_ln_data)}")
    
    # 更新元数据
    trajectory_dataset["metadata"]["successful_samples"] = successful_count
    trajectory_dataset["metadata"]["failed_samples"] = failed_count
    
    # 计算一些统计信息
    total_phrases = sum(len(item["phrases_data"]) for item in trajectory_dataset["data"])
    trajectory_dataset["metadata"]["total_phrases"] = total_phrases
    
    print(f"  总短语数: {total_phrases}")
    
    # 保存到JSON文件
    print(f"\n保存结果到: {output_path}")
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(trajectory_dataset, f, ensure_ascii=False, indent=2)
    
    print("数据集构建完成！")
    
    return trajectory_dataset


if __name__ == '__main__':
    # 配置路径
    LN_BASE_PATH = '/storage-root/datasets/yangfan/Seg_LLaVA_v2/datasets/Localized_Narratives'
    SMALL_TARGETS_JSON_PATH = '/public/yangfan/ICLR_AAAI/Trajectory-VLM/coco_bbox_analysis_20250824_232913/small_targets_details.json'
    COCO_IMAGES_PATH = '/storage-root/datasets/yangfan/coco2017'
    QWEN_MODEL_PATH = "/storage-root/9950backfile/yangfan/coyo/Qwen/Qwen2.5-VL-72B-Instruct"
    
    # 创建输出路径
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    output_path = f"small_targets_trajectory_dataset_{timestamp}.json"
    
    try:
        # 构建数据集
        trajectory_dataset = build_small_targets_trajectory_dataset(
            ln_base_path=LN_BASE_PATH,
            small_targets_json_path=SMALL_TARGETS_JSON_PATH,
            coco_images_path=COCO_IMAGES_PATH,
            output_path=output_path,
            use_qwen_correction=True,  # 启用Qwen2.5-VL坐标矫正
            qwen_model_path=QWEN_MODEL_PATH,
            max_samples=50  # 限制处理50个样本进行测试
        )
        
        print(f"\n数据集构建成功！")
        print(f"输出文件: {os.path.abspath(output_path)}")
        print(f"总样本数: {len(trajectory_dataset['data'])}")
        print(f"总短语数: {trajectory_dataset['metadata']['total_phrases']}")
        
    except Exception as e:
        print(f"构建数据集时出错: {e}")
        import traceback
        traceback.print_exc()