# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os

import imageio
import torch

from cosmos1.models.autoregressive.inference.sjd_world_generation_pipeline import ARSJDVideo2WorldGenerationPipeline
from cosmos1.models.autoregressive.sjd import SJDConfig
from cosmos1.models.autoregressive.utils.inference import add_common_arguments, load_vision_input, validate_args
from cosmos1.utils import log
from cosmos1.utils.io import read_prompts_from_file


def parse_args():
    parser = argparse.ArgumentParser(description="SJD-enabled prompted video to world generation demo script")
    add_common_arguments(parser)
    parser.add_argument(
        "--ar_model_dir",
        type=str,
        default="Cosmos-1.0-Autoregressive-5B-Video2World",
    )
    parser.add_argument(
        "--input_type",
        type=str,
        default="text_and_video",
        choices=["text_and_image", "text_and_video"],
        help="Input types",
    )
    parser.add_argument(
        "--prompt",
        type=str,
        help="Text prompt for generating a single video",
    )
    parser.add_argument(
        "--offload_text_encoder_model",
        action="store_true",
        help="Offload T5 model after inference",
    )
    # SJD specific arguments
    parser.add_argument("--enable_sjd", action="store_true", help="Enable Speculative Jacobi Decoding.")
    parser.add_argument(
        "--max_num_new_tokens", type=int, default=8, help="Maximum number of new tokens to consider in one step."
    )
    parser.add_argument(
        "--multi_token_init_scheme",
        type=str,
        default="random",
        choices=["random", "repeat_horizon"],
        help="Initialization scheme for multi-token prediction.",
    )

    args = parser.parse_args()
    return args


def main(args):
    """Run SJD-enabled prompted video-to-world generation demo.

    This function handles the main video-to-world generation pipeline, including:
    - Setting up the random seed for reproducibility
    - Initializing the generation pipeline with the provided configuration
    - Processing single or multiple prompts/images/videos from input
    - Generating videos from prompts and images/videos
    - Saving the generated videos and corresponding prompts to disk

    Args:
        cfg (argparse.Namespace): Configuration namespace containing:
            - Model configuration (checkpoint paths, model settings)
            - Generation parameters (temperature, top_p)
            - Input/output settings (images/videos, save paths)
            - Performance options (model offloading settings)
            - SJD configurations

    The function will save:
        - Generated MP4 video files

    If guardrails block the generation, a critical log message is displayed
    and the function continues to the next prompt if available.
    """
    inference_type = "video2world"
    sampling_config = validate_args(args, inference_type)

    sjd_config = SJDConfig(
        enable_sjd=args.enable_sjd,
        max_num_new_tokens=args.max_num_new_tokens,
        multi_token_init_scheme=args.multi_token_init_scheme,
    )

    # Initialize prompted base generation model pipeline
    pipeline = ARSJDVideo2WorldGenerationPipeline(
        inference_type=inference_type,
        checkpoint_dir=args.checkpoint_dir,
        checkpoint_name=args.ar_model_dir,
        sjd_config=sjd_config,
        disable_diffusion_decoder=args.disable_diffusion_decoder,
        offload_guardrail_models=args.offload_guardrail_models,
        offload_diffusion_decoder=args.offload_diffusion_decoder,
        offload_network=args.offload_ar_model,
        offload_tokenizer=args.offload_tokenizer,
        offload_text_encoder_model=args.offload_text_encoder_model,
    )

    # Load input image(s) or video(s)
    input_videos = load_vision_input(
        input_type=args.input_type,
        batch_input_path=args.batch_input_path,
        input_image_or_video_path=args.input_image_or_video_path,
        data_resolution=args.data_resolution,
        num_input_frames=args.num_input_frames,
    )
    # Load input prompt(s)
    if args.batch_input_path:
        prompts_list = read_prompts_from_file(args.batch_input_path)
    else:
        prompts_list = [{"visual_input": args.input_image_or_video_path, "prompt": args.prompt}]

    # Iterate through prompts
    for idx, prompt_entry in enumerate(prompts_list):
        video_path = prompt_entry["visual_input"]
        input_filename = os.path.basename(video_path)

        # Check if video exists in loaded videos
        if input_filename not in input_videos:
            log.critical(f"Input file {input_filename} not found, skipping prompt.")
            continue

        inp_vid = input_videos[input_filename]
        inp_prompt = prompt_entry["prompt"]

        # Generate video
        log.info(f"Run with input: {prompt_entry}")
        out_vid = pipeline.generate(
            inp_prompt=inp_prompt,
            inp_vid=inp_vid,
            num_input_frames=args.num_input_frames,
            seed=args.seed,
            sampling_config=sampling_config,
        )
        if out_vid is None:
            log.critical("Guardrail blocked video2world generation.")
            continue

        # Save video
        if args.input_image_or_video_path:
            out_vid_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4")
        else:
            out_vid_path = os.path.join(args.video_save_folder, f"{idx}.mp4")
        imageio.mimwrite(out_vid_path, [frame for frame in out_vid], fps=25)

        log.info(f"Saved video to {out_vid_path}")


if __name__ == "__main__":
    torch._C._jit_set_texpr_fuser_enabled(False)
    args = parse_args()
    main(args) 