"""Utilities for pytorch NN package"""
# pylint: disable=no-member, invalid-name

import torch as th
import torch.nn.functional as F
from torch import nn

from ... import DGLGraph
from ... import function as fn
from ...base import dgl_warning


def matmul_maybe_select(A, B):
    """Perform Matrix multiplication C = A * B but A could be an integer id vector.

    If A is an integer vector, we treat it as multiplying a one-hot encoded tensor.
    In this case, the expensive dense matrix multiply can be replaced by a much
    cheaper index lookup.

    For example,
    ::

        A = [2, 0, 1],
        B = [[0.1, 0.2],
             [0.3, 0.4],
             [0.5, 0.6]]

    then matmul_maybe_select(A, B) is equivalent to
    ::

        [[0, 0, 1],     [[0.1, 0.2],
         [1, 0, 0],  *   [0.3, 0.4],
         [0, 1, 0]]      [0.5, 0.6]]

    In all other cases, perform a normal matmul.

    Parameters
    ----------
    A : torch.Tensor
        lhs tensor
    B : torch.Tensor
        rhs tensor

    Returns
    -------
    C : torch.Tensor
        result tensor
    """
    if A.dtype == th.int64 and len(A.shape) == 1:
        return B.index_select(0, A)
    else:
        return th.matmul(A, B)


def bmm_maybe_select(A, B, index):
    """Slice submatrices of A by the given index and perform bmm.

    B is a 3D tensor of shape (N, D1, D2), which can be viewed as a stack of
    N matrices of shape (D1, D2). The input index is an integer vector of length M.
    A could be either:
    (1) a dense tensor of shape (M, D1),
    (2) an integer vector of length M.
    The result C is a 2D matrix of shape (M, D2)

    For case (1), C is computed by bmm:
    ::

        C[i, :] = matmul(A[i, :], B[index[i], :, :])

    For case (2), C is computed by index select:
    ::

        C[i, :] = B[index[i], A[i], :]

    Parameters
    ----------
    A : torch.Tensor
        lhs tensor
    B : torch.Tensor
        rhs tensor
    index : torch.Tensor
        index tensor

    Returns
    -------
    C : torch.Tensor
        return tensor
    """
    if A.dtype == th.int64 and len(A.shape) == 1:
        # following is a faster version of B[index, A, :]
        B = B.view(-1, B.shape[2])
        flatidx = index * B.shape[1] + A
        return B.index_select(0, flatidx)
    else:
        BB = B.index_select(0, index)
        return th.bmm(A.unsqueeze(1), BB).squeeze()


# pylint: disable=W0235
class Identity(nn.Module):
    """A placeholder identity operator that is argument-insensitive.
    (Identity has already been supported by PyTorch 1.2, we will directly
    import torch.nn.Identity in the future)
    """

    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        """Return input"""
        return x


class Sequential(nn.Sequential):
    r"""A sequential container for stacking graph neural network modules

    DGL supports two modes: sequentially apply GNN modules on 1) the same graph or
    2) a list of given graphs. In the second case, the number of graphs equals the
    number of modules inside this container.

    Parameters
    ----------
    *args :
        Sub-modules of torch.nn.Module that will be added to the container in
        the order by which they are passed in the constructor.

    Examples
    --------
    The following example uses PyTorch backend.

    Mode 1: sequentially apply GNN modules on the same graph

    >>> import torch
    >>> import dgl
    >>> import torch.nn as nn
    >>> import dgl.function as fn
    >>> from dgl.nn.pytorch import Sequential
    >>> class ExampleLayer(nn.Module):
    >>>     def __init__(self):
    >>>         super().__init__()
    >>>     def forward(self, graph, n_feat, e_feat):
    >>>         with graph.local_scope():
    >>>             graph.ndata['h'] = n_feat
    >>>             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
    >>>             n_feat += graph.ndata['h']
    >>>             graph.apply_edges(fn.u_add_v('h', 'h', 'e'))
    >>>             e_feat += graph.edata['e']
    >>>             return n_feat, e_feat
    >>>
    >>> g = dgl.DGLGraph()
    >>> g.add_nodes(3)
    >>> g.add_edges([0, 1, 2, 0, 1, 2, 0, 1, 2], [0, 0, 0, 1, 1, 1, 2, 2, 2])
    >>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    >>> n_feat = torch.rand(3, 4)
    >>> e_feat = torch.rand(9, 4)
    >>> net(g, n_feat, e_feat)
    (tensor([[39.8597, 45.4542, 25.1877, 30.8086],
             [40.7095, 45.3985, 25.4590, 30.0134],
             [40.7894, 45.2556, 25.5221, 30.4220]]),
     tensor([[80.3772, 89.7752, 50.7762, 60.5520],
             [80.5671, 89.3736, 50.6558, 60.6418],
             [80.4620, 89.5142, 50.3643, 60.3126],
             [80.4817, 89.8549, 50.9430, 59.9108],
             [80.2284, 89.6954, 50.0448, 60.1139],
             [79.7846, 89.6882, 50.5097, 60.6213],
             [80.2654, 90.2330, 50.2787, 60.6937],
             [80.3468, 90.0341, 50.2062, 60.2659],
             [80.0556, 90.2789, 50.2882, 60.5845]]))

    Mode 2: sequentially apply GNN modules on different graphs

    >>> import torch
    >>> import dgl
    >>> import torch.nn as nn
    >>> import dgl.function as fn
    >>> import networkx as nx
    >>> from dgl.nn.pytorch import Sequential
    >>> class ExampleLayer(nn.Module):
    >>>     def __init__(self):
    >>>         super().__init__()
    >>>     def forward(self, graph, n_feat):
    >>>         with graph.local_scope():
    >>>             graph.ndata['h'] = n_feat
    >>>             graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
    >>>             n_feat += graph.ndata['h']
    >>>             return n_feat.view(graph.number_of_nodes() // 2, 2, -1).sum(1)
    >>>
    >>> g1 = dgl.DGLGraph(nx.erdos_renyi_graph(32, 0.05))
    >>> g2 = dgl.DGLGraph(nx.erdos_renyi_graph(16, 0.2))
    >>> g3 = dgl.DGLGraph(nx.erdos_renyi_graph(8, 0.8))
    >>> net = Sequential(ExampleLayer(), ExampleLayer(), ExampleLayer())
    >>> n_feat = torch.rand(32, 4)
    >>> net([g1, g2, g3], n_feat)
    tensor([[209.6221, 225.5312, 193.8920, 220.1002],
            [250.0169, 271.9156, 240.2467, 267.7766],
            [220.4007, 239.7365, 213.8648, 234.9637],
            [196.4630, 207.6319, 184.2927, 208.7465]])
    """

    def __init__(self, *args):
        super(Sequential, self).__init__(*args)

    def forward(self, graph, *feats):
        r"""

        Sequentially apply modules to the input.

        Parameters
        ----------
        graph : DGLGraph or list of DGLGraphs
            The graph(s) to apply modules on.

        *feats :
            Input features.
            The output of the :math:`i`-th module should match the input
            of the :math:`(i+1)`-th module in the sequential.
        """
        if isinstance(graph, list):
            for graph_i, module in zip(graph, self):
                if not isinstance(feats, tuple):
                    feats = (feats,)
                feats = module(graph_i, *feats)
        elif isinstance(graph, DGLGraph):
            for module in self:
                if not isinstance(feats, tuple):
                    feats = (feats,)
                feats = module(graph, *feats)
        else:
            raise TypeError(
                "The first argument of forward must be a DGLGraph"
                " or a list of DGLGraph s"
            )
        return feats


class WeightBasis(nn.Module):
    r"""Basis decomposition from `Modeling Relational Data with Graph
    Convolutional Networks <https://arxiv.org/abs/1703.06103>`__

    It can be described as below:

    .. math::

        W_o = \sum_{b=1}^B a_{ob} V_b

    Each weight output :math:`W_o` is essentially a linear combination of basis
    transformations :math:`V_b` with coefficients :math:`a_{ob}`.

    If is useful as a form of regularization on a large parameter matrix. Thus,
    the number of weight outputs is usually larger than the number of bases.

    Parameters
    ----------
    shape : tuple[int]
        Shape of the basis parameter.
    num_bases : int
        Number of bases.
    num_outputs : int
        Number of outputs.
    """

    def __init__(self, shape, num_bases, num_outputs):
        super(WeightBasis, self).__init__()
        self.shape = shape
        self.num_bases = num_bases
        self.num_outputs = num_outputs

        if num_outputs <= num_bases:
            dgl_warning(
                "The number of weight outputs should be larger than the number"
                " of bases."
            )

        self.weight = nn.Parameter(th.Tensor(self.num_bases, *shape))
        nn.init.xavier_uniform_(
            self.weight, gain=nn.init.calculate_gain("relu")
        )
        # linear combination coefficients
        self.w_comp = nn.Parameter(th.Tensor(self.num_outputs, self.num_bases))
        nn.init.xavier_uniform_(
            self.w_comp, gain=nn.init.calculate_gain("relu")
        )

    def forward(self):
        r"""Forward computation

        Returns
        -------
        weight : torch.Tensor
            Composed weight tensor of shape ``(num_outputs,) + shape``
        """
        # generate all weights from bases
        weight = th.matmul(self.w_comp, self.weight.view(self.num_bases, -1))
        return weight.view(self.num_outputs, *self.shape)


class JumpingKnowledge(nn.Module):
    r"""The Jumping Knowledge aggregation module from `Representation Learning on
    Graphs with Jumping Knowledge Networks <https://arxiv.org/abs/1806.03536>`__

    It aggregates the output representations of multiple GNN layers with

    **concatenation**

    .. math::

        h_i^{(1)} \, \Vert \, \ldots \, \Vert \, h_i^{(T)}

    or **max pooling**

    .. math::

        \max \left( h_i^{(1)}, \ldots, h_i^{(T)} \right)

    or **LSTM**

    .. math::

        \sum_{t=1}^T \alpha_i^{(t)} h_i^{(t)}

    with attention scores :math:`\alpha_i^{(t)}` obtained from a BiLSTM

    Parameters
    ----------
    mode : str
        The aggregation to apply. It can be 'cat', 'max', or 'lstm',
        corresponding to the equations above in order.
    in_feats : int, optional
        This argument is only required if :attr:`mode` is ``'lstm'``.
        The output representation size of a single GNN layer. Note that
        all GNN layers need to have the same output representation size.
    num_layers : int, optional
        This argument is only required if :attr:`mode` is ``'lstm'``.
        The number of GNN layers for output aggregation.

    Examples
    --------
    >>> import dgl
    >>> import torch as th
    >>> from dgl.nn import JumpingKnowledge

    >>> # Output representations of two GNN layers
    >>> num_nodes = 3
    >>> in_feats = 4
    >>> feat_list = [th.zeros(num_nodes, in_feats), th.ones(num_nodes, in_feats)]

    >>> # Case1
    >>> model = JumpingKnowledge()
    >>> model(feat_list).shape
    torch.Size([3, 8])

    >>> # Case2
    >>> model = JumpingKnowledge(mode='max')
    >>> model(feat_list).shape
    torch.Size([3, 4])

    >>> # Case3
    >>> model = JumpingKnowledge(mode='max', in_feats=in_feats, num_layers=len(feat_list))
    >>> model(feat_list).shape
    torch.Size([3, 4])
    """

    def __init__(self, mode="cat", in_feats=None, num_layers=None):
        super(JumpingKnowledge, self).__init__()
        assert mode in [
            "cat",
            "max",
            "lstm",
        ], "Expect mode to be 'cat', or 'max' or 'lstm', got {}".format(mode)
        self.mode = mode

        if mode == "lstm":
            assert in_feats is not None, "in_feats is required for lstm mode"
            assert (
                num_layers is not None
            ), "num_layers is required for lstm mode"
            hidden_size = (num_layers * in_feats) // 2
            self.lstm = nn.LSTM(
                in_feats, hidden_size, bidirectional=True, batch_first=True
            )
            self.att = nn.Linear(2 * hidden_size, 1)

    def reset_parameters(self):
        r"""

        Description
        -----------
        Reinitialize learnable parameters. This comes into effect only for the lstm mode.
        """
        if self.mode == "lstm":
            self.lstm.reset_parameters()
            self.att.reset_parameters()

    def forward(self, feat_list):
        r"""

        Description
        -----------
        Aggregate output representations across multiple GNN layers.

        Parameters
        ----------
        feat_list : list[Tensor]
            feat_list[i] is the output representations of a GNN layer.

        Returns
        -------
        Tensor
            The aggregated representations.
        """
        if self.mode == "cat":
            return th.cat(feat_list, dim=-1)
        elif self.mode == "max":
            return th.stack(feat_list, dim=-1).max(dim=-1)[0]
        else:
            # LSTM
            stacked_feat_list = th.stack(
                feat_list, dim=1
            )  # (N, num_layers, in_feats)
            alpha, _ = self.lstm(stacked_feat_list)
            alpha = self.att(alpha).squeeze(-1)  # (N, num_layers)
            alpha = th.softmax(alpha, dim=-1)
            return (stacked_feat_list * alpha.unsqueeze(-1)).sum(dim=1)


class LabelPropagation(nn.Module):
    r"""Label Propagation from `Learning from Labeled and Unlabeled Data with Label
    Propagation <http://mlg.eng.cam.ac.uk/zoubin/papers/CMU-CALD-02-107.pdf>`__

    .. math::

        \mathbf{Y}^{(t+1)} = \alpha \tilde{A} \mathbf{Y}^{(t)} + (1 - \alpha) \mathbf{Y}^{(0)}

    where unlabeled data is initially set to zero and inferred from labeled data via
    propagation. :math:`\alpha` is a weight parameter for balancing between updated labels
    and initial labels. :math:`\tilde{A}` denotes the normalized adjacency matrix.

    Parameters
    ----------
    k: int
        The number of propagation steps.
    alpha : float
        The :math:`\alpha` coefficient in range [0, 1].
    norm_type : str, optional
        The type of normalization applied to the adjacency matrix, must be one of the
        following choices:

        * ``row``: row-normalized adjacency as :math:`D^{-1}A`

        * ``sym``: symmetrically normalized adjacency as :math:`D^{-1/2}AD^{-1/2}`

        Default: 'sym'.
    clamp : bool, optional
        A bool flag to indicate whether to clamp the labels to [0, 1] after propagation.
        Default: True.
    normalize: bool, optional
        A bool flag to indicate whether to apply row-normalization after propagation.
        Default: False.
    reset : bool, optional
        A bool flag to indicate whether to reset the known labels after each
        propagation step. Default: False.

    Examples
    --------
    >>> import torch
    >>> import dgl
    >>> from dgl.nn import LabelPropagation

    >>> label_propagation = LabelPropagation(k=5, alpha=0.5, clamp=False, normalize=True)
    >>> g = dgl.rand_graph(5, 10)
    >>> labels = torch.tensor([0, 2, 1, 3, 0]).long()
    >>> mask = torch.tensor([0, 1, 1, 1, 0]).bool()
    >>> new_labels = label_propagation(g, labels, mask)
    """

    def __init__(
        self,
        k,
        alpha,
        norm_type="sym",
        clamp=True,
        normalize=False,
        reset=False,
    ):
        super(LabelPropagation, self).__init__()
        self.k = k
        self.alpha = alpha
        self.norm_type = norm_type
        self.clamp = clamp
        self.normalize = normalize
        self.reset = reset

    def forward(self, g, labels, mask=None):
        r"""Compute the label propagation process.

        Parameters
        ----------
        g : DGLGraph
            The input graph.
        labels : torch.Tensor
            The input node labels. There are three cases supported.

            * A LongTensor of shape :math:`(N, 1)` or :math:`(N,)` for node class labels in
              multiclass classification, where :math:`N` is the number of nodes.
            * A LongTensor of shape :math:`(N, C)` for one-hot encoding of node class labels
              in multiclass classification, where :math:`C` is the number of classes.
            * A LongTensor of shape :math:`(N, L)` for node labels in multilabel binary
              classification, where :math:`L` is the number of labels.
        mask : torch.Tensor
            The bool indicators of shape :math:`(N,)` with True denoting labeled nodes.
            Default: None, indicating all nodes are labeled.

        Returns
        -------
        torch.Tensor
            The propagated node labels of shape :math:`(N, D)` with float type, where :math:`D`
            is the number of classes or labels.
        """
        with g.local_scope():
            # multi-label / multi-class
            if len(labels.size()) > 1 and labels.size(1) > 1:
                labels = labels.to(th.float32)
            # single-label multi-class
            else:
                labels = F.one_hot(labels.view(-1)).to(th.float32)

            y = labels
            if mask is not None:
                y = th.zeros_like(labels)
                y[mask] = labels[mask]

            init = (1 - self.alpha) * y
            in_degs = g.in_degrees().float().clamp(min=1)
            out_degs = g.out_degrees().float().clamp(min=1)
            if self.norm_type == "sym":
                norm_i = th.pow(in_degs, -0.5).to(labels.device).unsqueeze(1)
                norm_j = th.pow(out_degs, -0.5).to(labels.device).unsqueeze(1)
            elif self.norm_type == "row":
                norm_i = th.pow(in_degs, -1.0).to(labels.device).unsqueeze(1)
            else:
                raise ValueError(
                    f"Expect norm_type to be 'sym' or 'row', got {self.norm_type}"
                )

            for _ in range(self.k):
                g.ndata["h"] = y * norm_j if self.norm_type == "sym" else y
                g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
                y = init + self.alpha * g.ndata["h"] * norm_i

                if self.clamp:
                    y = y.clamp_(0.0, 1.0)
                if self.normalize:
                    y = F.normalize(y, p=1)
                if self.reset:
                    y[mask] = labels[mask]

            return y


class LaplacianPosEnc(nn.Module):
    r"""Laplacian Positional Encoder (LPE), as introduced in
    `GraphGPS: General Powerful Scalable Graph Transformers
    <https://arxiv.org/abs/2205.12454>`__
    This module is a learned laplacian positional encoding module using Transformer or DeepSet.

    Parameters
    ----------
    model_type : str
        Encoder model type for LPE, can only be "Transformer" or "DeepSet".
    num_layer : int
        Number of layers in Transformer/DeepSet Encoder.
    k : int
        Number of smallest non-trivial eigenvectors.
    lpe_dim : int
        Output size of final laplacian encoding.
    n_head : int, optional
        Number of heads in Transformer Encoder.
        Default : 1.
    batch_norm : bool, optional
        If True, apply batch normalization on raw LaplacianPE.
        Default : False.
    num_post_layer : int, optional
        If num_post_layer > 0, apply an MLP of ``num_post_layer`` layers after pooling.
        Default : 0.

    Example
    -------
    >>> import dgl
    >>> from dgl import LaplacianPE
    >>> from dgl.nn import LaplacianPosEnc

    >>> transform = LaplacianPE(k=5, feat_name='eigvec', eigval_name='eigval', padding=True)
    >>> g = dgl.graph(([0,1,2,3,4,2,3,1,4,0], [2,3,1,4,0,0,1,2,3,4]))
    >>> g = transform(g)
    >>> EigVals, EigVecs = g.ndata['eigval'], g.ndata['eigvec']
    >>> TransformerLPE = LaplacianPosEnc(model_type="Transformer", num_layer=3, k=5,
                                         lpe_dim=16, n_head=4)
    >>> PosEnc = TransformerLPE(EigVals, EigVecs)
    >>> DeepSetLPE = LaplacianPosEnc(model_type="DeepSet", num_layer=3, k=5,
                                     lpe_dim=16, num_post_layer=2)
    >>> PosEnc = DeepSetLPE(EigVals, EigVecs)
    """

    def __init__(
        self,
        model_type,
        num_layer,
        k,
        lpe_dim,
        n_head=1,
        batch_norm=False,
        num_post_layer=0,
    ):
        super(LaplacianPosEnc, self).__init__()
        self.model_type = model_type
        self.linear = nn.Linear(2, lpe_dim)

        if self.model_type == "Transformer":
            encoder_layer = nn.TransformerEncoderLayer(
                d_model=lpe_dim, nhead=n_head, batch_first=True
            )
            self.pe_encoder = nn.TransformerEncoder(
                encoder_layer, num_layers=num_layer
            )
        elif self.model_type == "DeepSet":
            layers = []
            if num_layer == 1:
                layers.append(nn.ReLU())
            else:
                self.linear = nn.Linear(2, 2 * lpe_dim)
                layers.append(nn.ReLU())
                for _ in range(num_layer - 2):
                    layers.append(nn.Linear(2 * lpe_dim, 2 * lpe_dim))
                    layers.append(nn.ReLU())
                layers.append(nn.Linear(2 * lpe_dim, lpe_dim))
                layers.append(nn.ReLU())
            self.pe_encoder = nn.Sequential(*layers)
        else:
            raise ValueError(
                f"model_type '{model_type}' is not allowed, must be 'Transformer'"
                "or 'DeepSet'."
            )

        if batch_norm:
            self.raw_norm = nn.BatchNorm1d(k)
        else:
            self.raw_norm = None

        if num_post_layer > 0:
            layers = []
            if num_post_layer == 1:
                layers.append(nn.Linear(lpe_dim, lpe_dim))
                layers.append(nn.ReLU())
            else:
                layers.append(nn.Linear(lpe_dim, 2 * lpe_dim))
                layers.append(nn.ReLU())
                for _ in range(num_post_layer - 2):
                    layers.append(nn.Linear(2 * lpe_dim, 2 * lpe_dim))
                    layers.append(nn.ReLU())
                layers.append(nn.Linear(2 * lpe_dim, lpe_dim))
                layers.append(nn.ReLU())
            self.post_mlp = nn.Sequential(*layers)
        else:
            self.post_mlp = None

    def forward(self, EigVals, EigVecs):
        r"""
        Parameters
        ----------
        EigVals : Tensor
            Laplacian Eigenvalues of shape :math:`(N, k)`, k different eigenvalues repeat N times,
            can be obtained by using `LaplacianPE`.
        EigVecs : Tensor
            Laplacian Eigenvectors of shape :math:`(N, k)`, can be obtained by using `LaplacianPE`.

        Returns
        -------
        Tensor
            Return the laplacian positional encodings of shape :math:`(N, lpe_dim)`,
            where :math:`N` is the number of nodes in the input graph.
        """
        PosEnc = th.cat(
            (EigVecs.unsqueeze(2), EigVals.unsqueeze(2)), dim=2
        ).float()
        empty_mask = th.isnan(PosEnc)

        PosEnc[empty_mask] = 0
        if self.raw_norm:
            PosEnc = self.raw_norm(PosEnc)
        PosEnc = self.linear(PosEnc)

        if self.model_type == "Transformer":
            PosEnc = self.pe_encoder(
                src=PosEnc, src_key_padding_mask=empty_mask[:, :, 1]
            )
        else:
            PosEnc = self.pe_encoder(PosEnc)

        # Remove masked sequences
        PosEnc[empty_mask[:, :, 1]] = 0

        # Sum pooling
        PosEnc = th.sum(PosEnc, 1, keepdim=False)

        # MLP post pooling
        if self.post_mlp:
            PosEnc = self.post_mlp(PosEnc)

        return PosEnc
