# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Sequence
from torch.nn import functional as F
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn.bricks.transformer import FFN, PatchEmbed
from mmengine.model import BaseModule, ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmhug.models.modeling_output import EncoderOutput

from mmhug.registry import HF_MODELS
from .utils import (
    MultiheadAttention,
    SwiGLUFFNFused,
    build_norm_layer,
    resize_pos_embed,
    to_2tuple,
    build_2d_sincos_position_embedding,
)
from .base_backbone import BaseBackbone

# MEAN = [123.675, 116.28, 103.53]
# STD = [58.395, 57.12, 57.375]

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]


class TransformerEncoderLayer(BaseModule):
    """Implements one encoder layer in Vision Transformer.

    Args:
        embed_dims (int): The feature dimension
        num_heads (int): Parallel attention heads
        feedforward_channels (int): The hidden dimension for FFNs
        layer_scale_init_value (float or torch.Tensor): Init value of layer
            scale. Defaults to 0.
        drop_rate (float): Probability of an element to be zeroed
            after the feed forward layer. Defaults to 0.
        attn_drop_rate (float): The drop out rate for attention output weights.
            Defaults to 0.
        drop_path_rate (float): Stochastic depth rate. Defaults to 0.
        num_fcs (int): The number of fully-connected layers for FFNs.
            Defaults to 2.
        qkv_bias (bool): enable bias for qkv if True. Defaults to True.
        ffn_type (str): Select the type of ffn layers. Defaults to 'origin'.
        act_cfg (dict): The activation config for FFNs.
            Defaults to ``dict(type='GELU')``.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(
        self,
        embed_dims,
        num_heads,
        feedforward_channels,
        layer_scale_init_value=0.0,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        num_fcs=2,
        qkv_bias=True,
        ffn_type="origin",
        act_cfg=dict(type="GELU"),
        norm_cfg=dict(type="LN"),
        init_cfg=None,
    ):
        super(TransformerEncoderLayer, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims

        self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)

        self.attn = MultiheadAttention(
            embed_dims=embed_dims,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate),
            qkv_bias=qkv_bias,
            layer_scale_init_value=layer_scale_init_value,
        )

        self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)

        if ffn_type == "origin":
            self.ffn = FFN(
                embed_dims=embed_dims,
                feedforward_channels=feedforward_channels,
                num_fcs=num_fcs,
                ffn_drop=drop_rate,
                dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate),
                act_cfg=act_cfg,
                layer_scale_init_value=layer_scale_init_value,
            )
        elif ffn_type == "swiglu_fused":
            self.ffn = SwiGLUFFNFused(
                embed_dims=embed_dims,
                feedforward_channels=feedforward_channels,
                layer_scale_init_value=layer_scale_init_value,
            )
        else:
            raise NotImplementedError

    @property
    def norm1(self):
        return self.ln1

    @property
    def norm2(self):
        return self.ln2

    def init_weights(self):
        super(TransformerEncoderLayer, self).init_weights()
        for m in self.ffn.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.normal_(m.bias, std=1e-6)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = self.ffn(self.ln2(x), identity=x)
        return x


@HF_MODELS.register_module()
class SapiensVisionTransformer(BaseBackbone):
    """Vision Transformer.

    A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers
    for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_

    Args:
        arch (str | dict): Vision Transformer architecture. If use string,
            choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
            and 'deit-base'. If use dict, it should have below keys:

            - **embed_dims** (int): The dimensions of embedding.
            - **num_layers** (int): The number of transformer encoder layers.
            - **num_heads** (int): The number of heads in attention modules.
            - **feedforward_channels** (int): The hidden dimensions in
              feedforward modules.

            Defaults to 'base'.
        img_size (int | tuple): The expected input image shape. Because we
            support dynamic input shape, just set the argument to the most
            common input image shape. Defaults to 224.
        patch_size (int | tuple): The patch size in patch embedding.
            Defaults to 16.
        in_channels (int): The num of input channels. Defaults to 3.
        out_indices (Sequence | int): Output from which stages.
            Defaults to -1, means the last stage.
        drop_rate (float): Probability of an element to be zeroed.
            Defaults to 0.
        drop_path_rate (float): stochastic depth rate. Defaults to 0.
        qkv_bias (bool): Whether to add bias for qkv in attention modules.
            Defaults to True.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        final_norm (bool): Whether to add a additional layer to normalize
            final feature map. Defaults to True.
        out_type (str): The type of output features. Please choose from

            - ``"cls_token"``: The class token tensor with shape (B, C).
            - ``"featmap"``: The feature map tensor from the patch tokens
              with shape (B, C, H, W).
            - ``"avg_featmap"``: The global averaged feature map tensor
              with shape (B, C).
            - ``"raw"``: The raw feature tensor includes patch tokens and
              class tokens with shape (B, L, C).

            Defaults to ``"cls_token"``.
        with_cls_token (bool): Whether concatenating class token into image
            tokens as transformer input. Defaults to True.
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters. Defaults to -1.
        interpolate_mode (str): Select the interpolate mode for position
            embeding vector resize. Defaults to "bicubic".
        layer_scale_init_value (float or torch.Tensor): Init value of layer
            scale. Defaults to 0.
        patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
        layer_cfgs (Sequence | dict): Configs of each transformer layer in
            encoder. Defaults to an empty dict.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """

    arch_zoo = {
        **dict.fromkeys(
            ["s", "small"],
            {
                "embed_dims": 768,
                "num_layers": 8,
                "num_heads": 8,
                "feedforward_channels": 768 * 3,
            },
        ),
        **dict.fromkeys(
            ["b", "base"],
            {
                "embed_dims": 768,
                "num_layers": 12,
                "num_heads": 12,
                "feedforward_channels": 3072,
            },
        ),
        **dict.fromkeys(
            ["l", "large"],
            {
                "embed_dims": 1024,
                "num_layers": 24,
                "num_heads": 16,
                "feedforward_channels": 4096,
            },
        ),
        **dict.fromkeys(
            ["h", "huge"],
            {
                # The same as the implementation in MAE
                # <https://arxiv.org/abs/2111.06377>
                "embed_dims": 1280,
                "num_layers": 32,
                "num_heads": 16,
                "feedforward_channels": 5120,
            },
        ),
        **dict.fromkeys(
            ["m", "mammoth"],
            {
                "embed_dims": 1536,
                "num_layers": 40,
                "num_heads": 24,
                "feedforward_channels": 1536 * 4,
            },
        ),
        # https://github.com/facebookresearch/maws/blob/abf391e8e4626e0e2d144583b073e0fdf37aeb10/maws/model_builder.py#L77
        **dict.fromkeys(
            ["maws-6.5b"],
            {
                "embed_dims": 4096,
                "num_layers": 32,
                "num_heads": 32,
                "feedforward_channels": 4096 * 4,
            },
        ),
        # https://github.com/apple/ml-aim/blob/0b1dea9128f4734ae89252078e65aa102999407a/aim/torch/models.py#L214
        **dict.fromkeys(
            ["aim-7b"],
            {
                "embed_dims": 4096,
                "num_layers": 32,
                "num_heads": 32,
                "feedforward_channels": 4096 * 4,
            },
        ),
        **dict.fromkeys(
            ["eva-g", "eva-giant"],
            {
                # The implementation in EVA
                # <https://arxiv.org/abs/2211.07636>
                "embed_dims": 1408,
                "num_layers": 40,
                "num_heads": 16,
                "feedforward_channels": 6144,
            },
        ),
        **dict.fromkeys(
            ["deit-t", "deit-tiny"],
            {
                "embed_dims": 192,
                "num_layers": 12,
                "num_heads": 3,
                "feedforward_channels": 192 * 4,
            },
        ),
        **dict.fromkeys(
            ["deit-s", "deit-small", "dinov2-s", "dinov2-small"],
            {
                "embed_dims": 384,
                "num_layers": 12,
                "num_heads": 6,
                "feedforward_channels": 384 * 4,
            },
        ),
        **dict.fromkeys(
            ["deit-b", "deit-base"],
            {
                "embed_dims": 768,
                "num_layers": 12,
                "num_heads": 12,
                "feedforward_channels": 768 * 4,
            },
        ),
        **dict.fromkeys(
            ["dinov2-g", "dinov2-giant"],
            {
                "embed_dims": 1536,
                "num_layers": 40,
                "num_heads": 24,
                "feedforward_channels": 6144,
            },
        ),
        ## ensure that embed dim is divisible by num heads
        ## num of params anf flops increased by embed_dims and num_layers
        **dict.fromkeys(  ## this is vit-large
            ["0.3b", "sapiens_0.3b"],
            {
                "embed_dims": 1024,
                "num_layers": 24,
                "num_heads": 16,
                "feedforward_channels": 1024 * 4,
            },
        ),
        **dict.fromkeys(  ## this is vit-huge
            ["0.6b", "sapiens_0.6b"],
            {
                "embed_dims": 1280,
                "num_layers": 32,
                "num_heads": 16,
                "feedforward_channels": 1280 * 4,
            },
        ),
        **dict.fromkeys(  ## this is vit-g
            ["1b", "sapiens_1b"],
            {
                "embed_dims": 1536,
                "num_layers": 40,
                "num_heads": 24,
                "feedforward_channels": 1536 * 4,
            },
        ),
        **dict.fromkeys(
            ["2b", "sapiens_2b"],
            {
                "embed_dims": 1920,
                "num_layers": 48,
                "num_heads": 32,
                "feedforward_channels": 1920 * 4,
            },
        ),
    }
    num_extra_tokens = 1  # class token
    OUT_TYPES = {"raw", "cls_token", "featmap", "avg_featmap", "latent_token"}

    def __init__(
        self,
        arch="base",
        img_size=224,
        patch_size=16,
        in_channels=3,
        out_indices=-1,
        drop_rate=0.0,
        drop_path_rate=0.0,
        qkv_bias=True,
        norm_cfg=dict(type="LN", eps=1e-6),
        final_norm=False,
        out_type="cls_token",
        with_cls_token=False,
        frozen_stages=-1,
        interpolate_mode="bicubic",
        layer_scale_init_value=0.0,
        patch_cfg=dict(),
        layer_cfgs=dict(),
        pre_norm=False,
        norm_in: bool = True,
        init_cfg=None,
    ):
        super(SapiensVisionTransformer, self).__init__(init_cfg)
        self.norm_in = norm_in
        self.register_buffer("mean", torch.tensor(MEAN).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor(STD).view(1, 3, 1, 1))

        if isinstance(arch, str):
            arch = arch.lower()
            assert arch in set(
                self.arch_zoo
            ), f"Arch {arch} is not in default archs {set(self.arch_zoo)}"
            self.arch_settings = self.arch_zoo[arch]
        else:
            essential_keys = {
                "embed_dims",
                "num_layers",
                "num_heads",
                "feedforward_channels",
            }
            assert isinstance(arch, dict) and essential_keys <= set(
                arch
            ), f"Custom arch needs a dict with keys {essential_keys}"
            self.arch_settings = arch

        self.embed_dims = self.arch_settings["embed_dims"]
        self.num_layers = self.arch_settings["num_layers"]
        self.img_size = to_2tuple(img_size)

        # Set patch embedding
        _patch_cfg = dict(
            in_channels=in_channels,
            input_size=img_size,
            embed_dims=self.embed_dims,
            conv_type="Conv2d",
            kernel_size=patch_size,
            stride=patch_size,
            bias=not pre_norm,  # disable bias if pre_norm is used(e.g., CLIP)
        )
        _patch_cfg.update(patch_cfg)
        self.patch_embed = PatchEmbed(**_patch_cfg)
        self.patch_resolution = self.patch_embed.init_out_size
        num_patches = self.patch_resolution[0] * self.patch_resolution[1]

        # Set out type
        if out_type not in self.OUT_TYPES:
            raise ValueError(
                f"Unsupported `out_type` {out_type}, please "
                f"choose from {self.OUT_TYPES}"
            )
        self.out_type = out_type

        # Set cls token
        self.with_cls_token = with_cls_token
        if with_cls_token:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
        elif out_type != "cls_token":
            self.cls_token = None
            self.num_extra_tokens = 0
        else:
            raise ValueError('with_cls_token must be True when `out_type="cls_token"`.')

        # Set position embedding
        self.interpolate_mode = interpolate_mode
        # self.pos_embed = nn.Parameter(
        #     torch.zeros(1, num_patches + self.num_extra_tokens, self.embed_dims)
        # )
        self.register_buffer(
            "pos_embed",
            build_2d_sincos_position_embedding(self.patch_resolution, self.embed_dims),
        )
        # self._register_load_state_dict_pre_hook(self._prepare_pos_embed)

        self.drop_after_pos = nn.Dropout(p=drop_rate)

        if isinstance(out_indices, int):
            out_indices = [out_indices]
        assert isinstance(out_indices, Sequence), (
            f'"out_indices" must by a sequence or int, '
            f"get {type(out_indices)} instead."
        )
        for i, index in enumerate(out_indices):
            if index < 0:
                out_indices[i] = self.num_layers + index
            assert (
                0 <= out_indices[i] <= self.num_layers
            ), f"Invalid out_indices {index}"
        self.out_indices = out_indices

        # stochastic depth decay rule
        dpr = np.linspace(0, drop_path_rate, self.num_layers)

        self.layers = ModuleList()
        if isinstance(layer_cfgs, dict):
            layer_cfgs = [layer_cfgs] * self.num_layers
        for i in range(self.num_layers):
            _layer_cfg = dict(
                embed_dims=self.embed_dims,
                num_heads=self.arch_settings["num_heads"],
                feedforward_channels=self.arch_settings["feedforward_channels"],
                layer_scale_init_value=layer_scale_init_value,
                drop_rate=drop_rate,
                drop_path_rate=dpr[i],
                qkv_bias=qkv_bias,
                norm_cfg=norm_cfg,
            )
            _layer_cfg.update(layer_cfgs[i])
            self.layers.append(TransformerEncoderLayer(**_layer_cfg))

        self.frozen_stages = frozen_stages
        if pre_norm:
            self.pre_norm = build_norm_layer(norm_cfg, self.embed_dims)
        else:
            self.pre_norm = nn.Identity()

        self.final_norm = final_norm
        if final_norm:
            self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)
        if self.out_type == "avg_featmap":
            self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)

        # freeze stages only when self.frozen_stages > 0
        if self.frozen_stages > 0:
            self._freeze_stages()

        return

    @property
    def norm1(self):
        return self.ln1

    @property
    def norm2(self):
        return self.ln2

    def init_weights(self):
        super(SapiensVisionTransformer, self).init_weights()

        if not (
            isinstance(self.init_cfg, dict) and self.init_cfg["type"] == "Pretrained"
        ):
            if self.pos_embed is not None:
                trunc_normal_(self.pos_embed, std=0.02)

    # def _prepare_pos_embed(self, state_dict, prefix, *args, **kwargs):
    #     name = prefix + "pos_embed"
    #     if name not in state_dict.keys():
    #         return

    #     ckpt_pos_embed_shape = state_dict[name].shape
    #     if (
    #         not self.with_cls_token
    #         and ckpt_pos_embed_shape[1] == self.pos_embed.shape[1] + 1
    #     ):
    #         # Remove cls token from state dict if it's not used.
    #         state_dict[name] = state_dict[name][:, 1:]
    #         ckpt_pos_embed_shape = state_dict[name].shape
    #     elif not self.with_cls_token and ckpt_pos_embed_shape[1] % 2 == 1:
    #         # beware, this is modification to remove class token when interpolation is required.
    #         ## vit-huge: 257 -> 192
    #         print("Note: removing the class token from pretrained weights")
    #         state_dict[name] = state_dict[name][:, 1:]
    #         ckpt_pos_embed_shape = state_dict[name].shape

    #     if self.pos_embed.shape != ckpt_pos_embed_shape:
    #         from mmengine.logging import MMLogger

    #         logger = MMLogger.get_current_instance()
    #         logger.info(
    #             f"Resize the pos_embed shape from {ckpt_pos_embed_shape} "
    #             f"to {self.pos_embed.shape}."
    #         )

    #         ckpt_pos_embed_shape = to_2tuple(
    #             int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens))
    #         )
    #         pos_embed_shape = self.patch_embed.init_out_size

    #         state_dict[name] = resize_pos_embed(
    #             state_dict[name],
    #             ckpt_pos_embed_shape,
    #             pos_embed_shape,
    #             self.interpolate_mode,
    #             self.num_extra_tokens,
    #         )
    #     return

    @staticmethod
    def resize_pos_embed(*args, **kwargs):
        """Interface for backward-compatibility."""
        return resize_pos_embed(*args, **kwargs)

    def _freeze_stages(self):
        # freeze position embedding
        if self.pos_embed is not None:
            self.pos_embed.requires_grad = False
        # set dropout to eval model
        self.drop_after_pos.eval()
        # freeze patch embedding
        self.patch_embed.eval()
        for param in self.patch_embed.parameters():
            param.requires_grad = False
        # freeze pre-norm
        for param in self.pre_norm.parameters():
            param.requires_grad = False
        # freeze cls_token
        if self.cls_token is not None:
            self.cls_token.requires_grad = False
        # freeze layers
        for i in range(1, self.frozen_stages + 1):
            m = self.layers[i - 1]
            m.eval()
            for param in m.parameters():
                param.requires_grad = False
        # freeze the last layer norm
        if self.frozen_stages == len(self.layers):
            if self.final_norm:
                self.ln1.eval()
                for param in self.ln1.parameters():
                    param.requires_grad = False

            if self.out_type == "avg_featmap":
                self.ln2.eval()
                for param in self.ln2.parameters():
                    param.requires_grad = False

    def norm_input(self, x):
        if not self.norm_in:
            return x
        # make sure the input is normalized with mean 0.5 and std 0.5
        x = x * 0.5 + 0.5
        x = torch.clamp(x, 0, 1)
        return (x - self.mean) / self.std

    def resize_input(self, x):
        if x.shape[-2] != self.img_size[0] or x.shape[-1] != self.img_size[1]:
            x = F.interpolate(
                x.float(), size=self.img_size, mode="bilinear", align_corners=False
            ).to(x.dtype)
        return x

    def forward(self, x, out_type=None):
        B = x.shape[0]
        x = self.resize_input(x)
        x = self.norm_input(x)
        x, patch_resolution = self.patch_embed(x)

        if self.cls_token is not None:
            cls_token = self.cls_token.expand(B, -1, -1)
            x = torch.cat((cls_token, x), dim=1)

        x = x + resize_pos_embed(
            self.pos_embed,
            self.patch_resolution,
            patch_resolution,
            mode=self.interpolate_mode,
            num_extra_tokens=self.num_extra_tokens,
        )

        x = self.drop_after_pos(x)

        x = self.pre_norm(x)  ## B x (num tokens) x embed_dim

        outs = []
        for i, layer in enumerate(self.layers):
            x = layer(x)

            if i == len(self.layers) - 1 and self.final_norm:
                x = self.ln1(x)

            if i in self.out_indices:
                out = self._format_output(x, patch_resolution, out_type)
                outs.append(out)

        output = EncoderOutput(
            last_hidden_state=out,
            hidden_states=outs,
        )
        return output

    def _format_output(self, x, hw, out_type=None):
        out_type = out_type or self.out_type
        if out_type == "raw":
            return x
        if out_type == "cls_token":
            return x[:, 0]

        patch_token = x[:, self.num_extra_tokens :]
        if out_type == "featmap":
            B = x.size(0)
            # (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
            return patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
        if out_type == "avg_featmap":
            return patch_token.mean(dim=1, keepdim=True)

    def get_layer_depth(self, param_name: str, prefix: str = ""):
        """Get the layer-wise depth of a parameter.

        Args:
            param_name (str): The name of the parameter.
            prefix (str): The prefix for the parameter.
                Defaults to an empty string.

        Returns:
            Tuple[int, int]: The layer-wise depth and the num of layers.

        Note:
            The first depth is the stem module (``layer_depth=0``), and the
            last depth is the subsequent module (``layer_depth=num_layers-1``)
        """
        num_layers = self.num_layers + 2

        if not param_name.startswith(prefix):
            # For subsequent module like head
            return num_layers - 1, num_layers

        param_name = param_name[len(prefix) :]

        if param_name in ("cls_token", "pos_embed"):
            layer_depth = 0
        elif param_name.startswith("patch_embed"):
            layer_depth = 0
        elif param_name.startswith("layers"):
            layer_id = int(param_name.split(".")[1])
            layer_depth = layer_id + 1
        else:
            layer_depth = num_layers - 1

        return layer_depth, num_layers


if __name__ == "__main__":
    from mmengine.device import get_device

    device = get_device()
    model = SapiensVisionTransformer(
        arch="sapiens_0.6b",  # 你的 model_name
        img_size=(768, 1024),  # 你的 (image_size[1], image_size[0])
        patch_size=16,  # 你的 patch_size
        qkv_bias=True,
        final_norm=True,
        drop_path_rate=0.0,
        with_cls_token=False,
        out_type="featmap",
        patch_cfg=dict(padding=2),
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.6b/sapiens_0.6b_goliath_best_goliath_AP_609.pth",
        ),
    )

    model.init_weights()

    tensor = torch.randn([1, 3, 224, 224])

    outs = model(tensor)
    for out in outs:
        print(out.shape)
