# query_former.py
# QueryFormer: transformer-style refinement module that lets latent queries attend to each other
# and optionally cross-attend to context tokens. Minimal implementation with LayerNorm, MHA and MLP.

from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F


class CrossAttentionBlock(nn.Module):
    """
    Single block that performs (1) self-attention on queries, (2) cross-attention from queries to context, (3) MLP.
    """

    def __init__(self, embed_dim: int = 512, num_heads: int = 8, mlp_ratio: float = 4.0, dropout: float = 0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, dropout=dropout, batch_first=True)
        self.norm2 = nn.LayerNorm(embed_dim)
        hidden = int(embed_dim * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, embed_dim),
            nn.Dropout(dropout)
        )
        self.norm3 = nn.LayerNorm(embed_dim)

    def forward(self, queries: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        queries: (B, K, D)
        context: optional (B, T, D) to cross-attend to
        """
        # self-attention
        q = self.norm1(queries)
        sa_out, _ = self.self_attn(q, q, q)
        queries = queries + sa_out

        # cross-attention (queries as target, context as source)
        if context is not None:
            q2 = self.norm2(queries)
            ca_out, _ = self.cross_attn(q2, context, context)
            queries = queries + ca_out

        # MLP
        q3 = self.norm3(queries)
        queries = queries + self.mlp(q3)
        return queries


class QueryFormer(nn.Module):
    """
    Stacks multiple CrossAttentionBlock layers to refine latent queries.
    """

    def __init__(self, embed_dim: int = 512, num_layers: int = 2, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.layers = nn.ModuleList([CrossAttentionBlock(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout) for _ in range(num_layers)])

    def forward(self, queries: torch.Tensor, context: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        queries: (B, K, D)
        context: optional (B, T, D)
        """
        x = queries
        for l in self.layers:
            x = l(x, context=context)
        return x
