import torch as th
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn

class Layer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.layer = nn.Linear(in_dim * 2, out_dim, bias=True)

    def forward(self, graph, feat, eweight=None):
        with graph.local_scope():
            graph.ndata['h'] = feat

            if eweight is None:
                graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'))
            else:
                graph.edata['ew'] = eweight
                graph.update_all(fn.u_mul_e('h', 'ew', 'm'), fn.mean('m', 'h'))

            h = self.layer(th.cat([graph.ndata['h'], feat], dim=-1))

            return h

class Model(nn.Module):
    def __init__(self, in_dim, out_dim, hid_dim=40):
        super().__init__()
        self.in_layer = Layer(in_dim, hid_dim)
        self.hid_layer = Layer(hid_dim, hid_dim)
        self.out_layer = Layer(hid_dim, out_dim)

    def forward(self, graph, feat, eweight=None):
        h = self.in_layer(graph, feat.float(), eweight)
        h = F.relu(h)
        h = self.hid_layer(graph, h, eweight)
        h = F.relu(h)
        h = self.out_layer(graph, h, eweight)
        return h
