# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import os
from abc import ABC, abstractmethod
from typing import Optional

import torch
import torch.nn.functional as F

from cosmos_transfer2._src.common.utils.s3_utils import load_from_s3_with_cache
from cosmos_transfer2._src.imaginaire.utils.distributed import rank0_first
from cosmos_transfer2._src.imaginaire.utils.env_parsers.cred_env_parser import CRED_ENVS


class BaseVAE(torch.nn.Module, ABC):
    """
    Abstract base class for a Variational Autoencoder (VAE).

    All subclasses should implement the methods to define the behavior for encoding
    and decoding, along with specifying the latent channel size.
    """

    def __init__(self, channel: int = 3, name: str = "vae"):
        super().__init__()
        self.channel = channel
        self.name = name

    @property
    def latent_ch(self) -> int:
        """
        Returns the number of latent channels in the VAE.
        """
        return self.channel

    @abstractmethod
    def encode(self, state: torch.Tensor) -> torch.Tensor:
        """
        Encodes the input tensor into a latent representation.

        Args:
        - state (torch.Tensor): The input tensor to encode.

        Returns:
        - torch.Tensor: The encoded latent tensor.
        """
        pass

    @abstractmethod
    def decode(self, latent: torch.Tensor) -> torch.Tensor:
        """
        Decodes the latent representation back to the original space.

        Args:
        - latent (torch.Tensor): The latent tensor to decode.

        Returns:
        - torch.Tensor: The decoded tensor.
        """
        pass

    @property
    def spatial_compression_factor(self) -> int:
        """
        Returns the spatial reduction factor for the VAE.
        """
        raise NotImplementedError("The spatial_compression_factor property must be implemented in the derived class.")


class BasePretrainedImageVAE(BaseVAE):
    """
    A base class for pretrained Variational Autoencoder (VAE) that loads mean and standard deviation values
    from a remote store, handles data type conversions, and normalization
    using provided mean and standard deviation values for latent space representation.
    Derived classes should load pre-trained encoder and decoder components from a remote store

    Attributes:
        latent_mean (Tensor): The mean used for normalizing the latent representation.
        latent_std (Tensor): The standard deviation used for normalizing the latent representation.
        dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.

    Args:
        mean_std_fp (str): File path to the pickle file containing mean and std of the latent space.
        latent_ch (int, optional): Number of latent channels (default is 16).
        is_image (bool, optional): Flag to indicate whether the output is an image (default is True).
        is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
    """

    def __init__(
        self,
        name: str,
        mean_std_fp: str,
        latent_ch: int = 16,
        is_image: bool = True,
        is_bf16: bool = True,
        s3_credential_path=Optional[str],
        load_mean_std: bool = True,
    ) -> None:
        super().__init__(latent_ch, name)
        dtype = torch.bfloat16 if is_bf16 else torch.float32
        self.dtype = dtype
        self.is_image = is_image
        self.mean_std_fp = mean_std_fp
        self.name = name
        self.load_mean_std = load_mean_std

        assert s3_credential_path is None or isinstance(s3_credential_path, str)
        if s3_credential_path is None:
            self.backend_args = None
        elif os.path.exists(s3_credential_path) or CRED_ENVS.APP_ENV in ["prod", "dev", "stg"]:
            self.backend_args = {
                "backend": "s3",
                "path_mapping": None,
                "s3_credential_path": s3_credential_path,
            }
        else:
            raise FileNotFoundError(f"Invalid s3_credential_path: {s3_credential_path} and APP_ENV is not prod/dev/stg")

        self.register_mean_std(mean_std_fp)

    def register_mean_std(self, mean_std_fp: str) -> None:
        target_shape = [1, self.latent_ch, 1, 1] if self.is_image else [1, self.latent_ch, 1, 1, 1]

        if self.load_mean_std:
            extention = mean_std_fp.split(".")[-1]
            latent_mean, latent_std = load_from_s3_with_cache(
                mean_std_fp,
                f"vae/{self.name}_mean_std.{extention}",
                easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())},
                backend_args=self.backend_args,
            )
            self.register_buffer(
                "latent_mean",
                latent_mean.to(self.dtype).reshape(*target_shape),
                persistent=False,
            )
            self.register_buffer(
                "latent_std",
                latent_std.to(self.dtype).reshape(*target_shape),
                persistent=False,
            )
        else:
            # Use zeros for mean and ones for std when load_mean_std=False
            device = torch.device(torch.cuda.current_device()) if torch.cuda.is_available() else torch.device("cpu")
            self.register_buffer(
                "latent_mean",
                torch.zeros(*target_shape, dtype=self.dtype, device=device),
                persistent=False,
            )
            self.register_buffer(
                "latent_std",
                torch.ones(*target_shape, dtype=self.dtype, device=device),
                persistent=False,
            )

    @torch.no_grad()
    def encode(self, state: torch.Tensor) -> torch.Tensor:
        """
        Encode the input state to latent space; also handle the dtype conversion, mean and std scaling
        """
        in_dtype = state.dtype
        latent_mean = self.latent_mean.to(in_dtype)
        latent_std = self.latent_std.to(in_dtype)
        encoded_state = self.encoder(state.to(self.dtype))
        if isinstance(encoded_state, torch.Tensor):
            pass
        elif isinstance(encoded_state, tuple):
            assert isinstance(encoded_state[0], torch.Tensor)
            encoded_state = encoded_state[0]
        else:
            raise ValueError("Invalid type of encoded state")
        return (encoded_state.to(in_dtype) - latent_mean) / latent_std

    @torch.no_grad()
    def decode(self, latent: torch.Tensor) -> torch.Tensor:
        """
        Decode the input latent to state; also handle the dtype conversion, mean and std scaling
        """
        in_dtype = latent.dtype
        latent = latent * self.latent_std.to(in_dtype) + self.latent_mean.to(in_dtype)
        return self.decoder(latent.to(self.dtype)).to(in_dtype)

    def reset_dtype(self, *args, **kwargs):
        """
        Resets the data type of the encoder and decoder to the model's default data type.

        Args:
            *args, **kwargs: Unused, present to allow flexibility in method calls.
        """
        del args, kwargs
        self.decoder.to(self.dtype)
        self.encoder.to(self.dtype)


class JITVAE(BasePretrainedImageVAE):
    """
    A JIT compiled Variational Autoencoder (VAE) that loads pre-trained encoder
    and decoder components from a remote store, handles data type conversions, and normalization
    using provided mean and standard deviation values for latent space representation.

    Attributes:
        encoder (Module): The JIT compiled encoder loaded from storage.
        decoder (Module): The JIT compiled decoder loaded from storage.
        latent_mean (Tensor): The mean used for normalizing the latent representation.
        latent_std (Tensor): The standard deviation used for normalizing the latent representation.
        dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.

    Args:
        enc_fp (str): File path to the encoder's JIT file on the remote store.
        dec_fp (str): File path to the decoder's JIT file on the remote store.
        name (str): Name of the model, used for differentiating cache file paths.
        mean_std_fp (str): File path to the pickle file containing mean and std of the latent space.
        latent_ch (int, optional): Number of latent channels (default is 16).
        is_image (bool, optional): Flag to indicate whether the output is an image (default is True).
        is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
    """

    def __init__(
        self,
        enc_fp: str,
        dec_fp: str,
        name: str,
        mean_std_fp: str,
        s3_credential_path: Optional[str] = None,
        latent_ch: int = 16,
        is_image: bool = True,
        is_bf16: bool = True,
        load_mean_std: bool = True,
    ):
        super().__init__(
            name,
            mean_std_fp,
            latent_ch,
            is_image,
            is_bf16,
            s3_credential_path=s3_credential_path,
            load_mean_std=load_mean_std,
        )
        self.load_encoder(enc_fp)
        self.load_decoder(dec_fp)

    def load_encoder(self, enc_fp: str) -> None:
        """
        Load the encoder from the remote store.

        Args:
        - enc_fp (str): File path to the encoder's JIT file on the remote store.
        """
        self.encoder = load_from_s3_with_cache(
            enc_fp,
            f"vae/{self.name}_enc.jit",
            easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())},
            backend_args=self.backend_args,
        )
        self.encoder.eval()
        for param in self.encoder.parameters():
            param.requires_grad = False
        self.encoder.to(self.dtype)

    def load_decoder(self, dec_fp: str) -> None:
        """
        Load the decoder from the remote store.

        Args:
        - dec_fp (str): File path to the decoder's JIT file on the remote store.
        """
        self.decoder = load_from_s3_with_cache(
            dec_fp,
            f"vae/{self.name}_dec.jit",
            easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())},
            backend_args=self.backend_args,
        )
        self.decoder.eval()
        for param in self.decoder.parameters():
            param.requires_grad = False
        self.decoder.to(self.dtype)


class StateDictVAE(BasePretrainedImageVAE):
    """
    A Variational Autoencoder (VAE) that loads pre-trained weights into
    provided encoder and decoder components from a remote store, handles data type conversions,
    and normalization using provided mean and standard deviation values for latent space representation.

    Attributes:
        encoder (Module): The encoder with weights loaded from storage.
        decoder (Module): The decoder with weights loaded from storage.
        latent_mean (Tensor): The mean used for normalizing the latent representation.
        latent_std (Tensor): The standard deviation used for normalizing the latent representation.
        dtype (dtype): Data type for model tensors, determined by whether bf16 is enabled.

    Args:
        enc_fp (str): File path to the encoder's JIT file on the remote store.
        dec_fp (str): File path to the decoder's JIT file on the remote store.
        vae (Module): Instance of VAE with not loaded weights
        name (str): Name of the model, used for differentiating cache file paths.
        mean_std_fp (str): File path to the pickle file containing mean and std of the latent space.
        latent_ch (int, optional): Number of latent channels (default is 16).
        is_image (bool, optional): Flag to indicate whether the output is an image (default is True).
        is_bf16 (bool, optional): Flag to use Brain Floating Point 16-bit data type (default is True).
    """

    def __init__(
        self,
        enc_fp: str,
        dec_fp: str,
        vae: torch.nn.Module,
        name: str,
        mean_std_fp: str,
        s3_credential_path: Optional[str] = None,
        latent_ch: int = 16,
        is_image: bool = True,
        is_bf16: bool = True,
    ):
        super().__init__(name, mean_std_fp, latent_ch, is_image, is_bf16, s3_credential_path=s3_credential_path)

        self.load_encoder_and_decoder(enc_fp, dec_fp, vae)

    def load_encoder_and_decoder(self, enc_fp: str, dec_fp: str, vae: torch.nn.Module) -> None:
        """
        Load the encoder from the remote store.

        Args:
        - vae_fp (str): File path to the vae's state dict file on the remote store.
        - vae (str): VAE module into which weights will be loaded.
        """
        state_dict_enc = load_from_s3_with_cache(
            enc_fp,
            f"vae/{self.name}_enc.jit",
            easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())},
            backend_args=self.backend_args,
        )

        state_dict_dec = load_from_s3_with_cache(
            dec_fp,
            f"vae/{self.name}_dec.jit",
            easy_io_kwargs={"map_location": torch.device(torch.cuda.current_device())},
            backend_args=self.backend_args,
        )

        jit_weights_state_dict = state_dict_enc.state_dict() | state_dict_dec.state_dict()
        jit_weights_state_dict = {
            k: v
            for k, v in jit_weights_state_dict.items()
            # Global variables captured by JIT
            if k
            not in (
                "encoder.patcher.wavelets",
                "encoder.patcher._arange",
                "decoder.unpatcher.wavelets",
                "decoder.unpatcher._arange",
            )
        }

        vae.load_state_dict(jit_weights_state_dict)
        vae.eval()
        for param in vae.parameters():
            param.requires_grad = False
        vae.to(self.dtype)

        self.vae = vae
        self.encoder = self.vae.encode
        self.decoder = self.vae.decode

    def reset_dtype(self, *args, **kwargs):
        """
        Resets the data type of the encoder and decoder to the model's default data type.

        Args:
            *args, **kwargs: Unused, present to allow flexibility in method calls.
        """
        del args, kwargs
        self.vae.to(self.dtype)


class SDVAE(BaseVAE):
    def __init__(self, batch_size=16, count_std: bool = False, is_downsample: bool = True) -> None:
        super().__init__(channel=4, name="sd_vae")
        self.dtype = torch.bfloat16
        self.register_buffer(
            "scale",
            torch.tensor([4.17, 4.62, 3.71, 3.28], dtype=self.dtype).reciprocal().reshape(1, -1, 1, 1),
            persistent=False,
        )
        self.register_buffer(
            "bias",
            -1.0 * torch.tensor([5.81, 3.25, 0.12, -2.15], dtype=self.dtype).reshape(1, -1, 1, 1) * self.scale,
            persistent=False,
        )
        self.batch_size = batch_size
        self.count_std = count_std
        self.is_downsample = is_downsample
        self.load_vae()
        self.reset_dtype()

    def reset_dtype(self, *args, **kwargs):
        del args, kwargs
        self.vae.to(self.dtype)

    @rank0_first
    def load_vae(self) -> None:
        os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
        os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
        import diffusers

        vae_name = "stabilityai/sd-vae-ft-mse"
        try:
            vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name, local_files_only=True)
        except:  # noqa: E722
            # Could not load the model from cache; try without local_files_only.
            vae = diffusers.models.AutoencoderKL.from_pretrained(vae_name)
        self.vae = vae.eval().requires_grad_(False)

    @torch.no_grad()
    def encode(self, state: torch.Tensor) -> torch.Tensor:
        """
        state : pixel range [-1, 1]
        """
        if self.is_downsample:
            _h, _w = state.shape[-2:]
            state = F.interpolate(state, size=(_h // 2, _w // 2), mode="bilinear", align_corners=False)
        in_dtype = state.dtype
        state = state.to(self.dtype)
        state = (state + 1.0) / 2.0
        latent_dist = self.vae.encode(state)["latent_dist"]
        mean, std = latent_dist.mean, latent_dist.std
        if self.count_std:
            latent = mean + torch.randn_like(mean) * std
        else:
            latent = mean
        latent = latent * self.scale
        latent = latent + self.bias
        return latent.to(in_dtype)

    @torch.no_grad()
    def decode(self, latent: torch.Tensor) -> torch.Tensor:
        in_dtype = latent.dtype
        latent = latent.to(self.dtype)
        latent = latent - self.bias
        latent = latent / self.scale
        latent = torch.cat([self.vae.decode(batch)["sample"] for batch in latent.split(self.batch_size)])
        if self.is_downsample:
            _h, _w = latent.shape[-2:]
            latent = F.interpolate(latent, size=(_h * 2, _w * 2), mode="bilinear", align_corners=False)
        return latent.to(in_dtype) * 2 - 1.0

    @property
    def spatial_compression_factor(self) -> int:
        return 8
