"""
Implementation of a GCN
"""


import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch_scatter import scatter_add,scatter_mean
import numpy as np
import math
from .utils import normalize_tensor_adj


class convClass(nn.Module):
    def __init__(self, input_dim , output_dim, activation):
        super(convClass, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.weight = Parameter(torch.Tensor(self.output_dim, self.input_dim))
        self.activation = activation
        self.reset_parameters()


    def forward(self, x, adj):
        x = F.linear(x, self.weight)
        return self.activation(torch.mm(adj,x))

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))


class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device, pooling="sum", dropout=0.5, threshold=0.3):
        super(GCN, self).__init__()
        self.device = device

        self.activation = nn.ReLU()

        self.conv1 = convClass(input_dim, hidden_dim, activation = self.activation)
        self.conv2 = convClass(hidden_dim, hidden_dim, activation = self.activation)


        if output_dim == 2:
            self.lin = nn.Linear(hidden_dim, 2)
        else:
            self.lin = nn.Linear(hidden_dim, output_dim)

        self.pooling = pooling
        self.threshold = threshold

        self.defense = None

    # def forward(self, x_in, adj, idx):
    def forward(self, x_in, adj, idx):

        x_in = F.dropout(x_in, p=0.5, training=self.training)

        x = self.conv1(x_in, adj)

        x = F.dropout(x, p=0.5, training=self.training)

        x = self.conv2(x, adj)

        idx = idx.unsqueeze(1).repeat(1, x.size(1))
        out = torch.zeros(torch.max(idx)+1, x.size(1)).to(self.device)

        if self.pooling == "sum":
            out = out.scatter_reduce(0, idx, x, reduce="sum")
        elif self.pooling == "mean":
            out = out.scatter_reduce(0, idx, x, reduce="mean")
        elif self.pooling == "max":
            out = out.scatter_reduce(0, idx, x, reduce="amax", include_self=False)


        out = self.lin(out)

        return F.log_softmax(out, dim=1) #out


    def predict(self, adj, x):
        """
        For a single prediction from the model

        ---
        Input:
            * adj : Adjacency
            * x_in : Features
        """


        n_nodes = adj.shape[0]
        adj = adj.to(self.device)
        x = x.to(self.device)
        adj = normalize_tensor_adj(adj, device=self.device)

        x = self.conv1(x, adj)

        x = F.dropout(x, p=0.5, training=self.training)

        x = self.conv2(x, adj)

        idx = list()
        idx.extend([0]*n_nodes)
        idx = torch.LongTensor(idx).to(self.device)

        idx = idx.unsqueeze(1).repeat(1, x.size(1))
        out = torch.zeros(torch.max(idx)+1, x.size(1)).to(self.device)
        if self.pooling == "sum":
            out = out.scatter_reduce(0, idx, x, reduce="sum")
        elif self.pooling == "mean":
            out = out.scatter_reduce(0, idx, x, reduce="mean")
        elif self.pooling == "max":
            out = out.scatter_reduce(0, idx, x, reduce="amax", include_self=False)

        out = self.lin(out)

        return F.log_softmax(out, dim=1)
