import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.data.batch import Batch

from torch.nn import GRU, Linear, ReLU, Sequential
from torch_geometric.nn import NNConv, Set2Set


class MPNNs2s(nn.Module):
    r"""MPNN from "Neural Message Passing for Quantum Chemistry" for
    regression and classification on graphs.
    Source: https://github.com/pyg-team/pytorch_geometric/blob/master
    /examples/qm9_nn_conv.py

        Arguments:
        in_channels (int): Size for the input node features.
        out_channels (int): dimension of output.
        num_nn (int): num_edge_features.
        hidden (int): Size for the output node representations. Default to 64.

    """
    def __init__(self, in_channels, out_channels, num_nn, hidden=64):
        super(MPNNs2s, self).__init__()
        self.lin0 = torch.nn.Linear(in_channels, hidden)

        nn = Sequential(Linear(num_nn, 16), ReLU(),
                        Linear(16, hidden * hidden))
        self.conv = NNConv(hidden, hidden, nn, aggr='add')
        self.gru = GRU(hidden, hidden)

        self.set2set = Set2Set(hidden, processing_steps=3, num_layers=3)
        self.lin1 = torch.nn.Linear(2 * hidden, hidden)
        self.lin2 = torch.nn.Linear(hidden, out_channels)

    def forward(self, data):
        if isinstance(data, Batch):
            x, edge_index, edge_attr, batch = data.x, data.edge_index, \
                                              data.edge_attr, data.batch
        elif isinstance(data, tuple):
            x, edge_index, edge_attr, batch = data
        else:
            raise TypeError('Unsupported data type!')

        self.gru.flatten_parameters()
        out = F.relu(self.lin0(x.float()))
        h = out.unsqueeze(0)

        for i in range(3):
            m = F.relu(self.conv(out, edge_index, edge_attr.float()))
            out, h = self.gru(m.unsqueeze(0), h)
            out = out.squeeze(0)

        out = self.set2set(out, batch)
        out = F.relu(self.lin1(out))
        out = self.lin2(out)
        return out
