import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree



class CCLayer(MessagePassing):
    def __init__(self, data, num_hidden, dropout, gammma):
        super(CCLayer, self).__init__(aggr='add')
        self.data = data
        self.dropout = nn.Dropout(dropout)
        self.gate = nn.Linear(2 * num_hidden, 1)
        self.row, self.col = data.edge_index
        self.norm_degree = degree(self.row, num_nodes=data.y.shape[0]).clamp(min=1)
        self.norm_degree = torch.pow(self.norm_degree, -0.5)
        self.norm_degree_origin = torch.pow(self.norm_degree , -0.5)
        self.gamma = gammma
        nn.init.xavier_normal_(self.gate.weight, gain=1.414)

    def forward(self, h):
        norm_degree_row = self.norm_degree_origin[self.row].unsqueeze(1).expand_as(h[self.row])
        norm_degree_col = self.norm_degree_origin[self.col].unsqueeze(1).expand_as(h[self.col])
        result_row = h[self.row] / norm_degree_row
        result_col = h[self.col] / norm_degree_col
        g = torch.tanh(1 / torch.pow(torch.sum(torch.pow(result_row - result_col, 2), dim=1) + 1e-4,
                                         self.gamma))  # gamma = 2-k
        norm = g*self.norm_degree[self.row] * self.norm_degree[self.col]
        norm = self.dropout(norm)
        out = self.propagate(self.data.edge_index, size=(h.size(0), h.size(0)), x=h, norm=norm)
        return out

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

    def update(self, aggr_out):
        return aggr_out




class CGCN(nn.Module):
    def __init__(self, data, num_features, num_hidden, num_classes, dropout, alpha, gamma, layer_num=2):
        super(CGCN, self).__init__()
        self.alpha = alpha
        self.layer_num = layer_num
        self.dropout = dropout
        self.layers = nn.ModuleList()
        for i in range(self.layer_num):
            self.layers.append(CCLayer(data, num_hidden, dropout, gamma))
        self.t1 = nn.Linear(num_features, num_hidden)
        self.t2 = nn.Linear(num_hidden, num_classes)
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.xavier_normal_(self.t1.weight, gain=1.414)
        nn.init.xavier_normal_(self.t2.weight, gain=1.414)

    def forward(self, h):
        h = F.dropout(h, p=self.dropout, training=self.training)
        h = torch.relu(self.t1(h))
        h = F.dropout(h, p=self.dropout, training=self.training)
        raw = h
        for i in range(self.layer_num):
            h = self.layers[i](h)
            h = self.alpha * raw + (1-self.alpha) * h
        h = self.t2(h)
        return F.log_softmax(h, 1), h