import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------------
# Masked Attention Pooling
# -------------------------
class MaskedAttentionPooling(nn.Module):
  """
  Input X: [B, K_max, D]  (padded)
  mask: [B, K_max]       (bool: True=valid, False=pad)
  Output Z: [B, D], attn: [B, K_max] (attn on valid tokens sums to 1)
  """
  def __init__(self, input_dim=1280, hidden_dim=128):
    super().__init__()
    self.V = nn.Linear(input_dim, hidden_dim)
    self.w = nn.Linear(hidden_dim, 1, bias=False)

  def forward(self, X, mask):
    # X: [B, K, D], mask: [B, K] (bool)
    B, K, D = X.shape

    H = torch.tanh(self.V(X))              # [B, K, L]
    scores = self.w(H).squeeze(-1)         # [B, K]

    # mask: True for valid tokens; convert to bool if needed
    mask = mask.bool() if mask.dtype != torch.bool else mask

    # set pad positions to -inf so softmax gives zero prob
    scores = scores.masked_fill(~mask, float('-inf'))  # [B, K]

    attn = F.softmax(scores, dim=1)        # [B, K], sum over K for each batch = 1 (only valid positions)
    attn = attn * mask.float()             # ensure pads are exactly zero (numerical safety)
    attn = attn / (attn.sum(dim=1, keepdim=True) + 1e-12)  # renormalize to avoid NaN if all-mask false

    Z = torch.bmm(attn.unsqueeze(1), X).squeeze(1)  # [B, 1, K] @ [B, K, D] -> [B, 1, D] -> [B, D]
    return Z, attn

# -------------------------
# PU_AttnPoolNet
# -------------------------
class PU_AttnPoolNet(nn.Module):
  """
  Lightweight model for protein pocket classification with ESM embeddings.
  Architecture:
    [ESM per-residue embedding] → [MaskedAttentionPooling] → [Classifier]
  """
  def __init__(self, input_dim=1280, hidden_dim=128, num_classes=8, dropout=0.2):
    super().__init__()
    self.pool = MaskedAttentionPooling(input_dim, hidden_dim)
    self.classifier = nn.Sequential(
      nn.LayerNorm(input_dim),
      nn.Linear(input_dim, input_dim // 2),
      nn.GELU(),
      nn.Dropout(dropout),
      nn.Linear(input_dim // 2, num_classes)
    )

  def forward(self, X, mask):
    """
    X: [B, K, D]
    mask: [B, K] (1 for valid, 0 for pad)
    """
    Z, attn = self.pool(X, mask)
    logits = self.classifier(Z)
    return logits, attn