import torch
import torch.nn as nn
import torch.nn.functional as F


class SelfAttention(nn.Module):
    def __init__(self, input_dim, num_layers=6, hidden_dim=768):
        super(SelfAttention, self).__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        # Self Attention layers
        self.self_attentions = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads=8, dropout=0.1)
            for _ in range(num_layers)
        ])

        # Feedforward network layers
        self.ffns = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 4),
                nn.ReLU(),
                nn.Linear(hidden_dim * 4, hidden_dim)
            )
            for _ in range(num_layers)
        ])

        # Linear projection for input sequences
        self.input_proj = nn.Linear(input_dim, hidden_dim)

    def forward(self, features):
        # Linear projection of input sequence
        features = self.input_proj(features)

        # Self attention
        for i in range(self.num_layers):
            # Self attention
            features, _ = self.self_attentions[i](
                features.permute(1, 0, 2), features.permute(1, 0, 2), features.permute(1, 0, 2),
                attn_mask=None, key_padding_mask=None, need_weights=False
            )

            # Feedforward network
            features = self.ffns[i](features)

        return features.permute(1, 0, 2)


class CrossAttention(nn.Module):
    def __init__(self, input_dim, num_layers=6, hidden_dim=768):
        super(CrossAttention, self).__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        # Self Attention layers for both text and image
        # from AVoiD-DF, not 'cross'
        self.self_attentions = SelfAttention(input_dim, num_layers, hidden_dim)

    def forward(self, features):
        output = self.self_attentions(features)

        return output


if __name__ == "__main__":
    features = torch.randn(1, 1, 512)

    cross_att = CrossAttention(input_dim=512, num_layers=3, hidden_dim=1024)

    output = cross_att(features)

    print(cross_att)

    inputs = torch.randn(1, 1, 512)
    net = cross_att

    import time
    for i in range(11):
        since = time.time()
        net.forward(inputs)
        print(time.time() - since)
    from thop import profile

    flops, params = profile(net, inputs=(inputs,))
    print('FLOPs = ' + str(2 * flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')
