"""
Graph Attention Networks in DGL using SPMV optimization.
References
----------
Paper: https://arxiv.org/abs/1710.10903
Author's code: https://github.com/PetarV-/GAT
Pytorch implementation: https://github.com/Diego999/pyGAT
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn.pytorch import edge_softmax


class GraphAttention(nn.Module):
    def __init__(
        self,
        in_dim,
        out_dim,
        num_heads,
        feat_drop,
        attn_drop,
        alpha,
        edge_dim,
        residual=False,
    ):
        super(GraphAttention, self).__init__()
        self.num_heads = num_heads
        self.fc_in_dim = in_dim
        self.fc_out_dim = num_heads * out_dim
        # self.fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
        self.out_dim = out_dim
        self.edge_dim = edge_dim
        if feat_drop:
            self.feat_drop = nn.Dropout(feat_drop)
        else:
            self.feat_drop = lambda x: x
        if attn_drop:
            self.attn_drop = nn.Dropout(attn_drop)
        else:
            self.attn_drop = lambda x: x
        # self.attn_l = nn.Parameter(torch.Tensor(size=(1, num_heads, out_dim)))
        # self.attn_r = nn.Parameter(torch.Tensor(size=(1, num_heads, out_dim)))
        # self.update_linear = nn.Linear(out_dim, out_dim, bias=False)
        self.update_in_dim = out_dim + edge_dim
        self.update_out_dim = out_dim
        self.update_linear = None
        # self.update_linear = nn.Linear(out_dim + edge_dim, out_dim, bias=False)

        # nn.init.xavier_normal_(self.fc.weight.data, gain=1.414)
        # nn.init.xavier_normal_(self.attn_l.data, gain=1.414)
        # nn.init.xavier_normal_(self.attn_r.data, gain=1.414)
        # nn.init.xavier_normal_(self.update_linear.weight.data, gain=1.414)

        self.leaky_relu = nn.LeakyReLU(alpha)
        self.softmax = edge_softmax
        self.residual = residual
        if residual:
            if in_dim != out_dim:
                self.res_fc_in_dim = in_dim
                self.res_fc_out_dim = num_heads * out_dim
                # self.res_fc = nn.Linear(in_dim, num_heads * out_dim, bias=False)
                # nn.init.xavier_normal_(self.res_fc.weight.data, gain=1.414)
            else:
                self.res_fc = None

    def message_fn(self, edges):
        # return {'ft': edges.src['ft'] * edges.data['a_drop']}
        return {
            "ft": torch.cat([edges.src["ft"], edges.data["er"]], dim=2)
            * edges.data["a_drop"]
        }

    def reduce_fn(self, nodes):
        accum = torch.sum(nodes.mailbox["ft"], 1)
        # return {'ft': self.update_linear(accum.view(-1, self.out_dim)).view(-1, accum.size(1),
        #                                                                                    self.out_dim)}
        # return {'ft': self.update_linear(accum.view(-1, self.out_dim + self.edge_dim)).view(-1, accum.size(1), self.out_dim)}
        # WARN: could be a potential issue
        return {
            "ft": F.linear(
                accum.view(-1, self.out_dim + self.edge_dim),
                weight=self.update_linear.t(),
            ).view(-1, accum.size(1), self.out_dim)
        }

    def get_param_id(self, param_name_dict, partial_name):
        match = [k for k, v in param_name_dict.items() if partial_name in k]
        assert len(match) == 1
        name = match[0]
        if name not in param_name_dict:
            raise AssertionError("arg not found")
        return param_name_dict[name]

    def forward(self, g, inputs, params, param_name_dict):
        # prepare
        self.update_linear = params[self.get_param_id(param_name_dict, "update_linear")]
        h = self.feat_drop(inputs)  # NxD
        ft = F.linear(
            h, weight=params[self.get_param_id(param_name_dict, "fc_linear")].t()
        ).reshape((h.shape[0], self.num_heads, -1))
        # ft = self.fc(h).reshape((h.shape[0], self.num_heads, -1))  # NxHxD'
        # a1 = (ft * self.attn_l ).sum(dim=-1).unsqueeze(-1) # N x H x 1
        a1 = (
            (ft * params[self.get_param_id(param_name_dict, "attn_l")])
            .sum(dim=-1)
            .unsqueeze(-1)
        )  # N x H x 1
        # a2 = (ft * self.attn_r).sum(dim=-1).unsqueeze(-1) # N x H x 1
        a2 = (
            (ft * params[self.get_param_id(param_name_dict, "attn_r")])
            .sum(dim=-1)
            .unsqueeze(-1)
        )  # N x H x 1
        g.ndata.update({"ft": ft, "a1": a1, "a2": a2})
        er = g.edata["e"].repeat(1, self.num_heads, 1)
        g.edata.update({"er": er})
        # 1. compute edge attention
        g.apply_edges(self.edge_attention)
        # 2. compute softmax
        self.edge_softmax(g)
        # 3. compute the aggregated node features scaled by the dropped,
        # unnormalized attention values.
        g.update_all(self.message_fn, self.reduce_fn)
        ret = g.ndata.pop("ft")
        del g.edata["er"]
        del g.ndata["a1"]
        del g.ndata["a2"]
        del g.edata["a_drop"]
        # 4. residual
        if self.residual:
            if self.res_fc is not None:
                # resval = self.res_fc(h).reshape((h.shape[0], self.num_heads, -1))  # NxHxD'
                resval = F.linear(
                    h, params[self.get_param_id(param_name_dict, "residual_linear")].t()
                ).reshape(
                    (h.shape[0], self.num_heads, -1)
                )  # NxHxD'
            else:
                resval = torch.unsqueeze(h, 1)  # Nx1xD'
            ret = resval + ret
        return ret

    def edge_attention(self, edges):
        # an edge UDF to compute unnormalized attention values from src and dst
        a = self.leaky_relu(edges.src["a1"] + edges.dst["a2"])
        return {"a": a}

    def edge_softmax(self, g):
        attention = self.softmax(g, g.edata.pop("a"))
        # Dropout attention scores and save them
        g.edata["a_drop"] = self.attn_drop(attention)


class GAT(nn.Module):
    def __init__(
        self,
        num_layers,
        in_dim,
        num_hidden,
        num_classes,
        heads,
        activation,
        feat_drop,
        attn_drop,
        alpha,
        edge_dim,
        residual,
        device,
    ):
        super(GAT, self).__init__()
        self.num_layers = num_layers
        self.gat_layers = nn.ModuleList()
        self.activation = activation
        # input projection (no residual)
        self.gat_layers.append(
            GraphAttention(
                in_dim,
                num_hidden,
                heads[0],
                feat_drop,
                attn_drop,
                alpha,
                edge_dim,
                False,
            )
        )
        # hidden layers
        for l in range(1, num_layers):
            # due to multi-head, the in_dim = num_hidden * num_heads
            self.gat_layers.append(
                GraphAttention(
                    num_hidden * heads[l - 1],
                    num_hidden,
                    heads[l],
                    feat_drop,
                    attn_drop,
                    alpha,
                    edge_dim,
                    residual,
                )
            )
        # output projection
        self.gat_layers.append(
            GraphAttention(
                num_hidden * heads[-2],
                num_classes,
                heads[-1],
                feat_drop,
                attn_drop,
                alpha,
                edge_dim,
                residual,
            )
        )

        # define individual layer weights
        self.weights = []
        self.weight_names = []
        for gi, gat_layer in enumerate(self.gat_layers):
            fc_linear = torch.Tensor(
                size=(gat_layer.fc_in_dim, gat_layer.fc_out_dim)
            ).to(device)
            fc_linear.requires_grad = True
            self.weights.append(fc_linear)
            self.weight_names.append("fc_linear.{}".format(gi))
            attn_l = torch.Tensor(size=(1, gat_layer.num_heads, gat_layer.out_dim)).to(
                device
            )
            attn_l.requires_grad = True
            attn_r = torch.Tensor(size=(1, gat_layer.num_heads, gat_layer.out_dim)).to(
                device
            )
            attn_r.requires_grad = True
            self.weights.append(attn_l)
            self.weight_names.append("attn_l.{}".format(gi))
            self.weights.append(attn_r)
            self.weight_names.append("attn_r.{}".format(gi))
            update_linear = torch.Tensor(
                size=(gat_layer.update_in_dim, gat_layer.update_out_dim)
            ).to(device)
            update_linear.requires_grad = True
            self.weights.append(update_linear)
            self.weight_names.append("update_linear.{}".format(gi))
            if residual:
                res_linear = torch.Tensor(
                    size=(gat_layer.res_fc_in_dim, gat_layer.res_fc_out_dim)
                ).to(device)
                res_linear.requires_grad = True
                self.weights.append(res_linear)
                self.weight_names.append("residual_linear.{}".format(gi))

        # initialize
        for i in range(len(self.weights)):
            self.weights[i] = nn.init.xavier_normal_(self.weights[i], gain=1.414)

    def prepare_param_idx(self, layer_id=0):
        full_name_idx = {n: i for i, n in enumerate(self.weight_names)}
        gat_layer_param_indx = [
            i for i, k in enumerate(self.weight_names) if "{}".format(layer_id) in k
        ]
        param_names = [self.weight_names[gi] for gi in gat_layer_param_indx]
        param_name_to_idx = {k: full_name_idx[k] for v, k in enumerate(param_names)}
        return param_name_to_idx

    def forward(self, g, inputs, params):
        h = inputs
        for l in range(self.num_layers):
            n_dict = self.prepare_param_idx(l)
            h = self.gat_layers[l](g, h, params, n_dict).flatten(1)
            h = self.activation(h)
        # output projection
        n_dict = self.prepare_param_idx(self.num_layers)
        logits = self.gat_layers[-1](g, h, params, n_dict).mean(1)
        return logits

    def graph_detach(self, g):
        """
        Detach all keys in the graph
        :param g:
        :return:
        """
        ndata = {}
        for k, v in g.ndata.items():
            ndata[k] = v.detach().clone()
        g.ndata.update(ndata)
        edata = {}
        for k, v in g.edata.items():
            edata[k] = v.detach().clone()
        g.edata.update(edata)
        return g
