#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# #
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Sample the auto-multiview model directly on `WaymoMultiviewDataset`.

This script mirrors the overall flow of `examples/multiview.py`, but instead of
reading JSON parameter files, it:

- Builds a `WaymoMultiviewDataset` (typically on the validation split).
- Iterates over a few samples.
- Runs the trained multiview model via `ControlVideo2WorldInference`.
- Saves generated multi-view videos to an output directory.
"""

from pathlib import Path
from typing import Annotated, Literal

import numpy as np
import pydantic
import torch
import tyro
from cosmos_oss.init import cleanup_environment, init_environment, init_output_dir

from cosmos_transfer2.config import (
    ResolvedDirectoryPath,
    handle_tyro_exception,
    is_rank0,
)
from cosmos_transfer2.multiview_config import MultiviewSetupArguments
from cosmos_transfer2._src.imaginaire.utils import log
from cosmos_transfer2._src.imaginaire.visualize.video import save_img_or_video
from cosmos_transfer2._src.predict2.models.video2world_model import NUM_CONDITIONAL_FRAMES_KEY
from cosmos_transfer2._src.transfer2.datasets.local_datasets.waymo_multiview_dataset import (
    WaymoMultiviewDataset,
)
from cosmos_transfer2._src.transfer2_multiview.inference.inference import ControlVideo2WorldInference


CONTROL_WEIGHT_KEY = "control_weight"


class WaymoSampleArgs(pydantic.BaseModel):
    model_config = pydantic.ConfigDict(extra="forbid", frozen=True)

    # Dataset-related arguments
    dataset_dir: Annotated[
        ResolvedDirectoryPath,
        tyro.conf.arg(aliases=("-d",), help="Root directory of the Waymo multiview dataset."),
    ]
    split: Literal["train", "val"] = "val"
    """Which subset to use: training split (train) or validation split (val)."""

    num_samples: int = 1000
    """Number of samples to generate from the chosen split."""
    num_workers: int = 2
    """Number of dataloader workers."""

    # Sampling-related arguments
    guidance: int = 3
    """Guidance scale used by the diffusion sampler."""
    num_conditional_frames: int = 0
    """Number of conditional frames; 0 means unconditional generation."""
    control_weight: float = 1.0
    """Control weight passed into the model."""
    seed: int = 0
    """Random seed used inside the model sampler."""

    fps: int | None = None
    """FPS of the output video. If None, use dataset FPS."""

    # Model / checkpoint setup
    setup: MultiviewSetupArguments
    """Multiview model setup arguments (experiment, checkpoint, output_dir, etc.)."""


def build_waymo_dataset(args: WaymoSampleArgs) -> WaymoMultiviewDataset:
    """
    Construct a `WaymoMultiviewDataset` instance for either train or val split.

    The configuration mirrors the training setup used in
    `get_hdmap_multiview_waymo_dataset` and `WaymoMultiviewDataset.__main__`.
    """
    camera_keys = [
        "pinhole_front",
        "pinhole_front_left",
        "pinhole_front_right",
        "pinhole_side_left",
        "pinhole_side_right",
    ]
    camera_to_view_id = {
        "pinhole_front": 0,
        "pinhole_front_left": 1,
        "pinhole_front_right": 2,
        "pinhole_side_left": 3,
        "pinhole_side_right": 4,
    }

    is_train = args.split == "train"

    dataset = WaymoMultiviewDataset(
        dataset_dir=str(args.dataset_dir),
        resolution="720",
        state_t=8,
        num_frames=29,
        sequence_interval=1,
        camera_keys=camera_keys,
        video_size=(704, 1280),
        front_camera_key="pinhole_front",
        camera_to_view_id=camera_to_view_id,
        front_view_caption_only=True,
        is_train=is_train,
        # Use default val_percent in the dataset implementation (0.02) for splitting
    )

    log.info(str(dataset))
    return dataset


def main(args: WaymoSampleArgs) -> None:
    # Prepare output directory (and save config on rank0)
    init_output_dir(args.setup.output_dir, profile=args.setup.profile)

    # Build inference pipeline (loads model + config from checkpoint & experiment)
    pipe = ControlVideo2WorldInference(
        experiment_name=args.setup.experiment,
        ckpt_path=args.setup.checkpoint_path,
        context_parallel_size=args.setup.context_parallel_size,
    )

    # Build dataset & dataloader
    dataset = build_waymo_dataset(args)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=False,
    )

    if len(dataloader) == 0:
        raise ValueError(f"No samples found in dataset_dir={args.dataset_dir}")

    n_to_generate = min(args.num_samples, len(dataloader))
    log.info(f"Generating {n_to_generate} samples from split='{args.split}'")

    generated = 0
    for batch_idx, batch in enumerate(dataloader):
        if generated >= n_to_generate:
            break

        # Ensure the control / conditioning keys are present in the batch
        batch[NUM_CONDITIONAL_FRAMES_KEY] = args.num_conditional_frames
        batch[CONTROL_WEIGHT_KEY] = args.control_weight

        # Derive a prompt from dataset captions (WaymoMultiviewDataset sets `ai_caption`)
        prompt_list = batch.get("ai_caption", [""])
        if isinstance(prompt_list, list):
            prompt = str(prompt_list[0])
        else:
            prompt = str(prompt_list)
        log.info(f"[{generated + 1}/{n_to_generate}] Batch {batch_idx}, prompt: {prompt!r}")

        # Run the model
        video = pipe.generate_from_batch(
            batch,
            guidance=args.guidance,
            seed=args.seed,
        )
        # video: [B=1, C, T, H, W] in [-1, 1]
        video = (1.0 + video[0]) / 2.0  # [C, T, H, W] in [0, 1]

        # Determine FPS
        if args.fps is not None:
            fps = args.fps
        else:
            # Dataset stores FPS as scalar or tensor; normalize to python int
            fps_val = batch.get("fps", 10)
            if isinstance(fps_val, torch.Tensor):
                fps = int(fps_val.flatten()[0].item())
            elif isinstance(fps_val, (list, tuple)):
                fps = int(fps_val[0])
            else:
                fps = int(fps_val)

        # Save output video
        split_tag = "train" if args.split == "train" else "val"
        index_str = (
            f"{int(batch['index']) if 'index' in batch else batch_idx:06d}"
            if not isinstance(batch.get("index", None), torch.Tensor)
            else f"{int(batch['index'].flatten()[0].item()):06d}"
        )
        out_stem = f"waymo_{split_tag}_{index_str}"
        out_path = args.setup.output_dir / out_stem

        # Convert to uint8 for saving
        video_uint8 = (video.clamp(0.0, 1.0) * 255.0).to(torch.uint8)
        save_img_or_video(video_uint8, str(out_path), fps=fps)
        log.success(f"Saved sample to {out_path}.mp4 (fps={fps})")

        # Create grid video: split views in time dimension and concat horizontally
        # video shape: [C, T, H, W] where T = num_views * frames_per_view
        # Original order: front(0), front_left(1), front_right(2), side_left(3), side_right(4)
        # Target order: side_left(3), front_left(1), front(0), front_right(2), side_right(4)
        num_views =batch.get("sample_n_views")
        num_frames_per_view = batch.get("num_video_frames_per_view")
        if isinstance(num_frames_per_view, torch.Tensor):
            num_views = int(num_views.item())
            num_frames_per_view = int(num_frames_per_view.item())
        else:
            num_views = int(num_views)
            num_frames_per_view = int(num_frames_per_view)
        
        C, T, H, W = video_uint8.shape
        assert T == num_views * num_frames_per_view, f"Expected T={num_views * num_frames_per_view}, got T={T}"
        video_views = video_uint8.view(C, num_views, num_frames_per_view, H, W)
        reorder_indices = [3, 1, 0, 2, 4]
        video_views_reordered = video_views[:, reorder_indices, :, :, :]
        video_views_reordered = video_views_reordered.permute(0, 2, 1, 3, 4)
        view_list = [video_views_reordered[:, :, i, :, :] for i in range(num_views)]
        video_grid = torch.cat(view_list, dim=-1)
        grid_out_path = args.setup.output_dir / f"{out_stem}_grid"
        save_img_or_video(video_grid, str(grid_out_path), fps=fps)
        log.success(f"Saved grid sample to {grid_out_path}.mp4 (fps={fps})")

        generated += 1

    # Clean up distributed state if needed
    pipe.cleanup()


if __name__ == "__main__":
    init_environment()
    try:
        args = tyro.cli(
            WaymoSampleArgs,
            description=__doc__,
            console_outputs=is_rank0(),
            config=(tyro.conf.OmitArgPrefixes,),
        )
    except Exception as e:
        handle_tyro_exception(e)

    # pyrefly: ignore  # unbound-name
    main(args)
    cleanup_environment()


