from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from mmengine.model import BaseModule
from mmhug.registry import HF_MODELS
from mmcv.cnn.bricks.transformer import FFN
from einops import rearrange
from mmengine.model.weight_init import trunc_normal_

from mmhug.models.custom_transformers.sapiens.utils import (
    build_2d_sincos_position_embedding,
    MultiheadAttention,
    CrossMultiheadAttention,
    SwiGLUFFNFused,
    build_norm_layer,
)


class MAEDecoderLayer(BaseModule):
    """Implements one encoder layer in Vision Transformer.

    Args:
        embed_dims (int): The feature dimension
        num_heads (int): Parallel attention heads
        feedforward_channels (int): The hidden dimension for FFNs
        layer_scale_init_value (float or torch.Tensor): Init value of layer
            scale. Defaults to 0.
        drop_rate (float): Probability of an element to be zeroed
            after the feed forward layer. Defaults to 0.
        attn_drop_rate (float): The drop out rate for attention output weights.
            Defaults to 0.
        drop_path_rate (float): Stochastic depth rate. Defaults to 0.
        num_fcs (int): The number of fully-connected layers for FFNs.
            Defaults to 2.
        qkv_bias (bool): enable bias for qkv if True. Defaults to True.
        ffn_type (str): Select the type of ffn layers. Defaults to 'origin'.
        act_cfg (dict): The activation config for FFNs.
            Defaults to ``dict(type='GELU')``.
        norm_cfg (dict): Config dict for normalization layer.
            Defaults to ``dict(type='LN')``.
        init_cfg (dict, optional): Initialization config dict.
            Defaults to None.
    """

    def __init__(
        self,
        embed_dims,
        num_heads,
        feedforward_channels,
        layer_scale_init_value=0.0,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        num_fcs=2,
        qkv_bias=True,
        ffn_type="origin",
        act_cfg=dict(type="GELU"),
        norm_cfg=dict(type="LN"),
        init_cfg=None,
    ):
        super(MAEDecoderLayer, self).__init__(init_cfg=init_cfg)

        self.embed_dims = embed_dims

        self.ln1 = build_norm_layer(norm_cfg, self.embed_dims)

        self.sa = MultiheadAttention(
            embed_dims=embed_dims,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate),
            qkv_bias=qkv_bias,
            layer_scale_init_value=layer_scale_init_value,
        )

        self.ln2 = build_norm_layer(norm_cfg, self.embed_dims)

        self.ca = CrossMultiheadAttention(
            embed_dims=embed_dims,
            num_heads=num_heads,
            attn_drop=attn_drop_rate,
            proj_drop=drop_rate,
            dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate),
            qkv_bias=qkv_bias,
            layer_scale_init_value=layer_scale_init_value,
        )

        if ffn_type == "origin":
            self.ffn = FFN(
                embed_dims=embed_dims,
                feedforward_channels=feedforward_channels,
                num_fcs=num_fcs,
                ffn_drop=drop_rate,
                dropout_layer=dict(type="DropPath", drop_prob=drop_path_rate),
                act_cfg=act_cfg,
                layer_scale_init_value=layer_scale_init_value,
            )
        elif ffn_type == "swiglu_fused":
            self.ffn = SwiGLUFFNFused(
                embed_dims=embed_dims,
                feedforward_channels=feedforward_channels,
                layer_scale_init_value=layer_scale_init_value,
            )
        else:
            raise NotImplementedError

        self.ln3 = build_norm_layer(norm_cfg, self.embed_dims)

    def init_weights(self):
        super(MAEDecoderLayer, self).init_weights()
        for m in self.ffn.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.normal_(m.bias, std=1e-6)

    def forward(self, x, prompt_tokens):

        x = x + self.sa(self.ln1(x))
        x = x + self.ca(self.ln2(x), prompt_tokens)
        x = x + self.ffn(self.ln3(x), identity=None)
        return x


@HF_MODELS.register_module(force=True)
class SlipmaeDecoder(BaseModule):
    def __init__(
        self,
        img_size=(1024, 768),
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 1024,
        decoder_embed_dim: int = 512,
        decoder_depth: int = 8,
        decoder_num_heads: int = 16,
        mlp_ratio: int = 4,
        norm_cfg: dict = dict(type="LN", eps=1e-6),
        predict_feature_dim: Optional[float] = None,
        num_extra_tokens: int = 3,  # 3 types of condition: id, nonvocal, vocal
        init_cfg: Optional[Union[List[dict], dict]] = None,
    ) -> None:
        super().__init__(init_cfg=init_cfg)
        self.patch_size = patch_size
        self.img_size = img_size
        self.patch_resolution = (
            img_size[0] // patch_size,
            img_size[1] // patch_size,
        )

        # used to convert the dim of features from encoder to the dim
        # compatible with that of decoder
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(
            torch.zeros(1, 1, decoder_embed_dim)
        )  ## learnable mask token. initialized as normal_
        torch.nn.init.normal_(self.mask_token, std=0.02)

        if num_extra_tokens > 0:
            self.decoder_extra_pos_embed = nn.Parameter(
                torch.zeros(1, num_extra_tokens, decoder_embed_dim)
            )
            trunc_normal_(self.decoder_extra_pos_embed, std=0.02)

        decoder_pos_embed = build_2d_sincos_position_embedding(
            self.patch_resolution, decoder_embed_dim, cls_token=False
        )
        self.register_buffer("decoder_pos_embed", decoder_pos_embed)

        self.decoder_blocks = nn.ModuleList(
            [
                MAEDecoderLayer(
                    decoder_embed_dim,
                    decoder_num_heads,
                    int(mlp_ratio * decoder_embed_dim),
                    qkv_bias=True,
                    norm_cfg=norm_cfg,
                )
                for _ in range(decoder_depth)
            ]
        )

        self.decoder_norm = build_norm_layer(norm_cfg, decoder_embed_dim)

        # Used to map features to pixels
        if predict_feature_dim is None:
            predict_feature_dim = patch_size**2 * in_chans
        self.decoder_pred = nn.Linear(decoder_embed_dim, predict_feature_dim, bias=True)

        self.num_extra_tokens = num_extra_tokens

    @property
    def decoder_norm(self):
        """The normalization layer of decoder."""
        return getattr(self, self.decoder_norm_name)

    def forward_trans(
        self,
        x: torch.Tensor,
        ids_restore: torch.Tensor,
    ) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): hidden features, which is of shape
                    B x num_extra_tokens + (L * mask_ratio) x C.
            ids_restore (torch.Tensor): ids to restore original image. B x (L * mask_ratio)
        """
        num_extra_tokens = self.num_extra_tokens

        x = self.decoder_embed(x)

        # split x into prompt tokens and mask tokens
        # [B, num_extra_tokens, C], [B, L * (1 - mask_ratio), C]
        extra_tokens, x = torch.split(
            x, [num_extra_tokens, x.size(1) - num_extra_tokens], dim=1
        )

        # append mask tokens to sequence
        ## B x (L * mask_ratio) x Decoder_dim
        mask_tokens = self.mask_token.repeat(
            x.shape[0], ids_restore.shape[1] - x.shape[1], 1
        )
        x = torch.cat([x, mask_tokens], dim=1)  ## B x L x Decoder_dim
        x = torch.gather(
            x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
        )

        x = x + self.decoder_pos_embed
        extra_tokens = extra_tokens + self.decoder_extra_pos_embed

        # apply Transformer blocks
        for blk in self.decoder_blocks:
            x = blk(x, extra_tokens)
        return x

    def to_pixel(self, x: torch.Tensor) -> torch.Tensor:
        x = self.decoder_norm(x)

        # (B N C) N = img_h // p * img_w // p, C = 3 * p * p
        x = self.decoder_pred(x)
        x = rearrange(
            x,
            "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
            b=x.shape[0],
            h=self.patch_resolution[0],
            w=self.patch_resolution[1],
            p1=self.patch_size,
            p2=self.patch_size,
        )
        return x

    def forward(
        self,
        x: torch.Tensor,
        ids_restore: torch.Tensor,
    ) -> torch.Tensor:
        """The forward function.

        The process computes the visible patches' features vectors and the mask
        tokens to output feature vectors, which will be used for
        reconstruction.

        Args:
            x (torch.Tensor): hidden features, which is of shape
                    B x (L * mask_ratio) + num_extra_tokens x C.
            ids_restore (torch.Tensor): ids to restore original image. B x L
        Returns:
            torch.Tensor: The reconstructed feature vectors, which is of
            shape B x (num_patches) x C.
        """
        # embed tokens
        ## B x (L * (1-mask_ratio) + num_extra_tokens) x C -> B x (L * (1-mask_ratio) + num_extra_tokens) x Decoder_dim
        x = self.forward_trans(x, ids_restore)

        x = self.to_pixel(x)

        return x
