# This file is from:
# https://github.com/zhaoyue-zephyrus/bsq-vit/blob/main/transcoder/models/transformer.py
# https://github.com/zhaoyue-zephyrus/bsq-vit/blob/main/transcoder/models/attention_mask.py

# MIT License

# Copyright (c) 2024 Yue Zhao

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from collections import OrderedDict
from typing import Callable, Optional, Union
from einops import rearrange
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from timm.models.layers import to_2tuple
from timm.models.layers import trunc_normal_
from timm.models.layers import DropPath


def get_attention_mask(sequence_length, device, mask_type="block-causal", **kwargs):
    if mask_type.lower() == "none" or mask_type is None:
        return None
    elif mask_type.lower() == "block-causal":
        return _block_caulsal_mask_impl(sequence_length, device, **kwargs)
    elif mask_type.lower() == "causal":
        return _caulsal_mask_impl(sequence_length, device, **kwargs)
    else:
        raise NotImplementedError(f"Mask type {mask_type} not implemented")


def _block_caulsal_mask_impl(sequence_length, device, block_size=16, **kwargs):
    """
    Create a block-causal mask
    """
    assert sequence_length % block_size == 0, (
        "for block causal masks sequence length must be divisible by block size"
    )
    blocks = torch.ones(
        sequence_length // block_size, block_size, block_size, device=device
    )
    block_diag_enable_mask = torch.block_diag(*blocks)
    causal_enable_mask = torch.ones(
        sequence_length, sequence_length, device=device
    ).tril_(0)
    disable_mask = (block_diag_enable_mask + causal_enable_mask) < 0.5
    return disable_mask


def _caulsal_mask_impl(sequence_length, device, **kwargs):
    """
    Create a causal mask
    """
    causal_disable_mask = torch.triu(
        torch.full(
            (sequence_length, sequence_length),
            float("-inf"),
            dtype=torch.float32,
            device=device,
        ),
        diagonal=1,
    )
    return causal_disable_mask


class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        super().__init__()
        self.inplace = inplace
        self.gamma = nn.Parameter(init_values * torch.ones(dim))

    def forward(self, x):
        return x.mul_(self.gamma) if self.inplace else x * self.gamma


class ResidualAttentionBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        n_head: int,
        mlp_ratio: float = 4.0,
        ls_init_value: float = None,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        act_layer: Callable = nn.GELU,
        norm_layer: Callable = nn.LayerNorm,
        use_preln: bool = True,
    ):
        super().__init__()

        self.ln_1 = norm_layer(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_head, dropout=attn_drop)
        self.ls_1 = (
            LayerScale(d_model, ls_init_value)
            if ls_init_value is not None
            else nn.Identity()
        )

        self.ln_2 = norm_layer(d_model)
        mlp_width = int(d_model * mlp_ratio)
        self.mlp = nn.Sequential(
            OrderedDict(
                [
                    ("c_fc", nn.Linear(d_model, mlp_width)),
                    ("gelu", act_layer()),
                    # disable this following JAX implementation.
                    # Reference: https://github.com/google-research/magvit/blob/main/videogvt/models/simplified_bert.py#L112
                    # ("drop1", nn.Dropout(drop)),
                    ("c_proj", nn.Linear(mlp_width, d_model)),
                    ("drop2", nn.Dropout(drop)),
                ]
            )
        )
        self.ls_2 = (
            LayerScale(d_model, ls_init_value)
            if ls_init_value is not None
            else nn.Identity()
        )

        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

        self.use_preln = use_preln

    def attention(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
    ):
        attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
        return self.attn(
            x, x, x, need_weights=False, attn_mask=attn_mask, is_causal=is_causal
        )[0]

    def checkpoint_forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
    ):
        state = x
        if self.use_preln:
            x = checkpoint(self.ln_1, x, use_reentrant=False)
            x = self.attention(x, attn_mask, is_causal)
            x = checkpoint(self.ls_1, x, use_reentrant=False)
            state = state + self.drop_path(x)
            x = checkpoint(self.ln_2, state, use_reentrant=False)
            x = self.mlp(x)
            x = checkpoint(self.ls_2, x, use_reentrant=False)
            state = state + self.drop_path(x)
        else:
            x = self.attention(x, attn_mask, is_causal)
            x = state + self.drop_path(x)
            state = checkpoint(self.ln_1, x, use_reentrant=False)
            x = self.mlp(state)
            state = state + self.drop_path(x)
            state = checkpoint(self.ln_2, state, use_reentrant=False)
        return state

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
        selective_checkpointing: bool = False,
    ):
        if selective_checkpointing:
            return self.checkpoint_forward(x, attn_mask, is_causal=is_causal)
        if self.use_preln:
            x = x + self.drop_path(
                self.ls_1(
                    self.attention(
                        self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal
                    )
                )
            )
            x = x + self.drop_path(self.ls_2(self.mlp(self.ln_2(x))))
        else:
            x = x + self.drop_path(
                self.attention(x, attn_mask=attn_mask, is_causal=is_causal)
            )
            x = self.ln_1(x)
            x = x + self.drop_path(self.mlp(x))
            x = self.ln_2(x)
        return x


class Transformer(nn.Module):
    def __init__(
        self,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float = 4.0,
        ls_init_value: float = None,
        drop: float = 0.0,
        attn_drop: float = 0.0,
        drop_path: float = 0.0,
        act_layer: nn.Module = nn.GELU,
        norm_layer: nn.Module = nn.LayerNorm,
        use_preln: bool = True,
    ):
        super().__init__()
        self.width = width
        self.layers = layers
        self.grad_checkpointing = False
        self.selective_checkpointing = False
        self.grad_checkpointing_params = {"use_reentrant": False}
        if attn_drop == 0 and drop_path == 0 and drop_path == 0:
            self.grad_checkpointing_params.update({"preserve_rng_state": False})
        else:
            self.grad_checkpointing_params.update({"preserve_rng_state": True})

        self.resblocks = nn.ModuleList(
            [
                ResidualAttentionBlock(
                    width,
                    heads,
                    mlp_ratio,
                    ls_init_value=ls_init_value,
                    drop=drop,
                    attn_drop=attn_drop,
                    drop_path=drop_path,
                    act_layer=act_layer,
                    norm_layer=norm_layer,
                    use_preln=use_preln,
                )
                for _ in range(layers)
            ]
        )

    def forward(
        self,
        x: torch.Tensor,
        attn_mask: Optional[torch.Tensor] = None,
        is_causal: bool = False,
    ):
        for r in self.resblocks:
            if (
                self.training
                and self.grad_checkpointing
                and not torch.jit.is_scripting()
            ):
                if not self.selective_checkpointing:
                    x = checkpoint(
                        r,
                        x,
                        attn_mask,
                        is_causal=is_causal,
                        **self.grad_checkpointing_params,
                    )
                else:
                    x = r(
                        x,
                        attn_mask=attn_mask,
                        is_causal=is_causal,
                        selective_checkpointing=True,
                    )
            else:
                x = r(x, attn_mask=attn_mask)
        return x


class TransformerEncoder(nn.Module):
    def __init__(
        self,
        image_size: int,
        patch_size: int,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float,
        double_z: bool,
        z_channels: int,
        num_frames: int = 1,
        cross_frames: bool = True,
        ls_init_value: float = None,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        ln_pre: bool = True,
        ln_post: bool = True,
        act_layer: str = "gelu",
        norm_layer: str = "layer_norm",
        mask_type: Union[str, None] = "none",
        mask_block_size: int = -1,
    ):
        super().__init__()
        self.image_size = to_2tuple(image_size)
        self.patch_size = to_2tuple(patch_size)
        self.grid_size = (
            self.image_size[0] // self.patch_size[0],
            self.image_size[1] // self.patch_size[1],
        )
        self.patches_per_frame = self.grid_size[0] * self.grid_size[1]
        self.mask_type = mask_type
        self.mask_block_size = mask_block_size

        if act_layer.lower() == "gelu":
            self.act_layer = nn.GELU
        else:
            raise ValueError(f"Unsupported activation function: {act_layer}")
        if norm_layer.lower() == "layer_norm":
            self.norm_layer = nn.LayerNorm
        else:
            raise ValueError(f"Unsupported normalization: {norm_layer}")

        self.conv1 = nn.Linear(
            in_features=3 * self.patch_size[0] * self.patch_size[1],
            out_features=width,
            bias=not ln_pre,
        )

        scale = width**-0.5
        self.positional_embedding = nn.Parameter(
            scale * torch.randn(self.grid_size[0] * self.grid_size[1], width)
        )
        assert num_frames >= 1
        self.num_frames = num_frames
        self.cross_frames = cross_frames
        if num_frames > 1 and cross_frames:
            self.temporal_positional_embedding = nn.Parameter(
                torch.zeros(num_frames, width)
            )
        else:
            self.temporal_positional_embedding = None

        self.ln_pre = self.norm_layer(width) if ln_pre else nn.Identity()

        self.transformer = Transformer(
            width,
            layers,
            heads,
            mlp_ratio,
            ls_init_value=ls_init_value,
            drop=drop_rate,
            attn_drop=attn_drop_rate,
            drop_path=drop_path_rate,
            act_layer=self.act_layer,
            norm_layer=self.norm_layer,
        )

        self.ln_post = self.norm_layer(width)

        if double_z:
            self.quant_embed = nn.Linear(
                in_features=width, out_features=z_channels * 2
            )
        else:
            self.quant_embed = nn.Linear(
                in_features=width, out_features=z_channels
            )
        self.init_parameters()

    def init_parameters(self):
        if self.positional_embedding is not None:
            nn.init.normal_(self.positional_embedding, std=0.02)
        trunc_normal_(self.conv1.weight, std=0.02)
        for block in self.transformer.resblocks:
            for n, p in block.named_parameters():
                if "weight" in n:
                    if "ln" not in n:
                        trunc_normal_(p, std=0.02)
                elif "bias" in n:
                    nn.init.zeros_(p)
                else:
                    raise NotImplementedError(f"Unknown parameters named {n}")

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True, selective=False):
        self.transformer.grad_checkpointing = enable
        self.transformer.selective_checkpointing = selective

    def forward(self, x):
        if self.num_frames == 1:
            x = rearrange(
                x,
                "b c (hh sh) (ww sw) -> b (hh ww) (c sh sw)",
                sh=self.patch_size[0],
                sw=self.patch_size[1],
            )
            x = self.conv1(x)
            x = x + self.positional_embedding.to(x.dtype)
        elif self.cross_frames:
            num_frames = x.shape[2]
            assert num_frames <= self.num_frames, (
                "Number of frames should be less or equal to the model setting"
            )
            x = rearrange(
                x,
                "b c t (hh sh) (ww sw) -> b (t hh ww) (c sh sw)",
                sh=self.patch_size[0],
                sw=self.patch_size[1],
            )
            x = self.conv1(x)
            tile_pos_embed = self.positional_embedding.repeat(num_frames, 1)
            tile_tem_embed = self.temporal_positional_embedding[
                :num_frames
            ].repeat_interleave(self.patches_per_frame, 0)
            total_pos_embed = tile_pos_embed + tile_tem_embed
            x = x + total_pos_embed.to(x.dtype).squeeze(0)
        else:
            x = rearrange(
                x,
                "b c t (hh sh) (ww sw) -> (b t) (hh ww) (c sh sw)",
                sh=self.patch_size[0],
                sw=self.patch_size[1],
            )
            x = self.conv1(x)
            x = x + self.positional_embedding.to(x.dtype)

        x = self.ln_pre(x)
        x = x.permute(1, 0, 2)
        block_size = (
            self.grid_size[0] * self.grid_size[1]
            if self.mask_block_size <= 0
            else self.mask_block_size
        )
        attn_mask = get_attention_mask(
            x.size(0), x.device, mask_type=self.mask_type, block_size=block_size
        )
        x = self.transformer(x, attn_mask, is_causal=self.mask_type == "causal")
        x = x.permute(1, 0, 2)
        x = self.ln_post(x)
        x = self.quant_embed(x)

        return x


class TransformerDecoder(nn.Module):
    def __init__(
        self,
        image_size: int,
        patch_size: int,
        width: int,
        layers: int,
        heads: int,
        mlp_ratio: float,
        double_z: bool,
        z_channels: int,
        num_frames: int = 1,
        cross_frames: bool = True,
        ls_init_value: float = None,
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        drop_path_rate: float = 0.0,
        ln_pre: bool = True,
        ln_post: bool = True,
        act_layer: str = "gelu",
        norm_layer: str = "layer_norm",
        use_ffn_output: bool = True,
        dim_ffn_output: int = 3072,
        logit_laplace: bool = False,
        mask_type: Union[str, None] = "none",
        mask_block_size: int = -1,
    ):
        super().__init__()
        self.image_size = to_2tuple(image_size)
        self.patch_size = to_2tuple(patch_size)
        self.grid_size = (
            self.image_size[0] // self.patch_size[0],
            self.image_size[1] // self.patch_size[1],
        )
        self.patches_per_frame = self.grid_size[0] * self.grid_size[1]
        self.mask_type = mask_type
        self.mask_block_size = mask_block_size

        if act_layer.lower() == "gelu":
            self.act_layer = nn.GELU
        else:
            raise ValueError(f"Unsupported activation function: {act_layer}")
        if norm_layer.lower() == "layer_norm":
            self.norm_layer = nn.LayerNorm
        else:
            raise ValueError(f"Unsupported normalization: {norm_layer}")

        self.use_ffn_output = use_ffn_output
        if use_ffn_output:
            self.ffn = nn.Sequential(
                nn.Linear(width, dim_ffn_output),
                nn.Tanh(),
            )
            self.conv_out = nn.Linear(
                in_features=dim_ffn_output,
                out_features=3
                * self.patch_size[0]
                * self.patch_size[1]
                * (1 + logit_laplace),
            )
        else:
            self.ffn = nn.Identity()
            self.conv_out = nn.Linear(
                in_features=width,
                out_features=3
                * self.patch_size[0]
                * self.patch_size[1]
                * (1 + logit_laplace),
            )

        scale = width**-0.5
        self.positional_embedding = nn.Parameter(
            scale * torch.randn(self.grid_size[0] * self.grid_size[1], width)
        )
        assert num_frames >= 1
        self.num_frames = num_frames
        self.cross_frames = cross_frames
        if num_frames > 1 and cross_frames:
            self.temporal_positional_embedding = nn.Parameter(
                torch.zeros(num_frames, width)
            )
        else:
            self.temporal_positional_embedding = None

        self.ln_pre = self.norm_layer(width) if ln_pre else nn.Identity()

        self.transformer = Transformer(
            width,
            layers,
            heads,
            mlp_ratio,
            ls_init_value=ls_init_value,
            drop=drop_rate,
            attn_drop=attn_drop_rate,
            drop_path=drop_path_rate,
            act_layer=self.act_layer,
            norm_layer=self.norm_layer,
        )

        self.ln_post = self.norm_layer(width) if ln_post else nn.Identity()

        self.post_quant_embed = nn.Linear(
            in_features=z_channels, out_features=width
        )

        self.init_parameters()

    def init_parameters(self):
        if self.positional_embedding is not None:
            nn.init.normal_(self.positional_embedding, std=0.02)

        for block in self.transformer.resblocks:
            for n, p in block.named_parameters():
                if "weight" in n:
                    if "ln" not in n:
                        trunc_normal_(p, std=0.02)
                elif "bias" in n:
                    nn.init.zeros_(p)
                else:
                    raise NotImplementedError(f"Unknown parameters named {n}")
        if self.use_ffn_output:
            trunc_normal_(self.ffn[0].weight, std=0.02)
        trunc_normal_(self.conv_out.weight, std=0.02)

    @torch.jit.ignore
    def set_grad_checkpointing(self, enable=True, selective=False):
        self.transformer.grad_checkpointing = enable
        self.transformer.selective_checkpointing = selective

    def forward(self, x):

        x = self.post_quant_embed(x)

        if self.num_frames == 1 or not self.cross_frames:
            x = x + self.positional_embedding.to(x.dtype)
        else:
            num_frames = x.shape[1] // self.patches_per_frame
            assert num_frames <= self.num_frames, (
                "Number of frames should be less or equal to the model setting"
            )
            tile_pos_embed = self.positional_embedding.repeat(num_frames, 1)
            tile_tem_embed = self.temporal_positional_embedding[
                :num_frames
            ].repeat_interleave(self.patches_per_frame, 0)
            total_pos_embed = tile_pos_embed + tile_tem_embed
            x = x + total_pos_embed.to(x.dtype).squeeze(0)
        x = self.ln_pre(x)
        x = x.permute(1, 0, 2)
        block_size = (
            self.grid_size[0] * self.grid_size[1]
            if self.mask_block_size <= 0
            else self.mask_block_size
        )
        attn_mask = get_attention_mask(
            x.size(0), x.device, mask_type=self.mask_type, block_size=block_size
        )
        x = self.transformer(x, attn_mask, is_causal=self.mask_type == "causal")
        x = x.permute(1, 0, 2)
        x = self.ln_post(x)
        x = self.ffn(x)
        x = self.conv_out(x)
        if self.num_frames == 1:
            x = rearrange(
                x,
                "b (hh ww) (c sh sw) -> b c (hh sh) (ww sw)",
                hh=self.grid_size[0],
                ww=self.grid_size[1],
                sh=self.patch_size[0],
                sw=self.patch_size[1],
            )
        elif self.cross_frames:
            x = rearrange(
                x,
                "b (t hh ww) (c sh sw) -> b c t (hh sh) (ww sw)",
                t=num_frames,
                hh=self.grid_size[0],
                ww=self.grid_size[1],
                sh=self.patch_size[0],
                sw=self.patch_size[1],
            )
        else:
            x = rearrange(
                x,
                "(b t) (hh ww) (c sh sw) -> b c t (hh sh) (ww sw)",
                t=num_frames,
                hh=self.grid_size[0],
                ww=self.grid_size[1],
                sh=self.patch_size[0],
                sw=self.patch_size[1],
            )

        return x
