#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
import networkx as nx

class GraphConvLayer(layers.Layer):
    def __init__(self, num_output_features):
        super(GraphConvLayer, self).__init__()
        self.num_output_features = num_output_features

    def build(self, input_shape):
        num_input_features = input_shape[1][-1]
        self.kernel = self.add_weight("kernel", (num_input_features, self.num_output_features))

    def call(self, inputs):
        adjacency_matrix, node_features = inputs
        adjacency_matrix = tf.linalg.diag(1.0 / tf.math.sqrt(tf.reduce_sum(adjacency_matrix, axis=-1) + 1e-5)) @ adjacency_matrix
        output_node_features = tf.nn.relu(
            tf.linalg.matmul(tf.linalg.matmul(adjacency_matrix, node_features), self.kernel)
        )
        return output_node_features



class GCN(Model):
    def __init__(self, num_classes):
        super(GCN, self).__init__()
        self.graph_conv1 = GraphConvLayer(128)
        self.graph_conv2 = GraphConvLayer(32)
        self.graph_conv3 = GraphConvLayer(18)
        self.fc = layers.Dense(num_classes, activation='softmax')  # num_classes units

    def call(self, inputs):
        adjacency_matrix, node_features = inputs
        x = self.graph_conv1([adjacency_matrix, node_features])
        x = self.graph_conv2([adjacency_matrix, x])
        x = self.graph_conv3([adjacency_matrix, x])
        x = tf.reduce_mean(x, axis=1)
        x = self.fc(x)
        return x

