# download the `full example <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py>`_.
# In its essence, GAT is just a different aggregation function with attention
# over features of neighbors, instead of a simple mean aggregation.
#
# GAT in DGL
# ----------
#
# To begin, you can get an overall impression about how a ``GATLayer`` module is
# implemented in DGL. In this section, the four equations above are broken down
# one at a time.

import torch
import torch.nn as nn
import torch.nn.functional as F
from pygsp import graphs, filters, reduction
from sklearn.metrics import f1_score
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


class GATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim):
        super(GATLayer, self).__init__()
        self.g = g.to(device)
        # equation (1)
        self.fc = nn.Linear(in_dim, out_dim, bias=False).to(device)
        # equation (2)
        self.attn_fc = nn.Linear(2 * out_dim, 1, bias=False).to(device)

        self.attn_fc_cosine = nn.Parameter(torch.rand(1))
        self.attn_fc_GANN_src = nn.Linear(out_dim, 32, bias=True)
        self.attn_fc_GANN_dst = nn.Linear(out_dim, 32, bias=True)

    def edge_attention(self, edges):
        # edge UDF for equation (2)
        z2 = torch.cat([edges.src['z'].to(device), edges.dst['z'].to(device)], dim=1)
        a = self.attn_fc(z2).to(device)
        return {'e': F.leaky_relu(a).to(device)}

    def edge_attention_cosine(self, edges):
        cos = nn.CosineSimilarity(dim=1, eps=1e-6)
        z2 = cos(edges.src['z'].to(device), edges.dst['z'].to(device))
        a = (torch.unsqueeze(z2 * self.attn_fc_cosine, dim=1)).to(device)

        return {'e': a.to(device)}

    def edge_attention_GANN(self, edges):
        src = self.attn_fc_GANN_src(edges.src['z']).to(device)
        dst = self.attn_fc_GANN_src(edges.dst['z']).to(device)

        num = src.shape[0]

        arr = torch.sum(src * dst, dim=1)

        z2 = torch.unsqueeze(arr, dim=1)

        return {'e': z2.to(device)}

    def message_func(self, edges):
        # message UDF for equation (3) & (4)
        return {'z': edges.src['z'].to(device), 'e': edges.data['e'].to(device)}

    def reduce_func(self, nodes):
        # reduce UDF for equation (3) & (4)
        # equation (3)
        alpha = F.softmax(nodes.mailbox['e'], dim=1).to(device)
        # equation (4)
        h = torch.sum(alpha * nodes.mailbox['z'], dim=1).to(device)
        return {'h': h}

    def forward(self, h):
        # equation (1)
        z = self.fc(h.to(device)).to(device)
        self.g.ndata['z'] = z.to(device)
        # equation (2)
        self.g.apply_edges(self.edge_attention)
        # equation (3) & (4)
        self.g.update_all(self.message_func, self.reduce_func)
        return self.g.ndata.pop('h').to(device)


##################################################################
# Equation (1)
# ^^^^^^^^^^^^
#
# .. math::
#
#   z_i^{(l)}=W^{(l)}h_i^{(l)},(1)
#
# The first one shows linear transformation. It's common and can be
# easily implemented in Pytorch using ``torch.nn.Linear``.
#
# Equation (2)
# ^^^^^^^^^^^^
#
# .. math::
#
#   e_{ij}^{(l)}=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}|z_j^{(l)})),(2)
#
# The un-normalized attention score :math:`e_{ij}` is calculated using the
# embeddings of adjacent nodes :math:`i` and :math:`j`. This suggests that the
# attention scores can be viewed as edge data, which can be calculated by the
# ``apply_edges`` API. The argument to the ``apply_edges`` is an **Edge UDF**,
# which is defined as below:

def edge_attention(self, edges):
    # edge UDF for equation (2)
    z2 = torch.cat([edges.src['z'], edges.dst['z']], dim=1)
    a = self.attn_fc(z2)
    return {'e': F.leaky_relu(a)}


########################################################################3
# Here, the dot product with the learnable weight vector :math:`\vec{a^{(l)}}`
# is implemented again using PyTorch's linear transformation ``attn_fc``. Note
# that ``apply_edges`` will **batch** all the edge data in one tensor, so the
# ``cat``, ``attn_fc`` here are applied on all the edges in parallel.
#
# Equation (3) & (4)
# ^^^^^^^^^^^^^^^^^^
#
# .. math::
#
#   \begin{align}
#   \alpha_{ij}^{(l)}&=\frac{\exp(e_{ij}^{(l)})}{\sum_{k\in \mathcal{N}(i)}^{}\exp(e_{ik}^{(l)})},&(3)\\
#   h_i^{(l+1)}&=\sigma\left(\sum_{j\in \mathcal{N}(i)} {\alpha^{(l)}_{ij} z^{(l)}_j }\right),&(4)
#   \end{align}
#
# Similar to GCN, ``update_all`` API is used to trigger message passing on all
# the nodes. The message function sends out two tensors: the transformed ``z``
# embedding of the source node and the un-normalized attention score ``e`` on
# each edge. The reduce function then performs two tasks:
#
#
# * Normalize the attention scores using softmax (equation (3)).
# * Aggregate neighbor embeddings weighted by the attention scores (equation(4)).
#
# Both tasks first fetch data from the mailbox and then manipulate it on the
# second dimension (``dim=1``), on which the messages are batched.

def reduce_func(self, nodes):
    # reduce UDF for equation (3) & (4)
    # equation (3)
    alpha = F.softmax(nodes.mailbox['e'], dim=1)
    # equation (4)
    h = torch.sum(alpha * nodes.mailbox['z'], dim=1)
    return {'h': h}


#####################################################################
# Multi-head attention
# ^^^^^^^^^^^^^^^^^^^^
#
# Analogous to multiple channels in ConvNet, GAT introduces **multi-head
# attention** to enrich the model capacity and to stabilize the learning
# process. Each attention head has its own parameters and their outputs can be
# merged in two ways:
#
# .. math:: \text{concatenation}: h^{(l+1)}_{i} =||_{k=1}^{K}\sigma\left(\sum_{j\in \mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right)
#
# or
#
# .. math:: \text{average}: h_{i}^{(l+1)}=\sigma\left(\frac{1}{K}\sum_{k=1}^{K}\sum_{j\in\mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right)
#
# where :math:`K` is the number of heads. You can use
# concatenation for intermediary layers and average for the final layer.
#
# Use the above defined single-head ``GATLayer`` as the building block
# for the ``MultiHeadGATLayer`` below:

class MultiHeadGATLayer(nn.Module):
    def __init__(self, g, in_dim, out_dim, num_heads, merge='cat'):
        super(MultiHeadGATLayer, self).__init__()
        self.heads = nn.ModuleList()
        for i in range(num_heads):
            self.heads.append(GATLayer(g, in_dim, out_dim))
        self.merge = merge

    def forward(self, h):
        head_outs = [attn_head(h) for attn_head in self.heads]
        if self.merge == 'cat':
            # concat on the output feature dimension (dim=1)
            return torch.cat(head_outs, dim=1)
        else:
            # merge using average
            return torch.mean(torch.stack(head_outs))


###########################################################################
# Put everything together
# ^^^^^^^^^^^^^^^^^^^^^^^
#
# Now, you can define a two-layer GAT model.

class GAT(nn.Module):
    def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
        super(GAT, self).__init__()
        self.layer1 = MultiHeadGATLayer(g, in_dim, hidden_dim, num_heads)
        # Be aware that the input dimension is hidden_dim*num_heads since
        # multiple head outputs are concatenated together. Also, only
        # one attention head in the output layer.
        self.layer2 = MultiHeadGATLayer(g, hidden_dim * num_heads, out_dim, 1)

    def forward(self, h):
        h = self.layer1(h).to(device)
        h = F.elu(h).to(device)
        h = self.layer2(h).to(device)
        return h


def evaluate(model, g, features, labels, mask):
    model.eval().to(device)
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask].to(device)
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices.to(device) == labels)
        score = f1_score(indices.data.cpu().numpy(), labels.data.cpu().numpy(), average='micro')
        #        return correct.item() * 1.0 / len(labels)
        return score

# def evaluate(model, g, features, labels, mask):
#    model.eval().to(device)
#    with torch.no_grad():
#        logits = model(features)
#        labels = labels[mask].to(device)
#        predict = np.where(logits.data.cpu().numpy() >= 0., 1, 0)
#        score = f1_score(labels.data.cpu().numpy(),
#                         predict, average='micro')
#
#        return score
