import numpy as np
import scipy.linalg as spl
import scipy.sparse as sp
import torch
import torch_geometric as pyg


class LP(torch.nn.Module):
    def __init__(self, num_classes, alpha, sigma):
        super().__init__()
        self.alpha = alpha
        self.sigma = sigma
        self.num_classes = num_classes

    def _propagation_matrix(self, adj, alpha, sigma, nodes=None):
        """
        Computes the propagation matrix  (1-alpha)(I - alpha D^{-sigma} A D^{sigma-1})^{-1}.

        Parameters
        ----------
        adj : sp.spmatrix, shape [n, n]
            Sparse adjacency matrix.
        alpha : float
            (1-alpha) is the teleport probability.
        sigma
            Hyper-parameter controlling the propagation style.
            Set sigma=1 to obtain the PPR matrix.
        nodes : np.ndarray, shape [?]
            Nodes for which we want to compute Personalized PageRank.
        Returns
        -------
        prop_matrix : np.ndarray, shape [n, n]
            Propagation matrix.

        """
        n = adj.shape[0]
        deg = adj.sum(1).A1

        deg_min_sig = sp.diags(np.power(deg, -sigma))
        deg_sig_min = sp.diags(np.power(deg, sigma - 1))
        pre_inv = sp.eye(n) - alpha * deg_min_sig @ adj @ deg_sig_min

        # solve for x in: pre_inv @ x = b
        b = np.eye(n)
        if nodes is not None:
            b = b[:, nodes]

        return (1 - alpha) * spl.solve(pre_inv.toarray().T, b).T

    def fit(self, edge_index, **kwargs):
        adj = pyg.utils.to_scipy_sparse_matrix(edge_index)
        self.pi = torch.from_numpy(
            self._propagation_matrix(adj, alpha=self.alpha, sigma=self.sigma)
        ).to(torch.float32)

    def forward(self, y, mask):
        one_hot_map = torch.eye(self.num_classes)
        one_hot_labels = torch.zeros((y.shape[0], self.num_classes))
        one_hot_labels[mask] = one_hot_map[y[mask]]
        out = self.pi @ one_hot_labels
        return out
