import torch
from torch import nn
import dgl
from dgl import function as fn
from dgl.utils import expand_as_pair
import torch.nn.functional as F
from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling, Set2Set, SortPooling

class MinPooling(nn.Module):
    def __init__(self):
        super(MinPooling, self).__init__()
    def forward(self, _, h):
        return torch.min(h, dim=0)[0]
    
class GINConv(nn.Module):
    def __init__(self, apply_func, aggregator_type, init_eps=0, learn_eps=False, edge_encoder=None, general_mode=0):
        super(GINConv, self).__init__()
        self.apply_func = apply_func
        self._aggregator_type = aggregator_type
        self.edge_encoder = edge_encoder
        self.general_mode = general_mode
        if aggregator_type == 'sum':
            self._reducer = fn.sum
        elif aggregator_type == 'max':
            self._reducer = fn.max
        elif aggregator_type == 'mean':
            self._reducer = fn.mean
        elif aggregator_type == 'min':
            self._reducer = fn.min
        else:
            self.agg_fn = GeneralPooling(self.apply_func.mlp.hidden_dim, general_mode=general_mode)
            def general_reduce_func(nodes, sftmx=None):
                return {'neigh': self.agg_fn(None, nodes.mailbox['m'])}
            self._reducer = lambda x, y: general_reduce_func    
        if learn_eps:
            self.eps = torch.nn.Parameter(torch.FloatTensor([init_eps]))
        else:
            self.register_buffer('eps', torch.FloatTensor([init_eps]))
        
    def forward(self, graph, feat, edge_feat=None):
        def _msg_func_with_edge_feats(edges):
            return {'m': F.relu(edges.src['h'] + edges.data['edge_attr'])}
        
        with graph.local_scope():
            feat_src, feat_dst = expand_as_pair(feat, graph)
            graph.srcdata['h'] = feat_src
            if self.edge_encoder is not None:
                graph.edata['edge_attr'] = self.edge_encoder(edge_feat)
            msg_func = _msg_func_with_edge_feats if edge_feat is not None else fn.copy_u('h', 'm')
            red_func = self._reducer('m', 'neigh')
            graph.update_all(msg_func, red_func)
            
            if graph.dstdata['neigh'].shape[-1] != feat_dst.shape[-1]:
                agg_h = graph.dstdata['neigh'] + ((1. + self.eps) * torch.cat((feat_dst, feat_dst), dim=-1))
            else:
                agg_h = graph.dstdata['neigh'] + ((1. + self.eps) * feat_dst)
            return self.apply_func(agg_h) 
        
class ApplyNodeFunc(nn.Module):
    def __init__(self, mlp, use_batchnorm=False):
        super(ApplyNodeFunc, self).__init__()
        self.mlp = mlp
        self.bn = nn.BatchNorm1d(self.mlp.output_dim) if use_batchnorm else nn.Identity()

    def forward(self, h):
        h = self.mlp(h)
        h = self.bn(h)
        h = F.relu(h)
        return h

class GeneralPooling(nn.Module):
    def __init__(self, hidden_dim, general_mode, eps=1e-12):
        super(GeneralPooling, self).__init__()
        self.eps = eps
        self.hidden_dim = hidden_dim
        self.use_pos = ((general_mode // 2) == 0)
        self.use_neg = ((general_mode % 2) == 0)
        self.use_reparameterization = True
        self.p_pos = nn.Parameter(torch.FloatTensor([0.0 if self.use_reparameterization else 1.0]))
        self.p_neg = nn.Parameter(torch.FloatTensor([0.0 if self.use_reparameterization else 1.0]))
        self.q_pos = nn.Parameter(torch.FloatTensor([0.0]))
        self.q_neg = nn.Parameter(torch.FloatTensor([0.0]))
        
    def forward(self, g, h):
        if self.use_pos:
            if self.use_neg:
                h_pos = F.relu(h[:, :, :self.hidden_dim//2])
            else:
                h_pos = F.relu(h)
            mask_pos = h_pos < self.eps
            allzero_pos = mask_pos.all(dim=-2, keepdim=False)
            
            if self.use_reparameterization:
                p_pos = 1. + torch.log(torch.exp(self.p_pos) + 1.)
            else:
                p_pos = self.p_pos
                
            pos = torch.exp(torch.logsumexp((torch.log(h_pos + self.eps)) * p_pos, dim=-2) / p_pos)
            pos = pos * ((1. / h_pos.shape[-2]) ** self.q_pos)
            pos[allzero_pos] = 0.
        
        if self.use_neg:
            if self.use_pos:
                h_neg = F.relu(h[:, :, self.hidden_dim//2:])
            else:
                h_neg = F.relu(h)
            mask_neg = h_neg < self.eps
            allzero_neg = mask_neg.all(dim=-2, keepdim=False)
            h_neg[h_neg < self.eps] = 1. / self.eps
            
            if self.use_reparameterization:
                p_neg = 1. + torch.log(torch.exp(self.p_neg) + 1.)
            else:
                p_neg = self.p_neg
            
            neg = torch.exp(-torch.logsumexp(-(torch.log(h_neg + self.eps)) * p_neg, dim=-2) / p_neg)
            neg = neg * ((1. / h_neg.shape[-2]) ** self.q_neg)
            neg[allzero_neg] = 0. 
        
        if self.use_pos and self.use_neg:
            return torch.cat((pos, neg), dim=-1)
        elif self.use_pos:
            return pos
        elif self.use_neg:
            return neg
        else:
            return None

        
class MLP(nn.Module):
    def __init__(self, num_layers, input_dim, hidden_dim, output_dim, use_batchnorm=False, use_bias=False):
        super(MLP, self).__init__()
        self.num_layers = num_layers
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        if num_layers <= 1:
            raise ValueError("number of layers should be more than or equal to 2!")
        else:
            # Multi-layer model
            self.linears = torch.nn.ModuleList()
            self.batch_norms = torch.nn.ModuleList()
            self.linears.append(nn.Linear(input_dim, hidden_dim, bias=use_bias))
            for layer in range(num_layers - 2):
                self.linears.append(nn.Linear(hidden_dim, hidden_dim, bias=True))
            self.linears.append(nn.Linear(hidden_dim, output_dim, bias=True))
            
            for layer in range(num_layers - 1):
                self.batch_norms.append(nn.BatchNorm1d(hidden_dim) if use_batchnorm else nn.Identity())

    def forward(self, x):
        h = x
        for i in range(self.num_layers - 1):
            h = self.linears[i](h)
            h = F.relu(self.batch_norms[i](h))
        return self.linears[-1](h)

class GIN(nn.Module):
    def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim, graph_pooling_type, neighbor_pooling_type,
                 learn_eps=True, use_batchnorm=False, general_mode=0, use_bias=False):
        super(GIN, self).__init__()
        
        self.num_layers = num_layers
        
        self.encoder = nn.Linear(input_dim, hidden_dim, bias=True)
        self.ginlayers = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        for layer in range(self.num_layers):
            mlp = MLP(num_mlp_layers,
                      hidden_dim,
                      hidden_dim,
                      hidden_dim,
                      use_batchnorm=use_batchnorm, use_bias=use_bias)
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim) if use_batchnorm else nn.Identity())
            self.ginlayers.append(GINConv(ApplyNodeFunc(mlp, use_batchnorm), neighbor_pooling_type, 0, True, general_mode=general_mode))
        
        self.graph_pooling_type = graph_pooling_type
        
        self.linears_prediction = torch.nn.ModuleList()
        self.linears_prediction.append(nn.Linear(hidden_dim, output_dim, bias=use_bias))
        
        if graph_pooling_type == 'set2set':
            self.linears_prediction.append(nn.Linear(hidden_dim * 2, output_dim, bias=use_bias))
            self.pool = Set2Set(hidden_dim, general_mode, 1)
        elif graph_pooling_type == 'sort':
            self.linears_prediction.append(nn.Linear(hidden_dim * general_mode, output_dim, bias=use_bias))
            self.pool = SortPooling(general_mode)
        else:
            self.linears_prediction.append(nn.Linear(hidden_dim, output_dim, bias=use_bias))
            if graph_pooling_type == 'sum':
                self.pool = lambda g, h: SumPooling()(g, h)
            elif graph_pooling_type == 'mean':
                self.pool = lambda g, h: AvgPooling()(g, h)
            elif graph_pooling_type == 'max':
                self.pool = lambda g, h: MaxPooling()(g, h)
            elif graph_pooling_type == 'min':
                self.pool = lambda g, h: MinPooling()(g, h)
            elif graph_pooling_type == 'general':
                self.pool = GeneralPooling(hidden_dim, general_mode=general_mode)
            else:
                raise NotImplementedError
            
    def forward(self, g, h, edge_attr=None):
        h = self.encoder(h)
        for i in range(self.num_layers):
            h = self.ginlayers[i](g, h, edge_attr)
            h = F.relu(self.batch_norms[i](h))
        if self.graph_pooling_type == 'general':
            hs = torch.split(h, g.batch_num_nodes().detach().cpu().numpy().tolist())
            soft_pooled_h = torch.stack([self.pool(partial_g, partial_h.unsqueeze(0)).squeeze(0) for partial_g, partial_h in zip(dgl.unbatch(g), hs)], dim=0)
        elif self.graph_pooling_type == 'min':
            hs = torch.split(h, g.batch_num_nodes().detach().cpu().numpy().tolist())
            soft_pooled_h = torch.stack([self.pool(partial_g, partial_h) for partial_g, partial_h in zip(dgl.unbatch(g), hs)], dim=0)
        else:
            soft_pooled_h = self.pool(g, h)
        return self.linears_prediction[-1](soft_pooled_h).squeeze()