# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from contextlib import contextmanager
from typing import Tuple, Union

import einops
import numpy as np
import torch
import torchvision
import torchvision.transforms.functional as transforms_F
from matplotlib import pyplot as plt

from cosmos_predict1.diffusion.training.models.extend_model import ExtendDiffusionModel
from cosmos_predict1.utils import log
from cosmos_predict1.utils.easy_io import easy_io

"""This file contain functions needed for long video generation,
* function `generate_video_from_batch_with_loop` is used by `single_gpu_sep20`

"""


@contextmanager
def switch_config_for_inference(model):
    """For extend model inference, we need to make sure the condition_location is set to "first_n" and apply_corruption_to_condition_region is False.
    This context manager changes the model configuration to the correct settings for inference, and then restores the original settings when exiting the context.
    Args:
        model (ExtendDiffusionModel): video generation model
    """
    # Store the current condition_location
    current_condition_location = model.config.conditioner.video_cond_bool.condition_location
    current_apply_corruption_to_condition_region = (
        model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region
    )
    try:
        log.info(
            "Change the condition_location to 'first_n' for inference, and apply_corruption_to_condition_region to False"
        )
        # Change the condition_location to "first_n" for inference
        model.config.conditioner.video_cond_bool.condition_location = "first_n"
        if current_apply_corruption_to_condition_region == "gaussian_blur":
            model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = "clean"
        elif current_apply_corruption_to_condition_region == "noise_with_sigma":
            model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = "noise_with_sigma_fixed"
        # Yield control back to the calling context
        yield
    finally:
        # Restore the original condition_location after exiting the context
        log.info(
            f"Restore the original condition_location {current_condition_location}, apply_corruption_to_condition_region {current_apply_corruption_to_condition_region}"
        )
        model.config.conditioner.video_cond_bool.condition_location = current_condition_location
        model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region = (
            current_apply_corruption_to_condition_region
        )


def visualize_latent_tensor_bcthw(tensor, nrow=1, show_norm=False, save_fig_path=None):
    """Debug function to display a latent tensor as a grid of images.
    Args:
        tensor (torch.Tensor): tensor in shape BCTHW
        nrow (int): number of images per row
        show_norm (bool): whether to display the norm of the tensor
        save_fig_path (str): path to save the visualization

    """
    log.info(
        f"display latent tensor shape {tensor.shape}, max={tensor.max()}, min={tensor.min()}, mean={tensor.mean()}, std={tensor.std()}"
    )
    tensor = tensor.float().cpu().detach()
    tensor = einops.rearrange(tensor, "b c (t n) h w -> (b t h) (n w) c", n=nrow)  # .numpy()
    # display the grid
    tensor_mean = tensor.mean(-1)
    tensor_norm = tensor.norm(dim=-1)
    log.info(f"tensor_norm, tensor_mean {tensor_norm.shape}, {tensor_mean.shape}")
    plt.figure(figsize=(20, 20))
    plt.imshow(tensor_mean)
    plt.title(f"mean {tensor_mean.mean()}, std {tensor_mean.std()}")
    if save_fig_path:
        os.makedirs(os.path.dirname(save_fig_path), exist_ok=True)
        log.info(f"save to {os.path.abspath(save_fig_path)}")
        plt.savefig(save_fig_path, bbox_inches="tight", pad_inches=0)
    plt.show()
    if show_norm:
        plt.figure(figsize=(20, 20))
        plt.imshow(tensor_norm)
        plt.show()


def visualize_tensor_bcthw(tensor: torch.Tensor, nrow=4, save_fig_path=None):
    """Debug function to display a tensor as a grid of images.
    Args:
        tensor (torch.Tensor): tensor in shape BCTHW
        nrow (int): number of images per row
        save_fig_path (str): path to save the visualization
    """
    log.info(f"display {tensor.shape}, {tensor.max()}, {tensor.min()}")
    assert tensor.max() < 200, f"tensor max {tensor.max()} > 200, the data range is likely wrong"
    tensor = tensor.float().cpu().detach()
    tensor = einops.rearrange(tensor, "b c t h w -> (b t) c h w")
    # use torchvision to save the tensor as a grid of images
    grid = torchvision.utils.make_grid(tensor, nrow=nrow)
    if save_fig_path is not None:
        os.makedirs(os.path.dirname(save_fig_path), exist_ok=True)
        log.info(f"save to {os.path.abspath(save_fig_path)}")
        torchvision.utils.save_image(tensor, save_fig_path)
    # display the grid
    plt.figure(figsize=(20, 20))
    plt.imshow(grid.permute(1, 2, 0))
    plt.show()


def compute_num_frames_condition(model: ExtendDiffusionModel, num_of_latent_overlap: int, downsample_factor=8) -> int:
    """This function computes the number of condition pixel frames given the number of latent frames to overlap.
    Args:
        model (ExtendDiffusionModel): video generation model
        num_of_latent_overlap (int): number of latent frames to overlap
        downsample_factor (int): downsample factor for temporal reduce
    Returns:
        int: number of condition frames in output space
    """
    if getattr(model.vae.video_vae, "is_casual", True):
        # For casual model
        num_frames_condition = (
            num_of_latent_overlap
            // model.vae.video_vae.latent_chunk_duration
            * model.vae.video_vae.pixel_chunk_duration
        )
        if num_of_latent_overlap % model.vae.video_vae.latent_chunk_duration == 1:
            num_frames_condition += 1
        elif num_of_latent_overlap % model.vae.video_vae.latent_chunk_duration > 1:
            num_frames_condition += (
                1 + (num_of_latent_overlap % model.vae.video_vae.latent_chunk_duration - 1) * downsample_factor
            )
    else:
        num_frames_condition = num_of_latent_overlap * downsample_factor

    return num_frames_condition


def read_video_or_image_into_frames_BCTHW(
    input_path: str,
    input_path_format: str = None,
    H: int = None,
    W: int = None,
    normalize: bool = True,
    max_frames: int = -1,
    also_return_fps: bool = False,
) -> torch.Tensor:
    """Read video or image from file and convert it to tensor. The frames will be normalized to [-1, 1].
    Args:
        input_path (str): path to the input video or image, end with .mp4 or .png or .jpg
        H (int): height to resize the video
        W (int): width to resize the video
    Returns:
        torch.Tensor: video tensor in shape (1, C, T, H, W), range [-1, 1]
    """
    log.info(f"Reading video from {input_path}")

    loaded_data = easy_io.load(input_path, file_format=input_path_format, backend_args=None)
    if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"):
        frames = np.array(loaded_data)  # HWC, [0,255]
        if frames.shape[-1] > 3:  # RGBA, set the transparent to white
            # Separate the RGB and Alpha channels
            rgb_channels = frames[..., :3]
            alpha_channel = frames[..., 3] / 255.0  # Normalize alpha channel to [0, 1]

            # Create a white background
            white_bg = np.ones_like(rgb_channels) * 255  # White background in RGB

            # Blend the RGB channels with the white background based on the alpha channel
            frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype(
                np.uint8
            )
        frames = [frames]
        fps = 0
    else:
        frames, meta_data = loaded_data
        fps = int(meta_data.get("fps"))
    if max_frames != -1:
        frames = frames[:max_frames]
    input_tensor = np.stack(frames, axis=0)
    input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w")
    if normalize:
        input_tensor = input_tensor / 128.0 - 1.0
        input_tensor = torch.from_numpy(input_tensor).bfloat16()  # TCHW
        log.info(f"Raw data shape: {input_tensor.shape}")
        if H is not None and W is not None:
            input_tensor = transforms_F.resize(
                input_tensor,
                size=(H, W),  # type: ignore
                interpolation=transforms_F.InterpolationMode.BICUBIC,
                antialias=True,
            )
    input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1)
    if normalize:
        input_tensor = input_tensor.to("cuda")
    log.info(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}")
    if also_return_fps:
        return input_tensor, fps
    return input_tensor


def create_condition_latent_from_input_frames(
    model: ExtendDiffusionModel,
    input_frames: torch.Tensor,
    num_frames_condition: int = 25,
):
    """Create condition latent for video generation. It will take the last num_frames_condition frames from the input frames as condition latent.
    Args:
        model (ExtendDiffusionModel): video generation model
        input_frames (torch.Tensor): video tensor in shape (1,C,T,H,W), range [-1, 1]
        num_frames_condition (int): number of condition frames
    Returns:
        torch.Tensor: condition latent in shape B,C,T,H,W
    """
    B, C, T, H, W = input_frames.shape
    num_frames_encode = (
        model.vae.pixel_chunk_duration
    )  # (model.state_shape[1] - 1) / model.vae.pixel_chunk_duration + 1
    log.info(
        f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}"
    )

    log.info(
        f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}"
    )

    assert (
        input_frames.shape[2] >= num_frames_condition
    ), f"input_frames not enought for condition, require at least {num_frames_condition}, get {input_frames.shape[2]}, {input_frames.shape}"
    assert (
        num_frames_encode >= num_frames_condition
    ), f"num_frames_encode should be larger than num_frames_condition, get {num_frames_encode}, {num_frames_condition}"

    # Put the conditioal frames to the begining of the video, and pad the end with zero
    condition_frames = input_frames[:, :, -num_frames_condition:]
    padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W)
    encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2)

    log.info(
        f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end"
    )
    if hasattr(model, "n_views"):
        encode_input_frames = einops.rearrange(encode_input_frames, "(B V) C T H W -> B C (V T) H W", V=model.n_views)
    latent = model.encode(encode_input_frames)
    return latent, encode_input_frames


def get_condition_latent(
    model: ExtendDiffusionModel,
    conditioned_image_or_video_path: str,
    num_of_latent_condition: int = 4,
    state_shape: list[int] = None,
    input_path_format: str = None,
):
    if state_shape is None:
        state_shape = model.state_shape
    if num_of_latent_condition == 0:
        log.info("No condition latent needed, return empty latent")
        condition_latent = (
            torch.zeros(
                [
                    1,
                ]
                + state_shape
            )
            .to(torch.bfloat16)
            .cuda()
        )
        return condition_latent, None

    H, W = (
        state_shape[-2] * model.vae.spatial_compression_factor,
        state_shape[-1] * model.vae.spatial_compression_factor,
    )
    input_frames = read_video_or_image_into_frames_BCTHW(
        conditioned_image_or_video_path,
        input_path_format=input_path_format,
        H=H,
        W=W,
    )
    num_frames_condition = compute_num_frames_condition(
        model, num_of_latent_condition, downsample_factor=model.vae.temporal_compression_factor
    )

    condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_frames_condition)
    condition_latent = condition_latent.to(torch.bfloat16)
    return condition_latent, input_frames


def generate_video_from_batch_with_loop(
    model: ExtendDiffusionModel,
    state_shape: list[int],
    is_negative_prompt: bool,
    data_batch: dict,
    condition_latent: torch.Tensor,
    # hyper-parameters for inference
    num_of_loops: int,
    num_of_latent_overlap_list: list[int],
    guidance: float,
    num_steps: int,
    seed: int,
    add_input_frames_guidance: bool = False,
    augment_sigma_list: list[float] = None,
    data_batch_list: Union[None, list[dict]] = None,
    visualize: bool = False,
    save_fig_path: str = None,
    skip_reencode: int = 0,
    return_noise: bool = False,
) -> Tuple[np.array, list, list, torch.Tensor] | Tuple[np.array, list, list, torch.Tensor, torch.Tensor]:
    """Generate video with loop, given data batch. The condition latent will be updated at each loop.
    Args:
        model (ExtendDiffusionModel)
        state_shape (list): shape of the state tensor
        is_negative_prompt (bool): whether to use negative prompt

        data_batch (dict): data batch for video generation
        condition_latent (torch.Tensor): condition latent in shape BCTHW

        num_of_loops (int): number of loops to generate video
        num_of_latent_overlap_list (list[int]): list number of latent frames to overlap between clips, different clips can have different overlap
        guidance (float): The guidance scale to use during sample generation; defaults to 5.0.
        num_steps (int): number of steps for diffusion sampling
        seed (int): random seed for sampling
        add_input_frames_guidance (bool): whether to add image guidance, default is False
        augment_sigma_list (list): list of sigma value for the condition corruption at different clip, used when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed". default is None

        data_batch_list (list): list of data batch for video generation, used when num_of_loops >= 1, to support multiple prompts in auto-regressive generation. default is None
        visualize (bool): whether to visualize the latent and grid, default is False
        save_fig_path (str): path to save the visualization, default is None

        skip_reencode (int): whether to skip re-encode the input frames, default is 0
        return_noise (bool): whether to return the initial noise used for sampling, used for ODE pairs generation. Default is False
    Returns:
        np.array: generated video in shape THWC, range [0, 255]
        list: list of condition latent, each in shape BCTHW
        list: list of sample latent, each in shape BCTHW
        torch.Tensor: initial noise used for sampling, shape BCTHW (if return_noise is True)
    """

    if data_batch_list is None:
        data_batch_list = [data_batch for _ in range(num_of_loops)]
    if visualize:
        assert save_fig_path is not None, "save_fig_path should be set when visualize is True"

    # Generate video with loop
    condition_latent_list = []
    decode_latent_list = []  # list collect the latent token to be decoded at the end
    sample_latent = []
    grid_list = []

    for i in range(num_of_loops):
        num_of_latent_overlap_i = num_of_latent_overlap_list[i]
        num_of_latent_overlap_i_plus_1 = (
            num_of_latent_overlap_list[i + 1]
            if i + 1 < len(num_of_latent_overlap_list)
            else num_of_latent_overlap_list[-1]
        )
        if condition_latent.shape[2] < state_shape[1]:
            # Padding condition latent to state shape
            log.info(f"Padding condition latent {condition_latent.shape} to state shape {state_shape}")
            b, c, t, h, w = condition_latent.shape
            condition_latent = torch.cat(
                [
                    condition_latent,
                    condition_latent.new_zeros(b, c, state_shape[1] - t, h, w),
                ],
                dim=2,
            ).contiguous()
            log.info(f"after padding, condition latent shape {condition_latent.shape}")
        log.info(f"Generate video loop {i} / {num_of_loops}")
        if visualize:
            log.info(f"Visualize condition latent {i}")
            visualize_latent_tensor_bcthw(
                condition_latent[:, :, :4].float(),
                nrow=4,
                save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_condition_latent_first_4.png"),
            )  # BCTHW

        condition_latent_list.append(condition_latent)

        augment_sigma_list = (
            model.config.conditioner.video_cond_bool.apply_corruption_to_condition_region_sigma_value
            if augment_sigma_list is None
            else augment_sigma_list
        )
        if i < len(augment_sigma_list):
            condition_video_augment_sigma_in_inference = augment_sigma_list[i]
            log.info(f"condition_video_augment_sigma_in_inference {condition_video_augment_sigma_in_inference}")
        else:
            condition_video_augment_sigma_in_inference = augment_sigma_list[-1]
        assert not add_input_frames_guidance, "add_input_frames_guidance should be False, not supported"
        sample = model.generate_samples_from_batch(
            data_batch_list[i],
            guidance=guidance,
            state_shape=state_shape,
            num_steps=num_steps,
            is_negative_prompt=is_negative_prompt,
            seed=seed + i,
            condition_latent=condition_latent,
            num_condition_t=num_of_latent_overlap_i,
            condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
            return_noise=return_noise,
        )

        if return_noise:
            sample, noise = sample

        if visualize:
            log.info(f"Visualize sampled latent {i} 4-8 frames")
            visualize_latent_tensor_bcthw(
                sample[:, :, 4:8].float(),
                nrow=4,
                save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_sample_latent_last_4.png"),
            )  # BCTHW

            diff_between_sample_and_condition = (sample - condition_latent)[:, :, :num_of_latent_overlap_i]
            log.info(
                f"Visualize diff between sample and condition latent {i} first 4 frames {diff_between_sample_and_condition.mean()}"
            )

        sample_latent.append(sample)
        T = condition_latent.shape[2]
        assert num_of_latent_overlap_i <= T, f"num_of_latent_overlap should be < T, get {num_of_latent_overlap_i}, {T}"

        if model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i:
            assert skip_reencode, "skip_reencode should be turned on when sample_tokens_start_from_p_or_i is True"
            if i == 0:
                decode_latent_list.append(sample)
            else:
                decode_latent_list.append(sample[:, :, num_of_latent_overlap_i:])
        else:
            grid_BCTHW = (1.0 + model.decode(sample)).clamp(0, 2) / 2  # [B, 3, T, H, W], [0, 1]

            if visualize:
                log.info(f"Visualize grid {i}")
                visualize_tensor_bcthw(
                    grid_BCTHW.float(), nrow=5, save_fig_path=os.path.join(save_fig_path, f"loop_{i:02d}_grid.png")
                )
            grid_np_THWC = (
                (grid_BCTHW[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8)
            )  # THW3, range [0, 255]

            # Post-process the output: cut the conditional frames from the output if it's not the first loop
            num_cond_frames = compute_num_frames_condition(
                model, num_of_latent_overlap_i_plus_1, downsample_factor=model.vae.temporal_compression_factor
            )
            if i == 0:
                new_grid_np_THWC = grid_np_THWC  # First output, dont cut the conditional frames
            else:
                new_grid_np_THWC = grid_np_THWC[
                    num_cond_frames:
                ]  # Remove the conditional frames from the output, since it's overlapped with previous loop
            grid_list.append(new_grid_np_THWC)

            # Prepare the next loop: re-compute the condition latent
            if hasattr(model, "n_views"):
                grid_BCTHW = einops.rearrange(grid_BCTHW, "B C (V T) H W -> (B V) C T H W", V=model.n_views)
            condition_frame_input = grid_BCTHW[:, :, -num_cond_frames:] * 2 - 1  # BCTHW, range [0, 1] to [-1, 1]
        if skip_reencode:
            # Use the last num_of_latent_overlap latent token as condition latent
            log.info(f"Skip re-encode the condition frames, use the last {num_of_latent_overlap_i_plus_1} latent token")
            condition_latent = sample[:, :, -num_of_latent_overlap_i_plus_1:]
        else:
            # Re-encode the condition frames to get the new condition latent
            condition_latent, _ = create_condition_latent_from_input_frames(
                model, condition_frame_input, num_frames_condition=num_cond_frames
            )  # BCTHW
        condition_latent = condition_latent.to(torch.bfloat16)

    # save videos
    if model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i:
        # decode all video together
        decode_latent_list = torch.cat(decode_latent_list, dim=2)
        grid_BCTHW = (1.0 + model.decode(decode_latent_list)).clamp(0, 2) / 2  # [B, 3, T, H, W], [0, 1]
        video_THWC = (
            (grid_BCTHW[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy().astype(np.uint8)
        )  # THW3, range [0, 255]
    else:
        video_THWC = np.concatenate(grid_list, axis=0)  # THW3, range [0, 255]

    if return_noise:
        return video_THWC, condition_latent_list, sample_latent, noise
    return video_THWC, condition_latent_list, sample_latent
