import torch
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree


class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, **kwargs):
        super().__init__(aggr='add')

        self.lin = Linear(in_channels, out_channels, bias=False)
        self.bias = Parameter(torch.Tensor(out_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin.reset_parameters()
        self.bias.data.zero_()

    def forward(self, x, edge_index):
        # x has shape [N, in_channels] -> so for our purpose [N, row_major(img)]
        # edge_index has shape [2, E]

        # Don't know is self-loops are required
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # It appears linearly transforming occurs in all layers
        x = self.lin(x)

        # Normalisation will need to change -> Move to converting to image for convolution
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Then work on propagation so would want transformed data by this stage
        out = self.propagate(edge_index, x=x, norm=norm)

        # Apply bias
        out += self.bias

        return out

    # This is very simple just sending the normalised vector of the selected node
    # But this may be fine for our purposes
    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

