# perceiver_resampler.py
# Lightweight Perceiver resampler: maps variable-length sequence to a fixed number of latent queries
# using cross-attention from learnable latents to inputs (single-head or multi-head attention).
# This is a minimal, easy-to-understand prototype.

from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F


class PerceiverResampler(nn.Module):
    """
    Given input tokens X: (B, T, D), produce fixed K latent queries L: (B, K, D)
    using cross-attention: L <- softmax((L Wq)(X Wk)^T) X Wv
    Latent queries are learnable parameters.
    """

    def __init__(self, latent_dim: int = 512, num_latents: int = 16, num_heads: int = 8, dropout: float = 0.1):
        super().__init__()
        self.num_latents = num_latents
        self.latent_dim = latent_dim
        self.num_heads = num_heads
        self.latents = nn.Parameter(torch.randn(num_latents, latent_dim) * 0.02)  # (K, D)

        # projections
        self.q_proj = nn.Linear(latent_dim, latent_dim)
        self.k_proj = nn.Linear(latent_dim, latent_dim)
        self.v_proj = nn.Linear(latent_dim, latent_dim)
        self.out = nn.Linear(latent_dim, latent_dim)
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(latent_dim)

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        """
        inputs: (B, T, D) where D == latent_dim (or prior projection used)
        returns: latents_out (B, K, D)
        """
        B, T, D = inputs.shape
        assert D == self.latent_dim, "Input dim must match latent_dim or project beforehand."

        # expand learnable latents across batch
        L = self.latents.unsqueeze(0).expand(B, -1, -1).contiguous()  # (B, K, D)
        # compute q,k,v
        q = self.q_proj(L)       # (B, K, D)
        k = self.k_proj(inputs)  # (B, T, D)
        v = self.v_proj(inputs)  # (B, T, D)

        # scaled dot product attention: split heads
        def split_heads(x):
            B, S, D = x.shape
            assert D % self.num_heads == 0
            d = D // self.num_heads
            return x.view(B, S, self.num_heads, d).transpose(1, 2)  # (B, H, S, d)

        qh = split_heads(q)  # (B, H, K, d)
        kh = split_heads(k)  # (B, H, T, d)
        vh = split_heads(v)  # (B, H, T, d)

        scale = (D // self.num_heads) ** -0.5
        attn_logits = torch.matmul(qh, kh.transpose(-2, -1)) * scale  # (B, H, K, T)
        attn = torch.softmax(attn_logits, dim=-1)  # (B, H, K, T)
        attn = self.dropout(attn)

        # attend to values
        out_h = torch.matmul(attn, vh)  # (B, H, K, d)
        # combine heads
        out_h = out_h.transpose(1, 2).contiguous().view(B, self.num_latents, D)  # (B, K, D)

        out = self.out(out_h)  # (B, K, D)
        out = self.norm(out + L)  # residual + norm
        return out
