"""
This code was copied from the GCN implementation in DGL examples.
"""
import torch
import torch.nn as nn
from dgl.nn.pytorch import GraphConv
from dgl.nn.pytorch import SGConv
from torch.nn import init

from dgl import function as fn
from dgl.base import DGLError
from dgl.utils import expand_as_pair
    
    

class ChebConv(nn.Module):
    def __init__(self,
                 K,
                 norm='both',
                 activation=None,
                 allow_zero_in_degree=False):
        super(ChebConv, self).__init__()
        self._K = K
        self._norm = norm
        self._allow_zero_in_degree = allow_zero_in_degree

        self.register_parameter('weight', None)
        self.register_parameter('bias', None)
        self._activation = activation
    
    def set_allow_zero_in_degree(self, set_value):
        self._allow_zero_in_degree = set_value

    def propagate(self, graph, feat):
        graph.srcdata['h'] = feat
        graph.update_all(fn.copy_src('h', 'm'), fn.sum(msg='m', out='h'))
        rst = graph.dstdata['h']
        return rst
    
    def forward(self, graph, feat, weight=None, edge_weight=None):
        with graph.local_scope():
            if not self._allow_zero_in_degree:
                if (graph.in_degrees() == 0).any():
                    raise DGLError('There are 0-in-degree nodes in the graph, '
                                   'output for those nodes will be invalid. '
                                   'This is harmful for some applications, '
                                   'causing silent performance regression. '
                                   'Adding self-loop on the input graph by '
                                   'calling `g = dgl.add_self_loop(g)` will resolve '
                                   'the issue. Setting ``allow_zero_in_degree`` '
                                   'to be `True` when constructing this module will '
                                   'suppress the check and let the code run.')
            feat_src, feat_dst = expand_as_pair(feat, graph)
            degs = graph.out_degrees().float().clamp(min=1)
            norm = torch.pow(degs, -0.5)
            shp = norm.shape + (1,) * (feat_src.dim() - 1)
            norm = torch.reshape(norm, shp)
            feat_src = feat_src * norm
            
            # Tx_0 = feat_src
            # Tx_1 = self.propagate(graph, feat_src)
            # rst = Tx_0 + Tx_1
            # for k in range(2, self.K + 1):
            #     Tx_2 = 2 * self.propagate(graph, Tx_1) - Tx_0
            #     rst = rst + Tx_2
            #     Tx_0, Tx_1 = Tx_1, Tx_2
            rst = self.propagate(graph, feat_src)
        
            degs = graph.in_degrees().float().clamp(min=1)
            norm = torch.pow(degs, -0.5)
            shp = norm.shape + (1,) * (feat_dst.dim() - 1)
            norm = torch.reshape(norm, shp)
            rst = rst * norm

            if self._activation is not None:
                rst = self._activation(rst)
            return rst

    def extra_repr(self):
        summary = 'K={_K}, normalization={_norm}'
        if '_activation' in self.__dict__:
            summary += ', activation={_activation}'
        return summary.format(**self.__dict__)
    
    
class ChebNetII(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 K,
                 n_layers,
                 activation,
                 dropout,
                 bias = False,
                 weight= False):
        super(ChebNetII, self).__init__()
        self.g = g
        self.layers = nn.ModuleList()
        
        # input layer
        self.bns = torch.nn.ModuleList()
        self.res_linears = nn.ModuleList()
        self.layers.append(ChebConv(K=K, activation=activation))
        self.bns.append(torch.nn.BatchNorm1d(in_feats, momentum = 0.01))
        self.res_linears.append(torch.nn.Linear(in_feats, in_feats))
        
        # hidden layers
        for i in range(1, n_layers - 1):
            self.layers.append(ChebConv(K=K, activation=activation))
            self.bns.append(torch.nn.BatchNorm1d(in_feats, momentum = 0.01))
            self.res_linears.append(torch.nn.Linear(in_feats, in_feats))
            
        # output layer
        self.layers.append(ChebConv(K=K, activation=activation))
        self.res_linears.append(torch.nn.Identity())
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, blocks):
        collect = []
        h = blocks[0].srcdata['feat']
        h = self.dropout(h)
        num_output_nodes = blocks[-1].num_dst_nodes()
        collect.append(h[:num_output_nodes])
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h_res = h[:block.num_dst_nodes()]
            h = layer(block, h)
            h = self.dropout(h)
            collect.append(h[:num_output_nodes])
            h += self.res_linears[l](h_res)
        return collect[-1]
        
    
class GCN(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 bias = True,
                 weight=True):
        super(GCN, self).__init__()
        self.g = g
        self.layers = nn.ModuleList()
        # input layer
        self.bns = torch.nn.ModuleList()
        self.res_linears = nn.ModuleList()
        self.layers.append(GraphConv(in_feats, n_hidden, weight = weight, bias = bias, activation=activation))
        self.bns.append(torch.nn.BatchNorm1d(n_hidden, momentum = 0.01))
        self.res_linears.append(torch.nn.Linear(in_feats, n_hidden))
        # hidden layers
        for i in range(1, n_layers - 1):
            self.layers.append(GraphConv(n_hidden, n_hidden, weight=weight, bias=bias, activation=activation))
            self.bns.append(torch.nn.BatchNorm1d(n_hidden, momentum = 0.01))
            self.res_linears.append(torch.nn.Linear(n_hidden, n_hidden))
        # output layer
        self.layers.append(GraphConv(n_hidden, n_classes))
        self.res_linears.append(torch.nn.Identity())
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, blocks):
        collect = []
        h = blocks[0].srcdata['feat']
        h = self.dropout(h)
        num_output_nodes = blocks[-1].num_dst_nodes()
        collect.append(h[:num_output_nodes])
        for l, (layer, block) in enumerate(zip(self.layers, blocks)):
            h_res = h[:block.num_dst_nodes()]
            h = layer(block, h)
            h = self.dropout(h)
            collect.append(h[:num_output_nodes])
            h += self.res_linears[l](h_res)
        return collect[-1]

