from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn

from torch_geometric.nn.aggr import Aggregation, MultiAggregation
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.typing import Adj, OptPairTensor, Size, SparseTensor, OptTensor
from torch_geometric.utils import spmm, degree

from graphgps.layer.utils import dag_degree


class DAGSAGEConv(MessagePassing):
    r"""The GraphSAGE operator from the `"Inductive Representation Learning on
    Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper.

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot
        \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j

    If :obj:`project = True`, then :math:`\mathbf{x}_j` will first get
    projected via

    .. math::
        \mathbf{x}_j \leftarrow \sigma ( \mathbf{W}_3 \mathbf{x}_j +
        \mathbf{b})

    as described in Eq. (3) of the paper.

    Args:
        in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
            derive the size from the first input(s) to the forward method.
            A tuple corresponds to the sizes of source and target
            dimensionalities.
        out_channels (int): Size of each output sample.
        aggr (str or Aggregation, optional): The aggregation scheme to use.
            Any aggregation of :obj:`torch_geometric.nn.aggr` can be used,
            *e.g.*, :obj:`"mean"`, :obj:`"max"`, or :obj:`"lstm"`.
            (default: :obj:`"mean"`)
        normalize (bool, optional): If set to :obj:`True`, output features
            will be :math:`\ell_2`-normalized, *i.e.*,
            :math:`\frac{\mathbf{x}^{\prime}_i}
            {\| \mathbf{x}^{\prime}_i \|_2}`.
            (default: :obj:`False`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        project (bool, optional): If set to :obj:`True`, the layer will apply a
            linear transformation followed by an activation function before
            aggregation (as described in Eq. (3) of the paper).
            (default: :obj:`False`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.

    Shapes:
        - **inputs:**
          node features :math:`(|\mathcal{V}|, F_{in})` or
          :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))`
          if bipartite,
          edge indices :math:`(2, |\mathcal{E}|)`
        - **outputs:** node features :math:`(|\mathcal{V}|, F_{out})` or
          :math:`(|\mathcal{V_t}|, F_{out})` if bipartite
    """

    def __init__(
            self,
            in_channels: Union[int, Tuple[int, int]],
            out_channels: int,
            share_init: bool = False,
            root_weight: bool = True,
            bias: bool = True,
            **kwargs,
    ):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.root_weight = root_weight
        self.share_init = share_init

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        super().__init__('mean', **kwargs)

        aggr_out_channels = in_channels[0]

        self.lin_l = Linear(aggr_out_channels, out_channels * 3, bias=False)
        if self.root_weight:
            self.lin_r = Linear(in_channels[1], out_channels, bias=bias)

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        self.lin_l.reset_parameters()
        a = nn.init.calculate_gain('relu') * np.sqrt(6 / (self.in_channels + self.out_channels))  # * np.sqrt(1/3)
        torch.nn.init.uniform_(self.lin_l.weight.data, -a, a)
        if self.share_init:
            self.lin_l.weight = torch.nn.Parameter(torch.cat(
                [self.lin_l.weight[:self.out_channels], self.lin_l.weight[:self.out_channels],
                 self.lin_l.weight[:self.out_channels]], dim=0))
        if self.root_weight:
            self.lin_r.reset_parameters()
        a = nn.init.calculate_gain('relu') * np.sqrt(6 / (self.in_channels + self.out_channels))  # * np.sqrt(1/3)
        torch.nn.init.uniform_(self.lin_r.weight.data, -a, a)

    def forward(
            self,
            x: Union[Tensor, OptPairTensor],
            edge_index: Adj,
            ordering: Optional[Tensor] = None,
            edge_weight: OptTensor = None,
            size: Size = None,
    ) -> Tensor:
        if size is None:
            size = [x.size(0),x.size(0)]

        x_l = self.lin_l(x)
        edge_indices, _ = dag_degree(edge_index, 3, size[0])

        edge_index_cat = torch.cat(edge_indices, dim=1)
        # propagate_type: (x: OptPairTensor)
        x_j = torch.cat((x_l[:, :self.out_channels][edge_indices[0][0]],
                         x_l[:, self.out_channels:2 * self.out_channels][edge_indices[1][0]],
                         x_l[:, 2 * self.out_channels:][edge_indices[2][0]]
                         ), dim=0)
        out = self.propagate(edge_index_cat, x_j=x_j, size=size)

        if self.root_weight:
            out = out + self.lin_r(x)

        return out

    def _collect(self, args, edge_index, size, kwargs):
        out = {'x_j': kwargs.get('x_j'),
               edge_index: edge_index,
               'size': size,
               'size_i': size[0],
               'size_j': size[0],
               'dim_size': size[0],
               'edge_weight': kwargs.get('edge_weight'),
               'index': edge_index[1]}
        return out

    def message(self, x_j: Tensor) -> Tensor:
        return x_j

    def message_and_aggregate(self, adj_t: Adj, x: Tensor) -> Tensor:
        if isinstance(adj_t, SparseTensor):
            adj_t = adj_t.set_value(None, layout=None)
        return spmm(adj_t, x, reduce=self.aggr)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, aggr={self.aggr})')
