import torch.nn as nn
import dgl

from nets.gru import GRU
from models.dgl.pna_layer import PNALayer
from nets.mlp_readout_layer import MLPReadout

import torch
import torch.nn.functional as F

"""
    PNA: Principal Neighbourhood Aggregation 
    Gabriele Corso, Luca Cavalleri, Dominique Beaini, Pietro Lio, Petar Velickovic
    https://arxiv.org/abs/2004.05718
    Architecture follows that in https://github.com/graphdeeplearning/benchmarking-gnns
"""

class GeneralPooling(nn.Module):
    def __init__(self, hidden_dim, general_mode=0, eps=1e-12):
        super(GeneralPooling, self).__init__()
        self.eps = eps
        self.hidden_dim = hidden_dim
        self.use_pos = ((general_mode // 2) == 0)
        self.use_neg = ((general_mode % 2) == 0)
        self.use_reparameterization = True
        self.p_pos = nn.Parameter(torch.FloatTensor([0.0 if self.use_reparameterization else 1.0]))
        self.p_neg = nn.Parameter(torch.FloatTensor([0.0 if self.use_reparameterization else 1.0]))
        self.q_pos = nn.Parameter(torch.FloatTensor([0.0]))
        self.q_neg = nn.Parameter(torch.FloatTensor([0.0]))
        
    def forward(self, h):
        h = F.relu(h)
        mask = h < self.eps
        allzero = mask.all(dim=-2, keepdim=False)
        h[:, :, ((self.hidden_dim + 1) // 2):][h[:, :, ((self.hidden_dim + 1) // 2):] < self.eps] = 1. / self.eps
        p_pos = 1. + torch.log(torch.exp(self.p_pos) + 1.)
        p_neg = 1. + torch.log(torch.exp(self.p_neg) + 1.)
        ps = torch.cat((p_pos.repeat((self.hidden_dim + 1) // 2), -p_neg.repeat(self.hidden_dim // 2)), dim=0)
        qs = torch.cat((self.q_pos.repeat((self.hidden_dim + 1) // 2), self.q_neg.repeat(self.hidden_dim // 2)), dim=0)
        h = torch.exp(torch.logsumexp((torch.log(h + self.eps)) * ps, dim=-2) / ps)
        h = h * ((1. / h.shape[-2]) ** qs)
        h[allzero] = 0.
        return h

class PNANet(nn.Module):
    def __init__(self, net_params):
        super().__init__()
        num_atom_type = net_params['num_atom_type']
        num_bond_type = net_params['num_bond_type']
        hidden_dim = net_params['hidden_dim']
        out_dim = net_params['out_dim']
        in_feat_dropout = net_params['in_feat_dropout']
        dropout = net_params['dropout']
        n_layers = net_params['L']
        self.readout = net_params['readout']
        self.graph_norm = net_params['graph_norm']
        self.batch_norm = net_params['batch_norm']
        self.residual = net_params['residual']
        self.aggregators = net_params['aggregators']
        self.scalers = net_params['scalers']
        self.avg_d = net_params['avg_d']
        self.towers = net_params['towers']
        self.divide_input_first = net_params['divide_input_first']
        self.divide_input_last = net_params['divide_input_last']
        self.edge_feat = net_params['edge_feat']
        edge_dim = net_params['edge_dim']
        pretrans_layers = net_params['pretrans_layers']
        posttrans_layers = net_params['posttrans_layers']
        self.gru_enable = net_params['gru']
        device = net_params['device']

        self.in_feat_dropout = nn.Dropout(in_feat_dropout)

        self.embedding_h = nn.Embedding(num_atom_type, hidden_dim)

        if self.edge_feat:
            self.embedding_e = nn.Embedding(num_bond_type, edge_dim)

        self.layers = nn.ModuleList([PNALayer(in_dim=hidden_dim, out_dim=hidden_dim, dropout=dropout,
                                              graph_norm=self.graph_norm, batch_norm=self.batch_norm,
                                              residual=self.residual, aggregators=self.aggregators, scalers=self.scalers,
                                              avg_d=self.avg_d, towers=self.towers, edge_features=self.edge_feat,
                                              edge_dim=edge_dim, divide_input=self.divide_input_first,
                                              pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers) for _
                                     in range(n_layers - 1)])
        self.layers.append(PNALayer(in_dim=hidden_dim, out_dim=out_dim, dropout=dropout,
                                    graph_norm=self.graph_norm, batch_norm=self.batch_norm,
                                    residual=self.residual, aggregators=self.aggregators, scalers=self.scalers,
                                    avg_d=self.avg_d, towers=self.towers, divide_input=self.divide_input_last,
                                    edge_features=self.edge_feat, edge_dim=edge_dim,
                                    pretrans_layers=pretrans_layers, posttrans_layers=posttrans_layers))

        if self.gru_enable:
            self.gru = GRU(hidden_dim, hidden_dim, device)

        self.MLP_layer = MLPReadout(out_dim, 1)  # 1 out dim since regression problem
        self.pool = GeneralPooling(out_dim)
            
    def forward(self, g, h, e, snorm_n, snorm_e):
        h = self.embedding_h(h)
        h = self.in_feat_dropout(h)
        if self.edge_feat:
            e = self.embedding_e(e)

        for i, conv in enumerate(self.layers):
            h_t = conv(g, h, e, snorm_n)
            if self.gru_enable and i != len(self.layers) - 1:
                h_t = self.gru(h, h_t)
            h = h_t

        g.ndata['h'] = h
        
        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'h')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'h')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'h')
        else:
            hs = torch.split(h, g.batch_num_nodes)
            hg = torch.stack([self.pool(partial_h.unsqueeze(0)).squeeze(0) for partial_h in hs], dim=0)
            # hg = torch.stack([self.pool(partial_g, partial_h.unsqueeze(0)).squeeze(0) for partial_g, partial_h in zip(dgl.unbatch(g), hs)], dim=0)
            # hg = dgl.mean_nodes(g, 'h')  # default readout is mean nodes
            # print(hg.shape)
            
        return self.MLP_layer(hg)

    def loss(self, scores, targets):
        loss = nn.L1Loss()(scores, targets)
        return loss
