# download the `full example <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py>`_.
# In its essence, GAT is just a different aggregation function with attention
# over features of neighbors, instead of a simple mean aggregation.
#
# GAT in DGL
# ----------
#
# To begin, you can get an overall impression about how a ``GATLayer`` module is
# implemented in DGL. In this section, the four equations above are broken down
# one at a time.

import torch
import torch.nn as nn
import torch.nn.functional as F
from pygsp import graphs, filters, reduction
import numpy as np

from dgl.data import citation_graph as citegrh
from dgl.data import reddit
from dgl.data import gnn_benckmark as gnnbnch
from scipy import sparse, stats
from scipy.sparse import csr_matrix


from dgl import DGLGraph
import networkx as nx
from dgl import  transform

import pandas as pd
import scipy

from numpy import inf
import random

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")    

class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g.to(device)
        # equation (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False).to(device)
        # equation (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False).to(device)

    def edge_attention(self, edges):
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'].to(device), edges.dst['z'].to(device)], dim=1)
        a = self.attn_fc(z2).to(device)
        return {'e': F.leaky_relu(a).to(device)}

    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z': edges.src['z'].to(device), 'e': edges.data['e'].to(device)}

    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1).to(device)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1).to(device)
        return {'h': h}

    def forward(self, h):
        # equation (1)
        z = self.fc(h.to(device)).to(device)
        self.g.ndata['z'] = z.to(device)
        # equation (2)
#        self.g.apply_edges(self.edge_attention_zero)
        self.g.apply_edges(self.edge_attention)
        # equation (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
#        self.g.apply_edges(self.edge_attention_zero)
        
        return self.g.ndata.pop('h').to(device)
      
    
  
      

##################################################################
# Equation (1)
# ^^^^^^^^^^^^
#
# .. math::
#
#   z_i^{(l)}=W^{(l)}h_i^{(l)},(1)
#
# The first one shows linear transformation. It's common and can be
# easily implemented in Pytorch using ``torch.nn.Linear``.
#
# Equation (2)
# ^^^^^^^^^^^^
#
# .. math::
#
#   e_{ij}^{(l)}=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}|z_j^{(l)})),(2)
#
# The un-normalized attention score :math:`e_{ij}` is calculated using the
# embeddings of adjacent nodes :math:`i` and :math:`j`. This suggests that the
# attention scores can be viewed as edge data, which can be calculated by the
# ``apply_edges`` API. The argument to the ``apply_edges`` is an **Edge UDF**,
# which is defined as below:

def edge_attention(self, edges):
    # edge UDF for equation (2)
    z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
    a = self.attn_fc(z2)
    return {'e' : F.leaky_relu(a)}

########################################################################3
# Here, the dot product with the learnable weight vector :math:`\vec{a^{(l)}}`
# is implemented again using PyTorch's linear transformation ``attn_fc``. Note
# that ``apply_edges`` will **batch** all the edge data in one tensor, so the
# ``cat``, ``attn_fc`` here are applied on all the edges in parallel.
#
# Equation (3) & (4)
# ^^^^^^^^^^^^^^^^^^
#
# .. math::
#
#   \begin{align}
#   \alpha_{ij}^{(l)}&=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},&(3)\\
#   h_i^{(l+1)}&=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),&(4)
#   \end{align}
#
# Similar to GCN, ``update_all`` API is used to trigger message passing on all
# the nodes. The message function sends out two tensors: the transformed ``z``
# embedding of the source node and the un-normalized attention score ``e`` on
# each edge. The reduce function then performs two tasks:
#
#
# * Normalize the attention scores using softmax (equation (3)).
# * Aggregate neighbor embeddings weighted by the attention scores (equation(4)).
#
# Both tasks first fetch data from the mailbox and then manipulate it on the
# second dimension (``dim=1``), on which the messages are batched.

def reduce_func(self, nodes):
    # reduce UDF for equation (3) & (4)
    # equation (3)
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    # equation (4)
    h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
    return {'h' : h}

#####################################################################
# Multi-head attention
# ^^^^^^^^^^^^^^^^^^^^
#
# Analogous to multiple channels in ConvNet, GAT introduces **multi-head
# attention** to enrich the model capacity and to stabilize the learning
# process. Each attention head has its own parameters and their outputs can be
# merged in two ways:
#
# .. math:: \text{concatenation}: h^{(l+1)}_{i} =||_{k=1}^{K}\sigma\left(\sum_{j\in \mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right)
#
# or
#
# .. math:: \text{average}: h_{i}^{(l+1)}=\sigma\left(\frac{1}{K}\sum_{k=1}^{K}\sum_{j\in\mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right)
#
# where :math:`K` is the number of heads. You can use
# concatenation for intermediary layers and average for the final layer.
#
# Use the above defined single-head ``GATLayer`` as the building block
# for the ``MultiHeadGATLayer`` below:

class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, resistance_distances,epsilon, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            print("Generating ",str(i),"th sparse graph for the ",str(i),"th head")
            g_sp = self.graph_sparsify_precompute_reff(g, resistance_distances, epsilon, i)
            self.heads.append(GATLayer(g_sp, in_dim, out_dim))
        self.merge = merge

    def graph_sparsify_precompute_reff(self, g, resistance_distances, epsilon, seed):
    
      W = g.adjacency_matrix_scipy(return_edge_ids=False)
    
      N = np.shape(W)[0]
      start_nodes, end_nodes, weights = sparse.find(sparse.tril(W))

      # Calculate the new weights.
      weights = np.maximum(0, weights)
      Re = np.maximum(0, resistance_distances[start_nodes, end_nodes])
      Pe = weights * Re
      Pe = Pe / np.sum(Pe)
      Pe = np.squeeze(Pe)


      # Rudelson, 1996 Random Vectors in the Isotropic Position
      # (too hard to figure out actual C0)
      C0 = 1 / 30.
      # Rudelson and Vershynin, 2007, Thm. 3.1
      C = 4 * C0
      q = round(N * np.log(N) * 9 * C ** 2 / (epsilon ** 2))

      #        results = stats.rv_discrete(values=(np.arange(np.shape(Pe)[0]), Pe)).rvs(size=int(q))
      np.random.seed(seed)
      results = np.random.choice(np.arange(np.shape(Pe)[0]), int(q), p=list(Pe))
      spin_counts = stats.itemfreq(results).astype(int)

      per_spin_weights = weights / (q * Pe)
      per_spin_weights[per_spin_weights == inf] = 0

      counts = np.zeros(np.shape(weights)[0])
      counts[spin_counts[:, 0]] = spin_counts[:, 1]
      new_weights = counts * per_spin_weights

      sparserW = sparse.csc_matrix((np.squeeze(new_weights), (start_nodes, end_nodes)),
                                 shape=(N, N))
      sparserW = sparserW + sparserW.T

      g_sp = DGLGraph()
      g_sp.from_scipy_sparse_matrix(sparserW)

      # add self loop
      g.add_edges(g.nodes(), g.nodes())
      g_sp.add_edges(g_sp.nodes(), g_sp.nodes())
    

      return g_sp
  

        
        
    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))

###########################################################################
# Put everything together
# ^^^^^^^^^^^^^^^^^^^^^^^
#
# Now, you can define a two-layer GAT model.

class GAT(nn.Module):
    def __init__(self, g, resistance_distances,epsilon, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads, resistance_distances,epsilon)
        # Be aware that the input dimension is hidden_dim*num_heads since
        # multiple head outputs are concatenated together. Also, only
        # one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1, resistance_distances,epsilon)

    def forward(self, h):
        h = self.layer1(h)
        h = F.elu(h)
        h = self.layer2(h)
        return h


def evaluate(model, g, features, labels, mask):
    model.eval().to(device)
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask].to(device)
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices.to(device) == labels)
        return correct.item() * 1.0 / len(labels)
