#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
重构后的小目标轨迹构建脚本（主orchestration模块）
拆分自build_small_targets_trajectory_with_vis.py，提高代码可维护性
"""

import json
import os
import datetime
from tqdm import tqdm

# 导入重构后的模块
try:
    from .data_loader import load_small_targets_data, load_ln_jsonl_files, get_json_anno_external
    from .trajectory_processor import split_caption_into_phrases, trace_segment_by_phrases_position_based, traces_to_bboxs
    from .visualizer import visualize_trajectory_data
    from .qwen_optimizer import process_json_with_qwen_optimization, initialize_qwen_optimizer
except ImportError:
    from data_loader import load_small_targets_data, load_ln_jsonl_files, get_json_anno_external
    from trajectory_processor import split_caption_into_phrases, trace_segment_by_phrases_position_based, traces_to_bboxs
    from visualizer import visualize_trajectory_data
    from qwen_optimizer import process_json_with_qwen_optimization, initialize_qwen_optimizer


def process_single_ln_data(json_anno, coco_images_path, output_dir, qwen_optimizer=None, 
                          visualize=True, use_qwen_optimization=False):
    """
    处理单个LN数据，提取短语和对应的坐标，并可视化
    支持使用Qwen2.5-VL优化定位精度
    
    Args:
        json_anno: JSON标注数据
        coco_images_path: COCO图像路径
        output_dir: 输出目录
        qwen_optimizer: Qwen优化器实例
        visualize: 是否生成可视化
        use_qwen_optimization: 是否使用Qwen优化
        
    Returns:
        dict: 处理结果字典
    """
    image_id = json_anno['image_id']
    dataset_id = json_anno['dataset_id']
    full_caption = json_anno['caption']
    
    # 提取轨迹数据
    xs, ys, ts, transcription, time_begins, time_ends = get_json_anno_external(json_anno)
    
    if not xs:
        return None
    
    # 构建图像路径
    split_name = dataset_id.split('_')[1]
    image_filename = f"{int(image_id):012d}.jpg"
    full_image_path = os.path.join(coco_images_path, split_name, image_filename)
    
    if use_qwen_optimization and qwen_optimizer is not None:
        # 使用Qwen2.5-VL优化的方法
        print(f"使用Qwen2.5-VL优化处理图像 {image_id}...")
        
        optimization_results = process_json_with_qwen_optimization(
            json_anno, 
            trace_segmentation_method='phrases_word_level',
            image_base_path=coco_images_path,
            qwen_optimizer=qwen_optimizer,
            quality_threshold=0.6,
            max_retries=0,
            visualize=False,  # 稍后单独处理可视化
            use_douglas_peucker=False,
            dp_epsilon=0.01,
            output_dir=output_dir
        )
        
        if optimization_results is None:
            return None
            
        phrases = optimization_results['tokens']
        optimized_bboxes = optimization_results['optimized_bboxes']
        segmented_xs = optimization_results['segmented_xs']
        segmented_ys = optimization_results['segmented_ys']
        
        # 从优化后的边界框中提取坐标
        xmins = [bbox[0] for bbox in optimized_bboxes]
        ymins = [bbox[1] for bbox in optimized_bboxes]
        xmaxs = [bbox[2] for bbox in optimized_bboxes]
        ymaxs = [bbox[3] for bbox in optimized_bboxes]
        
    else:
        # 使用原始方法
        # 分割成短语
        phrases = split_caption_into_phrases(full_caption)
        
        if not phrases:
            return None
        
        # 基于短语分割轨迹
        segmented_xs, segmented_ys, segmented_ts = trace_segment_by_phrases_position_based(
            xs, ys, ts, phrases, transcription, time_begins, time_ends
        )
        
        # 转换为边界框
        xmins, xmaxs, ymins, ymaxs = traces_to_bboxs(segmented_xs, segmented_ys)
    
    # 可视化（如果需要）
    visualization_files = None
    if visualize:
        visualization_files = visualize_trajectory_data(
            image_id, full_image_path, phrases, segmented_xs, segmented_ys,
            xmins, xmaxs, ymins, ymaxs, full_caption, output_dir
        )
    
    # 准备结果数据
    result = {
        "image_id": image_id,
        "image_filename": image_filename,
        "image_path": full_image_path,
        "full_caption": full_caption,
        "phrases_data": [],
        "visualization_files": visualization_files,
        "qwen_optimized": use_qwen_optimization
    }
    
    # 处理每个短语
    for i, phrase in enumerate(phrases):
        if i < len(xmins):
            bbox = (xmins[i], ymins[i], xmaxs[i], ymaxs[i])
            
            phrase_data = {
                "phrase": phrase,
                "bbox": bbox,  # (xmin, ymin, xmax, ymax) 归一化坐标
                "trace_points": list(zip(segmented_xs[i], segmented_ys[i])) if segmented_xs[i] else []
            }
            
            result["phrases_data"].append(phrase_data)
    
    return result


def build_small_targets_trajectory_dataset(ln_base_path, small_targets_json_path, 
                                         coco_images_path, output_path, 
                                         max_samples=None, visualize=True, 
                                         use_qwen_optimization=False, qwen_model_path=None,
                                         file_chunks=None):
    """
    构建小目标轨迹数据集（带可视化）
    支持使用Qwen2.5-VL优化定位精度
    
    Args:
        ln_base_path: Localized Narratives数据路径
        small_targets_json_path: 小目标JSON文件路径
        coco_images_path: COCO图像路径
        output_path: 输出JSON文件路径
        max_samples: 最大处理样本数
        visualize: 是否生成可视化
        use_qwen_optimization: 是否使用Qwen优化
        qwen_model_path: Qwen模型路径
        file_chunks: 要处理的文件块列表，如[0], [1], [2], [3]，None表示处理所有文件
        
    Returns:
        dict: 构建的数据集字典
    """
    
    print("="*60)
    print("开始构建小目标轨迹数据集（重构版本）")
    if use_qwen_optimization:
        print("启用Qwen2.5-VL优化")
    print("="*60)
    
    # 初始化Qwen优化器（如果需要）
    qwen_optimizer = None
    if use_qwen_optimization and qwen_model_path:
        qwen_optimizer = initialize_qwen_optimizer(qwen_model_path, num_gpus=4)
        if qwen_optimizer is None:
            print("将使用原始方法处理")
            use_qwen_optimization = False
    
    # 1. 加载小目标数据
    small_target_image_ids, small_targets_data = load_small_targets_data(small_targets_json_path)
    
    # 2. 加载并筛选LN数据
    filtered_ln_data = load_ln_jsonl_files(ln_base_path, small_target_image_ids, file_chunks)
    
    if max_samples:
        filtered_ln_data = filtered_ln_data[:max_samples]
        print(f"限制处理样本数: {max_samples}")
    
    # 3. 创建输出目录
    base_output_dir = os.path.splitext(output_path)[0] + "_visualizations"
    if visualize:
        os.makedirs(base_output_dir, exist_ok=True)
        print(f"可视化文件将保存到: {base_output_dir}")
    
    # 准备文件列表信息
    all_ln_files = [
        "coco_train_localized_narratives-00000-of-00004.jsonl",
        "coco_train_localized_narratives-00001-of-00004.jsonl",
        "coco_train_localized_narratives-00002-of-00004.jsonl",
        "coco_train_localized_narratives-00003-of-00004.jsonl"
    ]
    
    if file_chunks is None:
        used_files = all_ln_files
        file_chunks_info = "all"
    else:
        used_files = [all_ln_files[i] for i in file_chunks if 0 <= i < len(all_ln_files)]
        file_chunks_info = f"chunks_{'-'.join(map(str, file_chunks))}"
    
    # 4. 处理每个LN数据
    trajectory_dataset = {
        "metadata": {
            "source_files": used_files,
            "file_chunks": file_chunks,
            "file_chunks_info": file_chunks_info,
            "small_targets_source": small_targets_json_path,
            "total_samples": len(filtered_ln_data),
            "visualization_enabled": visualize,
            "visualization_directory": base_output_dir if visualize else None,
            "qwen_optimization_enabled": use_qwen_optimization,
            "creation_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "refactored_version": True
        },
        "data": []
    }
    
    print(f"开始处理 {len(filtered_ln_data)} 个样本...")
    
    successful_count = 0
    failed_count = 0
    
    for i, ln_data in enumerate(tqdm(filtered_ln_data, desc="处理轨迹数据")):
        # try:
            result = process_single_ln_data(
                ln_data, coco_images_path, base_output_dir, 
                qwen_optimizer=qwen_optimizer, 
                visualize=visualize, 
                use_qwen_optimization=use_qwen_optimization
            )
            
            if result is not None:
                trajectory_dataset["data"].append(result)
                successful_count += 1
            else:
                failed_count += 1
                
        # except Exception as e:
        #     print(f"处理第 {i+1} 个样本时出错: {e}")
        #     failed_count += 1
        #     continue
    
    # 5. 保存结果
    print(f"\n处理完成:")
    print(f"  成功: {successful_count}")
    print(f"  失败: {failed_count}")
    print(f"  总计: {len(filtered_ln_data)}")
    
    # 更新元数据
    trajectory_dataset["metadata"]["successful_samples"] = successful_count
    trajectory_dataset["metadata"]["failed_samples"] = failed_count
    
    # 计算一些统计信息
    total_phrases = sum(len(item["phrases_data"]) for item in trajectory_dataset["data"])
    trajectory_dataset["metadata"]["total_phrases"] = total_phrases
    
    print(f"  总短语数: {total_phrases}")
    
    # 保存到JSON文件
    print(f"\n保存结果到: {output_path}")
    os.makedirs(os.path.dirname(output_path) if os.path.dirname(output_path) else '.', exist_ok=True)
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(trajectory_dataset, f, ensure_ascii=False, indent=2)
    
    print("数据集构建完成！")
    
    return trajectory_dataset


def main():
    """主函数"""
    # 配置路径
    LN_BASE_PATH = '/storage-root/datasets/yangfan/Seg_LLaVA_v2/datasets/Localized_Narratives'
    SMALL_TARGETS_JSON_PATH = '/storage-root/datasets/yangfan/ICLR_AAAI/Trajectory-VLM/coco_bbox_analysis_20250824_232913/small_targets_details.json'
    # SMALL_TARGETS_JSON_PATH = '/storage-root/datasets/yangfan/ICLR_AAAI/Trajectory-VLM/processed_coco_phrases/processed_phrases_train_chunk0_20250901_132607.json'
    COCO_IMAGES_PATH = '/storage-root/datasets/yangfan/coco2017'
    
    # 配置要处理的文件块 - 在这里修改以适应不同的节点
    # 可选配置:
    # NODE_0_CHUNKS = [0]    # 处理第1个文件 (00000)
    # NODE_1_CHUNKS = [1]    # 处理第2个文件 (00001)  
    # NODE_2_CHUNKS = [2]    # 处理第3个文件 (00002)
    # NODE_3_CHUNKS = [3]    # 处理第4个文件 (00003)
    # ALL_CHUNKS = None      # 处理所有文件
    # MULTI_CHUNKS = [0,1]   # 处理多个文件
    
    FILE_CHUNKS = [0]  # 当前节点配置 - 修改这里来选择不同的文件块
    
    # 打印当前配置
    print("="*60)
    print("文件块处理配置:")
    if FILE_CHUNKS is None:
        print("  处理所有文件 (chunks 0,1,2,3)")
    else:
        chunk_names = [f"coco_train_localized_narratives-0000{i}-of-00004.jsonl" for i in FILE_CHUNKS]
        print(f"  处理文件块: {FILE_CHUNKS}")
        print(f"  对应文件: {chunk_names}")
    print("="*60)
    
    # 创建输出路径
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    if FILE_CHUNKS is None:
        chunk_suffix = "all"
    else:
        chunk_suffix = f"chunks_{'_'.join(map(str, FILE_CHUNKS))}"
    output_path = f"small_targets_trajectory_refactored_{chunk_suffix}_{timestamp}.json"
    
    try:
        # 构建数据集 - 支持Qwen2.5-VL优化
        QWEN_MODEL_PATH = "/storage-root/9950backfile/yangfan/coyo/Qwen/Qwen2.5-VL-72B-Instruct"  # 可选
        
        trajectory_dataset = build_small_targets_trajectory_dataset(
            ln_base_path=LN_BASE_PATH,
            small_targets_json_path=SMALL_TARGETS_JSON_PATH,
            coco_images_path=COCO_IMAGES_PATH,
            output_path=output_path,
            max_samples=None,  # 限制处理20个样本进行测试
            visualize=True,   # 启用可视化
            use_qwen_optimization=True,  # 设置为True启用Qwen优化
            qwen_model_path=QWEN_MODEL_PATH if os.path.exists(QWEN_MODEL_PATH) else None,
            file_chunks=FILE_CHUNKS
        )
        
        print(f"\n数据集构建成功！")
        print(f"JSON文件: {os.path.abspath(output_path)}")
        print(f"可视化目录: {os.path.abspath(os.path.splitext(output_path)[0] + '_visualizations')}")
        print(f"总样本数: {len(trajectory_dataset['data'])}")
        print(f"总短语数: {trajectory_dataset['metadata']['total_phrases']}")
        
        # 打印一些示例数据
        if trajectory_dataset['data']:
            print(f"\n示例数据:")
            sample = trajectory_dataset['data'][0]
            print(f"图像ID: {sample['image_id']}")
            print(f"完整描述: {sample['full_caption']}")
            print(f"短语数量: {len(sample['phrases_data'])}")
            for i, phrase_data in enumerate(sample['phrases_data'][:3]):  # 只显示前3个
                print(f"  短语{i+1}: {phrase_data['phrase']}")
                print(f"    边界框: {phrase_data['bbox']}")
                print(f"    轨迹点数: {len(phrase_data['trace_points'])}")
            
            if sample['visualization_files']:
                print(f"可视化文件:")
                for key, file_path in sample['visualization_files'].items():
                    if file_path:
                        print(f"  {key}: {os.path.basename(file_path)}")
        
    except Exception as e:
        print(f"构建数据集时出错: {e}")
        import traceback
        traceback.print_exc()


if __name__ == '__main__':
    main()