# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from torch.nn import functional as F
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple, Union
from mmcv.cnn.bricks.transformer import PatchEmbed
from mmengine.model import ModuleList
from mmengine.model.weight_init import trunc_normal_
from mmhug.registry import HF_MODELS
from mmhug.models.custom_transformers.sapiens.vit_sapiens import TransformerEncoderLayer
from mmhug.models.custom_transformers.sapiens.utils import (
    build_norm_layer,
    resize_pos_embed,
    to_2tuple,
    build_2d_sincos_position_embedding,
)
from mmhug.models.custom_transformers.sapiens.base_backbone import BaseBackbone

from mmhug.models.modeling_output import MAEEncoderOutput

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]


@HF_MODELS.register_module(force=True)
class SlipmaeEncoder(BaseBackbone):
    """Vision Transformer.

    A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers
    for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_

    Args:
        arch (str | dict): Vision Transformer architecture. If use string,
            choose from 'small', 'base', 'large', 'deit-tiny', 'deit-small'
            and 'deit-base'. If use dict, it should have below keys:

            - **embed_dims** (int): The dimensions of embedding.
            - **num_layers** (int): The number of transformer encoder layers.
            - **num_heads** (int): The number of heads in attention modules.
            - **feedforward_channels** (int): The hidden dimensions in
              feedforward modules.

            Defaults to 'base'.
        img_size (int | tuple): The expected input image shape. Because we
            support dynamic input shape, just set the argument to the most
            common input image shape. Defaults to 224.
        patch_size (int | tuple): The patch size in patch embedding.
            Defaults to 16.
        in_channels (int): The num of input channels. Defaults to 3.
        drop_rate (float): Probability of an element to be zeroed.
            Defaults to 0.
        drop_path_rate (float): stochastic depth rate. Defaults to 0.
        qkv_bias (bool): Whether to add bias for qkv in attention modules.
            Defaults to True.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        out_type (str): The type of output features. Please choose from

            - ``"featmap"``: The feature map tensor from the patch tokens
              with shape (B, C, H, W).
            - ``"avg_featmap"``: The global averaged feature map tensor
              with shape (B, C).
            - ``"raw"``: The raw feature tensor includes patch tokens and
              class tokens with shape (B, L, C).

            Defaults to ``"cls_token"``.
        frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
            -1 means not freezing any parameters. Defaults to -1.
        interpolate_mode (str): Select the interpolate mode for position
            embeding vector resize. Defaults to "bicubic".
        layer_scale_init_value (float or torch.Tensor): Init value of layer
            scale. Defaults to 0.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """

    arch_zoo = {
        **dict.fromkeys(
            ["s", "small"],
            {
                "embed_dims": 768,
                "num_layers": 8,
                "num_heads": 8,
                "feedforward_channels": 768 * 3,
            },
        ),
        **dict.fromkeys(
            ["b", "base"],
            {
                "embed_dims": 768,
                "num_layers": 12,
                "num_heads": 12,
                "feedforward_channels": 3072,
            },
        ),
        **dict.fromkeys(
            ["l", "large"],
            {
                "embed_dims": 1024,
                "num_layers": 24,
                "num_heads": 16,
                "feedforward_channels": 4096,
            },
        ),
        **dict.fromkeys(
            ["h", "huge"],
            {
                # The same as the implementation in MAE
                # <https://arxiv.org/abs/2111.06377>
                "embed_dims": 1280,
                "num_layers": 32,
                "num_heads": 16,
                "feedforward_channels": 5120,
            },
        ),
        ## ensure that embed dim is divisible by num heads
        ## num of params anf flops increased by embed_dims and num_layers
        **dict.fromkeys(  ## this is vit-large
            ["0.3b", "sapiens_0.3b"],
            {
                "embed_dims": 1024,
                "num_layers": 24,
                "num_heads": 16,
                "feedforward_channels": 1024 * 4,
            },
        ),
        **dict.fromkeys(  ## this is vit-huge
            ["0.6b", "sapiens_0.6b"],
            {
                "embed_dims": 1280,
                "num_layers": 32,
                "num_heads": 16,
                "feedforward_channels": 1280 * 4,
            },
        ),
        **dict.fromkeys(  ## this is vit-g
            ["1b", "sapiens_1b"],
            {
                "embed_dims": 1536,
                "num_layers": 40,
                "num_heads": 24,
                "feedforward_channels": 1536 * 4,
            },
        ),
        **dict.fromkeys(
            ["2b", "sapiens_2b"],
            {
                "embed_dims": 1920,
                "num_layers": 48,
                "num_heads": 32,
                "feedforward_channels": 1920 * 4,
            },
        ),
    }

    def train(self, mode: bool = True):
        super().train(mode)
        if mode:
            self.pos_embed.requires_grad_(False)

    def __init__(
        self,
        arch="base",
        img_size=(1024, 768),
        patch_size=16,
        in_channels=3,
        drop_rate=0.0,
        drop_path_rate=0.0,
        qkv_bias=True,
        norm_cfg=dict(type="LN", eps=1e-6),
        frozen_stages=-1,
        interpolate_mode="bicubic",
        layer_scale_init_value=0.0,
        pre_norm=False,
        norm_in: bool = True,
        num_extra_tokens=3,  # id, nonvocal, vocal
        mask_ratio: float = 0.75,
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/backbone.pth",
        ),
    ):
        super(SlipmaeEncoder, self).__init__(init_cfg)
        self.norm_in = norm_in
        self.register_buffer("mean", torch.tensor(MEAN).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor(STD).view(1, 3, 1, 1))

        if isinstance(arch, str):
            arch = arch.lower()
            assert arch in set(
                self.arch_zoo
            ), f"Arch {arch} is not in default archs {set(self.arch_zoo)}"
            self.arch_settings = self.arch_zoo[arch]
        else:
            essential_keys = {
                "embed_dims",
                "num_layers",
                "num_heads",
                "feedforward_channels",
            }
            assert isinstance(arch, dict) and essential_keys <= set(
                arch
            ), f"Custom arch needs a dict with keys {essential_keys}"
            self.arch_settings = arch

        self.embed_dims = self.arch_settings["embed_dims"]
        self.num_layers = self.arch_settings["num_layers"]
        self.img_size = to_2tuple(img_size)

        self.patch_embed = PatchEmbed(
            in_channels=in_channels,
            input_size=img_size,
            embed_dims=self.embed_dims,
            conv_type="Conv2d",
            kernel_size=patch_size,
            stride=patch_size,
            bias=not pre_norm,  # disable bias if pre_norm is used(e.g., CLIP)
            padding=2,
        )
        self.patch_resolution = self.patch_embed.init_out_size

        # Set position embedding
        self.interpolate_mode = interpolate_mode
        # Must be loaded from pretrained model
        self.register_buffer(
            "pos_embed",
            build_2d_sincos_position_embedding(self.patch_resolution, self.embed_dims),
        )

        self.num_extra_tokens = num_extra_tokens
        if num_extra_tokens > 0:
            self.extra_pos_embed = nn.Parameter(
                torch.zeros(1, num_extra_tokens, self.embed_dims)
            )
            self.extra_tokens = nn.Parameter(
                torch.zeros(1, num_extra_tokens, self.embed_dims)
            )
            trunc_normal_(self.extra_pos_embed, std=0.02)
            torch.nn.init.normal_(self.extra_tokens, std=0.02)

        self.drop_after_pos = nn.Dropout(p=drop_rate)

        # stochastic depth decay rule
        dpr = np.linspace(0, drop_path_rate, self.num_layers)

        self.layers = ModuleList()
        for i in range(self.num_layers):
            _layer_cfg = dict(
                embed_dims=self.embed_dims,
                num_heads=self.arch_settings["num_heads"],
                feedforward_channels=self.arch_settings["feedforward_channels"],
                layer_scale_init_value=layer_scale_init_value,
                drop_rate=drop_rate,
                drop_path_rate=dpr[i],
                qkv_bias=qkv_bias,
                norm_cfg=norm_cfg,
            )
            self.layers.append(TransformerEncoderLayer(**_layer_cfg))

        self.frozen_stages = frozen_stages
        if pre_norm:
            self.pre_norm = build_norm_layer(norm_cfg, self.embed_dims)
        else:
            self.pre_norm = nn.Identity()

        self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)

        # ------------------------ MAE ------------------------
        self.mask_ratio = mask_ratio
        # compute each patch's weight of id, nonvocal, vocal

    @property
    def norm1(self):
        return self.ln1

    @staticmethod
    def resize_pos_embed(*args, **kwargs):
        """Interface for backward-compatibility."""
        return resize_pos_embed(*args, **kwargs)

    def _freeze_stages(self):
        # freeze position embedding
        if self.pos_embed is not None:
            self.pos_embed.requires_grad = False
        # set dropout to eval model
        self.drop_after_pos.eval()
        # freeze patch embedding
        self.patch_embed.eval()
        for param in self.patch_embed.parameters():
            param.requires_grad = False
        # freeze pre-norm
        for param in self.pre_norm.parameters():
            param.requires_grad = False
        # freeze cls_token
        if self.cls_token is not None:
            self.cls_token.requires_grad = False
        # freeze layers
        for i in range(1, self.frozen_stages + 1):
            m = self.layers[i - 1]
            m.eval()
            for param in m.parameters():
                param.requires_grad = False
        # freeze the last layer norm
        if self.frozen_stages == len(self.layers):
            if self.final_norm:
                self.ln1.eval()
                for param in self.ln1.parameters():
                    param.requires_grad = False

    def norm_input(self, x):
        if not self.norm_in:
            return x
        # make sure the input is normalized with mean 0.5 and std 0.5
        x = x * 0.5 + 0.5
        x = torch.clamp(x, 0, 1)
        return (x - self.mean) / self.std

    def resize_input(self, x):
        if x.shape[-2] != self.img_size[0] or x.shape[-1] != self.img_size[1]:
            x = F.interpolate(
                x, size=self.img_size, mode="bilinear", align_corners=False
            )
        return x

   
    def random_masking(
        self,
        x: torch.Tensor,
        mask_ratio: float = 0.75,
        mask_prob: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate the mask for MAE Pre-training.

        Args:
            x (torch.Tensor): Image with data augmentation applied,
                which is of shape B x L x C.
            mask_ratio (float): The mask ratio of total patches.
                Defaults to 0.75.
            mask_prob (torch.Tensor, optional): Probabilities for masking each
                patch, shape [B, L]. If provided, patches with higher prob
                are more likely to be masked. Defaults to None, which means
                uniform masking.

        Returns:
            Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
                masked image, mask and the ids to restore original image.
                - `x_masked (torch.Tensor): masked image.
                - `mask (torch.Tensor): mask used to mask image.
                - `ids_restore (torch.Tensor): ids to restore original image.
        """
        N, L, D = x.shape  # batch, length, dim
        len_keep = int(L * (1 - mask_ratio))

        if mask_prob is None:
            # uniform masking
            noise = torch.rand(N, L, device=x.device)  # noise in [0, 1]
            # sort noise for each sample
            ids_shuffle = torch.argsort(
                noise, dim=1
            )  # ascend: small is keep, large is remove
            ids_restore = torch.argsort(ids_shuffle, dim=1)
        else:
            # non-uniform masking based on mask_prob
            # Gumbel-Top-k trick
            # We want to keep `len_keep` patches.
            # The probability of keeping a patch is `1 - mask_prob`.
            p_keep = 1 - mask_prob
            # To prevent log(0), we clamp the probabilities
            log_p_keep = torch.log(p_keep.clamp(min=1e-6))
            # Gumbel noise
            gumbel_noise = -torch.log(
                -torch.log(
                    torch.rand(N, L, device=x.device).clamp(min=1e-6, max=1 - 1e-6)
                )
            )
            scores = log_p_keep + gumbel_noise

            # We keep the patches with the highest scores.
            # `argsort` with descending=True gives indices of elements from large to small.
            ids_shuffle = torch.argsort(scores, dim=1, descending=True)
            ids_restore = torch.argsort(ids_shuffle, dim=1)

        # keep the first subset
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))

        # generate the binary mask: 0 is keep, 1 is remove
        mask = torch.ones([N, L], device=x.device)
        mask[:, :len_keep] = 0
        # unshuffle to get the binary mask
        mask = torch.gather(mask, dim=1, index=ids_restore)

        return x_masked, mask, ids_restore

    def forward(
        self, x, do_mask: bool = True, mask_ratio=None, mask_prob=None
    ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
        """
        Forward pass of the model.

        Args:
            x (torch.Tensor): Input tensor of shape (B, C, H, W).
            do_mask (bool, optional): Whether to apply masking. Defaults to True. If do_mask is false, mask_ratio will be directly ignored.
            mask_ratio (float, optional): Ratio of patches to mask. If None, self.mask_ratio will be used. If do_mask is False or mask is given, mask_ratio will be ignored.
            mask_prob (torch.Tensor, optional): Probabilities for masking each
                patch, shape [B, L]. If provided, patches with higher prob
                are more likely to be masked. Defaults to None, which means
                uniform masking.

        Returns:
            Union[Tensor, Tuple[Tensor, Tensor]]: Output tensor or tuple of output tensors.
        """
        # check param

        B = x.shape[0]
        x = self.resize_input(x)
        x = self.norm_input(x)
        x, patch_resolution = self.patch_embed(x)

        pos_emb = resize_pos_embed(
            self.pos_embed,
            self.patch_resolution,
            patch_resolution,
            mode=self.interpolate_mode,
            num_extra_tokens=0,
        )

        x = x + pos_emb

        # mae drop
        if do_mask:
            x, mask, ids_restore = self.random_masking(
                x, self.mask_ratio if mask_ratio is None else mask_ratio, mask_prob
            )
        else:
            mask = None
            ids_restore = None

        # add extra tokens

        if self.num_extra_tokens > 0:
            extra_tokens = self.extra_tokens.expand(B, -1, -1) + self.extra_pos_embed
            x = torch.cat((extra_tokens, x), dim=1)

        x = self.drop_after_pos(x)

        x = self.pre_norm(x)  ## B x (num tokens) x embed_dim

        for i, layer in enumerate(self.layers):
            x = layer(x)

        x = self.ln1(x)

        if do_mask:
            # b n c
            extra_token, patch_token = (
                x[:, : self.num_extra_tokens],
                x[:, self.num_extra_tokens :],
            )
        else:
            # b n c, b c h w
            extra_token, patch_token = self._format_output(x, patch_resolution)

        return MAEEncoderOutput(
            prompt_tokens=extra_token,
            patch_tokens=patch_token,
            mask=mask,
            ids_restore=ids_restore,
        )

    def _format_output(self, x, hw):
        patch_token = x[:, self.num_extra_tokens :]
        B = x.size(0)
        # (B, N, C) -> (B, H, W, C) -> (B, C, H, W)
        patch_token = patch_token.reshape(B, *hw, -1).permute(0, 3, 1, 2)
        if self.num_extra_tokens > 0:
            extra_token = x[:, : self.num_extra_tokens]
            return extra_token, patch_token
        return None, patch_token


if __name__ == "__main__":
    import numpy as np
    import torch
    from torchvision.io import read_image, write_png
    from mmengine.device import get_device
    from mmhug.models.custom_transformers.sapiens.heatmap_head import (
        HeatmapHead,
    )
    from mmhug.utils.vis_utils import apply_block_mask_to_rgb

    device = get_device()
    dtype = torch.bfloat16

    backbone = HF_MODELS.build(
        dict(
            type="SapiensVisionTransformer",
            arch="sapiens_0.3b",
            norm_in=True,  # the input is normalized with mean 0.5 and std 0.5, we need to renormalize with Sapiens mean and std
            img_size=(1024, 768),
            patch_size=16,
            in_channels=3,
            out_indices=-1,
            drop_rate=0.0,
            drop_path_rate=0.0,
            qkv_bias=True,
            norm_cfg=dict(type="LN", eps=1e-6),
            final_norm=True,
            with_cls_token=False,
            frozen_stages=-1,
            interpolate_mode="bicubic",
            layer_scale_init_value=0.0,
            patch_cfg=dict(padding=2),
            layer_cfgs=dict(),
            pre_norm=False,
            out_type="featmap",
            init_cfg=dict(
                type="Pretrained",
                checkpoint="checkpoints/sapiens-pose-0.3b/backbone.pth",
            ),
        ),
    ).to(device, dtype)

    backbone.init_weights()
    heatmap_head = HeatmapHead(
        in_channels=1024,
        out_channels=308,
        deconv_out_channels=(768, 768),  ## this will 2x at each step. so total is 4x
        deconv_kernel_sizes=(4, 4),
        conv_out_channels=(768, 768),
        conv_kernel_sizes=(1, 1),
        decoder=dict(
            type="UDPHeatmap", input_size=(768, 1024), heatmap_size=(256, 192), sigma=6
        ),
        init_cfg=dict(
            type="Pretrained",
            checkpoint="checkpoints/sapiens-pose-0.3b/heatmap_head.pth",
        ),
    ).to(device, dtype)
    heatmap_head.init_weights()

    img = "demo_assets/ref_img512.png"
    img = read_image(img)
    img = img.unsqueeze(0)
    img = img.float() / 255.0
    img = (img - 0.5) / 0.5
    img = img.to(device, dtype)
    feats = backbone(img, out_type="featmap").last_hidden_state
    heatmap, preds = heatmap_head(feats, decode_kpt=True)

    encoder = HF_MODELS.build(
        dict(  # 1024
            type="SapiensMotionExtractorV2",
            arch="sapiens_0.3b",
            norm_in=True,  # the input is normalized with mean 0.5 and std 0.5, we need to renormalize with Sapiens mean and std
            img_size=(512, 512),
            patch_size=16,
            in_channels=3,
            drop_rate=0.0,
            drop_path_rate=0.0,
            qkv_bias=True,
            norm_cfg=dict(type="LN", eps=1e-6),
            frozen_stages=-1,
            interpolate_mode="bicubic",
            layer_scale_init_value=0.0,
            pre_norm=False,
            num_extra_tokens=3,
            mask_ratio=0.75,
            init_cfg=dict(
                type="Pretrained",
                checkpoint="checkpoints/sapiens-pose-0.3b/backbone.pth",
            ),
        )
    ).to(device, dtype)

    mask_prob = (
        1
        - torch.max(
            F.interpolate(heatmap, (32, 32), mode="bilinear", align_corners=False),
            dim=1,
        ).values
    )
    output_keepface = encoder(img, do_mask=True, mask_prob=mask_prob.view(1, -1))
    mask_keepface = output_keepface.mask
    mask_keepface = mask_keepface.view(1, 32, 32)

    output_random = encoder(img, do_mask=True, mask_prob=None)
    mask_random = output_random.mask
    mask_random = mask_random.view(1, 32, 32)

    print(output_keepface.ids_restore)
    print(output_random.ids_restore)
    # draw heatmap
    img = img * 0.5 + 0.5
    rgb_keepface = (
        apply_block_mask_to_rgb(img.unsqueeze(0), mask_keepface.unsqueeze(0), alpha=0.5)
        .float()
        .cpu()
    )
    rgb_keepface = rgb_keepface.squeeze(0).squeeze(0)

    rgb_random = (
        apply_block_mask_to_rgb(img.unsqueeze(0), mask_random.unsqueeze(0), alpha=0.5)
        .float()
        .cpu()
    )
    rgb_random = rgb_random.squeeze(0).squeeze(0)

    write_png(
        (torch.cat([rgb_keepface, rgb_random], dim=-1) * 255).to(torch.uint8),
        "mask_img.png",
    )
