import torch
from jaxtyping import Float
from torch import Tensor, nn
from torch_geometric.data import Batch

from src.utils.adjacency_utils import get_adj, get_adj_row_indices


class RandomWalkPE(nn.Module):
    """Random Walk Positional Encoding."""

    def __init__(self, k: int, n: int) -> None:
        """Initialize Random Walk Positional Encoding.

        :param k: Number of random walk steps.
        :param n: Maximal number of nodes in one graph.
        """
        super().__init__()

        self.k = k
        self.n = n

    @property
    def d(self) -> int:
        """Return the dimension of random walk positional encoding.

        :return: Dimension of random walk positional encoding.
        """
        return int(self.k * self.n)

    def forward(self, data: Batch) -> Float[Tensor, "n_nodes n*k"]:
        """Compute random walk positional encoding.

        :param data: PyG batch object.
        :return: The random walk positional encoding for each node.
        """
        adj = get_adj(data, self.n)
        batch, node = get_adj_row_indices(data)
        rw_probabilities = self.k_step_random_walk(adj, self.k)
        return rw_probabilities[batch, node, :]

    @staticmethod
    def k_step_random_walk(adj: Tensor, k: int) -> Float[Tensor, "b n n*k"]:
        """Returns concatenated random walk matrices for k in {1, ..., k}.

        :param adj: Batched adjacency matrix of shape [b, n, n].
        :param k: Number of random walk steps.
        :return: Tensor of shape [b, n, n*k] containing the random walk probabilities.
        """
        rw_probabilities = adj
        output = [rw_probabilities]
        for _ in range(k - 1):
            rw_probabilities = torch.bmm(rw_probabilities, adj)
            output.append(rw_probabilities)
        output = torch.concat(output, axis=-1)
        return output
