from torch_geometric.utils import get_laplacian, add_self_loops
from torch_geometric.nn import MessagePassing
from torch.nn import Parameter, Linear
import torch.nn.functional as F
import torch
from scipy.special import comb
import numpy as np


def poly_bern_fun(x, alphas, K):
    z = x / 2
    y = alphas[0] * comb(K, 0) * torch.pow(1 - z, K)
    for i in np.arange(1, K + 1):
        y = y + alphas[i] * comb(K, i) * torch.pow(1 - z, K - i) * torch.pow(z, i)
    return y


class SAF(torch.nn.Module):
    def __init__(self, args, num_nodes, dev, in_dim, hid_dim, out_dim, dcpW, dcpU, agg_layers, agg_alpha, drop_ama, sparse_eps=1e-2, K=10):
        super(SAF, self).__init__()
        self.dprate = args.dprate
        self.dropout = args.dropout
        self.scale = args.scale
        self.dcpW = dcpW
        self.dcpU = dcpU
        self.agg_layers = agg_layers
        self.agg_alpha = agg_alpha
        self.K = K

        self.lin1 = Linear(in_dim, hid_dim)
        self.lin2 = Linear(hid_dim, out_dim)
        self.filter_prop = specBern_prop(K)

        self.lin_f = Linear(out_dim, 1, bias=False)
        self.lin_a = Linear(out_dim, 1, bias=False)
        self.drop_ama = drop_ama
        self.sparse_eps = sparse_eps

        self.iden_mat = torch.eye(num_nodes).to(dev)
        self.smallOffSet = 1e-10

    def _dropout(self, x, p):
        return F.dropout(x, p=p, training=self.training)

    def get_adapted_L(self, alphas):
        polyV = poly_bern_fun(x=self.dcpW, alphas=alphas, K=self.K)
        polyV = polyV / torch.max(alphas)
        adjusted_w = 1 / (polyV + self.smallOffSet) - 1

        # scale-normalization
        adjusted_w = self.scale * adjusted_w

        new_L = torch.matmul(self.dcpU, torch.matmul(adjusted_w.diag_embed(), self.dcpU.T))
        return new_L

    def sparsification(self, mat):
        if self.sparse_eps == 0:
            return mat
        else:
            return ((mat > self.sparse_eps).float() + (mat < -self.sparse_eps).float()) * mat

    def nonLoc_agg(self, x, new_L):
        z_a = x
        new_A = self.iden_mat - new_L
        new_A = self.sparsification(new_A)  # sparsification
        for _ in range(self.agg_layers):
            z_a = (1 - self.agg_alpha) * x + self.agg_alpha * torch.matmul(new_A, z_a)
        return z_a

    def forward(self, data):
        # load data
        x, edge_index = data.graph['node_feat'], data.graph['edge_index']

        # latent mapping
        x = self._dropout(x, p=self.dropout)
        x = F.relu(self.lin1(x))
        x = self._dropout(x, p=self.dropout)
        x = self.lin2(x)
        if self.dprate != 0.0:
            x = self._dropout(x, p=self.dprate)

        # spectral-filtering
        z_f, filterWeights = self.filter_prop(x, edge_index)

        # calculate the adapted L
        new_L = self.get_adapted_L(filterWeights)

        # spatial-aggregation
        z_a = self.nonLoc_agg(x, new_L)

        # prediction-amalgamation
        w_f = torch.sigmoid(self.lin_f(z_f))
        w_f = self._dropout(w_f, p=self.drop_ama)
        w_a = torch.sigmoid(self.lin_a(z_a))
        w_a = self._dropout(w_a, p=self.drop_ama)
        w = torch.cat((w_f, w_a), dim=1)
        w = F.normalize(w, p=1, dim=1)
        z = w[:, 0].unsqueeze(1) * z_f + w[:, 1].unsqueeze(1) * z_a

        return z


class specBern_prop(MessagePassing):
    # adopted from https://github.com/ivam-he/BernNet/blob/main/NodeClassification/Bernpro.py
    def __init__(self, K, bias=True, **kwargs):
        super(specBern_prop, self).__init__(aggr='add', **kwargs)

        self.K = K
        self.temp = Parameter(torch.Tensor(self.K + 1))
        self.reset_parameters()

    def reset_parameters(self):
        self.temp.data.fill_(1)

    def forward(self, x, edge_index, edge_weight=None):
        TEMP = F.relu(self.temp)

        # L=I-D^(-0.5)AD^(-0.5)
        edge_index1, norm1 = get_laplacian(edge_index, edge_weight, normalization='sym', dtype=x.dtype, num_nodes=x.size(self.node_dim))
        # 2I-L
        edge_index2, norm2 = add_self_loops(edge_index1, -norm1, fill_value=2., num_nodes=x.size(self.node_dim))

        tmp = []
        tmp.append(x)
        for i in range(self.K):
            x = self.propagate(edge_index2, x=x, norm=norm2, size=None)
            tmp.append(x)

        out = (comb(self.K, 0) / (2 ** self.K)) * TEMP[0] * tmp[self.K]

        for i in range(self.K):
            x = tmp[self.K - i - 1]
            x = self.propagate(edge_index1, x=x, norm=norm1, size=None)
            for j in range(i):
                x = self.propagate(edge_index1, x=x, norm=norm1, size=None)

            out = out + (comb(self.K, i + 1) / (2 ** self.K)) * TEMP[i + 1] * x
        return out, TEMP

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def __repr__(self):
        return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K,
                                          self.temp)
