import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
import dgl.function as fn
from dgl.nn.pytorch import GATConv


# Semantic attention in the metapath-based aggregation (the same as that in the HAN)
class SemanticAttention(nn.Module):
    def __init__(self, in_size, hidden_size=128):
        super(SemanticAttention, self).__init__()

        self.project = nn.Sequential(
            nn.Linear(in_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, 1, bias=False),
        )

    def forward(self, z):
        """
        Shape of z: (N, M , D*K)
        N: number of nodes
        M: number of metapath patterns
        D: hidden_size
        K: number of heads
        """
        w = self.project(z).mean(0)  # (M, 1)
        beta = torch.softmax(w, dim=0)  # (M, 1)
        beta = beta.expand((z.shape[0],) + beta.shape)  # (N, M, 1)

        return (beta * z).sum(1)  # (N, D * K)


# Metapath-based aggregation (the same as the HANLayer)
class HANLayer(nn.Module):
    def __init__(
        self, meta_path_patterns, in_size, out_size, layer_num_heads, dropout
    ):
        super(HANLayer, self).__init__()

        # One GAT layer for each meta path based adjacency matrix
        self.gat_layers = nn.ModuleList()
        for i in range(len(meta_path_patterns)):
            self.gat_layers.append(
                GATConv(
                    in_size,
                    out_size,
                    layer_num_heads,
                    dropout,
                    dropout,
                    activation=F.elu,
                    allow_zero_in_degree=True,
                )
            )
        self.semantic_attention = SemanticAttention(
            in_size=out_size * layer_num_heads
        )
        self.meta_path_patterns = list(
            tuple(meta_path_pattern) for meta_path_pattern in meta_path_patterns
        )

        self._cached_graph = None
        self._cached_coalesced_graph = {}

    def forward(self, g, h):
        semantic_embeddings = []
        # obtain metapath reachable graph
        if self._cached_graph is None or self._cached_graph is not g:
            self._cached_graph = g
            self._cached_coalesced_graph.clear()
            for meta_path_pattern in self.meta_path_patterns:
                self._cached_coalesced_graph[
                    meta_path_pattern
                ] = dgl.metapath_reachable_graph(g, meta_path_pattern)

        for i, meta_path_pattern in enumerate(self.meta_path_patterns):
            new_g = self._cached_coalesced_graph[meta_path_pattern]
            semantic_embeddings.append(self.gat_layers[i](new_g, h).flatten(1))
        semantic_embeddings = torch.stack(
            semantic_embeddings, dim=1
        )  # (N, M, D * K)

        return self.semantic_attention(semantic_embeddings)  # (N, D * K)


# Relational neighbor aggregation
class RelationalAGG(nn.Module):
    def __init__(self, g, in_size, out_size, dropout=0.1):
        super(RelationalAGG, self).__init__()
        self.in_size = in_size
        self.out_size = out_size

        # Transform weights for different types of edges
        self.W_T = nn.ModuleDict(
            {
                name: nn.Linear(in_size, out_size, bias=False)
                for name in g.etypes
            }
        )

        # Attention weights for different types of edges
        self.W_A = nn.ModuleDict(
            {name: nn.Linear(out_size, 1, bias=False) for name in g.etypes}
        )

        # layernorm
        self.layernorm = nn.LayerNorm(out_size)

        # dropout layer
        self.dropout = nn.Dropout(dropout)

    def forward(self, g, feat_dict):
        funcs = {}
        for srctype, etype, dsttype in g.canonical_etypes:
            g.nodes[dsttype].data["h"] = feat_dict[
                dsttype
            ]  # nodes' original feature
            g.nodes[srctype].data["h"] = feat_dict[srctype]
            g.nodes[srctype].data["t_h"] = self.W_T[etype](
                feat_dict[srctype]
            )  # src nodes' transformed feature

            # compute the attention numerator (exp)
            g.apply_edges(fn.u_mul_v("t_h", "h", "x"), etype=etype)
            g.edges[etype].data["x"] = torch.exp(
                self.W_A[etype](g.edges[etype].data["x"])
            )

            # first update to compute the attention denominator (\sum exp)
            funcs[etype] = (fn.copy_e("x", "m"), fn.sum("m", "att"))
        g.multi_update_all(funcs, "sum")

        funcs = {}
        for srctype, etype, dsttype in g.canonical_etypes:
            g.apply_edges(
                fn.e_div_v("x", "att", "att"), etype=etype
            )  # compute attention weights (numerator/denominator)
            funcs[etype] = (
                fn.u_mul_e("h", "att", "m"),
                fn.sum("m", "h"),
            )  # \sum(h0*att) -> h1
        # second update to obtain h1
        g.multi_update_all(funcs, "sum")

        # apply activation, layernorm, and dropout
        feat_dict = {}
        for ntype in g.ntypes:
            feat_dict[ntype] = self.dropout(
                self.layernorm(F.relu_(g.nodes[ntype].data["h"]))
            )  # apply activation, layernorm, and dropout

        return feat_dict


class TAHIN(nn.Module):
    def __init__(
        self, g, meta_path_patterns, in_size, out_size, num_heads, dropout
    ):
        super(TAHIN, self).__init__()

        # embeddings for different types of nodes, h0
        self.initializer = nn.init.xavier_uniform_
        self.feature_dict = nn.ParameterDict(
            {
                ntype: nn.Parameter(
                    self.initializer(torch.empty(g.num_nodes(ntype), in_size))
                )
                for ntype in g.ntypes
            }
        )

        # relational neighbor aggregation, this produces h1
        self.RelationalAGG = RelationalAGG(g, in_size, out_size)

        # metapath-based aggregation modules for user and item, this produces h2
        self.meta_path_patterns = meta_path_patterns
        # one HANLayer for user, one HANLayer for item
        self.hans = nn.ModuleDict(
            {
                key: HANLayer(value, in_size, out_size, num_heads, dropout)
                for key, value in self.meta_path_patterns.items()
            }
        )

        # layers to combine h0, h1, and h2
        # used to update node embeddings
        self.user_layer1 = nn.Linear(
            (num_heads + 1) * out_size, out_size, bias=True
        )
        self.user_layer2 = nn.Linear(2 * out_size, out_size, bias=True)
        self.item_layer1 = nn.Linear(
            (num_heads + 1) * out_size, out_size, bias=True
        )
        self.item_layer2 = nn.Linear(2 * out_size, out_size, bias=True)

        # layernorm
        self.layernorm = nn.LayerNorm(out_size)

        # network to score the node pairs
        self.pred = nn.Linear(out_size, out_size)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(out_size, 1)

    def forward(self, g, user_key, item_key, user_idx, item_idx):
        # relational neighbor aggregation, h1
        h1 = self.RelationalAGG(g, self.feature_dict)

        # metapath-based aggregation, h2
        h2 = {}
        for key in self.meta_path_patterns.keys():
            h2[key] = self.hans[key](g, self.feature_dict[key])

        # update node embeddings
        user_emb = torch.cat((h1[user_key], h2[user_key]), 1)
        item_emb = torch.cat((h1[item_key], h2[item_key]), 1)
        user_emb = self.user_layer1(user_emb)
        item_emb = self.item_layer1(item_emb)
        user_emb = self.user_layer2(
            torch.cat((user_emb, self.feature_dict[user_key]), 1)
        )
        item_emb = self.item_layer2(
            torch.cat((item_emb, self.feature_dict[item_key]), 1)
        )

        # Relu
        user_emb = F.relu_(user_emb)
        item_emb = F.relu_(item_emb)

        # layer norm
        user_emb = self.layernorm(user_emb)
        item_emb = self.layernorm(item_emb)

        # obtain users/items embeddings and their interactions
        user_feat = user_emb[user_idx]
        item_feat = item_emb[item_idx]
        interaction = user_feat * item_feat

        # score the node pairs
        pred = self.pred(interaction)
        pred = self.dropout(pred)  # dropout
        pred = self.fc(pred)
        pred = torch.sigmoid(pred)

        return pred.squeeze(1)
