"""Set transformer code adapted from https://github.com/juho-lee/set_transformer (MIT License)
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MAB(nn.Module):
    def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
        super(MAB, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)
        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)

        A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        return O

class SAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, ln=False):
        super(SAB, self).__init__()
        self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)

    def forward(self, X):
        return self.mab(X, X)

class ISAB(nn.Module):
    def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
        super(ISAB, self).__init__()
        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)

    def forward(self, X):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
        return self.mab1(X, H)

class PMA(nn.Module):
    def __init__(self, dim, num_heads, num_seeds, ln=False):
        super(PMA, self).__init__()
        self.num_seeds = num_seeds
        self.edges_mu = nn.Parameter(torch.randn(1, 1, dim))
        self.edges_logsigma = nn.Parameter(torch.zeros(1, 1, dim))
        nn.init.xavier_uniform_(self.edges_logsigma)
        self.mab = MAB(dim, dim, dim, num_heads, ln=ln)

    def forward(self, X):
        b, _, _, device = *X.shape, X.device
        mu = self.edges_mu.expand(b, self.num_seeds, -1)
        sigma = self.edges_logsigma.exp().expand(b, self.num_seeds, -1)
        edges = mu + sigma * torch.randn(mu.shape, device = device)

        return self.mab(edges, X)


class SetTransformer(nn.Module):
    def __init__(self, dim_input, num_outputs, dim_output,
            dim_hidden=128, num_heads=4, ln=False):
        super(SetTransformer, self).__init__()
        self.enc = nn.Sequential(
                SAB(dim_input, dim_hidden, num_heads, ln=ln),
                SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                SAB(dim_hidden, dim_hidden, num_heads, ln=ln))
        self.dec = nn.Sequential(
                PMA(dim_hidden, num_heads, num_outputs, ln=ln),
                SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
                nn.Linear(dim_hidden, dim_output))

    def forward(self, X):
        return self.dec(self.enc(X))



class STSet2Hypergraph(nn.Module):
    def __init__(self, max_k, d_in, d_hid):
        super().__init__()
        self.proj_in = nn.Linear(d_in, d_hid)
        self.set2set = SetTransformer(d_hid, max_k, d_hid, ln=True)
        self.mlp_out = nn.Sequential(
            nn.Linear(2 * d_hid, d_hid),
            nn.ReLU(inplace=True),
            nn.Linear(d_hid, 1),
            nn.Sigmoid()
        )
        self.edge_ind = nn.Sequential(
            nn.Linear(d_hid, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.proj_in(x)
        e = self.set2set(x)

        n_nodes = x.size(1)
        n_edges = e.size(1)
        outer = torch.cat([
            x.unsqueeze(1).expand(-1, n_edges, -1, -1), 
            e.unsqueeze(2).expand(-1, -1, n_nodes, -1)], dim=3)
        incidence = self.mlp_out(outer).squeeze(3)
        ind = self.edge_ind(e)
        return torch.cat([incidence, ind], dim=-1)