"""
Implementation of a GCN with our R-Pool
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_scatter import scatter_add,scatter_mean
from torch.nn.parameter import Parameter

from sklearn.mixture import GaussianMixture
import numpy as np
from .utils import normalize_tensor_adj

import math

class convClass(nn.Module):
    """
    A classical convolution class
    """
    def __init__(self, input_dim , output_dim, activation):
        super(convClass, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.weight = Parameter(torch.Tensor(self.output_dim, self.input_dim))
        self.activation = activation
        self.reset_parameters()


    def forward(self, x, adj):
        x = F.linear(x, self.weight)
        return self.activation(torch.mm(adj,x))

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))



class GCN_NEW(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, device, pooling="sum", threshold = 0.9, dropout=0.5):
        super(GCN_NEW, self).__init__()
        self.device = device
        self.activation = nn.ReLU()

        self.conv1 = convClass(input_dim, hidden_dim, activation = self.activation)
        self.conv2 = convClass(hidden_dim, hidden_dim, activation = self.activation)

        if output_dim == 2:
            self.lin = nn.Linear(hidden_dim, 2)
        else:
            self.lin = nn.Linear(hidden_dim, output_dim)


        self.threshold = threshold
        self.pooling = pooling.lower()
        assert self.pooling in ["mean", "max", "sum"]

    def forward(self, x_in, adj, idx):
        """
        For a batch prediction from the model

        ---
        Input:
            * adj : Adjacency
            * x_in : Features
            * idx : Batch list identification
        """

        x_in = F.dropout(x_in, p=0.5, training=self.training)

        x = self.conv1(x_in, adj)

        x = F.dropout(x, p=0.5, training=self.training)

        x = self.conv2(x, adj)


        idx = idx.unsqueeze(1).repeat(1, x.size(1))
        out = torch.zeros(torch.max(idx)+1, x.size(1)).to(self.device)
        # out = out.scatter_add_(0, idx, x)

        if self.pooling == "sum":
            out = out.scatter_reduce(0, idx, x, reduce="sum")
        elif self.pooling == "mean":
            out = out.scatter_reduce(0, idx, x, reduce="mean")
        elif self.pooling == "max":
            out = out.scatter_reduce(0, idx, x, reduce="amax", include_self=False)


        out = self.lin(out)

        return F.log_softmax(out, dim=1) #out


    def predict(self, adj, x_in):
        """
        For a single prediction from the model

        ---
        Input:
            * adj : Adjacency
            * x_in : Features
        """
        n_nodes = adj.shape[0]

        adj = adj.to(self.device)
        x_in = x_in.to(self.device)

        adj = normalize_tensor_adj(adj, self.device)

        x_in = F.dropout(x_in, p=0.5, training=self.training)

        x = self.conv1(x_in, adj)

        x = F.dropout(x, p=0.5, training=self.training)

        x = self.conv2(x, adj)

        # Fit a GMM
        gmm = GaussianMixture(n_components=2, max_iter=100)
        gmm.fit(x.detach().cpu().numpy())

        # Compute the OOD scores
        ood_scores = [rescaled_GEM_score(element, gmm.means_, phi=1)
                                    for element in x.detach().cpu().numpy()]

        # define the quantile based on the threshold
        val_filter = np.quantile(ood_scores, self.threshold)

        # keep only the relevant nodes
        g_ij = []
        for i in range(len(ood_scores)):
            if ood_scores[i] >= val_filter: #1:
                g_ij.append(x[i, :])
            else:
                g_ij.append(torch.zeros_like(x[i, :]))

        x = torch.vstack(g_ij)

        idx = list()
        idx.extend([0]*n_nodes)
        idx = torch.LongTensor(idx).to(self.device)

        idx = idx.unsqueeze(1).repeat(1, x.size(1))
        out = torch.zeros(torch.max(idx)+1, x.size(1)).to(self.device)

        if self.pooling == "sum":
            out = out.scatter_reduce(0, idx, x, reduce="sum")
        elif self.pooling == "mean":
            out = out.scatter_reduce(0, idx, x, reduce="mean")
        elif self.pooling == "max":
            out = out.scatter_reduce(0, idx, x, reduce="amax", include_self=False)


        out = self.lin(out)

        return F.log_softmax(out, dim=1)


def rescaled_GEM_score(x,mean,phi=1):
    energy=0

    for mu in mean:
        energy+=np.exp(mahalanobis(x,mu,phi))

    return energy

def mahalanobis(x,mu,phi=1):
    return(-0.5*(1/phi)*np.inner(x-mu,x-mu))
