import torch
import torch.nn as nn
from typing import Iterable, Optional, Dict, List, Tuple


class SlidingInterAttention(nn.Module):
    """
    Sliding attention following Algorithm 1 (batched, mask-safe, numeric-stable),
    with feature attention recomputed at each iteration.

    Forward signature:
        out_U, out_V, attn_U, attn_V, W_raw = forward(U, V, mask_U=None, mask_V=None, P_U=None, Q_V=None)

    Notes:
      - mask_U/mask_V: boolean [B, L] where True = padding (will be excluded)
      - Bandwidth is computed from effective (non-padding) length and clamped to [min_bandwidth, max_bandwidth].
      - All divisions use small eps clamp to avoid NaN.
    """
    def __init__(self, embed_dim: int, num_heads: int = 10, steps: int = 3,
                 min_bandwidth: int = 48, max_bandwidth: int = 144, eps: float = 1e-6,
                 scale: int = 3):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.steps = int(steps)
        self.min_bandwidth = min_bandwidth
        self.max_bandwidth = max_bandwidth
        self.eps = float(eps)
        self.scale = scale

        # learnable linear maps (ES, ER, EU, EV in paper)
        self.E_S = nn.Linear(embed_dim, embed_dim, bias=False)
        self.E_R = nn.Linear(embed_dim, embed_dim, bias=False)
        self.E_U = nn.Linear(embed_dim, embed_dim, bias=False)
        self.E_V = nn.Linear(embed_dim, embed_dim, bias=False)

        self.dim_scale = (embed_dim ** -0.5)

    def compute_bandwidth(self, mask_V: Optional[torch.Tensor], n: int) -> float:
        """
        Compute bandwidth from effective sequence length (non-padding).
        Use max valid length in batch so bandwidth covers the longest example.
        """
        if mask_V is None:
            valid_len = float(n)
        else:
            valid_counts = (~mask_V).sum(dim=1)  # [B]
            valid_len = float(valid_counts.max().item())
            if valid_len < 1.0:
                valid_len = 1.0
        bw = valid_len / self.scale
        bw = max(self.min_bandwidth, min(self.max_bandwidth, bw))
        return float(bw)

    def gaussian_window(self, P: torch.Tensor, Q: torch.Tensor, h: float) -> torch.Tensor:
        """
        Batched gaussian window.
        P: [B, m] sliding positions
        Q: [n] or [B, n] reference positions
        returns: S [B, m, n] unnormalized
        """
        B, m = P.shape
        if Q.dim() == 1:
            Qb = Q.unsqueeze(0).expand(B, -1)  # [B, n]
        else:
            Qb = Q  # assume [B,n]
        diff = P.unsqueeze(2) - Qb.unsqueeze(1)  # [B,m,n]
        S = torch.exp(- (diff * diff) / (2.0 * (h * h)))
        return S

    def forward(
        self,
        U: torch.Tensor,
        V: torch.Tensor,
        mask_U: Optional[torch.Tensor] = None,
        mask_V: Optional[torch.Tensor] = None,
        P_U: Optional[torch.Tensor] = None,
        Q_V: Optional[torch.Tensor] = None
    ):
        """
        U: [B, m, d]
        V: [B, n, d]
        mask_U/mask_V: [B, L] boolean True=padding
        P_U: optional initial [m] or [B,m] positions
        Q_V: optional positions [n] or [B,n]
        """
        B, m, d = U.shape
        n = V.shape[1]
        device = U.device
        eps = self.eps

        # initialize positions
        if Q_V is None:
            Q_V = torch.arange(1, n + 1, device=device, dtype=U.dtype)  # [n]
        if P_U is None:
            P_init = torch.linspace(1.0, float(n), steps=m, device=device, dtype=U.dtype)  # [m]
            P_U = P_init.unsqueeze(0).expand(B, -1)  # [B, m]
        else:
            if P_U.dim() == 1:
                P_U = P_U.unsqueeze(0).expand(B, -1)

        # compute bandwidth
        bw = self.compute_bandwidth(mask_V, n)

        # iterative sliding attention
        X_cur = U
        Y_cur = V
        for _ in range(self.steps):
            # recompute feature attention A
            U_proj = self.E_S(X_cur)   # [B,m,d]
            V_proj = self.E_R(Y_cur)   # [B,n,d]
            A = torch.matmul(U_proj, V_proj.transpose(-2, -1)) * self.dim_scale  # [B,m,n]
            A = A - A.max(dim=-1, keepdim=True)[0]  # numeric stability
            A = torch.exp(A)
            if (mask_U is not None) and (mask_V is not None):
                mask = mask_U.unsqueeze(-1) | mask_V.unsqueeze(1)
                A = A.masked_fill(mask, 0.0)

            # spatial attention
            S = self.gaussian_window(P_U, Q_V, bw)
            W_raw = A * S

            # row-normalize
            row_sum = W_raw.sum(dim=-1, keepdim=True).clamp_min(eps)
            W = W_raw / row_sum

            # update sliding positions
            if Q_V.dim() == 1:
                P_U = torch.matmul(W, Q_V)
            else:
                P_U = torch.bmm(W, Q_V.unsqueeze(-1)).squeeze(-1)

            # update embeddings
            V_lin = self.E_V(Y_cur)
            X_cur = torch.bmm(W, V_lin) + X_cur

            W_T_raw = W_raw.transpose(-2, -1)
            row_sum_T = W_T_raw.sum(dim=-1, keepdim=True).clamp_min(eps)
            W_T = W_T_raw / row_sum_T
            U_lin = self.E_U(X_cur)
            Y_cur = torch.bmm(W_T, U_lin) + Y_cur

        # final attention maps
        attn_U = W_raw
        attn_V = W_raw.transpose(-2, -1)

        # mask outputs
        if mask_U is not None:
            X_cur = X_cur.masked_fill(mask_U.unsqueeze(-1), 0.0)
        if mask_V is not None:
            Y_cur = Y_cur.masked_fill(mask_V.unsqueeze(-1), 0.0)

        return X_cur, Y_cur, attn_U, attn_V, W


class TripleInterAttention(nn.Module):
    """
    Apply SlidingInterAttention twice: Ag<->H and Ag<->L, then merge antigen updates.
    """
    def __init__(self, sliding_attn: SlidingInterAttention, H_weight: float = 0.5):
        super().__init__()
        self.sliding = sliding_attn
        self.H_weight = H_weight

    def pad_to_pairwise(self, pad_q: Optional[torch.Tensor], pad_k: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        if pad_q is None or pad_k is None:
            return None
        return pad_q.unsqueeze(2).expand(-1, -1, pad_k.size(1)) | pad_k.unsqueeze(1).expand(-1, pad_q.size(1), -1)

    def forward(
        self,
        H_embed: torch.Tensor,
        L_embed: torch.Tensor,
        Ag_embed: torch.Tensor,
        h_pad_mask: Optional[torch.Tensor] = None,
        l_pad_mask: Optional[torch.Tensor] = None,
        ag_pad_mask: Optional[torch.Tensor] = None,
    ):
        # Build pairwise masks (not strictly needed now; keep for reference)
        pair_ag_h = self.pad_to_pairwise(ag_pad_mask, h_pad_mask)
        pair_ag_l = self.pad_to_pairwise(ag_pad_mask, l_pad_mask)

        # Sliding attention Ag <-> H
        out_ag_h, out_h, attn_ag_h, attn_h_ag, _ = self.sliding(Ag_embed, H_embed, mask_U=ag_pad_mask, mask_V=h_pad_mask)

        # Sliding attention Ag <-> L
        out_ag_l, out_l, attn_ag_l, attn_l_ag, _ = self.sliding(Ag_embed, L_embed, mask_U=ag_pad_mask, mask_V=l_pad_mask)

        # Merge antigen updates
        out_ag = self.H_weight * out_ag_h + (1 - self.H_weight) * out_ag_l

        return out_h, out_l, out_ag, attn_h_ag, attn_l_ag, (attn_ag_h, attn_ag_l)