import torch
import torch.nn as nn

from modules.hoi4abot.hoibot.modules.transformer.modules_attn.STTRAN_transformer import CrossAttentionEncoderLayer, SelfAttentionEncoderLayer, SATranformerEncoder, CrossEncoder
from modules.hoi4abot.hoibot.modules.transformer.modules_attn.build_positional_embedding import build_pose_embed

from torch.nn.modules.utils import _ntuple
from timm.models.layers import trunc_normal_
from functools import partial

from einops import rearrange, reduce, repeat

to_2tuple = _ntuple(2)


class SC_transformer(nn.Module):
    def __init__(self,
                 embed_dim=256,  # embedding dimension in the inner transformer
                 windows_size=6,  # total frames
                 depth=12,  # number of blocks
                 dual_transformer_type="dual",  # Whether to use the dual transformer or a single
                 semantic_type="file",
                 mlp_ratio=4.0,
                 drop_rate=0.0, drop_path_rate=0.0, drop_path_type="progressive",
                 learnable_pos_embed="sinusoidal",
                 use_cls_token=True,
                 image_cls_type="mean",
                 isquery="objects",
                 layer_scale_type=None, layer_scale_init_value=0.1, layer_norm_eps=1e-8,
                 attn_drop=0.2,
                 num_heads=8, proj_drop=0, qk_scale=False, qkv_bias=True,
                 kv_sa=True,
                 kv_ffn=False,
                 is_entity=False,
                no_reduction=False,
                 ):

        super().__init__()
        assert drop_path_type in ["progressive", "uniform"], f"Drop path types are: [progressive, uniform]. Got {drop_path_type}."
        self.embedding_dim = embed_dim

        self.semantic_type = semantic_type
        self.isquery = isquery
        self.image_cls_type = image_cls_type
        self.id_ = 1 if (image_cls_type != "None" and not is_entity) else 0

        self.idx_cls_token = 1 if not use_cls_token else 0

        total_num_patches = windows_size + 1  # for the objects
        self.pos_embed = build_pose_embed(learnable_pos_embed, embed_dim, total_num_patches, drop_rate)

        self.num_heads = num_heads
        sa_encoder = SelfAttentionEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(mlp_ratio * embed_dim),
            dropout=drop_rate,
        )
        self.sa_encoder = SATranformerEncoder(sa_encoder, depth)
        ca_encoder = CrossAttentionEncoderLayer(
            d_model=embed_dim,
            d_kv=embed_dim,
            nhead=num_heads,
            dim_feedforward=int(mlp_ratio * embed_dim),
            dropout=drop_rate,
            q_sa=False,
            q_ffn=False,
            kv_sa=kv_sa,
            kv_ffn=kv_ffn,
        )
        self.ca_encoder = CrossEncoder(ca_encoder, 1)

        self.sa_encoder_2 = SATranformerEncoder(sa_encoder, depth)

        self.apply(self._init_weights)

    def info_model(self):
        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("[{} ({})] - {:.2f}M".format(f"Transformer", "SC", total_params / 10 ** 6))
    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

        @torch.jit.ignore
        def no_weight_decay(self):
            return {"pos_embed"}



    def get_pos_embed(self, temporal_idx):
        pos_embed_weight = self.pos_embed.pe[0, :temporal_idx.max() + 1, :]
        temporal_idx = (temporal_idx - temporal_idx.max(axis=1)[0][:, None] - 1) * (temporal_idx >= 0)
        pos_embed_weight = pos_embed_weight[temporal_idx]
        return pos_embed_weight

    def add_pos_embed(self, q, kv, pos_embed_weight):
        id_ = 2 if self.semantic_type != "None" else 1
        id_ = id_ - self.idx_cls_token
        kv[:, self.id_:, :] = kv[:, self.id_:, :] + pos_embed_weight
        q[:, id_:, :] = q[:, id_:, :] + pos_embed_weight
        return  q, kv, id_


    def get_padding_mask(self, padding_mask):
        B = padding_mask.shape[0]
        id_ = 2 if self.semantic_type != "None" else 1
        id_ = id_ - self.idx_cls_token
        if self.image_cls_type != "None":
            padding_mask_humans = torch.cat([torch.zeros((B, 2), device=padding_mask.device), padding_mask], axis=1)
        else:
            padding_mask_humans = padding_mask.type(torch.float)

        padding_mask_objects = torch.cat([torch.zeros((B, id_), device=padding_mask.device), padding_mask], axis=1)
        padding_mask_result = torch.matmul(padding_mask_humans[..., None], padding_mask_objects[..., None, :]).type(torch.bool)
        return padding_mask_result[...,0]

    def forward(self,
                humans: torch.Tensor,
                objects: torch.Tensor,
                temporal_idx=None,
                padding_mask=None):
        """
        :param humans:   [B, T+1, E]
        :param objects:  [B T+2, E]
        :param temporal_idx:   [B, T]
        :param padding_mask:   [B, T]

        :return:
        """

        if self.isquery=="objects":
            q = objects
            kv = humans
        else:
            kv = objects
            q = humans

        pos_embed = self.get_pos_embed(temporal_idx)
        q, kv, id_= self.add_pos_embed(q=q, kv=kv, pos_embed_weight=pos_embed)

        padding_mask = self.get_padding_mask(padding_mask)

        q = rearrange(q, "b t e -> t b e")
        kv = rearrange(kv, "b t e -> t b e")
        pos_embed = rearrange(pos_embed, "b t e -> t b e")

        q = self.sa_encoder(q, src_key_padding_mask=padding_mask)

        q[id_:] = q[id_:] +pos_embed

        B, lh = padding_mask.shape
        lo = objects.shape[1]
        if lo!=lh:
            padding_mask_kv = padding_mask[:, -lo:]
        q, kv, _ = self.ca_encoder(q, kv, attn_mask=None, q_padding_mask=padding_mask, kv_padding_mask=padding_mask_kv)

        q = self.sa_encoder_2(q,  src_key_padding_mask=padding_mask)
        q = rearrange(q, "t b e -> b t e")
        kv = rearrange(kv, "t b e -> b t e")
        if self.isquery=="objects":
            return kv, q
        else:
            return q, kv


if __name__ == '__main__':
    B, T, No, Nh, E = 16, 4, 8, 1, 256
    objects = torch.randn(B, T + 1, No, E)
    humans = torch.randn(B, T + 1, Nh, E)

    padding_mask = torch.zeros(16,5,8,256)
    attn_mask = torch.zeros(16,5,8,256)
    sc_transformer = SC_transformer(embed_dim=E, windows_size=T + 1)
    print(sc_transformer)
