# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for the inference libraries."""

import os
from glob import glob
from typing import Any, Callable, Optional, Union, Tuple, List, Dict

import mediapy as media
import numpy as np
import torch
from PIL import Image

from models.cosmos_tokenizer.networks import TokenizerModels

_DTYPE, _DEVICE = torch.bfloat16, "cuda"
_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max)
_SPATIAL_ALIGN = 16
_TEMPORAL_ALIGN = 8


def load_model(
    jit_filepath: str = None,
    tokenizer_config: Dict[str, Any] = None,
    device: str = "cuda",
) -> Union[torch.nn.Module, torch.jit.ScriptModule]:
    """Loads a torch.nn.Module from a filepath.

    Args:
        jit_filepath: The filepath to the JIT-compiled model.
        device: The device to load the model onto, default=cuda.
    Returns:
        The JIT compiled model loaded to device and on eval mode.
    """
    if tokenizer_config is None:
        return load_jit_model(jit_filepath, device)
    full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device)
    full_model.load_state_dict(ckpts.state_dict(), strict=False)
    return full_model.eval().to(device)


def load_encoder_model(
    jit_filepath: str = None,
    tokenizer_config: Dict[str, Any] = None,
    device: str = "cuda",
) -> Union[torch.nn.Module, torch.jit.ScriptModule]:
    """Loads a torch.nn.Module from a filepath.

    Args:
        jit_filepath: The filepath to the JIT-compiled model.
        device: The device to load the model onto, default=cuda.
    Returns:
        The JIT compiled model loaded to device and on eval mode.
    """
    if tokenizer_config is None:
        return load_jit_model(jit_filepath, device)
    full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device)
    encoder_model = full_model.encoder_jit()
    if jit_filepath.endswith(".jit"):
        encoder_model.load_state_dict(ckpts.state_dict(), strict=False)
    else:
        encoder_model.load_state_dict(ckpts, strict=False)
    encoder_model.load_state_dict(ckpts, strict=False)
    return encoder_model.eval().to(device)


def load_decoder_model(
    jit_filepath: str = None,
    tokenizer_config: Dict[str, Any] = None,
    device: str = "cuda",
) -> Union[torch.nn.Module, torch.jit.ScriptModule]:
    """Loads a torch.nn.Module from a filepath.

    Args:
        jit_filepath: The filepath to the JIT-compiled model.
        device: The device to load the model onto, default=cuda.
    Returns:
        The JIT compiled model loaded to device and on eval mode.
    """
    if tokenizer_config is None:
        return load_jit_model(jit_filepath, device)
    full_model, ckpts = _load_pytorch_model(jit_filepath, tokenizer_config, device)
    decoder_model = full_model.decoder_jit()
    if jit_filepath.endswith(".jit"):
        decoder_model.load_state_dict(ckpts.state_dict(), strict=False)
    else:
        decoder_model.load_state_dict(ckpts, strict=False)
    return decoder_model.eval().to(device)


def _load_pytorch_model(
    jit_filepath: str = None, tokenizer_config: str = None, device: str = "cuda"
) -> torch.nn.Module:
    """Loads a torch.nn.Module from a filepath.

    Args:
        jit_filepath: The filepath to the JIT-compiled model.
        device: The device to load the model onto, default=cuda.
    Returns:
        The JIT compiled model loaded to device and on eval mode.
    """
    tokenizer_name = tokenizer_config["name"]
    model = TokenizerModels[tokenizer_name].value(**tokenizer_config)
    if jit_filepath.endswith(".jit"):
        ckpts = torch.jit.load(jit_filepath)
    elif jit_filepath.endswith(".pth"):
        ckpts = torch.load(jit_filepath, map_location=device)
    else:
        raise ValueError(f"Invalid model file extension: {jit_filepath}")
    return model, ckpts


def load_jit_model(
    jit_filepath: str = None, device: str = "cuda"
) -> torch.jit.ScriptModule:
    """Loads a torch.jit.ScriptModule from a filepath.

    Args:
        jit_filepath: The filepath to the JIT-compiled model.
        device: The device to load the model onto, default=cuda.
    Returns:
        The JIT compiled model loaded to device and on eval mode.
    """
    model = torch.jit.load(jit_filepath)
    return model.eval().to(device)


def save_jit_model(
    model: Union[torch.jit.ScriptModule, torch.jit.RecursiveScriptModule] = None,
    jit_filepath: str = None,
) -> None:
    """Saves a torch.jit.ScriptModule or torch.jit.RecursiveScriptModule to file.

    Args:
        model: JIT compiled model loaded onto `config.checkpoint.jit.device`.
        jit_filepath: The filepath to the JIT-compiled model.
    """
    torch.jit.save(model, jit_filepath)


def get_filepaths(input_pattern) -> List[str]:
    """Returns a list of filepaths from a pattern."""
    filepaths = sorted(glob(str(input_pattern)))
    return list(set(filepaths))


def get_output_filepath(filepath: str, output_dir: str = None) -> str:
    """Returns the output filepath for the given input filepath."""
    output_dir = output_dir or f"{os.path.dirname(filepath)}/reconstructions"
    output_filepath = f"{output_dir}/{os.path.basename(filepath)}"
    os.makedirs(output_dir, exist_ok=True)
    return output_filepath


def read_image(filepath: str) -> np.ndarray:
    """Reads an image from a filepath.

    Args:
        filepath: The filepath to the image.

    Returns:
        The image as a numpy array, layout HxWxC, range [0..255], uint8 dtype.
    """
    image = media.read_image(filepath)
    # convert the grey scale image to RGB
    # since our tokenizers always assume 3-channel RGB image
    if image.ndim == 2:
        image = np.stack([image] * 3, axis=-1)
    # convert RGBA to RGB
    if image.shape[-1] == 4:
        image = image[..., :3]
    return image


def read_video(filepath: str) -> np.ndarray:
    """Reads a video from a filepath.

    Args:
        filepath: The filepath to the video.
    Returns:
        The video as a numpy array, layout TxHxWxC, range [0..255], uint8 dtype.
    """
    video = media.read_video(filepath)
    # convert the grey scale frame to RGB
    # since our tokenizers always assume 3-channel video
    if video.ndim == 3:
        video = np.stack([video] * 3, axis=-1)
    # convert RGBA to RGB
    if video.shape[-1] == 4:
        video = video[..., :3]
    return video


def resize_image(image: np.ndarray, short_size: int = None) -> np.ndarray:
    """Resizes an image to have the short side of `short_size`.

    Args:
        image: The image to resize, layout HxWxC, of any range.
        short_size: The size of the short side.
    Returns:
        The resized image.
    """
    if short_size is None:
        return image
    height, width = image.shape[-3:-1]
    if height <= width:
        height_new, width_new = short_size, int(width * short_size / height + 0.5)
        width_new = width_new if width_new % 2 == 0 else width_new + 1
    else:
        height_new, width_new = (
            int(height * short_size / width + 0.5),
            short_size,
        )
        height_new = height_new if height_new % 2 == 0 else height_new + 1
    return media.resize_image(image, shape=(height_new, width_new))


def resize_video(video: np.ndarray, short_size: int = None) -> np.ndarray:
    """Resizes a video to have the short side of `short_size`.

    Args:
        video: The video to resize, layout TxHxWxC, of any range.
        short_size: The size of the short side.
    Returns:
        The resized video.
    """
    if short_size is None:
        return video
    height, width = video.shape[-3:-1]
    if height <= width:
        height_new, width_new = short_size, int(width * short_size / height + 0.5)
        width_new = width_new if width_new % 2 == 0 else width_new + 1
    else:
        height_new, width_new = (
            int(height * short_size / width + 0.5),
            short_size,
        )
        height_new = height_new if height_new % 2 == 0 else height_new + 1
    return media.resize_video(video, shape=(height_new, width_new))


def write_image(filepath: str, image: np.ndarray):
    """Writes an image to a filepath."""
    return media.write_image(filepath, image)


def write_video(filepath: str, video: np.ndarray, fps: int = 24) -> None:
    """Writes a video to a filepath."""
    return media.write_video(filepath, video, fps=fps)


def numpy2tensor(
    input_image: np.ndarray,
    dtype: torch.dtype = _DTYPE,
    device: str = _DEVICE,
    range_min: int = -1,
) -> torch.Tensor:
    """Converts image(dtype=np.uint8) to `dtype` in range [0..255].

    Args:
        input_image: A batch of images in range [0..255], BxHxWx3 layout.
    Returns:
        A torch.Tensor of layout Bx3xHxW in range [-1..1], dtype.
    """
    ndim = input_image.ndim
    indices = list(range(1, ndim))[-1:] + list(range(1, ndim))[:-1]
    image = input_image.transpose((0,) + tuple(indices)) / _UINT8_MAX_F
    if range_min == -1:
        image = 2.0 * image - 1.0
    return torch.from_numpy(image).to(dtype).to(device)


def tensor2numpy(input_tensor: torch.Tensor, range_min: int = -1) -> np.ndarray:
    """Converts tensor in [-1,1] to image(dtype=np.uint8) in range [0..255].

    Args:
        input_tensor: Input image tensor of Bx3xHxW layout, range [-1..1].
    Returns:
        A numpy image of layout BxHxWx3, range [0..255], uint8 dtype.
    """
    if range_min == -1:
        input_tensor = (input_tensor.float() + 1.0) / 2.0
    ndim = input_tensor.ndim
    output_image = input_tensor.clamp(0, 1).cpu().numpy()
    output_image = output_image.transpose((0,) + tuple(range(2, ndim)) + (1,))
    return (output_image * _UINT8_MAX_F + 0.5).astype(np.uint8)


def pad_image_batch(
    batch: np.ndarray, spatial_align: int = _SPATIAL_ALIGN
) -> Tuple[np.ndarray, List[int]]:
    """Pads a batch of images to be divisible by `spatial_align`.

    Args:
        batch: The batch of images to pad, layout BxHxWx3, in any range.
        align: The alignment to pad to.
    Returns:
        The padded batch and the crop region.
    """
    height, width = batch.shape[1:3]
    align = spatial_align
    height_to_pad = (align - height % align) if height % align != 0 else 0
    width_to_pad = (align - width % align) if width % align != 0 else 0

    crop_region = [
        height_to_pad >> 1,
        width_to_pad >> 1,
        height + (height_to_pad >> 1),
        width + (width_to_pad >> 1),
    ]
    batch = np.pad(
        batch,
        (
            (0, 0),
            (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
            (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)),
            (0, 0),
        ),
        mode="constant",
    )
    return batch, crop_region


def pad_video_batch(
    batch: np.ndarray,
    temporal_align: int = _TEMPORAL_ALIGN,
    spatial_align: int = _SPATIAL_ALIGN,
) -> Tuple[np.ndarray, List[int]]:
    """Pads a batch of videos to be divisible by `temporal_align` or `spatial_align`.

    Zero pad spatially. Reflection pad temporally to handle causality better.
    Args:
        batch: The batch of videos to pad., layout BxFxHxWx3, in any range.
        align: The alignment to pad to.
    Returns:
        The padded batch and the crop region.
    """
    num_frames, height, width = batch.shape[-4:-1]
    align = spatial_align
    height_to_pad = (align - height % align) if height % align != 0 else 0
    width_to_pad = (align - width % align) if width % align != 0 else 0

    align = temporal_align
    frames_to_pad = (
        (align - (num_frames - 1) % align) if (num_frames - 1) % align != 0 else 0
    )

    crop_region = [
        frames_to_pad >> 1,
        height_to_pad >> 1,
        width_to_pad >> 1,
        num_frames + (frames_to_pad >> 1),
        height + (height_to_pad >> 1),
        width + (width_to_pad >> 1),
    ]
    batch = np.pad(
        batch,
        (
            (0, 0),
            (0, 0),
            (height_to_pad >> 1, height_to_pad - (height_to_pad >> 1)),
            (width_to_pad >> 1, width_to_pad - (width_to_pad >> 1)),
            (0, 0),
        ),
        mode="constant",
    )
    batch = np.pad(
        batch,
        (
            (0, 0),
            (frames_to_pad >> 1, frames_to_pad - (frames_to_pad >> 1)),
            (0, 0),
            (0, 0),
            (0, 0),
        ),
        mode="edge",
    )
    return batch, crop_region


def unpad_video_batch(batch: np.ndarray, crop_region: List[int]) -> np.ndarray:
    """Unpads video with `crop_region`.

    Args:
        batch: A batch of numpy videos, layout BxFxHxWxC.
        crop_region: [f1,y1,x1,f2,y2,x2] first, top, left, last, bot, right crop indices.

    Returns:
        np.ndarray: Cropped numpy video, layout BxFxHxWxC.
    """
    assert len(crop_region) == 6, "crop_region should be len of 6."
    f1, y1, x1, f2, y2, x2 = crop_region
    return batch[..., f1:f2, y1:y2, x1:x2, :]


def unpad_image_batch(batch: np.ndarray, crop_region: List[int]) -> np.ndarray:
    """Unpads image with `crop_region`.

    Args:
        batch: A batch of numpy images, layout BxHxWxC.
        crop_region: [y1,x1,y2,x2] top, left, bot, right crop indices.

    Returns:
        np.ndarray: Cropped numpy image, layout BxHxWxC.
    """
    assert len(crop_region) == 4, "crop_region should be len of 4."
    y1, x1, y2, x2 = crop_region
    return batch[..., y1:y2, x1:x2, :]
