import torch
import torch.nn as nn

from modules.hoi4abot.hoibot.modules.transformer.modules_attn.Attention import Transformer, Attention, Block
from modules.hoi4abot.hoibot.modules.transformer.modules_attn.build_positional_embedding import build_pose_embed

from torch.nn.modules.utils import _ntuple
from timm.models.layers import trunc_normal_
from functools import partial

from einops import rearrange, reduce, repeat

to_2tuple = _ntuple(2)


class TransformerHead(nn.Module):
    def __init__(self,
                 embed_dim=256,  # embedding dimension in the inner transformer
                 windows_size=6,  # total frames
                 depth=2,  # number of blocks
                 mlp_ratio=4.0,
                 add_id=1,
                 add_hoi_token=False,
                 concat=False, # False indicates just using the main branch, else both
                 hoi_token_type="None",
                 drop_rate=0.0, drop_path_rate=0.0, drop_path_type="progressive",
                 learnable_pos_embed="sinusoidal",
                 layer_scale_type=None, layer_scale_init_value=0.1, layer_norm_eps=1e-8,
                 attn_drop=0.2, num_heads=8, proj_drop=0, qk_scale=False, qkv_bias=True):

        super().__init__()
        assert drop_path_type in ["progressive",
                                  "uniform"], f"Drop path types are: [progressive, uniform]. Got {drop_path_type}."
        norm_layer = partial(nn.LayerNorm, eps=layer_norm_eps)
        total_num_patches = windows_size + 1
        self.pos_embed = build_pose_embed(learnable_pos_embed, embed_dim, total_num_patches, drop_rate)
        self.num_heads = num_heads
        self.concat = concat

        attn_target_head = partial(Attention, attn_drop=attn_drop,
                                   num_heads=num_heads, proj_drop=proj_drop,
                                   qk_scale=qk_scale, qkv_bias=qkv_bias)

        self.sa_head = Transformer(attn_target=attn_target_head,
                                   norm_layer=norm_layer,
                                   embed_dim=embed_dim, mlp_ratio=mlp_ratio,
                                   layer_scale_type=layer_scale_type, layer_scale_init_value=layer_scale_init_value,
                                   depth=depth, drop_rate=drop_rate, drop_path_type=drop_path_type,
                                   drop_path_rate=drop_path_rate)

        self.norm = norm_layer(embed_dim)
        self.pre_logits = nn.Identity()

        self.hoi_token_type = hoi_token_type
        self.id_=add_id
        if self.hoi_token_type == "learnable":
            self.hoi_token = nn.Parameter(torch.randn(1, 1, embed_dim))
            self.id_ += 1

        self.init_weights(learnable_pos_embed)

    def info_model(self):
        total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        print("[{} ({})] - {:.2f}M".format(f"Transformer", "Head", total_params / 10 ** 6))

    def init_weights(self, learnable_pos_embed):
        if learnable_pos_embed == "learnable":
            trunc_normal_(self.pos_embed, std=0.02)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, (nn.LayerNorm)):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

        @torch.jit.ignore
        def no_weight_decay(self):
            return {"pos_embed"}

    def get_pos_embed(self, temporal_idx):
        pos_embed_weight = self.pos_embed.pe[0, :temporal_idx.max() + 1, :]
        temporal_idx = (temporal_idx - temporal_idx.max(axis=1)[0][:, None] - 1) * (temporal_idx >= 0)
        pos_embed_weight = pos_embed_weight[temporal_idx]
        return pos_embed_weight

    def add_pos_embed(self, x, pos_embed_weight):
        x[:, self.id_:, :] = x[:, self.id_:, :] + pos_embed_weight
        return x

    def get_padding_mask(self, padding_mask):
        B = padding_mask.shape[0]
        padding_mask = torch.cat([torch.zeros((B, self.id_), device=padding_mask.device), padding_mask], axis=1)[..., None,:].type(torch.bool)
        return repeat(padding_mask, "b t1 t2 -> b n t1 t2", n=self.num_heads)

    def forward(self,
                mainbranch,
                secondbranch,
                temporal_idx,
                padding_mask):
        """
        :param humans:   [B, T+1, E] (cls + t times)
        :param objects:  [B T+1, E] (sem + t times)
        :param temporal_idx:   [B, T]
        :param padding_mask:   [B, T]

        :return:
        """
        pos_embed = self.get_pos_embed(temporal_idx)
        if self.concat:
            x = torch.cat([mainbranch, secondbranch], axis=-1)
        else:
            x = mainbranch

        if self.hoi_token_type == "learnable":
            x = torch.cat([repeat(self.hoi_token, "b t e -> (b n) t e", n=x.shape[0]), x], axis=1)

        x = self.add_pos_embed(x, pos_embed_weight=pos_embed)

        padding_mask = self.get_padding_mask(padding_mask)
        padding_mask = rearrange(padding_mask, "b n h o -> b n o h")

        x = self.sa_head(x, padding_mask=padding_mask)
        return x
