import torch
import torch.nn as nn

from modules.hoi4abot.hoibot.modules.transformer.modules_attn.STTRAN_transformer import CrossAttentionEncoderLayer, SelfAttentionEncoderLayer, SATranformerEncoder, CrossEncoder
from modules.hoi4abot.hoibot.modules.transformer.modules_attn.build_positional_embedding import build_pose_embed

from torch.nn.modules.utils import _ntuple
from timm.models.layers import trunc_normal_
from functools import partial

from einops import rearrange, reduce, repeat

to_2tuple = _ntuple(2)


class TempTransformer(nn.Module):
    def __init__(self,
                 embed_dim=256,  # embedding dimension in the inner transformer
                 windows_size=6,  # total frames
                 depth=12,  # number of blocks
                 dual_transformer_type="dual",  # Whether to use the dual transformer or a single
                 semantic_type="file",
                 mlp_ratio=4.0,
                 drop_rate=0.0, drop_path_rate=0.0, drop_path_type="progressive",
                 learnable_pos_embed="sinusoidal",
                 use_hoi_token=True,
                 image_cls_type="mean",
                 isquery="objects",
                 layer_scale_type=None, layer_scale_init_value=0.1, layer_norm_eps=1e-8,
                 attn_drop=0.2,
                 num_heads=8, proj_drop=0, qk_scale=False, qkv_bias=True,
                 kv_sa=True,
                 kv_ffn=False,
                 is_entity=False,
                 ):

        super().__init__()
        assert drop_path_type in ["progressive", "uniform"], f"Drop path types are: [progressive, uniform]. Got {drop_path_type}."
        self.semantic_type = semantic_type
        self.mainbranch = isquery
        self.image_cls_type = image_cls_type
        self.idx_hoi_token = 1 if use_hoi_token else 0
        self.embedding_dim = embed_dim

        obj_id = 1 if self.semantic_type != "None" else 0
        hum_id = 1 if self.image_cls_type != "None" else 0
        self.mainbranch_id = self.idx_hoi_token + (obj_id if self.mainbranch == "objects" else hum_id)


        total_num_patches = windows_size + 1  # for the objects
        self.pos_embed = build_pose_embed(learnable_pos_embed, embed_dim, total_num_patches, drop_rate)

        self.num_heads = num_heads
        sa_encoder = SelfAttentionEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(mlp_ratio * embed_dim),
            dropout=drop_rate,
        )
        self.sa_encoder = SATranformerEncoder(sa_encoder, depth)

        self.apply(self._init_weights)

    def info_model(self):
        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("[{} ({})] - {:.2f}M".format(f"Transformer", "Temporal", total_params / 10 ** 6))
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

        @torch.jit.ignore
        def no_weight_decay(self):
            return {"pos_embed"}

    def get_pos_embed(self, temporal_idx):
        pos_embed_weight = self.pos_embed.pe[0, :temporal_idx.max() + 1, :]
        temporal_idx = (temporal_idx - temporal_idx.max(axis=1)[0][:, None] - 1) * (temporal_idx >= 0)
        pos_embed_weight = pos_embed_weight[temporal_idx]
        return pos_embed_weight

    def add_pos_embed(self, mainbranch, pos_embed_weight):
        mainbranch[:, self.mainbranch_id:, :] = mainbranch[:, self.mainbranch_id:, :] + pos_embed_weight
        return mainbranch


    def get_padding_mask(self, padding_mask):
        B = padding_mask.shape[0]
        padding_mask_main = torch.cat([torch.zeros((B, self.mainbranch_id), device=padding_mask.device), padding_mask], axis=1)
        padding_mask_result = torch.matmul(padding_mask_main[..., None], padding_mask_main[..., None, :]).type(torch.bool)
        padding_mask_result = repeat(padding_mask_result, "b t1 t2 -> b n t1 t2", n=self.num_heads)
        return padding_mask_result

    def forward(self,
                mainbranch: torch.Tensor,
                temporal_idx=None,
                padding_mask=None):

        pos_embed = self.get_pos_embed(temporal_idx)
        mainbranch = self.add_pos_embed(mainbranch, pos_embed_weight=pos_embed)

        padding_mask = self.get_padding_mask(padding_mask)

        mainbranch = rearrange(mainbranch, "b t e -> t b e")
        mainbranch = self.sa_encoder(mainbranch, src_key_padding_mask=padding_mask[:,0,0])

        mainbranch = rearrange(mainbranch, "t b e -> b t e")
        return mainbranch


if __name__ == '__main__':
    B, T, No, Nh, E = 16, 4, 8, 1, 256
    objects = torch.randn(B, T + 1, No, E)
    humans = torch.randn(B, T + 1, Nh, E)

    padding_mask = torch.zeros(16,5,8,256)
    attn_mask = torch.zeros(16,5,8,256)
    sc_transformer = TempTransformer(embed_dim=E, windows_size=T + 1)
    print(sc_transformer)
