# multimodal_fusion.py
# Mid-fusion module combining audio features and text features.
# Uses a perceiver-style resampler to produce a fixed-length set of latent queries,
# then refines fused representation with a small QueryFormer transformer module.

from typing import Optional
import torch
import torch.nn as nn
from perceiver_resampler import PerceiverResampler
from query_former import QueryFormer


class MultiModalFusion(nn.Module):
    """
    Fuse audio and text features into a single representation.
    - audio_feats: (B, T_a, D_a) or (T_a, D_a)
    - text_feats:  (B, T_t, D_t) or (T_t, D_t)
    Returns fused_repr: (B, L, D_latent) where L is number of latent queries (resampler output)
    """

    def __init__(self,
                 audio_dim: int = 512,
                 text_dim: int = 768,
                 latent_dim: int = 512,
                 num_latents: int = 16,
                 num_query_layers: int = 2,
                 num_heads: int = 8,
                 dropout: float = 0.1):
        super().__init__()
        # project audio/text into common dimension
        self.audio_proj = nn.Linear(audio_dim, latent_dim)
        self.text_proj = nn.Linear(text_dim, latent_dim)
        # perceiver resampler: maps variable-length input tokens -> fixed-size latent set
        self.resampler = PerceiverResampler(latent_dim=latent_dim, num_latents=num_latents)
        # query-former: transformer-like refinement on latents
        self.query_former = QueryFormer(embed_dim=latent_dim, num_layers=num_query_layers, num_heads=num_heads, dropout=dropout)
        # final pooling or projection may be done externally

    def forward(self, audio_feats: torch.Tensor, text_feats: torch.Tensor) -> torch.Tensor:
        """
        audio_feats: (B, Ta, Da) or (Ta, Da)
        text_feats:  (B, Tt, Dt) or (Tt, Dt)
        Returns:
          fused: (B, num_latents, latent_dim)
        """
        # ensure batch dimension
        if audio_feats.dim() == 2:
            audio_feats = audio_feats.unsqueeze(0)
        if text_feats.dim() == 2:
            text_feats = text_feats.unsqueeze(0)

        a = self.audio_proj(audio_feats)   # (B, Ta, latent_dim)
        t = self.text_proj(text_feats)     # (B, Tt, latent_dim)
        # concatenate along time dimension
        concat = torch.cat([a, t], dim=1)  # (B, Ta+Tt, latent_dim)
        # resample to fixed latents
        latents = self.resampler(concat)   # (B, num_latents, latent_dim)
        # refine via query-former (self-attention + cross-attention style)
        fused = self.query_former(latents, context=concat)  # (B, num_latents, latent_dim)
        return fused
