"""
Graph Representation Learning via Hard Attention Networks in DGL using Adam optimization.
References
----------
Paper: https://arxiv.org/abs/1907.04652
"""

from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
import dgl.function as fn
from dgl.base import DGLError
from dgl.nn.pytorch import edge_softmax
from dgl.nn.pytorch.utils import Identity
from dgl.sampling import select_topk


class HardGAO(nn.Module):
    def __init__(
        self,
        in_feats,
        out_feats,
        num_heads=8,
        feat_drop=0.0,
        attn_drop=0.0,
        negative_slope=0.2,
        residual=True,
        activation=F.elu,
        k=8,
    ):
        super(HardGAO, self).__init__()
        self.num_heads = num_heads
        self.in_feats = in_feats
        self.out_feats = out_feats
        self.k = k
        self.residual = residual
        # Initialize Parameters for Additive Attention
        self.fc = nn.Linear(
            self.in_feats, self.out_feats * self.num_heads, bias=False
        )
        self.attn_l = nn.Parameter(
            torch.FloatTensor(size=(1, self.num_heads, self.out_feats))
        )
        self.attn_r = nn.Parameter(
            torch.FloatTensor(size=(1, self.num_heads, self.out_feats))
        )
        # Initialize Parameters for Hard Projection
        self.p = nn.Parameter(torch.FloatTensor(size=(1, in_feats)))
        # Initialize Dropouts
        self.feat_drop = nn.Dropout(feat_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.leaky_relu = nn.LeakyReLU(negative_slope)
        if self.residual:
            if self.in_feats == self.out_feats:
                self.residual_module = Identity()
            else:
                self.residual_module = nn.Linear(
                    self.in_feats, self.out_feats * num_heads, bias=False
                )

        self.reset_parameters()
        self.activation = activation

    def reset_parameters(self):
        gain = nn.init.calculate_gain("relu")
        nn.init.xavier_normal_(self.fc.weight, gain=gain)
        nn.init.xavier_normal_(self.p, gain=gain)
        nn.init.xavier_normal_(self.attn_l, gain=gain)
        nn.init.xavier_normal_(self.attn_r, gain=gain)
        if self.residual:
            nn.init.xavier_normal_(self.residual_module.weight, gain=gain)

    def forward(self, graph, feat, get_attention=False):
        # Check in degree and generate error
        if (graph.in_degrees() == 0).any():
            raise DGLError(
                "There are 0-in-degree nodes in the graph, "
                "output for those nodes will be invalid. "
                "This is harmful for some applications, "
                "causing silent performance regression. "
                "Adding self-loop on the input graph by "
                "calling `g = dgl.add_self_loop(g)` will resolve "
                "the issue. Setting ``allow_zero_in_degree`` "
                "to be `True` when constructing this module will "
                "suppress the check and let the code run."
            )
        # projection process to get importance vector y
        graph.ndata["y"] = torch.abs(
            torch.matmul(self.p, feat.T).view(-1)
        ) / torch.norm(self.p, p=2)
        # Use edge message passing function to get the weight from src node
        graph.apply_edges(fn.copy_u("y", "y"))
        # Select Top k neighbors
        subgraph = select_topk(graph.cpu(), self.k, "y").to(graph.device)
        # Sigmoid as information threshold
        subgraph.ndata["y"] = torch.sigmoid(subgraph.ndata["y"])
        # Using vector matrix elementwise mul for acceleration
        feat = subgraph.ndata["y"].view(-1, 1) * feat
        feat = self.feat_drop(feat)
        h = self.fc(feat).view(-1, self.num_heads, self.out_feats)
        el = (h * self.attn_l).sum(dim=-1).unsqueeze(-1)
        er = (h * self.attn_r).sum(dim=-1).unsqueeze(-1)
        # Assign the value on the subgraph
        subgraph.srcdata.update({"ft": h, "el": el})
        subgraph.dstdata.update({"er": er})
        # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
        subgraph.apply_edges(fn.u_add_v("el", "er", "e"))
        e = self.leaky_relu(subgraph.edata.pop("e"))
        # compute softmax
        subgraph.edata["a"] = self.attn_drop(edge_softmax(subgraph, e))
        # message passing
        subgraph.update_all(fn.u_mul_e("ft", "a", "m"), fn.sum("m", "ft"))
        rst = subgraph.dstdata["ft"]
        # activation
        if self.activation:
            rst = self.activation(rst)
        # Residual
        if self.residual:
            rst = rst + self.residual_module(feat).view(
                feat.shape[0], -1, self.out_feats
            )

        if get_attention:
            return rst, subgraph.edata["a"]
        else:
            return rst


class HardGAT(nn.Module):
    def __init__(
        self,
        g,
        num_layers,
        in_dim,
        num_hidden,
        num_classes,
        heads,
        activation,
        feat_drop,
        attn_drop,
        negative_slope,
        residual,
        k,
    ):
        super(HardGAT, self).__init__()
        self.g = g
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        gat_layer = partial(HardGAO, k=k)
        muls = heads
        # input projection (no residual)
        self.gat_layers.append(
            gat_layer(
                in_dim,
                num_hidden,
                heads[0],
                feat_drop,
                attn_drop,
                negative_slope,
                False,
                self.activation,
            )
        )
        # hidden layers
        for l in range(1, num_layers):
            # due to multi-head, the in_dim = num_hidden * num_heads
            self.gat_layers.append(
                gat_layer(
                    num_hidden * muls[l - 1],
                    num_hidden,
                    heads[l],
                    feat_drop,
                    attn_drop,
                    negative_slope,
                    residual,
                    self.activation,
                )
            )
        # output projection
        self.gat_layers.append(
            gat_layer(
                num_hidden * muls[-2],
                num_classes,
                heads[-1],
                feat_drop,
                attn_drop,
                negative_slope,
                False,
                None,
            )
        )

    def forward(self, inputs):
        h = inputs
        for l in range(self.num_layers):
            h = self.gat_layers[l](self.g, h).flatten(1)
        logits = self.gat_layers[-1](self.g, h).mean(1)
        return logits
