"""Set Transformer attention blocks."""

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from .feedforward import rFF


def masked_softmax(logits, mask, dim=-1):
    if mask is None:
        return F.softmax(logits, dim=dim)
    logits = logits.masked_fill(~mask, torch.finfo(logits.dtype).min)
    attn = F.softmax(logits, dim=dim)
    attn = attn * mask.to(attn.dtype)
    attn_sum = attn.sum(dim=dim, keepdim=True).clamp(min=1e-9)
    attn = attn / attn_sum
    return attn


class MAB(nn.Module):
    """Multihead Attention Block with key padding mask."""

    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=True, dropout=0.0):
        super().__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.head_dim = dim_V // num_heads

        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

        self.ln0 = nn.LayerNorm(dim_V) if ln else nn.Identity()
        self.ln1 = nn.LayerNorm(dim_V) if ln else nn.Identity()

        self.ff = rFF(dim_V, expansion=1, dropout=dropout)
        self.attn_dropout = nn.Dropout(dropout)

    def forward(self, Q, K, mask=None):
        B, N_q, _ = Q.shape
        _, N_k, _ = K.shape
        Q_proj = self.fc_q(Q)
        K_proj = self.fc_k(K)
        V_proj = self.fc_v(K)
        Q_ = self._split_heads(Q_proj)
        K_ = self._split_heads(K_proj)
        V_ = self._split_heads(V_proj)
        scale = math.sqrt(self.head_dim)
        A = torch.bmm(Q_, K_.transpose(1, 2)) / scale
        if mask is not None:
            mask_h = mask[:, None, None, :].expand(B, self.num_heads, 1, N_k)
            mask_h = mask_h.reshape(B * self.num_heads, 1, N_k)
        else:
            mask_h = None
        A = masked_softmax(A, mask_h, dim=-1)
        A = self.attn_dropout(A)
        O = torch.bmm(A, V_)
        O = self._merge_heads(O, B)
        O = self.fc_o(O)
        O = self.ln0(Q_proj + O)
        O = self.ln1(O + self.ff(O))
        return O

    def _split_heads(self, x):
        B, N, _ = x.shape
        x = x.view(B, N, self.num_heads, self.head_dim)
        x = x.permute(0, 2, 1, 3)
        return x.reshape(B * self.num_heads, N, self.head_dim)

    def _merge_heads(self, x, B):
        N = x.shape[1]
        x = x.view(B, self.num_heads, N, self.head_dim)
        x = x.permute(0, 2, 1, 3)
        return x.reshape(B, N, self.dim_V)


class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=True, dropout=0.0):
        super().__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln, dropout=dropout)

    def forward(self, X, mask=None):
        return self.mab(X, X, mask=mask)


class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=True, dropout=0.0):
        super().__init__()
        self.I = nn.Parameter(torch.randn(1, num_inds, dim_out) * 0.02)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln, dropout=dropout)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln, dropout=dropout)

    def forward(self, X, mask=None):
        B = X.size(0)
        H = self.mab0(self.I.expand(B, -1, -1), X, mask=mask)
        return self.mab1(X, H, mask=None)


class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=True, dropout=0.0):
        super().__init__()
        self.S = nn.Parameter(torch.randn(1, num_seeds, dim) * 0.02)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln, dropout=dropout)

    def forward(self, X, mask=None):
        B = X.size(0)
        return self.mab(self.S.expand(B, -1, -1), X, mask=mask)

