import torch
import torch.nn as nn

from modules.hoi4abot.hoibot.modules.transformer.modules_attn.CrossAttention import CrossTransformer, CrossAttention, \
    CrossAttentionBlock
from modules.hoi4abot.hoibot.modules.transformer.modules_attn.build_positional_embedding import build_pose_embed

from torch.nn.modules.utils import _ntuple
from timm.models.layers import trunc_normal_
from functools import partial

from einops import rearrange, reduce, repeat

to_2tuple = _ntuple(2)


class DualTransformer(nn.Module):
    def __init__(self,
                 embed_dim=256,  # embedding dimension in the inner transformer
                 windows_size=6,  # total frames
                 depth=12,  # number of blocks
                 dual_transformer_type="dual",  # Whether to use the dual transformer or a single
                 semantic_type="file",
                 mlp_ratio=4.0,
                 drop_rate=0.0, drop_path_rate=0.0, drop_path_type="progressive",
                 learnable_pos_embed="sinusoidal",
                 use_hoi_token=True,
                 mainbranch="objects",
                 image_cls_type="mean",
                 layer_scale_type=None, layer_scale_init_value=0.1, layer_norm_eps=1e-8,
                 attn_drop=0.2, num_heads=8, proj_drop=0, qk_scale=False, qkv_bias=True):

        super().__init__()
        assert drop_path_type in ["progressive",
                                  "uniform"], f"Drop path types are: [progressive, uniform]. Got {drop_path_type}."
        norm_layer = partial(nn.LayerNorm, eps=layer_norm_eps)
        self.embedding_dim = embed_dim
        self.dual_transformer_type = dual_transformer_type
        self.semantic_type = semantic_type
        self.mainbranch = mainbranch
        self.image_cls_type = image_cls_type
        self.idx_hoi_token = 1 if use_hoi_token else 0

        obj_id = 1 if self.semantic_type != "None" else 0
        hum_id = 1 if self.image_cls_type != "None" else 0
        self.secondbranch_id = hum_id if self.mainbranch == "objects" else obj_id
        self.mainbranch_id = self.idx_hoi_token + (obj_id if self.mainbranch == "objects" else hum_id)

        total_num_patches = windows_size + 1  # for the objects
        self.pos_embed = build_pose_embed(learnable_pos_embed, embed_dim, total_num_patches, drop_rate)

        self.num_heads = num_heads
        if self.dual_transformer_type == "dual":
            attn_target_first_stage = partial(CrossAttention,
                                              attn_drop=attn_drop, num_heads=num_heads, proj_drop=proj_drop,
                                              qk_scale=qk_scale, qkv_bias=qkv_bias)
            self.first_stage = CrossTransformer(attn_target=attn_target_first_stage,
                                                norm_layer=norm_layer,
                                                embed_dim=embed_dim, mlp_ratio=mlp_ratio,
                                                layer_scale_type=layer_scale_type,
                                                layer_scale_init_value=layer_scale_init_value,
                                                depth=depth, drop_rate=drop_rate, drop_path_type=drop_path_type,
                                                drop_path_rate=drop_path_rate)

        attn_target_second_stage = partial(CrossAttention,
                                           attn_drop=attn_drop, num_heads=num_heads, proj_drop=proj_drop,
                                           qk_scale=qk_scale, qkv_bias=qkv_bias)

        self.second_stage = Decoder(
            total_num_patches=total_num_patches,
            attn_target=attn_target_second_stage,
            decoder_depth=depth,
            decoder_embed_dim=embed_dim,
            embed_dim=embed_dim,
            learnable_pos_embed=learnable_pos_embed,
            share_pos_embed=True,
            layer_scale_type=layer_scale_type,
            act_layer="gelu",
            layer_norm="default",
            mlp_ratio=mlp_ratio,
            drop_rate=drop_rate
        )

        self.norm = norm_layer(embed_dim)
        self.pre_logits = nn.Identity()

        self.init_weights(learnable_pos_embed)

    def info_model(self):
        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("[{} ({})] - {:.2f}M".format(f"Transformer", "Dual", total_params / 10 ** 6))

    def init_weights(self, learnable_pos_embed):
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

        @torch.jit.ignore
        def no_weight_decay(self):
            return {"pos_embed"}

    def get_pos_embed(self, temporal_idx):
        pos_embed_weight = self.pos_embed.pe[0, :temporal_idx.max() + 1, :]
        temporal_idx = (temporal_idx - temporal_idx.max(axis=1)[0][:, None] - 1) * (temporal_idx >= 0)
        pos_embed_weight = pos_embed_weight[temporal_idx]
        return pos_embed_weight

    def add_pos_embed(self, mainbranch, secondbranch, pos_embed_weight):
        secondbranch[:, self.secondbranch_id:, :] = secondbranch[:, self.secondbranch_id:, :] + pos_embed_weight
        mainbranch[:, self.mainbranch_id:, :] = mainbranch[:, self.mainbranch_id:, :] + pos_embed_weight
        return mainbranch, secondbranch


    def get_padding_mask(self, padding_mask):
        B = padding_mask.shape[0]
        padding_mask_second = torch.cat([torch.zeros((B, self.secondbranch_id), device=padding_mask.device), padding_mask], axis=1)
        padding_mask_main = torch.cat([torch.zeros((B, self.mainbranch_id), device=padding_mask.device), padding_mask], axis=1)

        padding_mask_result = torch.matmul(padding_mask_second[..., None], padding_mask_main[..., None, :]).type(torch.bool)
        padding_mask_result = repeat(padding_mask_result, "b t1 t2 -> b n t1 t2", n=self.num_heads)
        return padding_mask_result

    def forward(self,
                humans: torch.Tensor,
                objects: torch.Tensor,
                temporal_idx,
                padding_mask):
        if self.mainbranch == "objects":
            mainbranch = objects
            secondbranch = humans
        else:
            secondbranch = objects
            mainbranch = humans

        pos_embed = self.get_pos_embed(temporal_idx)
        mainbranch, secondbranch = self.add_pos_embed(mainbranch=mainbranch, secondbranch=secondbranch, pos_embed_weight=pos_embed)

        padding_mask = self.get_padding_mask(padding_mask)

        if (self.dual_transformer_type == "dual"):
            secondbranch = self.first_stage(secondbranch, mainbranch, padding_mask=padding_mask)

            B, T, E = pos_embed.shape
            secondbranch = secondbranch + torch.cat([torch.zeros((B, self.secondbranch_id, E), device=pos_embed.device), pos_embed], axis=1)


        padding_mask = rearrange(padding_mask, "b n s m -> b n m s")
        mainbranch = self.second_stage(mainbranch, secondbranch, input_pos_embed=pos_embed, padding_mask=padding_mask)

        return mainbranch, secondbranch


class Decoder(nn.Module):
    def __init__(
            self,
            total_num_patches,
            attn_target,
            embed_dim,
            decoder_embed_dim=512,
            decoder_depth=8,
            drop_path_rate=0.0,
            mlp_ratio=4,
            drop_rate=0.0,
            layer_norm_eps=1e-6,
            share_pos_embed=True,
            learnable_pos_embed="sinusoidal",
            layer_scale_type=None,  # from cait; possible values are None, "per_channel", "scalar"
            layer_scale_init_value=1e-4,  # from cait; float
            **kwargs,
    ):
        super().__init__()
        self.drop_rate = drop_rate
        self.share_pos_embed = share_pos_embed
        self.learnable_pos_embed = learnable_pos_embed
        self.build_pos_embedding(
            share_pos_embed=share_pos_embed,
            learnable_pos_embed=learnable_pos_embed,
            total_num_patches=total_num_patches,
            embed_dim=embed_dim,
        )

        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
        self.decoder_embed_dim = decoder_embed_dim
        norm_layer = partial(nn.LayerNorm, eps=layer_norm_eps)
        self.norm = norm_layer(decoder_embed_dim)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, decoder_depth)]  # stochastic depth decay rule

        self.decoder_blocks = nn.ModuleList(
            [
                CrossAttentionBlock(
                    dim=decoder_embed_dim,
                    attn_target=attn_target,
                    mlp_ratio=mlp_ratio,
                    drop=drop_rate,
                    drop_path=dpr[i],
                    norm_layer=norm_layer,
                    layer_scale_type=layer_scale_type,
                    layer_scale_init_value=layer_scale_init_value,
                )
                for i in range(decoder_depth)
            ]
        )

    def build_pos_embedding(
            self,
            share_pos_embed,
            learnable_pos_embed,
            total_num_patches,
            embed_dim,
    ):
        if share_pos_embed is True:
            # we expect pos_embed to be passed during `forward`
            # sharing nn.Parameter objects across modules_XMTT is not recommended practice in PyTorch
            self.pos_embed = None
        else:
            self.pos_embed = build_pose_embed(learnable_pos_embed, embed_dim, total_num_patches, self.drop_rate)

    def forward(self, x_q, x_kv, padding_mask=None, input_pos_embed=None, add_pos_embed=False):
        pos_embed = input_pos_embed if self.share_pos_embed else self.pos_embed

        if add_pos_embed:
            x_q = pos_embed(x_q)
        x_q = self.decoder_embed(x_q)

        for i, blk in enumerate(self.decoder_blocks):
            x_q = blk(x_q, x_kv, padding_mask=padding_mask)
        return x_q


if __name__ == '__main__':
    B, T, No, Nh, E = 16, 4, 8, 1, 256
    objects = torch.randn(B, T + 1, No, E)
    humans = torch.randn(B, T + 1, Nh, E)

    dual_transformer = DualTransformer(embed_dim=E, windows_size=T + 1)

    print("Input objects", objects.shape)
    print("Input humans", humans.shape)
    humans, objects = dual_transformer(objects, humans)
    print("Output objects", objects.shape)
    print("Output humans", humans.shape)
