import argparse
import json
import os
import random
import numpy as np
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

from tqdm import tqdm

from pointcloud_validation.lib_render import (
    ProjectionConfig,
    PointCloudProjector,
    CameraRigSpec,
    NovelPoseScheduler,
    PointCloudProcessor,
    ObjectInfoProcessor,
)
from pointcloud_validation.lib_render.object_info import InsertionInfoProvider
from pointcloud_validation.lib_render.clip_config import ClipConfig, create_clip_config
from pointcloud_validation.lib_render.utils import POINT_CLOUD_STORAGE_FORMAT, VideoSaver, ConfigSaver


def set_global_seed(seed: int) -> None:
    """
    统一设置随机数种子，确保单进程与 Ray 模式下结果一致。
    """
    os.environ["PYTHONHASHSEED"] = str(seed)
    try:
        import tensorflow as tf
        tf.random.set_seed(seed)
        # 提高确定性（如可用）
        os.environ.setdefault("TF_DETERMINISTIC_OPS", "1")
    except Exception:
        pass
    random.seed(seed)
    np.random.seed(seed)

def resolve_files(root: Path, clip_id: str, frame_idx: int) -> Tuple[Optional[Path], Optional[str]]:
    dir = root / clip_id
    if POINT_CLOUD_STORAGE_FORMAT in ["npz", "npz_fp16", "npz_bf16"]:
        pc = dir / f"{frame_idx:06d}_merged.npz"
        fmt = 'npz'
    elif POINT_CLOUD_STORAGE_FORMAT == "bfloat16":
        pc = dir / f"{frame_idx:06d}_merged.pkl.gz"
        fmt = 'pkl.gz'
    else:
        pc = dir / f"{frame_idx:06d}_merged.ply"
        fmt = 'ply'
    if pc.exists():
        return pc, fmt
    return None, None


def process_clip_logic(
    args: argparse.Namespace,
    ply_root: Path,
    clip_dir_name: str,
    entries: List[Dict[str, Any]],
    processor: PointCloudProcessor
) -> Tuple[str, List[Path]]:
    """
    处理单个 clip 的所有 entries 和 frames 的核心逻辑。
    此函数由顺序执行器或 Ray Actor 调用。
    """

    default_cam_names = ['front', 'front_left', 'front_right', 'side_left', 'side_right']
    camera_names = default_cam_names if (args.cams is None or args.cams == ['all']) else args.cams

    clip_configs: List[ClipConfig] = []
    entry_index_to_config: Dict[int, ClipConfig] = {}
    entries_filtered: List[Dict[str, Any]] = []
    for entry in entries:
        clip_config = create_clip_config(
            entry=entry,
            ply_root=ply_root,
            global_base_shift=args.base_shift,
            global_base_longitudinal=args.base_longitudinal,
        )
        # 仅保留右转变道样本
        if clip_config.action_for_seg != "right":
            continue
        clip_configs.append(clip_config)
        entries_filtered.append(entry)
        entry_index = clip_config.entry_index
        entry_index_to_config[entry_index] = clip_config

    if len(clip_configs) == 0:
        return clip_dir_name, []

    scheduler = NovelPoseScheduler(
        waymo_root=args.waymo_root,
        clip_id=clip_dir_name,
        camera_names=camera_names,
        segment_len=args.segment_len,
        lane_change_frames=args.lane_change_frames,
        accel_frames=args.accel_frames,
        clip_configs=clip_configs,
    )
    rig = CameraRigSpec.from_waymo(
        extrinsics_list=scheduler.extrinsics,
        intrinsics_list=scheduler.intrinsics,
        camera_names=camera_names,
    )
    object_processor = ObjectInfoProcessor(
        input_root=args.waymo_root,
        clip_id=clip_dir_name,
        num_frames=int(getattr(scheduler, "num_frames")),
    )
    # 传入 waymo_root，使插入目标的位姿能按照与 ObjectInfoProcessor 相同的方式做归一化，
    # 从而与当前相机 / 点云使用的世界坐标系对齐。
    insertion_info_provider = InsertionInfoProvider(
        clip_id=clip_dir_name,
        waymo_root=args.waymo_root,
    )
    insertion_info = insertion_info_provider.get_insertion_info()

    saved_videos: List[Path] = []
    output_root = Path(args.output)
    pointcloud_root = Path(args.pointcloud)
    clip_out_dir = output_root / clip_dir_name
    
    if args.visualize:
        for cam in camera_names:
            (clip_out_dir / cam).mkdir(parents=True, exist_ok=True)

    entry_iterator = tqdm(entries_filtered, desc=f"{clip_dir_name} Entries", position=1, leave=False)
    for entry in entry_iterator:
        entry_index = int(entry.get('index'))
        clip_config = entry_index_to_config[entry_index]

        if ConfigSaver.exists(output_root, clip_config.save_name):
            entry_iterator.set_postfix_str(f"skip {clip_config.save_name}")
            continue

        seg_idx = clip_config.seg_idx_of_entry
        seg_start = seg_idx * args.segment_len
        seg_end = min(scheduler.num_frames, seg_start + args.segment_len)

        rgb_video_frames: Dict[str, List[np.ndarray]] = {cam: [] for cam in camera_names}
        mask_video_frames: Dict[str, List[np.ndarray]] = {cam: [] for cam in camera_names}
        area_video_frames: Dict[str, List[np.ndarray]] = {cam: [] for cam in camera_names}

        for f in range(seg_start, seg_end):
            pc_path, _fmt = resolve_files(pointcloud_root, clip_dir_name, f)
            if pc_path is None:
                if f == seg_start:
                    print(f"[WARN] {clip_dir_name} 段 {seg_idx} 帧 {f:06d} 缺少点云，跳过该帧")
                continue

            override_cam_c2w: Dict[str, np.ndarray] = {}
            base_cam_c2w: Dict[str, np.ndarray] = {}
            for cam in camera_names:
                override_cam_c2w[cam] = scheduler.get_pose(cam, f, entry_index)
                base_cam_c2w[cam] = scheduler.get_base_pose(cam, f)

            rgb_frames_dict, mask_frames_dict, area_frames_dict = processor.process_file(
                pc_path,
                clip_out_dir,
                rig=rig,
                base_cam_c2w=base_cam_c2w,
                override_cam_c2w=override_cam_c2w,
                cams=camera_names,
                front_cam_name='front',
                frame_idx_int=f,
                object_processor=object_processor,
                clip_config=clip_config,
                insertion_info=insertion_info,
                visualize=args.visualize,
            )

            for cam_name, frame in rgb_frames_dict.items():
                if cam_name in rgb_video_frames:
                    rgb_video_frames[cam_name].append(frame)
            for cam_name, mask in mask_frames_dict.items():
                if cam_name in mask_video_frames:
                    mask_video_frames[cam_name].append(mask)
            for cam_name, area_frame in area_frames_dict.items():
                if cam_name in area_video_frames:
                    area_video_frames[cam_name].append(area_frame)

        if len(rgb_video_frames) > 0:
            valid_cameras = [cam for cam in camera_names if cam in rgb_video_frames and len(rgb_video_frames[cam]) > 0]
            
            # 为每个相机保存视频
            for cam_name in valid_cameras:
                cam_dir = output_root / cam_name
                cam_dir.mkdir(parents=True, exist_ok=True)
                
                rgb_video_file = cam_dir / f"{clip_config.save_name}.mp4"
                mask_video_file = cam_dir / f"{clip_config.save_name}_mask.mp4"
                area_video_file = cam_dir / f"{clip_config.save_name}_area.mp4"
                
                rgb_frames_to_save = rgb_video_frames[cam_name]
                mask_frames_to_save = mask_video_frames.get(cam_name, [])
                area_frames_to_save = area_video_frames.get(cam_name, [])

                VideoSaver.save_rgb_and_mask_videos(
                    rgb_video_file=rgb_video_file,
                    mask_video_file=mask_video_file,
                    rgb_frames=rgb_frames_to_save,
                    mask_frames=mask_frames_to_save,
                    fps=10,
                    crf=18,
                )
                saved_videos.append(rgb_video_file)

                VideoSaver.save_rgb_video(
                    video_file=area_video_file,
                    rgb_frames=area_frames_to_save,
                    fps=10,
                    crf=18,
                )
                saved_videos.append(area_video_file)

            # 在所有相机视频保存完成后保存一次配置
            if len(valid_cameras) > 0:
                ConfigSaver.save(output_root, clip_config)

    return clip_dir_name, saved_videos


def main():
    parser = argparse.ArgumentParser(description='点云重投影工具')
    parser.add_argument('--pointcloud', type=str, required=True, help='点云文件目录或单个点云文件路径')
    parser.add_argument('--output', type=str, required=True, help='输出目录')
    parser.add_argument('--cams', type=str, nargs='+', default=['all'], help='要处理的相机名列表')
    parser.add_argument('--seed', type=int, default=42, help='随机数种子')
    parser.add_argument('--segment_len', type=int, default=29, help='分段长度（帧）')
    parser.add_argument('--lane_change_frames', type=int, default=29, help='变道完成所需帧数')
    parser.add_argument('--accel_frames', type=int, default=29, help='加/减速完成所需帧数')
    parser.add_argument('--base_shift', type=float, default=2.0, help='单次变道横向位移（米）')
    parser.add_argument('--base_longitudinal', type=float, default=2.0, help='单次纵向加/减速位移（米）')
    parser.add_argument('--waymo_root', type=str, default='/data/dataset/waymo', help='Waymo 数据根目录')
    parser.add_argument('--visualize', action='store_true', default=False, help='是否保存可视化图片和视频')
    parser.add_argument('--use_ray', action='store_true', default=True, help='是否使用 Ray 进行并行处理')
    parser.add_argument('--num_gpus', type=int, default=2, help='要使用的 GPU 总数')
    parser.add_argument('--parallel_factor', type=int, default=3, help='每个 GPU 的并行工作单元数量')

    args = parser.parse_args()

    # 统一设定随机种子，确保与是否开启 Ray 无关
    set_global_seed(args.seed)

    VALIDATION_PLY_ROOT = Path("/data/dataset/validation_transfer/objects_ply_transformed_with_training")
    args.output = Path(args.output) / 'render' / 'validation'
    args.waymo_root = Path(args.waymo_root) / 'validation'
    args.pointcloud = Path(args.pointcloud) / 'validation'
    args.clip_config_save_path = Path(args.output) / 'clip_config' / 'validation'
    args.config_path = Path("/data/dataset/waymo_transfer2/validation/final_info.json")
    ply_root = VALIDATION_PLY_ROOT

    config = ProjectionConfig(projection_resolution=(720, 1280))

    output_dir = Path(args.output)
    output_dir.mkdir(parents=True, exist_ok=True)
    cams = args.cams

    with open(args.config_path, 'r') as fcfg:
        cfg = json.load(fcfg)

    clip_groups: Dict[str, List[Dict[str, Any]]] = {}
    for entry in cfg:
        clip_name = entry.get('clip_name').rsplit('_', 1)[0]
        clip_groups.setdefault(clip_name, []).append(entry)

    if args.use_ray:
        print(f"启用 Ray 并行处理。GPU: {args.num_gpus}, 并行因子: {args.parallel_factor}x")

        import ray
        from ray.util import ActorPool

        ray.init(num_gpus=args.num_gpus)
        print(f"Ray 已启动，可用资源: {ray.cluster_resources()}")

        num_actors = int(args.num_gpus * args.parallel_factor)
        if num_actors == 0:
            raise ValueError("GPU 数量为 0，无法创建 Actor。")

        gpu_per_actor = args.num_gpus / num_actors
        print(f"将启动 {num_actors} 个 Actor, 每个 Actor 请求 {gpu_per_actor:.2f} GPU 资源。")

        @ray.remote(num_gpus=gpu_per_actor)
        class GpuRunner:
            """
            一个 Ray Actor，它在专用的 GPU 资源上运行一个完整 clip 的所有处理。
            """
            def __init__(self, args: argparse.Namespace, config: ProjectionConfig, ply_root: Path):
                import tensorflow
                tensorflow.config.set_visible_devices([], 'GPU')

                # 在 Actor 进程内也设置相同随机种子
                set_global_seed(args.seed)

                self.args = args
                self.config = config
                self.ply_root = ply_root

                self.projector = PointCloudProjector(self.config)
                self.processor = PointCloudProcessor(self.config, self.projector)

            def process_clip(self, clip_dir_name: str, entries: List[Dict[str, Any]]) -> Tuple[str, List[Path]]:
                return process_clip_logic(
                    self.args,
                    self.ply_root,
                    clip_dir_name,
                    entries,
                    self.processor,
                )

        actor_pool = [GpuRunner.remote(args, config, ply_root) for _ in range(num_actors)]
        pool = ActorPool(actor_pool)

        all_tasks = []
        print(f"正在为 {len(clip_groups)} 个 clips 提交任务...")
        for clip_dir_name, entries in clip_groups.items():
            task = pool.submit(
                lambda actor, clip_args: actor.process_clip.remote(*clip_args),
                (clip_dir_name, entries)
            )
            all_tasks.append(task)

        print(f"已提交所有 {len(all_tasks)} 个任务，正在等待完成...")
        main_pbar = tqdm(total=len(all_tasks), desc="完成 Clips")
        for _ in range(len(all_tasks)):
            try:
                clip_name, video_files = pool.get_next_unordered()
                print(f"\nClip {clip_name} 已完成。共保存 {len(video_files)} 个视频。")
                main_pbar.update(1)
            except Exception as e:
                print(f"\n一个 clip 任务失败: {e}")
                main_pbar.update(1)
        main_pbar.close()
        ray.shutdown()
        print("所有 Clip 任务已完成，Ray 已关闭。")
    else:
        import tensorflow
        tensorflow.config.set_visible_devices([], 'GPU')

        set_global_seed(args.seed)

        projector = PointCloudProjector(config)
        processor = PointCloudProcessor(config, projector)

        main_pbar = tqdm(clip_groups.items(), desc="Clips")
        for clip_dir_name, entries in main_pbar:
            main_pbar.set_description(f"正在处理 {clip_dir_name}")
            clip_name, video_files = process_clip_logic(
                args,
                ply_root,
                clip_dir_name,
                entries,
                processor,
            )
            print(f"Clip {clip_name} 已完成。共保存 {len(video_files)} 个视频。")
        print("所有 Clip 任务已顺序完成。")


if __name__ == '__main__':
    main()


