"""
Graph Attention Networks in DGL using SPMV optimization.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""

import torch
import torch.nn as nn
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax


class GraphAttention(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        num_heads,
        feat_drop,
        attn_drop,
        alpha,
        edge_dim,
        residual=False,
    ):
        super(GraphAttention, self).__init__()
        self.num_heads = num_heads
        self.fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
        self.out_dim = out_dim
        self.edge_dim = edge_dim
        if feat_drop:
            self.feat_drop = nn.Dropout(feat_drop)
        else:
            self.feat_drop = lambda x: x
        if attn_drop:
            self.attn_drop = nn.Dropout(attn_drop)
        else:
            self.attn_drop = lambda x: x
        self.attn_l = nn.Parameter(torch.Tensor(size=(1, num_heads, out_dim)))
        self.attn_r = nn.Parameter(torch.Tensor(size=(1, num_heads, out_dim)))
        # self.update_linear = nn.Linear(out_dim, out_dim, bias=False)
        self.update_linear = nn.Linear(out_dim + edge_dim, out_dim, bias=False)

        nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
        nn.init.xavier_normal_(self.attn_l.data, gain=1.414)
        nn.init.xavier_normal_(self.attn_r.data, gain=1.414)
        nn.init.xavier_normal_(self.update_linear.weight.data, gain=1.414)

        self.leaky_relu = nn.LeakyReLU(alpha)
        self.softmax = edge_softmax
        self.residual = residual
        if residual:
            if in_dim != out_dim:
                self.res_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
                nn.init.xavier_normal_(self.res_fc.weight.data, gain=1.414)
            else:
                self.res_fc = None

    def message_fn(self, edges):
        # return {'ft': edges.src['ft'] * edges.data['a_drop']}
        return {
            "ft": torch.cat([edges.src["ft"], edges.data["er"]], dim=2)
            * edges.data["a_drop"]
        }

    def reduce_fn(self, nodes):
        accum = torch.sum(nodes.mailbox["ft"], 1)
        # return {'ft': self.update_linear(accum.view(-1, self.out_dim)).view(-1, accum.size(1),
        #                                                                                    self.out_dim)}
        return {
            "ft": self.update_linear(accum.view(-1, self.out_dim + self.edge_dim)).view(
                -1, accum.size(1), self.out_dim
            )
        }

    def forward(self, g, inputs):
        # prepare
        h = self.feat_drop(inputs)  # NxD
        ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1))  # NxHxD'
        a1 = (ft * self.attn_l).sum(dim=-1).unsqueeze(-1)  # N x H x 1
        a2 = (ft * self.attn_r).sum(dim=-1).unsqueeze(-1)  # N x H x 1
        g.ndata.update({"ft": ft, "a1": a1, "a2": a2})
        er = g.edata["e"].repeat(1, self.num_heads, 1)
        g.edata.update({"er": er})
        # 1. compute edge attention
        g.apply_edges(self.edge_attention)
        # 2. compute softmax
        self.edge_softmax(g)
        # 3. compute the aggregated node features scaled by the dropped,
        # unnormalized attention values.
        g.update_all(self.message_fn, self.reduce_fn)
        ret = g.ndata.pop("ft")
        del g.edata["er"]
        del g.ndata["a1"]
        del g.ndata["a2"]
        del g.edata["a_drop"]
        # 4. residual
        if self.residual:
            if self.res_fc is not None:
                resval = self.res_fc(h).reshape(
                    (h.shape[0], self.num_heads, -1)
                )  # NxHxD'
            else:
                resval = torch.unsqueeze(h, 1)  # Nx1xD'
            ret = resval + ret
        return ret

    def edge_attention(self, edges):
        # an edge UDF to compute unnormalized attention values from src and dst
        a = self.leaky_relu(edges.src["a1"] + edges.dst["a2"])
        return {"a": a}

    def edge_softmax(self, g):
        attention = self.softmax(g, g.edata.pop("a"))
        # Dropout attention scores and save them
        g.edata["a_drop"] = self.attn_drop(attention)


class GAT(nn.Module):
    def __init__(
        self,
        num_layers,
        in_dim,
        num_hidden,
        num_classes,
        heads,
        activation,
        feat_drop,
        attn_drop,
        alpha,
        edge_dim,
        residual,
    ):
        super(GAT, self).__init__()
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        # input projection (no residual)
        self.gat_layers.append(
            GraphAttention(
                in_dim,
                num_hidden,
                heads[0],
                feat_drop,
                attn_drop,
                alpha,
                edge_dim,
                False,
            )
        )
        # hidden layers
        for l in range(1, num_layers):
            # due to multi-head, the in_dim = num_hidden * num_heads
            self.gat_layers.append(
                GraphAttention(
                    num_hidden * heads[l - 1],
                    num_hidden,
                    heads[l],
                    feat_drop,
                    attn_drop,
                    alpha,
                    edge_dim,
                    residual,
                )
            )
        # output projection
        self.gat_layers.append(
            GraphAttention(
                num_hidden * heads[-2],
                num_classes,
                heads[-1],
                feat_drop,
                attn_drop,
                alpha,
                edge_dim,
                residual,
            )
        )

    def forward(self, g, inputs):
        h = inputs
        for l in range(self.num_layers):
            h = self.gat_layers[l](g, h).flatten(1)
            h = self.activation(h)
        # output projection
        logits = self.gat_layers[-1](g, h).mean(1)
        return logits

    def graph_detach(self, g):
        """
        Detach all keys in the graph
        :param g:
        :return:
        """
        ndata = {}
        for k, v in g.ndata.items():
            ndata[k] = v.detach().clone()
        g.ndata.update(ndata)
        edata = {}
        for k, v in g.edata.items():
            edata[k] = v.detach().clone()
        g.edata.update(edata)
        return g
