import math
import torch
import torch.nn as nn


class CrossAttentionWithResidual(nn.Module):
    def __init__(self, input_dim_text, input_dim_image, num_layers=6, hidden_dim=768):
        super(CrossAttentionWithResidual, self).__init__()
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim

        # Cross Attention layers for text and image
        self.cross_attentions_text = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads=8, dropout=0.1)
            for _ in range(num_layers)
        ])
        self.cross_attentions_image = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads=8, dropout=0.1)
            for _ in range(num_layers)
        ])

        # Self Attention layers for text and image
        self.self_attentions_text = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads=8, dropout=0.1)
            for _ in range(num_layers)
        ])
        self.self_attentions_image = nn.ModuleList([
            nn.MultiheadAttention(hidden_dim, num_heads=8, dropout=0.1)
            for _ in range(num_layers)
        ])

        # Feedforward network layers for text and image
        self.ffns_text = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 4),
                nn.ReLU(),
                nn.Linear(hidden_dim * 4, hidden_dim)
            )
            for _ in range(num_layers)
        ])
        self.ffns_image = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim * 4),
                nn.ReLU(),
                nn.Linear(hidden_dim * 4, hidden_dim)
            )
            for _ in range(num_layers)
        ])

        # Linear projection for input sequences
        self.input_proj_text = nn.Linear(input_dim_text, hidden_dim)
        self.input_proj_image = nn.Linear(input_dim_image, hidden_dim)

    def forward(self, text_features, image_features):
        # Linear projection of input sequences
        text_features = self.input_proj_text(text_features)
        image_features = self.input_proj_image(image_features)

        # Cross attention followed by self attention and feedforward network
        for i in range(self.num_layers):
            # Cross attention from text to image
            text_features_res = text_features
            text_features, _ = self.cross_attentions_text[i](
                text_features.permute(1, 0, 2), image_features.permute(1, 0, 2), image_features.permute(1, 0, 2),
                attn_mask=None, key_padding_mask=None, need_weights=False
            )
            text_features += text_features_res

            # Cross attention from image to text
            image_features_res = image_features
            image_features, _ = self.cross_attentions_image[i](
                image_features.permute(1, 0, 2), text_features.permute(1, 0, 2), text_features.permute(1, 0, 2),
                attn_mask=None, key_padding_mask=None, need_weights=False
            )
            image_features += image_features_res

            # Self attention for text
            text_features_res = text_features
            text_features, _ = self.self_attentions_text[i](
                text_features.permute(1, 0, 2), text_features.permute(1, 0, 2), text_features.permute(1, 0, 2),
                attn_mask=None, key_padding_mask=None, need_weights=False
            )
            text_features += text_features_res

            # Self attention for image
            image_features_res = image_features
            image_features, _ = self.self_attentions_image[i](
                image_features.permute(1, 0, 2), image_features.permute(1, 0, 2), image_features.permute(1, 0, 2),
                attn_mask=None, key_padding_mask=None, need_weights=False
            )
            image_features += image_features_res

            # Feedforward network for text
            text_features_res = text_features
            text_features = self.ffns_text[i](text_features)
            text_features += text_features_res

            # Feedforward network for image
            image_features_res = image_features
            image_features = self.ffns_image[i](image_features)
            image_features += image_features_res

        return text_features.permute(1, 0, 2), image_features.permute(1, 0, 2)


if __name__ == "__main__":
    inputs1 = torch.randn(1, 1, 512)
    inputs2 = torch.randn(1, 1, 512)
    net = CrossAttentionWithResidual(input_dim_text=512, input_dim_image=512, num_layers=6, hidden_dim=768)
    import time
    for i in range(11):
        since = time.time()
        net.forward(inputs1, inputs2)
        print(time.time() - since)

    from thop import profile

    flops, params = profile(net, inputs=(inputs1, inputs2,))
    print('FLOPs = ' + str(2 * flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')
