import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import math
from pygcn.layers import MLPLayer
from pygcn.utils import MixedDropout


class GCNPND(nn.Module):
    def __init__(self, num, nfeat, nhid, nclass, hops, threshold, aggre_range, k_connection, adj_droprate, alpha,
                 lambd, input_droprate, hidden_droprate):
        super(GCNPND, self).__init__()
        self.prop = Parameter(torch.FloatTensor(num, nhid))
        self.reset_parameters()
        self.layer1 = MLPLayer(nfeat, nhid)
        self.layer2 = MLPLayer(nhid, nclass)

        self.mixDrop = MixedDropout(adj_droprate)
        self.input_droprate = input_droprate
        self.hidden_droprate = hidden_droprate
        self.hops = hops
        self.lambd = lambd
        self.threshold = threshold
        self.aggre_range = aggre_range
        self.k_connection = k_connection
        self.alpha = alpha

    def forward(self, x, adj):
        x = F.dropout(x, self.input_droprate, training=self.training)
        x = F.relu(self.layer1(x))
        feature_aggre = self.aggregation(x, adj)
        newAdj = self.cosine_similarity(feature_aggre, self.threshold, adj, self.lambd)
        new_feature = self.propagation(x, newAdj, hops=self.hops)
        new_hiden = F.dropout(new_feature, self.hidden_droprate, training=self.training)
        feature_outNew = self.layer2(new_hiden)
        return feature_outNew

    def aggregation(self, x, adj):
        # aggregate local neighborhood before updating graph topology
        original_feature = x
        for _ in range(self.aggre_range):
            original_feature = (1 - self.alpha) * adj @ original_feature + self.alpha * original_feature
        return original_feature

    def propagation(self, x, adj, hops):
        # adaptive aggregation from multi-hop neighbors
        feature_list = [x]
        original_feature = x
        for _ in range(hops):
            a_drop = self.mixDrop(adj)
            original_feature = a_drop @ original_feature
            feature_list.append(original_feature)
        newProp = torch.unsqueeze(self.prop, dim=2)
        new_feature = torch.stack(feature_list, dim=1)
        new_feature = F.dropout(new_feature, 0.5, training=self.training)
        new_coef = torch.matmul(new_feature, newProp)
        new_coef = torch.squeeze(new_coef)
        coef = F.softmax(new_coef, dim=1)
        coef = torch.unsqueeze(coef, dim=2)
        result = new_feature * coef
        result = torch.sum(result, dim=1)
        return result

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.prop.size(1))
        self.prop.data.normal_(-stdv, stdv)

    def cosine_similarity(self, feature, threshold, adj, alpha):
        # calculate cosineSimilarity between nodes
        feature = F.normalize(feature, p=2, dim=-1)
        cosine = torch.matmul(feature, feature.transpose(-1, -2))
        index = adj.coalesce().indices()
        value = cosine[list(index)]
        valueAdj = adj.coalesce().values()
        zero = -1e9 * torch.ones_like(value)
        eyes = torch.eye(cosine.shape[0])
        newCos = cosine - eyes
        val, ind = torch.topk(newCos, self.k_connection, dim=-1)
        newMatirx = torch.zeros_like(cosine).scatter_(-1, ind, val)
        similarity_sparse = newMatirx.to_sparse()
        cosine = torch.where(value < threshold, zero, valueAdj)
        coef = torch.sparse.FloatTensor(index, cosine, adj.size())
        coef = coef + alpha * similarity_sparse
        coef = torch.sparse.softmax(coef, dim=1)
        return coef
