from typing import Optional

import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.transforms import BaseTransform

from greatx.functional import spmm


class FeaturePropagation(BaseTransform):
    r"""Implementation of FeaturePropagation
    from the `"On the Unreasonable Effectiveness
    of Feature propagation in Learning
    on Graphs with Missing Node Features"
    <https://openreview.net/forum?id=qe_qOarxjg>`_
    paper (Log'22)

    Parameters
    ----------
    num_iterations : int, optional
        number of iterations to run, by default 40
    missing_mask : Optional[Tensor], optional
        mask on missing features, by default None
    normalize : bool, optional
        whether to compute symmetric normalization
        coefficients on the fly, by default True
    add_self_loops : bool, optional
        whether to add self-loops to the input graph,
        by default True

    Example
    -------
    .. code-block:: python

        data = ... # PyG-like data
        data = FeaturePropagation(num_iterations=40)(data)

        # missing_mask is a mask `[num_nodes, num_features]`
        # indicating where the feature is missing
        data = FeaturePropagation(missing_mask=missing_mask)(data)

    See also
    --------
    :class:`torch_geometric.transforms.FeaturePropagation`


    Reference:

    https://github.com/twitter-research/feature-propagation

    """
    def __init__(self, missing_mask: Optional[Tensor] = None,
                 num_iterations: int = 40, normalize: bool = True):
        super().__init__()
        self.missing_mask = missing_mask
        self.num_iterations = num_iterations
        self.normalize = normalize

    def __call__(self, data: Data) -> Data:

        # out is inizialized to 0 for missing values.
        # However, its initialization does not
        # matter for the final value at convergence
        x = data.x
        known_feature_mask = None
        missing_mask = self.missing_mask
        if missing_mask is not None:
            out = torch.zeros_like(x)
            known_feature_mask = ~missing_mask
            out[known_feature_mask] = x[known_feature_mask]
        else:
            out = x.clone()

        edge_index, edge_weight = data.edge_index, data.edge_weight

        if self.normalize:
            edge_index, edge_weight = gcn_norm(edge_index, edge_weight,
                                               x.size(0), improved=False,
                                               add_self_loops=False,
                                               dtype=x.dtype)
        for _ in range(self.num_iterations):
            # Diffuse current features
            out = spmm(out, edge_index, edge_weight)
            if known_feature_mask is not None:
                # Reset original known features
                out[known_feature_mask] = x[known_feature_mask]
        data.x = out

        return data
