
import numpy as np
import math
import torch
from torch import nn
from torch.nn import init
import torch.nn.functional as F

from dgl import function as fn
from dgl.base import DGLError
from dgl.utils import expand_as_pair
from dgl.nn.pytorch.conv import GraphConv


class GraphConvAGGR(nn.Module):
    def __init__(
        self,
        in_feats,
        norm="both",
        activation=None,
        allow_zero_in_degree=False,
    ):
        super(GraphConvAGGR, self).__init__()
        if norm not in ("none", "both", "right", "left"):
            raise DGLError(
                'Invalid norm value. Must be either "none", "both", "right" or "left".'
                ' But got "{}".'.format(norm)
            )
        self._in_feats = in_feats
        self._norm = norm
        self._allow_zero_in_degree = allow_zero_in_degree
        self._activation = activation

    def set_allow_zero_in_degree(self, set_value):
        self._allow_zero_in_degree = set_value

    def forward(self, graph, feat, weight=None, edge_weight=None):
        with graph.local_scope():
            if not self._allow_zero_in_degree:
                if (graph.in_degrees() == 0).any():
                    raise DGLError(
                        "There are 0-in-degree nodes in the graph, "
                        "output for those nodes will be invalid. "
                        "This is harmful for some applications, "
                        "causing silent performance regression. "
                        "Adding self-loop on the input graph by "
                        "calling `g = dgl.add_self_loop(g)` will resolve "
                        "the issue. Setting ``allow_zero_in_degree`` "
                        "to be `True` when constructing this module will "
                        "suppress the check and let the code run."
                    )
            aggregate_fn = fn.copy_u("h", "m")
            if edge_weight is not None:
                assert edge_weight.shape[0] == graph.num_edges()
                graph.edata["_edge_weight"] = edge_weight
                aggregate_fn = fn.u_mul_e("h", "_edge_weight", "m")

            # (BarclayII) For RGCN on heterogeneous graphs we need to support GCN on bipartite.
            feat_src, feat_dst = expand_as_pair(feat, graph)
            if self._norm in ["left", "both"]:
                degs = graph.out_degrees().to(feat_src).clamp(min=1)
                if self._norm == "both":
                    norm = torch.pow(degs, -0.5)
                else:
                    norm = 1.0 / degs
                shp = norm.shape + (1,) * (feat_src.dim() - 1)
                norm = torch.reshape(norm, shp)
                feat_src = feat_src * norm

            graph.srcdata["h"] = feat_src
            graph.update_all(aggregate_fn, fn.sum(msg="m", out="h"))
            rst = graph.dstdata["h"]

            if self._norm in ["right", "both"]:
                degs = graph.in_degrees().to(feat_dst).clamp(min=1)
                if self._norm == "both":
                    norm = torch.pow(degs, -0.5)
                else:
                    norm = 1.0 / degs
                shp = norm.shape + (1,) * (feat_dst.dim() - 1)
                norm = torch.reshape(norm, shp)
                rst = rst * norm

            if self._activation is not None:
                rst = self._activation(rst)
            return rst

    def extra_repr(self):
        summary = "in={_in_feats}"
        summary += ", normalization={_norm}"
        if "_activation" in self.__dict__:
            summary += ", activation={_activation}"
        return summary.format(**self.__dict__)



class ChebnetIIProp(nn.Module):
    def __init__(self, K, Init=False, random_init=False):
        super(ChebnetIIProp, self).__init__()
        self.K = K
        self.temp = nn.Parameter(torch.Tensor(self.K + 1))
        self.Init = Init
        self.random_init = random_init
        self.reset_parameters()

    def reset_parameters(self):
        self.temp.data.fill_(1.0)
        if self.Init:
            for j in range(self.K + 1):
                x_j = math.cos((self.K - j + 0.5) * math.pi / (self.K + 1))
                self.temp.data[j] = x_j ** 2
        if self.random_init:
            TEMP = np.random.uniform(-1., 1., self.K + 1)
            self.temp = nn.Parameter(torch.tensor(TEMP))

    def forward(self, graph, x):
        with graph.local_scope():
            coe_tmp = F.relu(self.temp)
            coe = coe_tmp.clone()
            for i in range(self.K + 1):
                coe[i] = coe_tmp[0] * math.cos((self.K + 0.5) * math.pi / (self.K + 1))
                for j in range(1, self.K + 1):
                    x_j = math.cos((self.K - j + 0.5) * math.pi / (self.K + 1))
                    coe[i] += coe_tmp[j] * math.cos(x_j)
                coe[i] = 2 * coe[i] / (self.K + 1)
            
            graph.ndata['h'] = x
            for _ in range(self.K):
                graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h'))
            return graph.ndata['h']


class ChebnetIIProp_SP(nn.Module):
    def __init__(self, in_feats, K, Init=False, random_init=False):
        super(ChebnetIIProp_SP, self).__init__()
        self.K = K
        self.temp = nn.Parameter(torch.Tensor(self.K + 1))
        self.Init = Init
        self.random_init = random_init
        self.propagate = GraphConv(in_feats=in_feats, out_feats=in_feats, weight=False, bias=False, activation=None)
        self.reset_parameters()

    def reset_parameters(self):
        self.temp.data.fill_(1.0)
        if self.Init:
            for j in range(self.K + 1):
                x_j = math.cos((self.K - j + 0.5) * math.pi / (self.K + 1))
                self.temp.data[j] = x_j ** 2
        if self.random_init:
            TEMP = np.random.uniform(-1., 1., self.K + 1)
            self.temp = nn.Parameter(torch.tensor(TEMP))

    def forward(self, graph, x):
        with graph.local_scope():
            coe_tmp = F.relu(self.temp)
            coe = coe_tmp.clone()
            for i in range(self.K + 1):
                coe[i] = coe_tmp[0] * math.cos((self.K + 0.5) * math.pi / (self.K + 1))
                for j in range(1, self.K + 1):
                    x_j = math.cos((self.K - j + 0.5) * math.pi / (self.K + 1))
                    coe[i] += coe_tmp[j] * math.cos(x_j)
                coe[i] = 2 * coe[i] / (self.K + 1)
            
        Tx_0 = x
        Tx_1 = self.propagate(graph, x)
        out = coe[0] / 2 *  Tx_0 + coe[1] * Tx_1
        for i in range(2, self.K+1):
            Tx_2 = self.propagate(graph, Tx_1)
            Tx_2 = 2 * Tx_2 - Tx_0
            out = out + coe[i] * Tx_2
            Tx_0 , Tx_1 = Tx_1, Tx_2
        return out