from typing import Optional

import sys
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_scatter import scatter, scatter_logsumexp

class GeneralPooling(nn.Module):
    def __init__(self, hidden_dim, eps=1e-12):
        super(GeneralPooling, self).__init__()
        self.eps = eps
        self.hidden_dim = hidden_dim
        self.p_pos = nn.Parameter(torch.FloatTensor([0.0]))
        self.p_neg = nn.Parameter(torch.FloatTensor([0.0]))
        self.q_pos = nn.Parameter(torch.FloatTensor([0.0]))
        self.q_neg = nn.Parameter(torch.FloatTensor([0.0]))
        
    def forward(self, h, batch, size: Optional[int] = None):
        size = int(batch.max().item() + 1) if size is None else size
        
        h = F.relu(h)
        h_pos = h[:, :(self.hidden_dim + 1) // 2]
        h_neg = h[:, (self.hidden_dim + 1) // 2:]
        # h_pos, h_neg = torch.split(F.relu(h), (h.shape[-1] + 1) // 2, dim=-1)
        
        num_nodes = scatter(torch.ones(h.shape[0]).to(h.device), batch, dim=0, dim_size=size, reduce='add')
        max_pos = scatter(h_pos, batch, dim=0, dim_size=size, reduce='max')
        max_neg = scatter(h_neg, batch, dim=0, dim_size=size, reduce='max')
        h_neg[h_neg < self.eps] = 1. / self.eps
        
        p_pos = 1. + torch.log(torch.exp(self.p_pos) + 1.)
        pos = scatter_logsumexp( (torch.log(h_pos + self.eps)) * p_pos, batch, dim=0, dim_size=size)
        
        pos = torch.exp(pos / p_pos) * ((1. / num_nodes.unsqueeze(-1)) ** self.q_pos)
        pos[max_pos < self.eps] = 0.
        
        p_neg = 1. + torch.log(torch.exp(self.p_neg) + 1.)
        neg = scatter_logsumexp(-(torch.log(h_neg + self.eps)) * p_neg, batch, dim=0, dim_size=size)
        neg = torch.exp(-neg / p_neg) * ((1. / num_nodes.unsqueeze(-1)) ** self.q_neg)
        neg[max_neg < self.eps] = 0.
        
        final = torch.cat((pos, neg), dim=-1)
        return final
        
class GCNConvWithGNP(GCNConv):
    def __init__(self, in_channels, out_channels, **kwargs):
        kwargs.setdefault('aggr', None)
        super(GCNConvWithGNP, self).__init__(in_channels, out_channels, **kwargs) 
        self.gnp = GeneralPooling(out_channels)
        
    def aggregate(self, inputs, index, dim_size) -> Tensor:
        return self.gnp(inputs, index, dim_size)