
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, dim_in, dim_out, dim_hidden, num_layers):
        super().__init__()
        self.dim_in = dim_in
        self.dim_out = dim_out
        self.dim_hidden = dim_hidden
        self.num_layers = num_layers

        net = []
        for l in range(num_layers):
            net.append(nn.Linear(self.dim_in if l == 0 else self.dim_hidden, self.dim_out if l == num_layers - 1 else self.dim_hidden, bias=False))

        self.net = nn.ModuleList(net)
    
    def forward(self, x):
        for l in range(self.num_layers):
            x = self.net[l](x)
            if l != self.num_layers - 1:
                x = F.relu(x, inplace=True)
                
        return x 


class AWG(nn.Module):
    def __init__(self, visual_dim, audio_dim, hidden_dim=512):
        super(AWG, self).__init__()
        self.meta_net = nn.Sequential(
            nn.Linear(visual_dim + audio_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 2),
            nn.Softmax(dim=-1)
        )
    
    def forward(self, visual, audio):
        combined = torch.cat((visual, audio), dim=-1)
        return self.meta_net(combined)

class FusionModel(nn.Module):
    def __init__(self, visual_dim=1024, audio_dim=1024, hidden_dim=1024):
        super(FusionModel, self).__init__()
        
        self.visual_proj = nn.Linear(visual_dim, hidden_dim // 2)
        self.audio_proj = nn.Linear(audio_dim, hidden_dim // 2)
        self.local_proj_visual = nn.Linear(hidden_dim // 2, hidden_dim)
        self.gl_proj_visual = nn.Linear(hidden_dim // 2, 768)
        self.lg = MLP(1024 + 768, 512, 64, 3)
        self.mlp = MLP(1024, 1, 512, 3)
        self.meta_weight_generator = AWG(
            visual_dim=hidden_dim // 2, 
            audio_dim= hidden_dim // 2
        )

        self.meta_weight_generator2 = AWG(
            visual_dim=hidden_dim, 
            audio_dim= 768
        )

    def forward(self, visual_features, audio_features, local_features=None, gl_features=None):
        visual_proj = self.visual_proj(visual_features)
        audio_proj = self.audio_proj(audio_features)
        
        if local_features is not None:
            local_proj_v = self.local_proj_visual(visual_proj)
            visual_proj_l = local_features * local_proj_v

            gl_pro = self.gl_proj_visual(visual_proj)
            visual_proj_g = gl_features * gl_pro
            norm_weights1 = self.meta_weight_generator2(visual_proj_l, visual_proj_g)
            visual_proj_lg = torch.cat((visual_proj_l * norm_weights1[..., 0].unsqueeze(-1), visual_proj_g * norm_weights1[..., 1].unsqueeze(-1)), dim=-1)
            visual_proj_lg = self.lg(visual_proj_lg)

        norm_weights2 = self.meta_weight_generator(visual_proj_lg, audio_proj)
        fused_features = torch.cat((visual_proj_lg* norm_weights2[..., 0].unsqueeze(-1), audio_proj* norm_weights2[..., 0].unsqueeze(-1)), dim=-1)
        output = self.mlp(fused_features)
        return output


