import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.module import Module
from torch_scatter import scatter_add, scatter
from torch_geometric.utils import add_remaining_self_loops
from script.models.utils import zeros, glorot

class GraphConvolution(Module):
    """
    Simple GCN layer.
    """

    def __init__(self, in_features, out_features, act, dropout=0.6, use_bias=True, use_act=True):
        super(GraphConvolution, self).__init__()
        self.dropout = dropout
        self.linear = nn.Linear(in_features, out_features, False)
        self.use_bias = use_bias
        self.use_act = use_act
        if self.use_bias:
            self.bias = nn.Parameter(torch.empty(out_features))
        self.act = act
        self.in_features = in_features
        self.out_features = out_features
        self.reset_params()

    def reset_params(self):
        glorot(self.linear.weight)
        zeros(self.bias)

    @staticmethod
    def norm(edge_index, num_nodes, edge_weight=None, improved=False, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1),), dtype=dtype,
                                     device=edge_index.device)

        fill_value = 1 if not improved else 2
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index):
        
        ## Linear Transformation
        hidden = self.linear.forward(x)
        if self.dropout > 0:
            hidden = F.dropout(hidden, self.dropout, training=self.training)

        ## Aggregation
        edge_index, norm = self.norm(edge_index, hidden.size(0), dtype=hidden.dtype)
        node_i = edge_index[0]
        node_j = edge_index[1]
        hidden_j = torch.nn.functional.embedding(node_j, hidden)
        support = norm.view(-1, 1) * hidden_j
        h = scatter(support, node_i, dim=0, dim_size=x.size(0))  # aggregate the neighbors of node_i

        if self.use_bias:
            h = h + self.bias
        ## Activation
        if self.use_act:
            output = self.act(h)
        else:
            output = h
        return output

    def extra_repr(self):
        return 'input_dim={}, output_dim={}'.format(
                self.in_features, self.out_features
        )