

import math
from typing import Any, Optional

import torch
import torch.nn as nn
from transformers import AutoConfig, AutoModel, AutoProcessor
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTModel


class ViTEmbeddingsNoCLS(ViTEmbeddings):
    """ViT Embedding Module without CLS token."""

    def __init__(self, config: AutoConfig, use_mask_token: bool = False):
        """Initialization.

        Args:
            config (AutoConfig): config for ViT.
            use_mask_token (bool, optional): whether to use mask token. Defaults to False.
        """
        super(ViTEmbeddingsNoCLS, self).__init__(config, use_mask_token=use_mask_token)
        self.cls_token = None

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """

        num_patches = embeddings.shape[1]
        num_positions = self.position_embeddings.shape[1] - 1
        if num_patches == num_positions and height == width:
            return self.position_embeddings
        patch_pos_embed = self.position_embeddings[:, 1:]
        dim = embeddings.shape[-1]
        h0 = height // self.config.patch_size
        w0 = width // self.config.patch_size
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        h0, w0 = h0 + 0.1, w0 + 0.1
        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
            mode="bicubic",
            align_corners=False,
        )
        assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return patch_pos_embed

    def forward(
        self,
        pixel_values: torch.Tensor,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        interpolate_pos_encoding: bool = False,
    ) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

        if bool_masked_pos is not None:
            seq_length = embeddings.shape[1]
            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
            # replace the masked visual tokens by mask_tokens
            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

        # add positional encoding to each token
        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
        else:
            embeddings = embeddings + self.position_embeddings[:, 1:]

        embeddings = self.dropout(embeddings)

        return embeddings


# modified from huggingface transformers ViTModel
class ViTModelNoCLS(ViTModel):
    """ViT Model without CLS token."""

    def __init__(self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False) -> None:
        super(ViTModelNoCLS, self).__init__(config, add_pooling_layer, use_mask_token)
        self.embeddings = ViTEmbeddingsNoCLS(config, use_mask_token=use_mask_token)
        self.no_cls = True

    def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
            # `trunc_normal_cpu` not implemented in `half` issues
            module.weight.data = nn.init.trunc_normal_(
                module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
            ).to(module.weight.dtype)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, ViTEmbeddings):
            module.position_embeddings.data = nn.init.trunc_normal_(
                module.position_embeddings.data.to(torch.float32),
                mean=0.0,
                std=self.config.initializer_range,
            ).to(module.position_embeddings.dtype)


# modified from huggingface transformers ViTEmbeddings
class ViTEmbeddingsReg(ViTEmbeddings):
    """
    ViT Embedding Module with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1
    """

    def __init__(self, config: AutoConfig, use_mask_token: bool = False, num_reg_tokens: int = 7):
        super(ViTEmbeddingsReg, self).__init__(config, use_mask_token=use_mask_token)
        self.reg_token = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size))
        self.num_reg_tokens = num_reg_tokens
        self.reg_pos_embed = nn.Parameter(torch.randn(1, num_reg_tokens, config.hidden_size))

        self.reg_pos_embed.data = nn.init.trunc_normal_(
            self.reg_pos_embed.data.to(torch.float32),
            mean=0.0,
            std=self.config.initializer_range,
        ).to(self.reg_pos_embed.dtype)

        self.reg_token.data = nn.init.trunc_normal_(
            self.reg_token.data.to(torch.float32),
            mean=0.0,
            std=self.config.initializer_range,
        ).to(self.reg_token.dtype)

    def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
        """
        This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
        resolution images.

        Source:
        https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
        """

        num_patches = embeddings.shape[1] - 1 - self.num_reg_tokens
        num_positions = self.position_embeddings.shape[1] - 1
        if num_patches == num_positions and height == width:
            return self.position_embeddings
        class_pos_embed = self.position_embeddings[:, 0]
        patch_pos_embed = self.position_embeddings[:, 1:]
        reg_pos_embed = self.reg_pos_embed
        dim = embeddings.shape[-1]
        h0 = height // self.config.patch_size
        w0 = width // self.config.patch_size
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        h0, w0 = h0 + 0.1, w0 + 0.1
        patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
        patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
        patch_pos_embed = nn.functional.interpolate(
            patch_pos_embed,
            scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
            mode="bicubic",
            align_corners=False,
        )
        assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
        patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, reg_pos_embed), dim=1)

    def forward(
        self,
        pixel_values: torch.Tensor,
        bool_masked_pos: Optional[torch.BoolTensor] = None,
        interpolate_pos_encoding: bool = False,
    ) -> torch.Tensor:
        batch_size, num_channels, height, width = pixel_values.shape
        embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

        if bool_masked_pos is not None:
            seq_length = embeddings.shape[1]
            mask_tokens = self.mask_token.expand(batch_size, seq_length, -1)
            # replace the masked visual tokens by mask_tokens
            mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            embeddings = embeddings * (1.0 - mask) + mask_tokens * mask

        # add the [CLS] token to the embedded patch tokens
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        reg_tokens = self.reg_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings, reg_tokens), dim=1)

        # add positional encoding to each token
        if interpolate_pos_encoding:
            embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
        else:
            embeddings = embeddings + torch.cat([self.position_embeddings, self.reg_pos_embed], dim=1)

        embeddings = self.dropout(embeddings)

        return embeddings


# modified from huggingface transformers ViTModel
class ViTModelReg(ViTModel):
    """ViT Model with register tokens. https://openreview.net/forum?id=2dnO3LLiJ1"""

    def __init__(
        self, config: AutoConfig, add_pooling_layer: bool = True, use_mask_token: bool = False, num_reg_tokens: int = 7
    ):
        super(ViTModelReg, self).__init__(config, add_pooling_layer, use_mask_token)
        self.embeddings = ViTEmbeddingsReg(config, use_mask_token=use_mask_token, num_reg_tokens=num_reg_tokens)
        self.num_reg_tokens = num_reg_tokens

    def _init_weights(self, module: nn.Linear | nn.Conv2d | nn.LayerNorm) -> None:
        """Initialize the weights"""
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid
            # `trunc_normal_cpu` not implemented in `half` issues
            module.weight.data = nn.init.trunc_normal_(
                module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range
            ).to(module.weight.dtype)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        elif isinstance(module, ViTEmbeddings):
            module.position_embeddings.data = nn.init.trunc_normal_(
                module.position_embeddings.data.to(torch.float32),
                mean=0.0,
                std=self.config.initializer_range,
            ).to(module.position_embeddings.dtype)
            module.cls_token.data = nn.init.trunc_normal_(
                module.cls_token.data.to(torch.float32),
                mean=0.0,
                std=self.config.initializer_range,
            ).to(module.cls_token.dtype)


class DeiT(nn.Module):
    """DeiT model.

    Paper: Training data-efficient image transformers & distillation through attention
        https://arxiv.org/abs/2012.12877
    Huggingface Reference: https://huggingface.co/docs/transformers/en/model_doc/deit

    Attributes:
        model_name (str): name of the model.
        pretrained (bool): whether to use pretrained weights.
    """

    def __init__(
        self,
        model_name: str = "facebook/deit-small-patch16-224",
        pretrained: bool = False,
        image_size: int = 224,
    ):
        super().__init__()
        self.image_size = image_size
        model = AutoModel.from_pretrained(model_name)
        if pretrained:
            self.model = model
        else:
            deit_config = model.config
            self.model = AutoModel.from_config(deit_config)
            del model

        self.model.pooler = nn.Identity()

        self.processor = AutoProcessor.from_pretrained(model_name)

    def get_feature_size(
        self,
        keep_spatial: bool = False,
        return_torch_size: bool = False,
    ) -> torch.Size | tuple[int, ...]:
        """Get the size of the feature.

        Args:
            keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False.
            return_torch_size (bool): if true, return torch.Size type. Defaults to False.

        Returns:
            torch.Size | tuple[int, ...]: returned feature shape.
        """
        with torch.inference_mode():
            image_size = (224, 224)
            x = torch.zeros((1, *image_size, 3), dtype=torch.uint8)
            y = self.forward(x)[:, 1:]  # for getting feature size, discard cls token
            size = y.size()[1:][::-1]
            if keep_spatial:
                assert math.isqrt(size[-1])
                h = w = int(math.sqrt(size[-1]))
                size = (size[0], h, w)
                if return_torch_size:
                    size = torch.Size(size)
            return size

    def forward(
        self,
        x: torch.Tensor,
        do_resize: bool = True,
        interpolate_pos_encoding: Optional[bool] = None,
        do_rescale: bool = True,
        do_normalize: bool = True,
    ) -> torch.Tensor:
        """Forward pass of the model

        Args:
            x (torch.Tensor): model input.

            - arguments for self.processor. Details can be find at
                https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor
            do_resize (bool): if do resizing in processor. Defaults to True.
            interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None.
            do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True.
            do_normalize (bool): if do normalize in processor. Defaults to True.

        Returns:
            torch.Tensor: model output.
        """
        input = self.processor(
            x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
        ).to(self.model.device)
        y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
        return y.last_hidden_state


class DeiTNoCLS(nn.Module):
    """Modified DeiT model without CLS token."""

    def __init__(
        self, model_name: str = "nocls-facebook/deit-small-patch16-224", pretrained: bool = False, image_size: int = 224
    ):
        super().__init__()
        self.image_size = image_size
        pretrained_model_name = model_name.replace("nocls-", "")
        deit_config = AutoConfig.from_pretrained(pretrained_model_name)
        self.model = ViTModelNoCLS(deit_config)
        if pretrained:
            pretrained_model = AutoModel.from_pretrained(pretrained_model_name)
            pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()}
            self.load_state_dict(pretrained_dict, strict=False)
            del pretrained_model, pretrained_dict

        self.model.pooler = nn.Identity()
        self.processor = AutoProcessor.from_pretrained(pretrained_model_name)
        self.no_cls = True

    def get_feature_size(
        self,
        keep_spatial: bool = False,
        return_torch_size: bool = False,
    ) -> torch.Size | tuple[int, ...]:
        """Get the size of the feature.

        Args:
            keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False.
            return_torch_size (bool): if true, return torch.Size type. Defaults to False.

        Returns:
            torch.Size | tuple[int, ...]: returned feature shape.
        """
        with torch.inference_mode():
            image_size = (self.image_size, self.image_size)
            x = torch.zeros((1, *image_size, 3), dtype=torch.uint8)
            y = self.forward(x)
            size = y.size()[1:][::-1]
            if keep_spatial:
                assert math.isqrt(size[-1])
                h = w = int(math.sqrt(size[-1]))
                size = (size[0], h, w)
                if return_torch_size:
                    size = torch.Size(size)
            return size

    def forward(
        self,
        x: torch.Tensor,
        do_resize: bool = True,
        interpolate_pos_encoding: Optional[bool] = None,
        do_rescale: bool = True,
        do_normalize: bool = True,
    ) -> torch.Tensor:
        """Forward pass of the model

        Args:
            x (torch.Tensor): model input.

            - arguments for self.processor. Details can be find at
                https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor
            do_resize (bool): if do resizing in processor. Defaults to True.
            do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True.
            do_normalize (bool): if do normalize in processor. Defaults to True.

            - argument for forward
            interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None.

        Returns:
            torch.Tensor: model output.
        """
        input = self.processor(
            x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
        ).to(self.model.device)
        y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
        return y.last_hidden_state


class DeiTReg(nn.Module):
    """Modified DeiT model with register tokens."""

    def __init__(
        self,
        model_name: str = "reg-facebook/deit-small-patch16-224",
        pretrained: bool = False,
        image_size: int = 224,
        num_reg_tokens: int = 7,
    ):
        super().__init__()
        self.image_size = image_size
        pretrained_model_name = model_name.replace("reg-", "")
        deit_config = AutoConfig.from_pretrained(pretrained_model_name)
        self.model = ViTModelReg(deit_config, num_reg_tokens=num_reg_tokens)
        if pretrained:
            pretrained_model = AutoModel.from_pretrained(pretrained_model_name)
            pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in self.model.state_dict()}
            self.load_state_dict(pretrained_dict, strict=False)
            del pretrained_model, pretrained_dict

        self.model.pooler = nn.Identity()
        self.processor = AutoProcessor.from_pretrained(pretrained_model_name)
        self.num_reg_tokens = num_reg_tokens

    def get_feature_size(
        self,
        keep_spatial: bool = False,
        return_torch_size: bool = False,
    ) -> torch.Size | tuple[int, ...]:
        """Get the size of the feature.

        Args:
            keep_spatial (bool): keep spatial dim of the feature shape. Defaults to False.
            return_torch_size (bool): if true, return torch.Size type. Defaults to False.

        Returns:
            torch.Size | tuple[int, ...]: returned feature shape.
        """
        with torch.inference_mode():
            image_size = (self.image_size, self.image_size)
            x = torch.zeros((1, *image_size, 3), dtype=torch.uint8)
            y = self.forward(x)[:, 1 : -self.num_reg_tokens]
            size = y.size()[1:][::-1]
            if keep_spatial:
                assert math.isqrt(size[-1])
                h = w = int(math.sqrt(size[-1]))
                size = (size[0], h, w)
                if return_torch_size:
                    size = torch.Size(size)
            return size

    def forward(
        self,
        x: torch.Tensor,
        do_resize: bool = True,
        interpolate_pos_encoding: Optional[bool] = None,
        do_rescale: bool = True,
        do_normalize: bool = True,
    ) -> torch.Tensor:
        """Forward pass of the model

        Args:
            x (torch.Tensor): model input.

            - arguments for self.processor. Details can be find at
                https://huggingface.co/docs/transformers/v4.41.3/en/model_doc/deit#transformers.DeiTImageProcessor
            do_resize (bool): if do resizing in processor. Defaults to True.
            interpolate_pos_encoding (bool): if interpolate the positional embedding. Defaults to None.
            do_rescale (bool): if do rescaling (0-255 -> 0-1) in processor. Defaults to True.
            do_normalize (bool): if do normalize in processor. Defaults to True.

        Returns:
            torch.Tensor: model output.
        """
        input = self.processor(
            x, return_tensors="pt", do_resize=do_resize, do_rescale=do_rescale, do_normalize=do_normalize
        ).to(self.model.device)
        y = self.model(**input, interpolate_pos_encoding=interpolate_pos_encoding)
        return y.last_hidden_state


def build_backbone(model_name: str, pretrained: bool = False, image_size: int = 224, **kwargs: Any) -> nn.Module:
    """Build the backbone visual encoder of robot vision foundation model.

    Args:
        model_name (str): name of the model.
        pretrained (bool): whether to use pretrained weights. Defaults to False.
        image_size (int): size of the image. Assume a square image. Defaults to 224
        kwargs (Any): any kwargs specific to some models. For example,
            `num_reg_tokens` for `DeiTReg` when `"reg"` in `model_name`

    Returns:
        nn.Module: backbone network.
    """
    if "reg" in model_name:
        return DeiTReg(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs)
    elif "nocls" in model_name:
        return DeiTNoCLS(model_name=model_name, pretrained=pretrained, image_size=image_size, **kwargs)
    elif "deit" in model_name:
        return DeiT(model_name=model_name, pretrained=pretrained, image_size=image_size)
    else:
        raise NotImplementedError(f"Requested {model_name} is not implemented.")
