import torch
from torch.nn import Parameter as Param
from torch_geometric.nn.conv import MessagePassing

from codes.model.inits import uniform
from codes.utils.util import get_param, get_param_id


class RGCNConv(MessagePassing):
    r"""The relational graph convolutional operator from the `"Modeling
    Relational Data with Graph Convolutional Networks"
    <https://arxiv.org/abs/1703.06103>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot
        \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
        \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,

    where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types.
    Edge type needs to be a one-dimensional :obj:`torch.long` tensor which
    stores a relation identifier
    :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge.

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        num_relations (int): Number of relations.
        num_bases (int): Number of bases used for basis-decomposition.
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        num_relations,
        num_bases,
        root_weight=True,
        bias=True,
        **kwargs
    ):
        super(RGCNConv, self).__init__(aggr="add", **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_bases = num_bases
        self.root_weight = root_weight
        self.use_bias = bias

        # self.basis = Param(torch.Tensor(num_bases, in_channels, out_channels))
        # self.att = Param(torch.Tensor(num_relations, num_bases))

        # if root_weight:
        #     self.root = Param(torch.Tensor(in_channels, out_channels))
        # else:
        #     self.register_parameter('root', None)

        # if bias:
        #     self.bias = Param(torch.Tensor(out_channels))
        # else:
        #     self.register_parameter('bias', None)

    # def reset_parameters(self):
    #     size = self.num_bases * self.in_channels
    #     uniform(size, self.basis)
    #     uniform(size, self.att)
    #     uniform(size, self.root)
    #     uniform(size, self.bias)

    def forward(
        self,
        x,
        edge_index,
        edge_types,
        relation_weights,
        params,
        param_name_dict,
        edge_norm=None,
        size=None,
    ):
        """"""
        # self.basis = params[get_param_id(param_name_dict, "basis")]
        # if self.root_weight:
        #     self.root = params[get_param_id(param_name_dict, "root")]
        # if self.use_bias:
        #     self.bias = params[get_param_id(param_name_dict, "bias")]
        self.basis = get_param(params, param_name_dict, "basis")
        if self.root_weight:
            self.root = get_param(params, param_name_dict, "root")
        if self.use_bias:
            self.bias = get_param(params, param_name_dict, "bias")

        return self.propagate(
            edge_index,
            size=size,
            x=x,
            edge_types=edge_types,
            edge_norm=edge_norm,
            relation_weights=relation_weights,
        )

    def message(self, x_j, edge_index_j, edge_types, edge_norm, relation_weights):
        w = torch.matmul(relation_weights, self.basis.view(self.num_bases, -1))

        w = w.view(self.num_relations, self.in_channels, self.out_channels)
        w = torch.index_select(w, 0, edge_types)
        out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)

        return out if edge_norm is None else out * edge_norm.view(-1, 1)

    def update(self, aggr_out, x):
        if self.root_weight:
            out = aggr_out + torch.matmul(x, self.root)

        if self.use_bias:
            out = out + self.bias
        return out

    def __repr__(self):
        return "{}({}, {}, num_relations={})".format(
            self.__class__.__name__,
            self.in_channels,
            self.out_channels,
            self.num_relations,
        )
