# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import os
import pickle
import traceback
import warnings
from typing import Dict, Optional
from moviepy.editor import ImageSequenceClip, vfx

import numpy as np
import torch
from decord import VideoReader, cpu
from concurrent.futures import ThreadPoolExecutor
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import transforms as T
from torchvision.transforms import functional as transforms_F
from tqdm import tqdm
from einops import rearrange

from cosmos_transfer2._src.imaginaire.lazy_config import LazyCall as L
from cosmos_transfer2._src.imaginaire.lazy_config import instantiate
from cosmos_transfer2._src.imaginaire.modules.input_handling.utils import detect_aspect_ratio
from cosmos_transfer2._src.imaginaire.utils import log
from cosmos_transfer2._src.predict2.datasets.local_datasets.dataset_utils import (
    ResizePreprocess,
    ToTensorVideo,
)

# mappings between control types and corresponding sub-folders names in the data folder
CTRL_TYPE_INFO = {
    "keypoint": {"folder": "keypoint", "format": "pickle", "data_dict_key": "keypoint"},
    "depth": {"folder": "depth", "format": "mp4", "data_dict_key": "depth"},
    "lidar": {"folder": "lidar", "format": "mp4", "data_dict_key": "lidar"},
    "hdmap_bbox": {"folder": "control_input_hdmap_bbox", "format": "mp4", "data_dict_key": "hdmap"},
    "seg": {"folder": "seg", "format": "pickle", "data_dict_key": "segmentation"},
    "edge": {"folder": None},  # Canny edge, computed on-the-fly
    "vis": {"folder": None},  # Blur, computed on-the-fly
    "upscale": {"folder": None},  # Computed on-the-fly
}

AUTO_MV_DEFAULT_PROMPT = 'This multi-camera perspective captures a drive along a multi-lane urban freeway during the daytime under a hazy or partly cloudy sky. The vehicle travels in one of the right lanes, flanked on one side by a high retaining wall featuring a concrete base and a brown, brick-patterned upper section with some climbing vines, and on the other side by a concrete median barrier. As the car moves forward, it approaches and passes under a large concrete overpass, which frames a view of a distant downtown city skyline with numerous high-rise buildings. A green freeway sign for the "Hill St / Grand Ave" exit is briefly visible, and another overpass, distinguished by its overhead catenary power lines suggesting a light rail system, also crosses the roadway. The flow of traffic is moderate, with the camera vehicle sharing the road with other cars, including a white Dodge Grand Caravan minivan, a dark-colored SUV, and a black sedan, which are visible at various points ahead and behind. The asphalt road surface shows some visible cracks and wear, contributing to the overall scene of a typical day on a major metropolitan highway.'


class WaymoMultiviewDataset(Dataset):
    """
    Waymo variant of the multiview local dataset.

    This is a direct copy of MultiviewTransferDataset to serve as a starting point
    for integrating Waymo-style data layout. Update paths/layouts as needed.
    """

    def __init__(
        self,
        dataset_dir,
        num_frames,
        resolution,
        video_size,
        camera_keys,
        front_camera_key=None,
        camera_to_view_id: Dict = None,
        sequence_interval=1,
        state_t=8,
        front_view_caption_only=True,
        is_train=True,
        val_percent=0.02,
        **kwargs,
    ):
        super().__init__()
        self.dataset_dir = dataset_dir
        self.sequence_length = num_frames
        self.resolution = resolution
        self.camera_keys = camera_keys
        self.camera_to_view_id = camera_to_view_id
        self.front_camera_key = front_camera_key
        self.sequence_interval = sequence_interval
        self.state_t = state_t
        self.front_view_caption_only = front_view_caption_only
        self.is_train = is_train
        self.H, self.W = video_size

        # Setup captions directory
        self.captions_dir = os.path.join(self.dataset_dir, "captions", self.front_camera_key or self.camera_keys[0])

        data_config_path = os.path.join(self.dataset_dir, "final_info.json")
        with open(data_config_path, "r") as f:
            all_data_info = json.load(f)

        if is_train:
            replacement_samples = [item for item in all_data_info if item.get('method') == 'replacement']
            sparse_samples = [item for item in all_data_info if item.get('method') == 'sparse']
            cutoff_replacement = int(len(replacement_samples) * val_percent)
            cutoff_sparse = int(len(sparse_samples) * val_percent)
            self.data_info = replacement_samples[:-cutoff_replacement] + sparse_samples[:-cutoff_sparse]
            num_replacement = len(replacement_samples[:-cutoff_replacement])
            num_sparse = len(sparse_samples[:-cutoff_sparse])
        else:
            replacement_samples = [item for item in all_data_info if item.get('method') == 'object']
            sparse_samples = [item for item in all_data_info if item.get('method') == 'sparse']
            self.data_info = replacement_samples + sparse_samples
            num_replacement = len(replacement_samples)
            num_sparse = len(sparse_samples)
        log.info(f"Loaded {len(all_data_info)} total infos. Using {len(self.data_info)} for {'training' if is_train else 'validation'}.")
        log.info(f"Split: {num_replacement} object samples, {num_sparse} sparse samples.")

        self.num_failed_loads = 0
        self.preprocess = T.Compose([ToTensorVideo(), ResizePreprocess((video_size[0], video_size[1]))])

    def __str__(self) -> str:
        return (f"{len(self.data_info)} samples for {'training' if self.is_train else 'validation'} "
                f"from {self.dataset_dir}")

    def __len__(self):
        return len(self.data_info)

    def _sample_frames(self, video_path):
        """
        Sample frames from video and get metadata.
        Adapted from ExampleTransferDataset to support sequence_interval.
        
        Args:
            video_path: Path to video file
            frame_ids: If None, randomly sample frames. If provided, use these frame_ids.
        """
        try:
            vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
            n_frames = len(vr)
        except Exception as e:
            warnings.warn(f"Failed to open video {video_path}: {e}")
            return None, None, None

        # Calculate valid start frame range
        # Total frames needed = (sequence_length - 1) * interval + 1
        total_frame_span = (self.sequence_length - 1) * self.sequence_interval + 1
        max_start_idx = n_frames - total_frame_span

        if max_start_idx < 0:  # Video is too short
            warnings.warn(f"Video is too short ({n_frames} frames) for {self.sequence_length} frames with interval {self.sequence_interval}: {video_path}")
            return None, None, None
        # Sample start frame
        start_frame = np.random.randint(0, max_start_idx + 1)

        # Calculate frame_ids with interval
        frame_ids = [start_frame + i * self.sequence_interval for i in range(self.sequence_length)]

        # Load frames
        frames = vr.get_batch(frame_ids).asnumpy()
        frames = frames.astype(np.uint8)
        try:
            fps = vr.get_avg_fps()
        except Exception:  # failed to read FPS
            fps = 10  # Waymo default

        return frames, frame_ids, fps

    def _load_video(self, video_path, frame_ids):
        vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
        assert (np.array(frame_ids) < len(vr)).all()
        assert (np.array(frame_ids) >= 0).all()
        vr.seek(0)
        frame_data = vr.get_batch(frame_ids).asnumpy()
        try:
            fps = vr.get_avg_fps()
        except Exception:  # failed to read FPS
            fps = 24
        return frame_data, fps


    def __getitem__(self, index):
        """
        Multi-view data loading with routing logic based on method.
        """
        try:
            info: dict = self.data_info[index]
            method = info.get('method')
            video_name = info['clip_name']
            video_base_name, clip_ids = video_name.rsplit("_", maxsplit=1)
            clip_ids = int(clip_ids)
            global_ids = info['index']
            traj_cfg = info['trajectory']

            # Sample frames once from front camera (all views use same frame_ids)
            front_camera_key = self.front_camera_key or self.camera_keys[0]
            front_video_path = os.path.join(self.dataset_dir, "videos", front_camera_key, f"{video_name}.mp4")
            frames_np, frame_ids, fps = self._sample_frames(front_video_path)
            if frames_np is None:
                raise ValueError(f"Failed to sample frames from {front_video_path}")

            # Process all camera views
            videos_list, control_inputs_list = [], []
            sparse_video_list, reference_video_list = [], []
            sparse_mask_list, reference_mask_list = [], []
            sparse_ctrl_list, reference_ctrl_list = [], []

            for camera_ids, camera_key in enumerate(self.camera_keys):
                video_path = os.path.join(self.dataset_dir, "videos", camera_key, f"{video_name}.mp4")
                next_video_path = os.path.join(self.dataset_dir, "videos", camera_key, f"{video_base_name}_{clip_ids + 1}.mp4")
                ctrl_path = os.path.join(self.dataset_dir, "control_input_hdmap_bbox", camera_key, f"{video_name}.mp4")
                next_ctrl_path = os.path.join(self.dataset_dir, "control_input_hdmap_bbox", camera_key, f"{video_base_name}_{clip_ids + 1}.mp4")
                sparse_path = os.path.join(self.dataset_dir, "control_input_sparse", camera_key, f"{video_name}_{global_ids}.mp4")
                sparse_mask_path = os.path.join(self.dataset_dir, "control_input_sparse", camera_key, f"{video_name}_{global_ids}_mask.mp4")
                if method != "sparse":
                    video_mask_path = os.path.join(self.dataset_dir, "control_input_sparse", camera_key, f"{video_name}_{global_ids}_area.mp4")
                else:
                    video_mask_path = None

                frames_np, _ = self._load_video(video_path, frame_ids)
                frames_t = torch.from_numpy(frames_np.astype(np.uint8)).permute(0, 3, 1, 2)
                frames_t = self.preprocess(frames_t)
                frames_t = torch.clamp(frames_t * 255.0, 0, 255).to(torch.uint8)
                video = frames_t.permute(1, 0, 2, 3)  # C, T, H, W
                aspect_ratio = detect_aspect_ratio((self.W, self.H))

                sparse_frames_np, _ = self._load_video(sparse_path, frame_ids)
                sparse_frames_t = torch.from_numpy(sparse_frames_np.astype(np.uint8)).permute(0, 3, 1, 2)
                sparse_frames_t = self.preprocess(sparse_frames_t)
                sparse_frames_t = torch.clamp(sparse_frames_t * 255.0, 0, 255).to(torch.uint8)
                sparse_video = sparse_frames_t.permute(1, 0, 2, 3)  # C, T, H, W

                sparse_mask_np, _ = self._load_video(sparse_mask_path, frame_ids)
                sparse_mask_t = torch.from_numpy(sparse_mask_np.astype(np.uint8)).permute(0, 3, 1, 2)
                sparse_mask_t = self.preprocess(sparse_mask_t)
                sparse_mask_t = torch.clamp(sparse_mask_t * 255.0, 0, 255).to(torch.uint8)
                sparse_mask = sparse_mask_t.permute(1, 0, 2, 3)  # C, T, H, W

                ctrl_np, _ = self._load_video(ctrl_path, frame_ids)
                ctrl_t = torch.from_numpy(ctrl_np.astype(np.uint8)).permute(0, 3, 1, 2)
                ctrl_t = self.preprocess(ctrl_t)
                ctrl_t = torch.clamp(ctrl_t * 255.0, 0, 255).to(torch.uint8)
                ctrl = ctrl_t.permute(1, 0, 2, 3)  # C, T, H, W

                if video_mask_path is not None:
                    video_mask_np, _ = self._load_video(video_mask_path, frame_ids)
                    video_mask_t = torch.from_numpy(video_mask_np.astype(np.uint8)).permute(0, 3, 1, 2)
                    video_mask_t = self.preprocess(video_mask_t)
                    video_mask_t = torch.clamp(video_mask_t * 255.0, 0, 255).to(torch.uint8)
                    video_mask = video_mask_t.permute(1, 0, 2, 3)  # C, T, H, W
                else:
                    video_mask = torch.zeros_like(video)

                if traj_cfg["action_for_seg"] == "up":
                    next_frames_np, _ = self._load_video(next_video_path, frame_ids)
                    next_frames_t = torch.from_numpy(next_frames_np.astype(np.uint8)).permute(0, 3, 1, 2)
                    next_frames_t = self.preprocess(next_frames_t)
                    next_frames_t = torch.clamp(next_frames_t * 255.0, 0, 255).to(torch.uint8)
                    next_video = next_frames_t.permute(1, 0, 2, 3)  # C, T, H, W

                    next_ctrl_np, _ = self._load_video(next_ctrl_path, frame_ids)
                    next_ctrl_t = torch.from_numpy(next_ctrl_np.astype(np.uint8)).permute(0, 3, 1, 2)
                    next_ctrl_t = self.preprocess(next_ctrl_t)
                    next_ctrl_t = torch.clamp(next_ctrl_t * 255.0, 0, 255).to(torch.uint8)
                    next_ctrl = next_ctrl_t.permute(1, 0, 2, 3)  # C, T, H, W

                    reference_video = torch.cat([video, next_video], dim=1)
                    reference_mask = torch.cat([video_mask, torch.zeros_like(next_video)], dim=1)
                    reference_ctrl = torch.cat([ctrl, next_ctrl], dim=1)
                else:
                    reference_video = video.clone()
                    reference_mask = video_mask.clone()
                    reference_ctrl = ctrl.clone()

                reference_video[reference_mask > 235] = 0

                data = {
                    "video": video,
                    "sparse_video": sparse_video,
                    "sparse_mask": sparse_mask,
                    "sparse_ctrl": ctrl,
                    "reference_video": reference_video,
                    "reference_mask": reference_mask,
                    "reference_ctrl": reference_ctrl,
                    "frame_start": frame_ids[0],
                    "frame_end": frame_ids[-1] + 1,
                    "frame_indices": frame_ids,
                    "aspect_ratio": aspect_ratio,
                    "fps": fps,
                }
                data["video_name"] = {
                    "video_path": video_path,
                }

                caption_path = os.path.join(self.captions_dir, f"{video_name}.json")
                data["ai_caption"] = AUTO_MV_DEFAULT_PROMPT
                if os.path.exists(caption_path):
                    with open(caption_path, "r") as f:
                        metadata = json.load(f)
                    if "caption" in metadata and len(metadata["caption"]) > 0:
                        data["ai_caption"] = metadata["caption"]

                if traj_cfg["action_for_seg"] in ("left", "right"):
                    use_fixed_motion = traj_cfg["use_fixed_motion"]
                    longitudinal = traj_cfg["longitudinal"] * (-1 if traj_cfg["action_for_seg"] == "left" else 1)
                    data["reference_video"] = change_shift(data["reference_video"], longitudinal, camera_ids, use_fixed_motion)
                    data["reference_mask"] = change_shift(data["reference_mask"], longitudinal, camera_ids, use_fixed_motion)
                    data["reference_ctrl"] = change_shift(data["reference_ctrl"], longitudinal, camera_ids, use_fixed_motion)
                elif traj_cfg["action_for_seg"] in ("up", "down"):
                    rate = 1.0 + 0.1 * traj_cfg["shift"] * (1 if traj_cfg["action_for_seg"] == "up" else -1)
                    data["reference_video"] = change_rate(data["reference_video"], rate, len(frame_ids), fps)
                    data["reference_mask"] = change_rate(data["reference_mask"], rate, len(frame_ids), fps)
                    data["reference_ctrl"] = change_rate(data["reference_ctrl"], rate, len(frame_ids), fps)

                videos_list.append(data["video"])
                sparse_video_list.append(data["sparse_video"])
                reference_video_list.append(data["reference_video"])
                sparse_mask_list.append(data["sparse_mask"])
                reference_mask_list.append(data["reference_mask"])
                sparse_ctrl_list.append(data["sparse_ctrl"])
                reference_ctrl_list.append(data["reference_ctrl"])

            final_data = dict()
            final_data["video"] = torch.cat(videos_list, dim=1)
            final_data["control_input_sparse_video"] = torch.cat(sparse_video_list, dim=1)
            final_data["control_input_sparse_mask"] = torch.cat(sparse_mask_list, dim=1).to(torch.float32) / 255.
            final_data["control_input_sparse_ctrl"] = torch.cat(sparse_ctrl_list, dim=1)
            final_data["control_input_reference_video"] = torch.cat(reference_video_list, dim=1)
            final_data["control_input_reference_mask"] = torch.cat(reference_mask_list, dim=1).to(torch.float32) / 255.
            final_data["control_input_reference_ctrl"] = torch.cat(reference_ctrl_list, dim=1)

            final_data["image_size"] = torch.tensor([self.H, self.W, self.H, self.W])
            final_data["fps"] = fps
            final_data["sample_n_views"] = len(self.camera_keys)
            final_data["num_video_frames_per_view"] = self.sequence_length
            final_data["index"] = global_ids

            view_indices = [self.camera_to_view_id[key] for key in self.camera_keys]
            final_data["view_indices"] = torch.tensor(view_indices).repeat_interleave(self.sequence_length).contiguous()
            final_data["latent_view_indices_B_T"] = (
                torch.tensor(view_indices).repeat_interleave(self.state_t).contiguous()
            )
            final_data["video_name"] = {
                "video_path": front_video_path,
            }
            final_data["aspect_ratio"] = data["aspect_ratio"]
            final_data["ai_caption"] = data["ai_caption"]
            final_data["padding_mask"] = torch.zeros(1, self.H, self.W)
            final_data["ref_cam_view_idx_sample_position"] = -1
            final_data["front_cam_view_idx_sample_position"] = torch.tensor([0])
            return final_data

        except Exception as e:
            self.num_failed_loads += 1
            log.warning(
                f"Failed to load data for index {index} (total failures: {self.num_failed_loads}): {e}\n"
                f"{traceback.format_exc()}",
                rank0_only=False,
            )
            # Retry with a random index
            if len(self) > 0:
                return self[np.random.randint(len(self))]
            else:
                raise RuntimeError(f"Failed to load data and dataset is empty.")

    def _load_frames_preprocessed(self, video_path, frame_ids):
        frames_np, fps = self._load_video(video_path, frame_ids)
        frames_t = torch.from_numpy(frames_np.astype(np.uint8)).permute(0, 3, 1, 2)
        frames_t = self.preprocess(frames_t)
        frames_t = torch.clamp(frames_t * 255.0, 0, 255).to(torch.uint8)
        video = frames_t.permute(1, 0, 2, 3)  # C, T, H, W
        return video, fps


def change_rate(video: torch.Tensor, rate: float, n_frames: int, fps: int) -> torch.Tensor:
    device = video.device
    video = video.permute(1, 2, 3, 0).contiguous().cpu().numpy()
    clip_mpy = ImageSequenceClip(list(video), fps=fps)
    new_clip = clip_mpy.fx(vfx.speedx, rate)
    frames = np.array(list(new_clip.iter_frames()))[:n_frames]
    out = torch.from_numpy(frames).permute(3, 0, 1, 2).to(device).contiguous()
    return out


def _compute_lane_change_curves(num_frames: int, total_shift: float, lane_change_frames: int):
    """计算逐帧的平移与偏航曲线。返回 (shift_curve_m, yaw_curve_rad)。"""
    duration = max(1, int(lane_change_frames))
    t = np.arange(num_frames)
    tau = np.clip(t / duration, 0.0, 1.0)
    # 位移采用余弦缓动，到达 total_shift 后保持
    shift_curve = total_shift * 0.5 * (1 - np.cos(np.pi * tau))
    shift_curve[tau >= 1.0] = total_shift
    # 偏航角幅值（度）设定上限，随总位移线性增长到上限
    yaw_peak_deg = min(5.0, 2.0 * abs(total_shift))
    yaw_peak_rad = np.deg2rad(yaw_peak_deg)
    yaw_curve = yaw_peak_rad * np.sin(np.pi * tau)
    yaw_curve[tau >= 1.0] = 0.0
    return shift_curve.astype(np.float32), yaw_curve.astype(np.float32)


def change_shift(video: torch.Tensor, shift_val: float, view_ids: int, use_fixed_motion: bool = False) -> dict:
    px_per_meter = 30.0
    shear_deg_front = 3.0
    rot_deg_front = 2.0
    shear_deg_diag = 4.0
    rot_deg_diag = 3.0
    persp_strength_side = 0.06
    max_scale_cap = 1.5

    def _apply_affine_to_clip(clip, angles, translates, scales, shears, interpolation=transforms_F.InterpolationMode.BILINEAR):
        C, T, H, W = clip.shape
        assert len(angles) == T
        assert len(translates) == T
        assert len(scales) == T
        assert len(shears) == T
        outs = []
        for t in range(T):
            frame = clip[:, t]
            orig_dtype = frame.dtype
            if not frame.is_floating_point():
                frame = frame.float()
            out_f = transforms_F.affine(
                frame,
                angle=float(angles[t]),
                translate=(float(translates[t][0]), float(translates[t][1])),
                scale=float(scales[t]),
                shear=(float(shears[t][0]), float(shears[t][1])),
                interpolation=interpolation,
                fill=0,
            )
            if orig_dtype == torch.uint8:
                out_f = out_f.clamp(0, 255).round().to(torch.uint8)
            else:
                out_f = out_f.to(orig_dtype)
            outs.append(out_f.unsqueeze(1))
        return torch.cat(outs, dim=1)

    def _apply_perspective_to_clip(clip, startpoints_list, endpoints_list, interpolation=transforms_F.InterpolationMode.BILINEAR):
        C, T, H, W = clip.shape
        assert len(startpoints_list) == T
        assert len(endpoints_list) == T
        outs = []
        for t in range(T):
            frame = clip[:, t]
            orig_dtype = frame.dtype
            if not frame.is_floating_point():
                frame = frame.float()
            out_f = transforms_F.perspective(
                frame,
                startpoints=startpoints_list[t],
                endpoints=endpoints_list[t],
                interpolation=interpolation,
                fill=0,
            )
            if orig_dtype == torch.uint8:
                out_f = out_f.clamp(0, 255).round().to(torch.uint8)
            else:
                out_f = out_f.to(orig_dtype)
            outs.append(out_f.unsqueeze(1))
        return torch.cat(outs, dim=1)

    def _corners(W: int, H: int):
        return [[0.0, 0.0], [float(W - 1), 0.0], [float(W - 1), float(H - 1)], [0.0, float(H - 1)]]

    def _mat33_from_affine_coeffs(coeffs):
        a1, b1, c1, d1, e1, f1 = coeffs
        return np.array([[a1, b1, c1], [d1, e1, f1], [0.0, 0.0, 1.0]], dtype=np.float64)

    def _mul_point(M, x, y):
        v = M @ np.array([x, y, 1.0], dtype=np.float64)
        if v[2] == 0:
            return float('inf'), float('inf')
        return float(v[0] / v[2]), float(v[1] / v[2])

    def _get_inv_affine_mat33(W, H, angle, translate, scale, shear_x):
        center = [W * 0.5, H * 0.5]
        coeffs = transforms_F._get_inverse_affine_matrix(center, angle, list(translate), scale, [shear_x, 0.0])
        return _mat33_from_affine_coeffs(coeffs)

    def _get_perspective_inv_mat33(startpoints, endpoints):
        def build_A(src, dst):
            x, y = src
            u, v = dst
            return np.array([
                [x, y, 1, 0, 0, 0, -u * x, -u * y, -u],
                [0, 0, 0, x, y, 1, -v * x, -v * y, -v],
            ], dtype=np.float64)

        A = np.zeros((8, 9), dtype=np.float64)
        for i in range(4):
            A[2 * i:2 * i + 2, :] = build_A(endpoints[i], startpoints[i])
        _, _, Vt = np.linalg.svd(A)
        h = Vt[-1, :]
        Hm = h.reshape(3, 3)
        return Hm / Hm[2, 2]

    def _minimal_scale_for_affine(W, H, angle, shear_x, base_translate, max_cap):
        tx, ty = base_translate
        delta_candidates = [0.0]
        if tx != 0.0:
            d = 0.15 * tx
            lim = 0.05 * W
            delta_candidates += [max(-lim, -d), min(lim, d)]
        best_scale = max_cap
        best_tx = tx
        for delta in delta_candidates:
            tx_try = tx + delta

            def ok(scale_val: float) -> bool:
                M = _get_inv_affine_mat33(W, H, angle, (tx_try, ty), scale_val, shear_x)
                for (cx, cy) in _corners(W, H):
                    x, y = _mul_point(M, cx, cy)
                    if x < -1e-3 or y < -1e-3 or x > (W - 1) + 1e-3 or y > (H - 1) + 1e-3:
                        return False
                return True

            if ok(1.0):
                cand = 1.0
            else:
                lo, hi = 1.0, max_cap
                for _ in range(18):
                    mid = (lo + hi) * 0.5
                    if ok(mid):
                        hi = mid
                    else:
                        lo = mid
                cand = hi
            if cand < best_scale:
                best_scale = cand
                best_tx = tx_try
        return best_scale, best_tx

    def _minimal_scale_for_persp_then_affine(W, H, startpoints, endpoints, base_translate, max_cap):
        tx, ty = base_translate
        delta_candidates = [0.0]
        if tx != 0.0:
            d = 0.15 * tx
            lim = 0.05 * W
            delta_candidates += [max(-lim, -d), min(lim, d)]
        H_inv = _get_perspective_inv_mat33(startpoints, endpoints)
        best_scale = max_cap
        best_tx = tx
        for delta in delta_candidates:
            tx_try = tx + delta

            def ok(scale_val: float) -> bool:
                M_aff_inv = _get_inv_affine_mat33(W, H, 0.0, (tx_try, ty), scale_val, 0.0)
                M_total = H_inv @ M_aff_inv
                for (cx, cy) in _corners(W, H):
                    x, y = _mul_point(M_total, cx, cy)
                    if x < -1e-3 or y < -1e-3 or x > (W - 1) + 1e-3 or y > (H - 1) + 1e-3:
                        return False
                return True

            if ok(1.0):
                cand = 1.0
            else:
                lo, hi = 1.0, max_cap
                for _ in range(18):
                    mid = (lo + hi) * 0.5
                    if ok(mid):
                        hi = mid
                    else:
                        lo = mid
                cand = hi
            if cand < best_scale:
                best_scale = cand
                best_tx = tx_try
        return best_scale, best_tx

    C, T, H, W = video.shape
    orig_dtype = video.dtype

    # 统一使用“参数列表”接口
    angles_list = []
    translates_list = []
    scales_list = []
    shears_list = []
    startpoints_list = []
    endpoints_list = []
    aff_angles_list = []  # 仅供视角 3/4 透视后仿射使用
    aff_translates_list = []
    aff_scales_list = []
    aff_shears_list = []

    if use_fixed_motion:
        # 固定变换：计算一次并复制 T 次
        dx = shift_val * px_per_meter
        if view_ids == 0:
            sign = (1.0 if dx > 0 else -1.0) if dx != 0 else 0.0
            angle = -rot_deg_front * sign
            shear_x = shear_deg_front * sign
            scale_border, tx_refine = _minimal_scale_for_affine(W, H, angle, shear_x, (-dx, 0.0), max_scale_cap)
            angles_list = [angle] * T
            translates_list = [(float(tx_refine), 0.0)] * T
            scales_list = [float(scale_border)] * T
            shears_list = [(shear_x, 0.0)] * T
        elif view_ids in (1, 2):
            sign = 1.0 if (dx > 0) else -1.0 if (dx < 0) else 0.0
            angle = -rot_deg_diag * sign
            shear_x = shear_deg_diag * sign
            scale_border, tx_refine = _minimal_scale_for_affine(W, H, angle, shear_x, (-0.7 * dx, 0.0), max_scale_cap)
            angles_list = [angle] * T
            translates_list = [(float(tx_refine), 0.0)] * T
            scales_list = [float(scale_border)] * T
            shears_list = [(shear_x, 0.0)] * T
        elif view_ids in (3, 4):
            start = [[0.0, 0.0], [float(W - 1), 0.0], [float(W - 1), float(H - 1)], [0.0, float(H - 1)]]
            k = persp_strength_side * (dx / max(W, 1.0))
            shift_top = -k * W
            shift_bottom = -0.4 * k * W
            end = [
                [max(0.0, min(W - 1.0, 0.0 + shift_top)), 0.0],
                [max(0.0, min(W - 1.0, (W - 1.0) + shift_top)), 0.0],
                [max(0.0, min(W - 1.0, (W - 1.0) + shift_bottom)), float(H - 1.0)],
                [max(0.0, min(W - 1.0, 0.0 + shift_bottom)), float(H - 1.0)],
            ]
            startpoints_list = [start] * T
            endpoints_list = [end] * T
            scale_border, tx_refine = _minimal_scale_for_persp_then_affine(W, H, start, end, (-0.3 * dx, 0.0), max_scale_cap)
            aff_angles_list = [0.0] * T
            aff_translates_list = [(float(tx_refine), 0.0)] * T
            aff_scales_list = [float(scale_border)] * T
            aff_shears_list = [(0.0, 0.0)] * T
        else:
            angles_list = [0.0] * T
            translates_list = [(0.0, 0.0)] * T
            scales_list = [1.0] * T
            shears_list = [(0.0, 0.0)] * T
    else:
        # 渐变变换：仅对“平移量”做逐帧渐变
        shift_curve_m, _ = _compute_lane_change_curves(num_frames=T, total_shift=shift_val, lane_change_frames=T)
        dx_curve = shift_curve_m * px_per_meter
        total_dx = float(shift_val * px_per_meter)
        sign_total = (1.0 if total_dx > 0 else -1.0) if total_dx != 0.0 else 0.0
        for t in range(T):
            dx = float(dx_curve[t])
            if view_ids == 0:
                angle = -rot_deg_front * sign_total
                shear_x = shear_deg_front * sign_total
                scale_border, tx_refine = _minimal_scale_for_affine(W, H, angle, shear_x, (-dx, 0.0), max_scale_cap)
                angles_list.append(angle)
                translates_list.append((float(tx_refine), 0.0))
                scales_list.append(float(scale_border))
                shears_list.append((shear_x, 0.0))
            elif view_ids in (1, 2):
                angle = -rot_deg_diag * sign_total
                shear_x = shear_deg_diag * sign_total
                scale_border, tx_refine = _minimal_scale_for_affine(W, H, angle, shear_x, (-0.7 * dx, 0.0), max_scale_cap)
                angles_list.append(angle)
                translates_list.append((float(tx_refine), 0.0))
                scales_list.append(float(scale_border))
                shears_list.append((shear_x, 0.0))
            elif view_ids in (3, 4):
                start = [[0.0, 0.0], [float(W - 1), 0.0], [float(W - 1), float(H - 1)], [0.0, float(H - 1)]]
                k = persp_strength_side * (dx / max(W, 1.0))
                shift_top = -k * W
                shift_bottom = -0.4 * k * W
                end = [
                    [max(0.0, min(W - 1.0, 0.0 + shift_top)), 0.0],
                    [max(0.0, min(W - 1.0, (W - 1.0) + shift_top)), 0.0],
                    [max(0.0, min(W - 1.0, (W - 1.0) + shift_bottom)), float(H - 1.0)],
                    [max(0.0, min(W - 1.0, 0.0 + shift_bottom)), float(H - 1.0)],
                ]
                startpoints_list.append(start)
                endpoints_list.append(end)
                scale_border, tx_refine = _minimal_scale_for_persp_then_affine(W, H, start, end, (-0.3 * dx, 0.0), max_scale_cap)
                aff_angles_list.append(0.0)
                aff_translates_list.append((float(tx_refine), 0.0))
                aff_scales_list.append(float(scale_border))
                aff_shears_list.append((0.0, 0.0))
            else:
                angles_list.append(0.0)
                translates_list.append((0.0, 0.0))
                scales_list.append(1.0)
                shears_list.append((0.0, 0.0))

    # 应用变换
    if view_ids in (0, 1, 2):
        out_video = _apply_affine_to_clip(video, angles_list, translates_list, scales_list, shears_list)
    elif view_ids in (3, 4):
        persp_video = _apply_perspective_to_clip(video, startpoints_list, endpoints_list)
        out_video = _apply_affine_to_clip(persp_video, aff_angles_list, aff_translates_list, aff_scales_list, aff_shears_list)
    else:
        out_video = video

    if orig_dtype == torch.uint8:
        out_video = out_video.clamp(0, 255).round().to(torch.uint8).contiguous()
    else:
        out_video = out_video.to(orig_dtype).contiguous()

    return out_video


if __name__ == "__main__":
    """
    Sanity check for the dataset.
    """
    import imageio

    camera_keys = [
        "pinhole_front", 
        "pinhole_front_left", 
        "pinhole_front_right", 
        "pinhole_side_left", 
        "pinhole_side_right",
    ]
    camera_to_view_id = {
        "pinhole_front": 0,
        "pinhole_front_left": 1,
        "pinhole_front_right": 2,
        "pinhole_side_left": 3,
        "pinhole_side_right": 4,
    }
    dataset_instance = L(WaymoMultiviewDataset)(
        dataset_dir="/data/lyy_dataset/waymo_transfer2/training/",
        hint_key="control_input_hdmap_bbox",
        resolution="720",
        state_t=8,
        num_frames=29,
        camera_keys=camera_keys,
        video_size=(704, 1280),
        front_camera_key="pinhole_front",
        camera_to_view_id=camera_to_view_id,
        front_view_caption_only=True,
        is_train=True,
        val_percent=0.001,
    )
    dataset = instantiate(dataset_instance)
    print("finished init dataset")
    
    dataloader = DataLoader(dataset=dataset, batch_size=1, num_workers=0, pin_memory=True, drop_last=True)
    start_index = 0
    subset_indices = list(range(start_index, len(dataset)))
    dataset_subset = Subset(dataset, subset_indices)
    dataloader_vis = torch.utils.data.DataLoader(
        dataset_subset,
        batch_size=1,
        shuffle=False,
        num_workers=16,
        pin_memory=True,
    )
    
    # 创建输出目录
    output_dir = "/data/lyy_dataset/test_transfer/visualize_data"
    os.makedirs(output_dir, exist_ok=True)
    
    # 定义视角顺序：从左到右分别是 side_left, front_left, front, front_right, side_right
    # 对应的原始索引：3, 1, 0, 2, 4
    view_order = [3, 1, 0, 2, 4]  # side_left, front_left, front, front_right, side_right
    num_views = len(camera_keys)
    sequence_length = 29  # num_frames

    def split_views(tensor, num_views=num_views, seq_len=sequence_length):
        """
        将拼接的多视角数据分离成单个视角
        tensor: [B, C, T_total, H, W] 其中 T_total = num_views * seq_len
        返回: list of [B, C, seq_len, H, W] for each view
        """
        B, C, T_total, H, W = tensor.shape
        views = []
        for view_idx in range(num_views):
            start_t = view_idx * seq_len
            end_t = (view_idx + 1) * seq_len
            view_data = tensor[:, :, start_t:end_t, :, :]  # [B, C, seq_len, H, W]
            views.append(view_data)
        return views

    def reorder_views(views_list, order=view_order):
        return [views_list[i] for i in order]

    def concat_views_horizontally(views_list):
        """
        将多个视角横向拼接
        views_list: list of [B, C, T, H, W]
        返回: [B, C, T, H, W_total] 其中 W_total = W * num_views
        """
        return torch.cat(views_list, dim=-1)  # 在宽度维度上拼接

    def save_video(path, frames, fps):
        imageio.mimsave(path, frames, fps=fps)

    write_workers = max(os.cpu_count() // 2, 1)
    max_pending = max(write_workers * 2, 2)
    pending = []

    with ThreadPoolExecutor(max_workers=write_workers) as executor:
        for index, data in enumerate(tqdm(dataloader_vis)):
            # 分离所有数据类型
            video_views = reorder_views(split_views(data['video']))
            sparse_ctrl_views = reorder_views(split_views(data['control_input_sparse_ctrl']))
            sparse_mask_views = reorder_views(split_views((data['control_input_sparse_mask'] * 255.).to(torch.uint8)))
            sparse_video_views = reorder_views(split_views(data['control_input_sparse_video']))
            reference_ctrl_views = reorder_views(split_views(data['control_input_reference_ctrl']))
            reference_mask_views = reorder_views(split_views((data['control_input_reference_mask'] * 255.).to(torch.uint8)))
            reference_video_views = reorder_views(split_views(data['control_input_reference_video']))

            # 构建可视化列表：每种数据类型都是横向展开的5个视角
            visualize_list = [
                concat_views_horizontally(video_views),
                concat_views_horizontally(sparse_ctrl_views),
                concat_views_horizontally(sparse_video_views),
                concat_views_horizontally(sparse_mask_views),
                concat_views_horizontally(reference_ctrl_views),
                concat_views_horizontally(reference_video_views),
                concat_views_horizontally(reference_mask_views),
            ]

            # 在高度维度上拼接所有数据类型
            tensor_visualize = torch.cat(visualize_list, dim=-2)  # [B, C, T, H_total, W_total]

            # 转换为 numpy 格式: [T, H, W, C]
            tensor_visualize = tensor_visualize.squeeze(0).permute(1, 2, 3, 0).cpu().numpy()
            if tensor_visualize.dtype != np.uint8:
                tensor_visualize = np.clip(tensor_visualize, 0, 255).astype(np.uint8)

            video_name = os.path.join(output_dir, f"{str(data['index'].item()).zfill(4)}.mp4")
            future = executor.submit(save_video, video_name, tensor_visualize, 10)
            pending.append(future)
            if len(pending) >= max_pending:
                pending.pop(0).result()

        for future in pending:
            future.result()
