from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn import MessagePassing
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch.nn import Parameter

from greatx.nn.layers import Sequential, activations
from greatx.utils import wrapper


class GPRProp(MessagePassing):
    r"""GPR propagation with PPR initialization from
    `"Adaptive Universal Generalized PageRank Graph Neural Network"
    <https://arxiv.org/abs/2011.09643>`_ (Chien et al. ICLR'21)

    Parameters
    ----------
    K : int
        Number of propagation steps.
    alpha : float
        Teleport probability for PPR initialization.
    norm : bool
        Whether to normalize the propagation coefficients.
    """

    def __init__(self, K: int, alpha: float = 0.1, norm: bool = False):
        super().__init__(aggr='add')
        self.K = K
        self.norm = norm

        temp = alpha * (1 - alpha) ** torch.arange(K + 1, dtype=torch.float)
        temp[-1] = (1 - alpha) ** K
        self.temp = Parameter(temp)

    def normalize_coefficients(self):
        temp = torch.sign(self.temp) * torch.softmax(self.temp.abs(), dim=0)
        return temp

    def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None):
        edge_index, edge_weight = gcn_norm(edge_index, edge_weight,
                                           num_nodes=x.size(0),
                                           add_self_loops=True,
                                           dtype=x.dtype)

        temp = self.normalize_coefficients() if self.norm else self.temp

        out = x * temp[0]
        for k in range(1, self.K + 1):
            x = self.propagate(edge_index, x=x, norm=edge_weight)
            out += temp[k] * x
        return out

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j


class GPRGNN(nn.Module):
    r"""GPR-GNN with PPR initialization from
    `"Adaptive Universal Generalized PageRank Graph Neural Network"
    <https://arxiv.org/abs/2011.09643>`_

    Parameters
    ----------
    in_channels : int
        Input feature dimension.
    out_channels : int
        Output dimension (number of classes).
    hids : List[int]
        Hidden layer dimensions.
    acts : List[str]
        Activation functions for each layer.
    K : int
        Number of propagation steps.
    alpha : float
        Teleport probability for PPR initialization.
    dropout : float
        Dropout probability for MLP.
    norm : bool
        Whether to normalize the propagation coefficients.
    """

    @wrapper
    def __init__(self, in_channels: int, out_channels: int,
                 hids: List[int] = [64], acts: List[str] = ['relu'],
                 K: int = 10, alpha: float = 0.1,
                 dropout: float = 0.5, norm: bool = False):

        super().__init__()

        layers = []
        assert len(hids) == len(acts)
        for hid, act in zip(hids, acts):
            layers.append(nn.Linear(in_channels, hid))
            layers.append(activations.get(act))
            layers.append(nn.Dropout(dropout))
            in_channels = hid
        layers.append(nn.Linear(in_channels, out_channels))
        self.mlp = Sequential(*layers)

        self.prop = GPRProp(K=K, alpha=alpha, norm=norm)

    def reset_parameters(self):
        self.mlp.reset_parameters()
        self.prop.reset_parameters()

    def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None) -> Tensor:
        x = self.mlp(x)
        return self.prop(x, edge_index, edge_weight)
