import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F


class GraphConv(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(GraphConv, self).__init__()
        self.layer = nn.Linear(in_dim * 2, out_dim, bias=True)

    def forward(self, graph, n_feats, e_weights=None):
        graph.ndata['h'] = n_feats

        if e_weights == None:
            graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'))
        else:
            graph.edata['ew'] = e_weights
            graph.update_all(fn.u_mul_e('h', 'ew', 'm'), fn.mean('m', 'h'))

        graph.ndata['h'] = self.layer(
            th.cat([graph.ndata['h'], n_feats], dim=-1)
        )

        output = graph.ndata['h']
        return output


class GraphConvModel(nn.Module):
    def __init__(self, in_dim, hid_dim, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.hid_dim = hid_dim
        self.out_dim = out_dim

        self.in_layer = GraphConv(self.in_dim, self.hid_dim)
        self.hid_layer = GraphConv(self.hid_dim, self.hid_dim)
        self.out_layer = GraphConv(self.hid_dim, self.out_dim)

    def forward(self, graph, n_feat, edge_weights=None):
        h = self.in_layer(graph, n_feat, edge_weights)
        h = F.relu(h)
        h = self.hid_layer(graph, h, edge_weights)
        h = F.relu(h)
        h = self.out_layer(graph, h, edge_weights)

        return h
