import torch
import torch.nn as nn
from timm.models.vision_transformer import Block


class MbtEncoder(nn.Module):
    def __init__(self, num_latents=4, dim=512, heads=8, mlp_ratio=4):
        super(MbtEncoder, self).__init__()
        assert dim % heads == 0, "dim must be divided exactly by heads"

        # Latents for fusion
        self.num_latents = num_latents
        self.latents = nn.Parameter(torch.empty(1, num_latents, dim).normal_(std=0.02))
        self.scale_a = nn.Parameter(torch.zeros(1))
        self.scale_v = nn.Parameter(torch.zeros(1))

        self.audio_encoder = Block(dim, heads, mlp_ratio, qkv_bias=True)
        self.video_encoder = Block(dim, heads, mlp_ratio, qkv_bias=True)

    def attention(self, q, k, v):  # requires q,k,v to have same dim
        B, N, C = q.shape
        attn = (q @ k.transpose(-2, -1)) * (C ** -0.5)  # scaling
        attn = attn.softmax(dim=-1)
        x = (attn @ v).reshape(B, N, C)
        return x

    def fusion(self, audio_tokens, visual_tokens):
        BS = audio_tokens.shape[0]

        # concat all the tokens
        concat_ = torch.cat((audio_tokens, visual_tokens), dim=1)

        # cross attention (AV -->> latents)
        fused_latents = self.attention(q=self.latents.expand(BS, -1, -1), k=concat_, v=concat_)

        # cross attention (latents -->> AV)
        audio_tokens = audio_tokens + self.scale_a * self.attention(q=audio_tokens, k=fused_latents, v=fused_latents)
        visual_tokens = visual_tokens + self.scale_v * self.attention(q=visual_tokens, k=fused_latents, v=fused_latents)

        return audio_tokens, visual_tokens

    def forward(self, x, y):
        # x is audio feature, y is video feature
        # x, y shape are expected as [B, L, C]

        # Bottleneck Fusion
        x, y = self.fusion(x, y)

        x = self.audio_encoder(x)
        y = self.video_encoder(y)

        return x, y


class MBT(nn.Module):
    def __init__(self, depth=4, t=20, num_latents=4, dim=512, heads=8, drop_p=0.1):
        super(MBT, self).__init__()

        self.audio_cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        self.video_cls_token = nn.Parameter(torch.zeros(1, 1, dim))

        self.audio_pos_embed = nn.Parameter(torch.zeros(1, t + 1, dim))  # learnable position embedding
        self.video_pos_embed = nn.Parameter(torch.zeros(1, t + 1, dim))  # learnable position embedding

        self.audio_drop = nn.Dropout(drop_p)
        self.video_drop = nn.Dropout(drop_p)

        self.mbt_blocks = nn.ModuleList([
            MbtEncoder(num_latents=num_latents, dim=dim, heads=heads)
            for i in range(depth)])

        self.audio_norm = nn.LayerNorm(dim, eps=1e-6)
        self.video_norm = nn.LayerNorm(dim, eps=1e-6)

    def forward(self, audio, video):
        # audio, video shape are expected as [B, T, C]
        B = audio.shape[0]

        audio_cls_tokens = self.audio_cls_token.expand(B, -1, -1)  # [B, 1, C]
        video_cls_tokens = self.video_cls_token.expand(B, -1, -1)  # [B, 1, C]

        # add cls token
        audio = torch.cat((audio_cls_tokens, audio), dim=1)  # [B, 1+T, C]
        video = torch.cat((video_cls_tokens, video), dim=1)  # [B, 1+T, C]

        # add position embedding
        audio = audio + self.audio_pos_embed  # [B, 1+T, C]
        video = video + self.video_pos_embed  # [B, 1+T, C]

        # dropout
        audio = self.audio_drop(audio)  # [B, 1+T, C]
        video = self.video_drop(video)  # [B, 1+T, C]

        # fusion and interaction
        for blk in self.mbt_blocks:
            audio, video = blk(audio, video)  # [B, 1+T, C]

        # norm
        audio = self.audio_norm(audio)  # [B, 1+T, C]
        video = self.video_norm(video)  # [B, 1+T, C]

        audio_feature = audio[:, 0]  # [B, C]
        video_feature = video[:, 0]  # [B, C]

        return audio_feature, video_feature


if __name__ == "__main__":
    B, T, C = 1, 1, 512
    audio = torch.randn(B, T, C)  # [B, T, C]
    video = torch.randn(B, T, C)  # [B, T, C]

    depth = 6  # num of MbtEncoder
    num_latents = 4  # num of bottleneck tokens
    heads = 8  # num of head of self-Attention, C must be divided exactly by heads

    net = MBT(depth=depth, t=T, num_latents=num_latents, dim=C, heads=heads)
    # audio, video = net(audio, video)
    import time
    for i in range(11):
        since = time.time()
        net.forward(audio, video)
        print(time.time() - since)
    from thop import profile

    flops, params = profile(net, inputs=(audio, video,))
    print('FLOPs = ' + str(flops*2 / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')
