# RGCN - DGL implementation

import torch
import torch.nn as nn
import dgl.function as fn
from functools import partial
import torch.nn.functional as F


class RGCNLayer(nn.Module):
    def __init__(
        self,
        in_feat,
        out_feat,
        bias=None,
        activation=None,
        self_loop=False,
        dropout=0.0,
    ):
        super(RGCNLayer, self).__init__()
        self.bias = bias
        self.activation = activation
        self.self_loop = self_loop

        if self.bias == True:
            self.bias = nn.Parameter(torch.Tensor(out_feat))
            nn.init.xavier_uniform_(self.bias, gain=nn.init.calculate_gain("relu"))

        # weight for self loop
        if self.self_loop:
            self.loop_weight = nn.Parameter(torch.Tensor(in_feat, out_feat))
            nn.init.xavier_uniform_(
                self.loop_weight, gain=nn.init.calculate_gain("relu")
            )

        if dropout:
            self.dropout = nn.Dropout(dropout)
        else:
            self.dropout = None

    # define how propagation is done in subclass
    def propagate(self, g):
        raise NotImplementedError

    def forward(self, g):
        if self.self_loop:
            loop_message = torch.mm(g.ndata["h"], self.loop_weight)
            if self.dropout is not None:
                loop_message = self.dropout(loop_message)

        self.propagate(g)

        # apply bias and activation
        node_repr = g.ndata["h"]
        if self.bias:
            node_repr = node_repr + self.bias
        if self.self_loop:
            node_repr = node_repr + loop_message
        if self.activation:
            node_repr = self.activation(node_repr)

        g.ndata["h"] = node_repr


class RGCNBasisLayer(RGCNLayer):
    def __init__(
        self,
        in_feat,
        out_feat,
        num_rels,
        num_bases=-1,
        bias=None,
        activation=None,
        is_input_layer=False,
    ):
        super(RGCNBasisLayer, self).__init__(in_feat, out_feat, bias, activation)
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.num_rels = num_rels
        self.num_bases = num_bases
        self.is_input_layer = is_input_layer
        if self.num_bases <= 0 or self.num_bases > self.num_rels:
            self.num_bases = self.num_rels

        # add basis weights
        self.weight = nn.Parameter(
            torch.Tensor(self.num_bases, self.in_feat, self.out_feat)
        )
        if self.num_bases < self.num_rels:
            # linear combination coefficients
            self.w_comp = nn.Parameter(torch.Tensor(self.num_rels, self.num_bases))
        nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain("relu"))
        if self.num_bases < self.num_rels:
            nn.init.xavier_uniform_(self.w_comp, gain=nn.init.calculate_gain("relu"))

    def propagate(self, g):
        if self.num_bases < self.num_rels:
            # generate all weights from bases
            weight = self.weight.view(self.num_bases, self.in_feat * self.out_feat)
            weight = torch.matmul(self.w_comp, weight).view(
                self.num_rels, self.in_feat, self.out_feat
            )
        else:
            weight = self.weight

        def msg_func(edges):
            w = weight.index_select(0, edges.data["type"])
            msg = torch.bmm(edges.src["h"].unsqueeze(1), w).squeeze()
            msg = msg * edges.data["norm"]
            return {"msg": msg}

        g.update_all(msg_func, fn.sum(msg="msg", out="h"), None)


class RGCNBlockLayer(RGCNLayer):
    def __init__(
        self,
        in_feat,
        out_feat,
        num_rels,
        num_bases,
        bias=None,
        activation=None,
        self_loop=False,
        dropout=0.0,
    ):
        super(RGCNBlockLayer, self).__init__(
            in_feat, out_feat, bias, activation, self_loop=self_loop, dropout=dropout
        )
        self.num_rels = num_rels
        self.num_bases = num_bases
        assert self.num_bases > 0

        self.out_feat = out_feat
        self.submat_in = in_feat // self.num_bases
        self.submat_out = out_feat // self.num_bases

        # assuming in_feat and out_feat are both divisible by num_bases
        self.weight = nn.Parameter(
            torch.Tensor(
                self.num_rels, self.num_bases * self.submat_in * self.submat_out
            )
        )
        nn.init.xavier_uniform_(self.weight, gain=nn.init.calculate_gain("relu"))

    def msg_func(self, edges):
        weight = self.weight.index_select(0, edges.data["type"]).view(
            -1, self.submat_in, self.submat_out
        )
        node = edges.src["h"].view(-1, 1, self.submat_in)
        msg = torch.bmm(node, weight).view(-1, self.out_feat)
        return {"msg": msg}

    def propagate(self, g):
        g.update_all(self.msg_func, fn.sum(msg="msg", out="h"), self.apply_func)

    def apply_func(self, nodes):
        return {"h": nodes.data["h"] * nodes.data["norm"]}


class BaseRGCN(nn.Module):
    def __init__(
        self,
        h_dim,
        out_dim,
        num_rels,
        num_bases=-1,
        num_hidden_layers=1,
        dropout=0,
        use_cuda=False,
    ):
        super(BaseRGCN, self).__init__()
        # self.num_nodes = num_nodes
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.num_rels = num_rels
        self.num_bases = num_bases
        self.num_hidden_layers = num_hidden_layers
        self.dropout = dropout
        self.use_cuda = use_cuda

        # create rgcn layers
        self.build_model()

        # create initial features
        # self.features = self.create_features()

    def build_model(self):
        self.layers = nn.ModuleList()
        # i2h
        i2h = self.build_input_layer()
        if i2h is not None:
            self.layers.append(i2h)
        # h2h
        for idx in range(self.num_hidden_layers):
            h2h = self.build_hidden_layer(idx)
            self.layers.append(h2h)
        # h2o
        h2o = self.build_output_layer()
        if h2o is not None:
            self.layers.append(h2o)

    # initialize feature for each node
    def create_features(self):
        return None

    def build_input_layer(self):
        return None

    def build_hidden_layer(self, idx):
        raise NotImplementedError

    def build_output_layer(self):
        return None

    def forward(self, g):
        if self.features is not None:
            g.ndata["id"] = self.features
        for layer in self.layers:
            layer(g)
        return g.ndata.pop("h")


class EntityClassify(BaseRGCN):
    def create_features(self, num_nodes):
        features = torch.arange(num_nodes)
        if self.use_cuda:
            features = features.cuda()
        return features

    def build_input_layer(self):
        return RGCNBasisLayer(
            1,
            self.h_dim,
            self.num_rels,
            self.num_bases,
            activation=F.relu,
            is_input_layer=True,
        )

    def build_hidden_layer(self, idx):
        return RGCNBasisLayer(
            self.h_dim, self.h_dim, self.num_rels, self.num_bases, activation=F.relu
        )

    def build_output_layer(self):
        return RGCNBasisLayer(
            self.h_dim,
            self.out_dim,
            self.num_rels,
            self.num_bases,
            activation=partial(F.softmax, dim=1),
        )
