#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""

This file contains the GNN architecture, as the GNN class.
One of the fundamental steps in the GNN's update rule is the
use of an appropriate convolution. We define 2 convolutions,
one for coloured edges, and one for colourless edges.

@author: ----
"""
import torch

from torch_geometric.nn import MessagePassing

import torch.nn.functional as F
from torch.nn import Parameter


class EC_GCNConv(MessagePassing):
    # in_channels (int) - Size of each input sample
    # out_channels (int) - Size of each output sample
    def __init__(self, in_channels, out_channels, num_edge_types=1):
        super(EC_GCNConv, self).__init__(aggr='max') # "Max" aggregation
        self.weights = Parameter( torch.Tensor(num_edge_types, out_channels, in_channels ))
        self.weights.data.normal_(0, 0.001)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_edge_types = num_edge_types
        
    def forward(self, x, edge_index, edge_type):
        out = torch.zeros(x.size(0), self.out_channels, device=x.device)
        for i in range(self.num_edge_types):
            edge_mask = edge_type == i
            temp_edges = edge_index[:, edge_mask]
            out += F.linear(self.propagate(temp_edges, x=x, size=(x.size(0), x.size(0))), self.weights[i], bias=None)
        return out
    
    def message(self, x_j):
        return x_j
    
    def update(self, aggr_out):
        return aggr_out

class NEC_GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels, num_edge_types=1):
        super(NEC_GCNConv, self).__init__(aggr='max') # "Max" aggregation
        self.lin = torch.nn.Linear(in_channels, out_channels)
        
    def forward(self, x, edge_index, edge_type):
        return self.propagate(edge_index, size=(x.size(0), x.size(0)), x=x)
    
    def message(self, x_j):
        return x_j
    
    def update(self, aggr_out):
        aggr_out = self.lin(aggr_out)
        return aggr_out

class GNN(torch.nn.Module):
    def __init__(self, feature_dimension, edge_colours=False, num_edge_types=1):
        super(GNN, self).__init__()
        if edge_colours:
            self.conv1 = EC_GCNConv(feature_dimension, 2 * feature_dimension, num_edge_types)
            self.conv2 = EC_GCNConv(2 * feature_dimension, feature_dimension, num_edge_types)
        else:
            self.conv1 = NEC_GCNConv(feature_dimension, 2 * feature_dimension, num_edge_types)
            self.conv2 = NEC_GCNConv(2 * feature_dimension, feature_dimension, num_edge_types)
        
        self.lin_self_1 = torch.nn.Linear(feature_dimension, 2 * feature_dimension)
        self.lin_self_2 = torch.nn.Linear(2 * feature_dimension, feature_dimension)
        
        self.output = torch.nn.Sigmoid()
        
    def forward(self, data):
        x, edge_index, edge_type = data.x, data.edge_index, data.edge_type
        
        x = self.lin_self_1(x) + self.conv1(x, edge_index, edge_type)
        x = torch.relu(x)
        x = self.lin_self_2(x) + self.conv2(x, edge_index, edge_type)
        
        # Note: this translation is irrelevant since the bias vectors are not
        # constrained to the positive reals, therefore it isn't mentioned in
        # the report. However, I've left it here for completeness since the
        # models were trained with it.
        return self.output(x - 10) 
