from pathlib import Path
from typing import Any, List

import cv2
import numpy as np
import imageio as imageio_v1

# 复用现有的高效 I/O 实现
from demo_waymo import (
    load_point_cloud_efficient,
    save_point_cloud_efficient,
    POINT_CLOUD_STORAGE_FORMAT,
)


class ImageSaver:
    @staticmethod
    def save_rgb_jpg(output_dir: Path, rgb_image: np.ndarray, frame_id: str) -> None:
        output_dir.mkdir(parents=True, exist_ok=True)
        cv2.imwrite(
            str(output_dir / f"{str(frame_id).zfill(3)}.jpg"),
            cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR),
        )

    @staticmethod
    def save_mask_png(output_dir: Path, mask_image: np.ndarray, frame_id: str) -> None:
        output_dir.mkdir(parents=True, exist_ok=True)
        if mask_image.dtype != np.uint8:
            normalized = (mask_image > 0).astype(np.uint8) * 255
        else:
            max_val = int(mask_image.max()) if mask_image.size > 0 else 0
            normalized = mask_image if max_val > 1 else (mask_image * 255).astype(np.uint8)
        cv2.imwrite(
            str(output_dir / f"{str(frame_id).zfill(3)}_m.png"),
            normalized,
        )


class VideoSaver:
    @staticmethod
    def save_rgb_video(
        video_file: Path,
        rgb_frames: List[np.ndarray],
        fps: int = 10,
        crf: int = 18,
    ) -> None:
        """
        保存 RGB 视频
        
        Args:
            video_file: 视频文件路径
            rgb_frames: RGB 帧列表，每个帧为 (H, W, 3) uint8 数组
            fps: 帧率
            crf: CRF 值 (0-51, 越小质量越高)
        """
        if len(rgb_frames) == 0:
            return
        
        video_file.parent.mkdir(parents=True, exist_ok=True)
        
        output_writer = imageio_v1.get_writer(
            video_file,
            format="ffmpeg",
            fps=fps,
            codec="libx264",
            macro_block_size=None,  # This makes sure num_frames is correct (by default it is rounded to 16x).
            ffmpeg_params=[
                "-crf", str(crf),     # Lower CRF for higher quality (0-51, lower is better)
                "-preset", "slow",   # Slower preset for better compression/quality
                "-pix_fmt", "yuv420p", # Ensures wide compatibility
            ],
        )
        
        for frame in rgb_frames:
            # 确保为 uint8 三通道
            if frame.dtype != np.uint8:
                frame = np.clip(frame, 0, 255).astype(np.uint8)
            if frame.ndim == 2:
                frame = np.stack([frame] * 3, axis=-1)
            elif frame.shape[-1] == 4:
                frame = frame[..., :3]
            output_writer.append_data(frame)
        output_writer.close()
    
    @staticmethod
    def save_mask_video(
        video_file: Path,
        mask_frames: List[np.ndarray],
        fps: int = 10,
        crf: int = 18,
    ) -> None:
        """
        保存 mask 视频（灰度视频）
        
        Args:
            video_file: 视频文件路径
            mask_frames: mask 帧列表，每个帧为 (H, W) 布尔或数值数组
            fps: 帧率
            crf: CRF 值 (0-51, 越小质量越高)
        """
        if len(mask_frames) == 0:
            return
        
        video_file.parent.mkdir(parents=True, exist_ok=True)
        
        # 将 mask 转换为 uint8 格式 (0-255)
        normalized_frames = []
        for mask in mask_frames:
            if mask.dtype == bool:
                normalized = mask.astype(np.uint8) * 255
            elif mask.dtype != np.uint8:
                max_val = int(mask.max()) if mask.size > 0 else 0
                if max_val > 1:
                    normalized = mask.astype(np.uint8)
                else:
                    normalized = (mask * 255).astype(np.uint8)
            else:
                max_val = int(mask.max()) if mask.size > 0 else 0
                normalized = mask if max_val > 1 else (mask * 255).astype(np.uint8)
            
            # 转换为 (H, W, 3) 格式用于视频编码
            if len(normalized.shape) == 2:
                normalized = np.stack([normalized] * 3, axis=-1)
            
            normalized_frames.append(normalized)
        
        output_writer = imageio_v1.get_writer(
            video_file,
            format="ffmpeg",
            fps=fps,
            codec="libx264",
            macro_block_size=None,
            ffmpeg_params=[
                "-crf", str(crf),
                "-preset", "slow",
                "-pix_fmt", "yuv420p",
            ],
        )
        
        for frame in normalized_frames:
            output_writer.append_data(frame)
        output_writer.close()
    
    @staticmethod
    def save_rgb_and_mask_videos(
        rgb_video_file: Path,
        mask_video_file: Path,
        rgb_frames: List[np.ndarray],
        mask_frames: List[np.ndarray],
        fps: int = 10,
        crf: int = 18,
    ) -> None:
        """
        同时保存 RGB 和 mask 视频
        
        Args:
            rgb_video_file: RGB 视频文件路径
            mask_video_file: mask 视频文件路径
            rgb_frames: RGB 帧列表
            mask_frames: mask 帧列表
            fps: 帧率
            crf: CRF 值
        """
        VideoSaver.save_rgb_video(rgb_video_file, rgb_frames, fps, crf)
        VideoSaver.save_mask_video(mask_video_file, mask_frames, fps, crf)


class ConfigSaver:
    @staticmethod
    def _config_dir(output_root: Path) -> Path:
        return output_root / "clip_configs"

    @staticmethod
    def get_config_path(output_root: Path, save_name: str) -> Path:
        cfg_dir = ConfigSaver._config_dir(output_root)
        return cfg_dir / f"{save_name}.json"

    @staticmethod
    def exists(output_root: Path, save_name: str) -> bool:
        return ConfigSaver.get_config_path(output_root, save_name).exists()

    @staticmethod
    def save(output_root: Path, clip_config: Any) -> Path:
        """
        将单个 entry 的配置保存到新目录下。
        """
        import json

        cfg_dir = ConfigSaver._config_dir(output_root)
        cfg_dir.mkdir(parents=True, exist_ok=True)

        # 构造与原先完全一致的字段
        config_dict = dict(
            do_back_project=getattr(clip_config, "do_back_project"),
            object_ids_to_remove=getattr(clip_config, "object_ids_to_remove"),
            objects_to_add_info=[
                (str(oid), str(pth)) for oid, pth in getattr(clip_config, "objects_to_add_info")
            ],
            local_base_shift=getattr(clip_config, "local_base_shift"),
            local_base_longitudinal=getattr(clip_config, "local_base_longitudinal"),
            action_for_seg=getattr(clip_config, "action_for_seg"),
            seg_idx_of_entry=getattr(clip_config, "seg_idx_of_entry"),
            entry_index=getattr(clip_config, "entry_index"),
            save_name=getattr(clip_config, "save_name"),
            use_fixed_motion=getattr(clip_config, "use_fixed_motion"),
        )

        save_path = ConfigSaver.get_config_path(output_root, config_dict["save_name"])
        with open(save_path, "w") as f:
            json.dump(config_dict, f, indent=2, ensure_ascii=False)
        return save_path
