import torch
from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import glorot
from torch_geometric.nn.dense.linear import Linear


class Simplified_GNN(MessagePassing):
    def __init__(self, args, in_channels, hid_channels, out_channels, graph,):
        super().__init__()

        self.args = args

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.setup_layers()

    def setup_layers(self):
        self.linear = Linear(self.in_channels, self.out_channels, bias=True, weight_initializer='glorot')

    def forward(self, x, edge_index):
        degree = scatter_add(torch.ones_like(edge_index[0]), edge_index[0])
        edge_weight = 1 / degree[edge_index[1]]

        ax = x
        for _ in range(self.args.num_layers):
            ax = self.propagate(edge_index, x=ax, edge_weight=edge_weight)
        axw = self.linear(ax)

        return axw
