import torch
import torch.nn as nn
import torch.nn.functional as F

class DETRVisualAligner(nn.Module):
    def __init__(
        self,
        detr_channels: int,
        clip_dim: int = 768,
        num_patches: int = 64,
        encoder_layers: int = 1,
        nhead: int = 8,
        mlp_hidden: int = 1024,
        dropout: float = 0.1,
    ):
        """
        detr_channels: Number of channels in DETR feature map (C)
        clip_dim: CLIP visual embedding dimension (default 768)
        num_patches: Assumes H*W = num_patches (e.g., 8×8=64)
        encoder_layers: Number of transformer encoder layers
        nhead: Number of attention heads
        mlp_hidden: Hidden size of the MLP projection
        """
        super().__init__()
        # 1. Project DETR features to clip_dim as patch embeddings
        self.patch_proj = nn.Conv2d(detr_channels, clip_dim, kernel_size=1)

        # 2. Lightweight Transformer encoder to model spatial relations
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=clip_dim, nhead=nhead, dim_feedforward=mlp_hidden, dropout=dropout, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=encoder_layers)

        # 3. Pooling + MLP projection to CLIP space
        self.pool = nn.AdaptiveAvgPool1d(1)  # Pool over token dimension
        self.mlp = nn.Sequential(
            nn.Linear(clip_dim, mlp_hidden, bias=False),
            nn.GELU(),
            nn.LayerNorm(mlp_hidden),
            nn.Linear(mlp_hidden, clip_dim, bias=False),
        )

    def forward_get_best_cache(self, detr_feats: torch.Tensor, cache_embs: torch.Tensor) -> torch.Tensor:
        """
        Returns the most similar cache feature for each DETR feature.
        Args:
            detr_feats: [B, C, H, W]
            cache_embs: [K, D]
        Returns:
            matched_cache: [B, D]  # the best-matching cache feature for each sample
        """
        B, C, H, W = detr_feats.shape

        # 1. Project and encode
        x = self.patch_proj(detr_feats)         # [B, D, H, W]
        x = x.flatten(2).transpose(1, 2)        # [B, P, D]
        x_enc = self.encoder(x)                 # [B, P, D]

        global_token = self.pool(x_enc.transpose(1, 2)).squeeze(-1)  # [B, D]
        img_emb = self.mlp(global_token)        # [B, D]
        img_emb = F.normalize(img_emb, dim=1)   # [B, D]

        cache_embs = F.normalize(cache_embs.to(img_emb.device), dim=1)  # [K, D]

        # 2. Cosine similarity: [B, D] × [D, K] → [B, K]
        sim_scores = torch.matmul(img_emb, cache_embs.T)  # [B, K]

        # 3. Best match per sample
        best_ids = torch.argmax(sim_scores, dim=1)  # [B]

        # 4. Gather best cache embeddings
        matched_cache = cache_embs[best_ids]  # [B, D]
        return matched_cache





