#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
"""
Implementation of the following modules is borrowed from ml-cvnets repo:
https://github.com/apple/ml-cvnets/blob/main/cvnets/models/classification/vit.py

Please see ACKNOWLEDGEMENTS for license details.
"""

from typing import Dict, Optional, Tuple, Union

import numpy as np
import torch
from torch import Tensor, nn

from timm.models import register_model
from helper.models.mobileclip.modules.common.transformer import (
    PositionalEmbedding,
    TransformerEncoder,
    get_normalization_layer,
)
from helper.models.mobileclip.modules.image.image_projection import SimpleImageProjectionHead
from helper.models.mobileclip import logger


class ConvNormAct(nn.Module):
    """
    Applies an N-dimensional convolution over an input.

    Args:
        cfg: Model configuration.
        in_channels: :math:`C_{out}` from an expected output of size
            :math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
        out_channels: :math:`C_{out}` from an expected output of size
            :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.
        kernel_size: Kernel size for convolution. An integer, or tuple of length ``N``.
        stride: Stride for convolution. An integer, or tuple of length ``N``. Default: 1.
        dilation: Dilation rate for convolution. An integer, or tuple of length ``N``.
            Default: ``1``.
        padding: Padding for convolution. An integer, or tuple of length ``N``.
            If not specified, padding is automatically computed based on kernel size and
            dilation range. Default : ``None`` (equivalent to ``[
            int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(N)]``).
        groups: Number of groups in convolution. Default: ``1``.
        bias: Use bias. Default: ``False``.
        padding_mode: Padding mode ('zeros', 'reflect', 'replicate' or 'circular').
            Default: ``zeros``.
        use_norm: Use normalization layer after convolution. Default: ``True``.
        use_act: Use activation layer after convolution (or convolution and normalization).
            Default: ``True``.
        norm_layer: If not None, the provided normalization layer object will be used.
            Otherwise, a normalization object will be created based on config
            ``model.normalization.*`` opts.
        act_layer: If not None, the provided activation function will be used.
            Otherwise, an activation function will be created based on config
            ``model.activation.*`` opts.

    Shape:
        - Input: :math:`(bs, C_{in}, X_{1}, ..., X_{N})`.
        - Output: :math:`(bs, C_{out}, Y_{1}, ..., Y_{N})`.

    .. note::
        For depth-wise convolution, `groups=C_{in}=C_{out}`.
    """

    def __init__(
        self,
        cfg: Dict,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Tuple[int, ...]],
        stride: Union[int, Tuple[int, ...]] = 1,
        dilation: Union[int, Tuple[int, ...]] = 1,
        padding: Optional[Union[int, Tuple[int, ...]]] = None,
        groups: int = 1,
        bias: bool = False,
        padding_mode: str = "zeros",
        use_norm: bool = True,
        use_act: bool = True,
        norm_layer: Optional[nn.Module] = None,
        act_layer: Optional[nn.Module] = None,
        *args,
        **kwargs,
    ) -> None:
        super().__init__()
        self.ndim = 2

        if norm_layer is None and use_norm:
            norm_type = cfg.get("normalization", "batch_norm")
            if norm_type == "batch_norm":
                norm_layer = nn.BatchNorm2d(
                    num_features=out_channels,
                    momentum=cfg.get("momentum", 0.1),
                )
            else:
                norm_layer = get_normalization_layer(
                    num_features=out_channels, norm_type=norm_type
                )
        elif norm_layer is not None and use_norm:
            logger.error(
                f"When use_norm is False, norm_layer should be None, but norm_layer={norm_layer} is provided."
            )

        if act_layer is None and use_act:
            act_layer = nn.GELU()  # Default to GELU
        elif act_layer is not None and use_act:
            logger.error(
                f"When use_act is False, act_layer should be None, but act_layer={act_layer} is provided."
            )

        if (
            use_norm
            and any(param[0] == "bias" for param in norm_layer.named_parameters())
            and bias
        ):
            assert (
                not bias
            ), "Do not use bias when using normalization layers with bias."

        if isinstance(kernel_size, int):
            kernel_size = (kernel_size,) * self.ndim

        if isinstance(stride, int):
            stride = (stride,) * self.ndim

        if isinstance(dilation, int):
            dilation = (dilation,) * self.ndim

        assert isinstance(kernel_size, Tuple)
        assert isinstance(stride, Tuple)
        assert isinstance(dilation, Tuple)

        if padding is None:
            padding = (
                int((kernel_size[i] - 1) / 2) * dilation[i] for i in range(self.ndim)
            )

        if in_channels % groups != 0:
            logger.error(
                "Input channels are not divisible by groups. {}%{} != 0 ".format(
                    in_channels, groups
                )
            )
        if out_channels % groups != 0:
            logger.error(
                "Output channels are not divisible by groups. {}%{} != 0 ".format(
                    out_channels, groups
                )
            )

        block = nn.Sequential()

        conv_layer = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,  # type: ignore
            stride=stride,  # type: ignore
            padding=padding,
            dilation=dilation,  # type: ignore
            groups=groups,
            bias=bias,
            padding_mode=padding_mode,
        )

        block.add_module(name="conv", module=conv_layer)

        self.norm_name = None
        if use_norm:
            block.add_module(name="norm", module=norm_layer)
            self.norm_name = norm_layer.__class__.__name__

        self.act_name = None
        if use_act:
            block.add_module(name="act", module=act_layer)
            self.act_name = act_layer.__class__.__name__

        self.block = block
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.groups = groups
        self.kernel_size = conv_layer.kernel_size
        self.bias = bias
        self.dilation = dilation

    def forward(self, x: Tensor) -> Tensor:
        return self.block(x)


class VisionTransformer(nn.Module):
    """
    This class defines the `Vision Transformer architecture <https://arxiv.org/abs/2010.11929>`_. Our model implementation
    is inspired from `Early Convolutions Help Transformers See Better <https://arxiv.org/abs/2106.14881>`_

    .. note::
        Our implementation is different from the original implementation in two ways:
        1. Kernel size is odd.
        2. Our positional encoding implementation allows us to use ViT with any multiple input scales
        3. We do not use StochasticDepth
        4. We do not add positional encoding to class token (if enabled), as suggested in `DeiT-3 paper <https://arxiv.org/abs/2204.07118>`_
    """

    def __init__(self, cfg, *args, **kwargs) -> None:
        super().__init__()
        image_channels = 3
        num_classes = cfg.get("n_classes", 1000)

        self.projection_dim = None
        if "projection_dim" in kwargs:
            self.projection_dim = kwargs.get("projection_dim")

        kernel_sizes_conv_stem = [4, 2, 2]
        strides_conv_stem = [4, 2, 2]

        # Typically, in the ImageNet dataset, we use 224x224 as a resolution.
        # For out ViT implementation, patch size is 16 (16 = 4 * 2 * 2)
        # Therefore, total number of embeddings along width and height are (224 / 16)^2
        num_embeddings = (224 // 16) ** 2

        embed_dim = cfg["embed_dim"]
        ffn_dim = cfg["embed_dim"] * 4
        pos_emb_drop_p = cfg.get("pos_emb_drop_p", 0.0)
        n_transformer_layers = cfg["n_transformer_layers"]
        num_heads = cfg["n_attn_heads"]
        attn_dropout = cfg.get("attn_dropout", 0.0)
        dropout = cfg.get("dropout", 0.0)
        ffn_dropout = cfg.get("ffn_dropout", 0.0)
        norm_layer = cfg.get("norm_layer", "layer_norm")

        conv_stem_proj_dim = max(32, embed_dim // 4)
        patch_emb = [
            ConvNormAct(
                cfg=cfg,
                in_channels=image_channels,
                out_channels=conv_stem_proj_dim,
                kernel_size=kernel_sizes_conv_stem[0],
                stride=strides_conv_stem[0],
                bias=False,
                use_norm=True,
                use_act=True,
            ),
            ConvNormAct(
                cfg=cfg,
                in_channels=conv_stem_proj_dim,
                out_channels=conv_stem_proj_dim,
                kernel_size=kernel_sizes_conv_stem[1],
                stride=strides_conv_stem[1],
                bias=False,
                use_norm=True,
                use_act=True,
            ),
            ConvNormAct(
                cfg=cfg,
                in_channels=conv_stem_proj_dim,
                out_channels=embed_dim,
                kernel_size=kernel_sizes_conv_stem[2],
                stride=strides_conv_stem[2],
                bias=True,
                use_norm=False,
                use_act=False,
            ),
        ]

        self.patch_emb = nn.Sequential(*patch_emb)

        use_cls_token = not cfg.get("no_cls_token", False)
        stochastic_dropout = cfg.get("stochastic_dropout", 0.0)
        per_layer_stochastic_drop_rate = [
            round(x, 3)
            for x in np.linspace(0, stochastic_dropout, n_transformer_layers)
        ]
        transformer_blocks = [
            TransformerEncoder(
                embed_dim=embed_dim,
                ffn_latent_dim=ffn_dim,
                num_heads=num_heads,
                attn_dropout=attn_dropout,
                dropout=dropout,
                ffn_dropout=ffn_dropout,
                transformer_norm_layer=norm_layer,
                stochastic_dropout=per_layer_stochastic_drop_rate[layer_idx],
            )
            for layer_idx in range(n_transformer_layers)
        ]

        self.post_transformer_norm = get_normalization_layer(
            num_features=embed_dim, norm_type=norm_layer
        )

        self.transformer = nn.Sequential(*transformer_blocks)

        if self.projection_dim is None:
            self.classifier = nn.Linear(embed_dim, num_classes)
        else:
            self.classifier = SimpleImageProjectionHead(embed_dim, self.projection_dim)

        if use_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(size=(1, 1, embed_dim)))
            torch.nn.init.trunc_normal_(self.cls_token, std=0.02)
        else:
            self.cls_token = None

        self.pos_embed = PositionalEmbedding(
            num_embeddings=num_embeddings,
            embedding_dim=embed_dim,
            padding_idx=None,
            interpolation_mode="bilinear",
        )
        self.emb_dropout = nn.Dropout(p=pos_emb_drop_p)

    def extract_patch_embeddings(self, x: Tensor) -> Tuple[Tensor, Tuple[int, int]]:
        # input is of shape [Batch, in_channels, height, width]. in_channels is mostly 3 (for RGB images)
        batch_size = x.shape[0]

        # [Batch, in_channels, height, width] --> [Batch, emb_dim, num_patches_height, num_patches_width]
        patch_emb = self.patch_emb(x)
        n_h, n_w = patch_emb.shape[-2:]

        # [Batch, emb_dim, num_patches_height, num_patches_width] --> [Batch, emb_dim, num_patches]
        patch_emb = patch_emb.flatten(2)
        # [Batch, emb_dim, num_patches] --> [Batch, num_patches, emb_dim]
        patch_emb = patch_emb.transpose(1, 2).contiguous()

        n_patches = patch_emb.shape[1]
        # we resize the positional encodings dynamically.
        pos_emb = self.pos_embed(n_patches).to(patch_emb.dtype)

        # add positional encodings
        patch_emb = pos_emb + patch_emb

        # add classification token
        if self.cls_token is not None:
            # [1, 1, emb_dim] --> [Batch, 1, emb_dim]
            cls_tokens = self.cls_token.expand(batch_size, -1, -1)
            # Concat([Batch, 1, emb_dim], [Batch, num_patches, emb_dim]) --> [Batch, num_patches + 1, emb_dim]
            patch_emb = torch.cat((cls_tokens, patch_emb), dim=1)

        # dropout
        patch_emb = self.emb_dropout(patch_emb)
        return patch_emb, (n_h, n_w)

    def _features_from_transformer(
        self, x: Tensor, *args, **kwargs
    ) -> Tuple[Tensor, Tuple[int, int]]:
        # this function extract patch embeddings and then apply transformer module to learn
        # inter-patch representations

        # [B, N, C] --> [N, B, embed_dim], where B is batch size, N is number of tokens,
        # and embed_dim is feature dim
        x, (n_h, n_w) = self.extract_patch_embeddings(x)

        for layer in self.transformer:
            x = layer(x)
        x = self.post_transformer_norm(x)

        return x, (n_h, n_w)

    def extract_features(
        self, x: Tensor, *args, **kwargs
    ) -> Tuple[Tensor, Optional[Tensor]]:
        # The extract_features function for ViT returns two outputs: (1) embedding corresponding to CLS token
        # and (2) image embeddings of the shape [B, C, h//o, w//o], where the value of o is typically 16.
        return_image_embeddings = kwargs.get("return_image_embeddings", False)

        # [B, C, H, W] --> [B, N + 1, embed_dim] or [B, N, embed_dim]
        # here, B is batch size, C is input channels
        # H and W are input height and width
        # N is the number of pixels (or tokens) after processing input with conv stem and reshaping
        # We add +1 for cls token (if applicable)
        # embed_dim --> embedding dimension
        x, (n_h, n_w) = self._features_from_transformer(x, *args, **kwargs)

        if self.cls_token is not None:
            # [B, N + 1, embed_dim] --> [B, embed_dim], [B, N, embed_dim]
            cls_embedding, image_embedding = torch.split(
                x, split_size_or_sections=[1, x.shape[1] - 1], dim=1
            )
            cls_embedding = cls_embedding.squeeze(1)
        else:
            # [B, N, embed_dim] -> [B, embed_dim]
            cls_embedding = torch.mean(x, dim=1)
            # [B, N, embed_dim]
            image_embedding = x

        if return_image_embeddings:
            # reshape image embedding to 4-D tensor
            # [B, N, C] --> [B, C, N]
            image_embedding = image_embedding.transpose(1, 2).contiguous()
            image_embedding = image_embedding.reshape(
                image_embedding.shape[0], -1, n_h, n_w
            )

            return cls_embedding, image_embedding
        else:
            return cls_embedding, None

    def forward_classifier(self, x: Tensor, *args, **kwargs) -> Tuple[Tensor, Tensor]:
        cls_embedding, image_embedding = self.extract_features(x, *args, **kwargs)
        # classify based on CLS token
        cls_embedding = self.classifier(cls_embedding)
        return cls_embedding, image_embedding

    def forward(self, x: Tensor, *args, **kwargs) -> Union[Tensor, Dict[str, Tensor]]:
        # In ViT model, we can return either classifier embeddings (logits) or image embeddings or both.
        # To return the image embeddings, we need to set keyword argument (return_image_embeddings) as True.
        if kwargs.get("return_image_embeddings", False):
            out_dict = dict()
            prediction, image_embedding = self.forward_classifier(x, *args, **kwargs)
            out_dict.update({"logits": prediction})
            if image_embedding is not None:
                out_dict.update({"image_embeddings": image_embedding})
            return out_dict
        else:
            prediction, _ = self.forward_classifier(x, *args, **kwargs)
            return prediction


@register_model
def vit_b16(pretrained=False, **kwargs):
    # Vision transformer config
    cfg = {
        "norm_layer": "layer_norm_fp32",
        "act_layer": "gelu",
        "embed_dim": 768,
        "n_transformer_layers": 12,
        "n_attn_heads": 12,
    }
    model = VisionTransformer(cfg=cfg, **kwargs)
    if pretrained:
        raise ValueError("Functionality not implemented.")
    return model
