import torch
from torch import nn
import torch.nn.functional as F
from set_transformer_modules import *
import dgl
from dgl.nn.pytorch.glob import Set2Set
        
class GeneralPoolingSet(nn.Module):
    def __init__(self, hidden_dim, general_mode=0, eps=1e-12):
        super(GeneralPoolingSet, self).__init__()
        self.eps = eps
        self.hidden_dim = hidden_dim
        self.use_pos = ((general_mode // 2) == 0)
        self.use_neg = ((general_mode % 2) == 0)
        self.use_reparameterization = True
        self.p_pos = nn.Parameter(torch.FloatTensor([0.0 if self.use_reparameterization else 1.0]))
        self.p_neg = nn.Parameter(torch.FloatTensor([0.0 if self.use_reparameterization else 1.0]))
        self.q_pos = nn.Parameter(torch.FloatTensor([0.0]))
        self.q_neg = nn.Parameter(torch.FloatTensor([0.0]))
        
    def forward(self, _mask, h):
        if self.use_pos:
            if self.use_neg:
                h_pos = F.relu(h[:, :, :self.hidden_dim//2])
            else:
                h_pos = F.relu(h)
            mask_pos = h_pos < self.eps
            allzero_pos = mask_pos.all(dim=-2, keepdim=False)
            
            if self.use_reparameterization:
                p_pos = 1. + torch.log(torch.exp(self.p_pos) + 1.)
            else:
                p_pos = self.p_pos
                
            pos = torch.exp(torch.logsumexp((torch.log(h_pos + self.eps)) * p_pos, dim=-2) / p_pos)
            pos = pos * ((1. / torch.sum(_mask, dim=-2)) ** self.q_pos)
            pos[allzero_pos] = 0.
        
        if self.use_neg:
            if self.use_pos:
                h_neg = F.relu(h[:, :, self.hidden_dim//2:])
            else:
                h_neg = F.relu(h)
            mask_neg = h_neg < self.eps
            allzero_neg = mask_neg.all(dim=-2, keepdim=False)
            h_neg[h_neg < self.eps] = 1. / self.eps
            
            if self.use_reparameterization:
                p_neg = 1. + torch.log(torch.exp(self.p_neg) + 1.)
            else:
                p_neg = self.p_neg
            
            neg = torch.exp(-torch.logsumexp(-(torch.log(h_neg + self.eps)) * p_neg, dim=-2) / p_neg)
            neg = neg * ((1. / torch.sum(_mask, dim=-2)) ** self.q_neg)
            neg[allzero_neg] = 0. 
            
        if self.use_pos and self.use_neg:
            return torch.cat((pos, neg), dim=-1)
        elif self.use_pos:
            return pos
        elif self.use_neg:
            return neg
        else:
            return None
    
class Set2SetPooling(nn.Module):
    def __init__(self, input_dim, n_iters):
        super(Set2SetPooling, self).__init__()
        self.input_dim = input_dim
        self.s2s = Set2Set(input_dim, n_iters, 1)
        
    def forward(self, mask, h):
        lengths = torch.sum(mask.squeeze(-1), dim=-1).detach().cpu().numpy().tolist()
        gs = dgl.batch([dgl.graph(([], []), num_nodes=_len).to(h.device) for _len in lengths])
        feats = h.view(-1, self.input_dim)[mask.view(-1) > 0.5]
        return self.s2s(gs, feats)
        
        
class SetTransformer(nn.Module):
    def __init__(self, dim_input, num_outputs, dim_output,
            num_inds=32, dim_hidden=128, num_heads=4, ln=False, block_type=0):
        super(SetTransformer, self).__init__()
        if block_type == 0:
            self.enc = nn.Sequential(
                    ISAB(dim_input, dim_hidden, num_inds, num_heads=num_heads, ln=ln),
                    ISAB(dim_hidden, dim_hidden, num_inds, num_heads=num_heads, ln=ln))
        else:
            self.enc = nn.Sequential(
                    SAB(dim_input, dim_hidden, num_heads=num_heads, ln=ln),
                    SAB(dim_hidden, dim_hidden, num_heads=num_heads, ln=ln))
        self.dec = nn.Sequential(
                PMA(dim_hidden, dim_hidden, num_heads=num_heads, ln=ln),
                nn.Linear(dim_hidden, dim_output))

    def forward(self, X, mask):
        mask = (mask.squeeze(-1) < 0.5)
        for layer in self.enc:
            X = layer(X, mask=mask)
        X = self.dec[0](X, mask=mask)
        X = self.dec[1](X)
        return X.squeeze()
    
class SetModel(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, pooling_type, general_mode=0, use_bias=False):
        super(SetModel, self).__init__()
        self.pooling_type = pooling_type
        # List of MLPs
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(nn.Linear(hidden_dim if i > 0 else input_dim, hidden_dim))
        self.linears_prediction = nn.Linear(hidden_dim, output_dim, bias=use_bias)
        
        if pooling_type == 'sum':
            self.pool = lambda mask, h: torch.sum(h, dim=-2)
        elif pooling_type == 'mean':
            self.pool = lambda mask, h: (torch.sum(h, dim=-2) / torch.sum(mask, dim=-2))
        elif pooling_type == 'max':
            self.pool = lambda mask, h: torch.max(h, dim=-2)[0]
        elif pooling_type == 'general':
            self.pool = GeneralPoolingSet(hidden_dim, general_mode=general_mode)
        elif pooling_type == 'min':
            self.pool = lambda mask, h: torch.min(h + ((mask < 0.5).float() * 1e9), dim=-2)[0]
        elif pooling_type == 'set2set':
            self.pool = Set2SetPooling(hidden_dim, n_iters=general_mode)
            self.linears_prediction = nn.Linear(hidden_dim * 2, output_dim, bias=use_bias)
        else:
            raise NotImplementedError
            
    def forward(self, h, mask):
        for layer in self.layers:
            h = layer(h)
        h = F.relu(h)
        h = h * mask
        soft_h = self.pool(mask, h)
        soft_output = self.linears_prediction(soft_h).squeeze()
        return soft_output