"""
GCN using modified message passing
- Adapted from DGL https://github.com/dmlc/dgl/blob/master/examples/pytorch/gcn/gcn_mp.py
"""

import torch
import torch.nn as nn


# Sends a message of node feature h.
# msg = fn.copy_src(src='h', out='m')
# msg = fn.copy_edge('h', 'm')
def copy_message_func(edges):
    return {"m": edges.src["h"]}


def edge_message_func(edges):
    return {"m": edges.src["h"] * edges.data["e"].squeeze(1)}


def reduce(nodes):
    """Take an average over all neighbor node features hu and use it to
    overwrite the original node feature."""
    accum = torch.mean(nodes.mailbox["m"], 1)
    return {"h": accum}


class NodeApplyModule(nn.Module):
    """Update the node feature hv with ReLU(Whv+b)."""

    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data["h"])
        h = self.activation(h)
        return {"h": h}


class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, data):
        # Initialize the node features with h.
        g, feature = data
        g.ndata["h"] = feature
        if g.ndata["h"].size(1) == 1:
            # first iteration, select node feature, only perform copy
            mfunc = copy_message_func
        else:
            mfunc = edge_message_func
        g.update_all(mfunc, reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop("h")
