import torch
import torch.nn.functional as F
from torch import nn

EPS = 1e-5

class GeneralPooling(nn.Module):
    def __init__(self, hidden_dim, general_mode=0, eps=1e-12):
        super(GeneralPooling, self).__init__()
        self.eps = eps
        self.hidden_dim = hidden_dim
        self.use_pos = ((general_mode // 2) == 0)
        self.use_neg = ((general_mode % 2) == 0)
        self.use_reparameterization = True
        self.p_pos = nn.Parameter(torch.FloatTensor([0.0 if self.use_reparameterization else 1.0]))
        self.p_neg = nn.Parameter(torch.FloatTensor([0.0 if self.use_reparameterization else 1.0]))
        self.q_pos = nn.Parameter(torch.FloatTensor([0.0]))
        self.q_neg = nn.Parameter(torch.FloatTensor([0.0]))
        
    def forward(self, h):
        h = F.relu(h)
        mask = h < self.eps
        allzero = mask.all(dim=-2, keepdim=False)
        h[:, :, ((self.hidden_dim + 1) // 2):][h[:, :, ((self.hidden_dim + 1) // 2):] < self.eps] = 1. / self.eps
        p_pos = 1. + torch.log(torch.exp(self.p_pos) + 1.)
        p_neg = 1. + torch.log(torch.exp(self.p_neg) + 1.)
        ps = torch.cat((p_pos.repeat((self.hidden_dim + 1) // 2), -p_neg.repeat(self.hidden_dim // 2)), dim=0)
        qs = torch.cat((self.q_pos.repeat((self.hidden_dim + 1) // 2), self.q_neg.repeat(self.hidden_dim // 2)), dim=0)
        h = torch.exp(torch.logsumexp((torch.log(h + self.eps)) * ps, dim=-2) / ps)
        h = h * ((1. / h.shape[-2]) ** qs)
        h[allzero] = 0.
        return h

def aggregate_mean(h):
    return torch.mean(h, dim=1)


def aggregate_max(h):
    return torch.max(h, dim=1)[0]


def aggregate_min(h):
    return torch.min(h, dim=1)[0]

class STDAggregation(nn.Module):
    def __init__(self, hidden_dim):
        super(STDAggregation, self).__init__()
        
    def forward(self, h):
        return torch.sqrt(aggregate_var(h) + EPS)


def aggregate_std(h):
    return torch.sqrt(aggregate_var(h) + EPS)


def aggregate_var(h):
    h_mean_squares = torch.mean(h * h, dim=-2)
    h_mean = torch.mean(h, dim=-2)
    var = torch.relu(h_mean_squares - h_mean * h_mean)
    return var


def aggregate_moment(h, n=3):
    # for each node (E[(X-E[X])^n])^{1/n}
    # EPS is added to the absolute value of expectation before taking the nth root for stability
    h_mean = torch.mean(h, dim=1, keepdim=True)
    h_n = torch.mean(torch.pow(h - h_mean, n))
    rooted_h_n = torch.sign(h_n) * torch.pow(torch.abs(h_n) + EPS, 1. / n)
    return rooted_h_n


def aggregate_moment_3(h):
    return aggregate_moment(h, n=3)


def aggregate_moment_4(h):
    return aggregate_moment(h, n=4)


def aggregate_moment_5(h):
    return aggregate_moment(h, n=5)


def aggregate_sum(h):
    return torch.sum(h, dim=1)


AGGREGATORS = {'mean': lambda hdim: GeneralPooling(hdim), 'sum': lambda hdim: GeneralPooling(hdim),
               'max': lambda hdim: GeneralPooling(hdim), 'min': lambda hdim: GeneralPooling(hdim),
               'std': lambda hdim: STDAggregation(hdim), 'var': lambda hdim: aggregate_var,
               'moment3': lambda hdim: aggregate_moment_3, 'moment4': lambda hdim: aggregate_moment_4, 'moment5': lambda hdim: aggregate_moment_5}
