"""
Random walker module for simplicial complexes.

Notes
-----
All functions and classes in this module are JIT compiled using TorchScript for
additional performance benefits.
"""
import torch

from scrawl.simplicial import SimplicialComplex, SimplicialData


@torch.jit.script
def structural_features(
    window_size: int,
    walk_simplices: torch.Tensor,
    lower_adjacency: torch.Tensor,
    upper_adjacency: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute the structural features (identity and connectivity) for a given walk.

    Parameters
    ----------
    window_size : int
        The local window size (i.e., matrix width) for the structural features.
    walk_simplices : torch.Tensor
        Indices of the simplices in the walk.
    lower_adjacency : torch.Tensor
        Lower adjacency matrix of the simplicial complex.
    upper_adjacency : torch.Tensor
        Upper adjacency matrix of the simplicial complex.

    Returns
    -------
    identity_features : torch.Tensor
        Identity features of the walk.
    lower_connectivity_features : torch.Tensor
        Lower connectivity features of the walk.
    upper_connectivity_features : torch.Tensor
        Upper connectivity features of the walk.
    """
    walk_simplices_two = torch.hstack((torch.full((window_size,), -1), walk_simplices))
    window_indices = torch.arange(walk_simplices.size(0))[:, None] + torch.arange(
        window_size + 1
    )

    identity_features = torch.eq(
        walk_simplices_two[window_indices[:, :-1]], walk_simplices[:, None]
    )

    lower_connectivity_features = lower_adjacency[
        walk_simplices_two[window_indices[:, :-2]], walk_simplices[:, None]
    ]
    lower_connectivity_features[window_indices[:, :-2] < window_size] = 0

    upper_connectivity_features = upper_adjacency[
        walk_simplices_two[window_indices[:, :-2]], walk_simplices[:, None]
    ]
    upper_connectivity_features[window_indices[:, :-2] < window_size] = 0

    return identity_features, lower_connectivity_features, upper_connectivity_features


@torch.jit.script
def feature_matrix(
    rank: int,
    walk_simplices: torch.Tensor,
    connection_simplices: torch.Tensor,
    connection_types: torch.Tensor,
    data: SimplicialData,
    identity_features: torch.Tensor,
    lower_connectivity_features: torch.Tensor,
    upper_connectivity_features: torch.Tensor,
    local_window_size: int,
    lower_feature_size: int,
    upper_feature_size: int,
) -> torch.Tensor:
    """
    Compute the feature matrix for a given walk.

    Parameters
    ----------
    rank : int
        The rank of the simplices in the random walk.
    walk_simplices : torch.Tensor
        Indices of the simplices in the walk.
    connection_simplices : torch.Tensor
        Indices of the connection used in each walk step.
    connection_types : torch.Tensor
        Direction of the connection (i.e., upper or lower) for each step in the
        walk.
    data : SimplicialData
        Simplicial complex and associated data.
    identity_features : torch.Tensor
        Identity features of the walk.
    lower_connectivity_features : torch.Tensor
        Lower connectivity features of the walk.
    upper_connectivity_features : torch.Tensor
        Upper connectivity features of the walk.
    local_window_size : int
        The local window size for the structural features.
    lower_feature_size : int
        Size of the features on lower simplices.
    upper_feature_size : int
        Size of the features on upper simplices.

    Returns
    -------
    torch.Tensor
        The feature matrix describing the given walk.

    See Also
    --------
    structural_features
        Function to compute values for `identity_features`,
        `lower_connectivity_features`, and `upper_connectivity_features`.
    """
    lower_features_start = data[rank].size(1)
    upper_features_start = lower_features_start + lower_feature_size
    identity_features_start = upper_features_start + upper_feature_size
    lower_connectivity_start = identity_features_start + local_window_size
    upper_connectivity_start = lower_connectivity_start + local_window_size - 1
    total = upper_connectivity_start + local_window_size - 1

    feature_matrix = torch.zeros(
        (
            len(walk_simplices),
            total,
        ),
        dtype=data.dtype,
        device=data.device,
    )

    feature_matrix[:, 0:lower_features_start] = data[rank][walk_simplices]

    if rank > 0:
        feature_matrix[
            connection_types == -1, lower_features_start:upper_features_start
        ] = data[rank - 1][connection_simplices[connection_types == -1]]

    if rank < data.domain.dim:
        feature_matrix[
            connection_types == 1,
            upper_features_start:identity_features_start,
        ] = data[rank + 1][connection_simplices[connection_types == 1]]

    feature_matrix[
        :, identity_features_start:lower_connectivity_start
    ] = identity_features[:, 0:local_window_size]
    feature_matrix[
        :, lower_connectivity_start:upper_connectivity_start
    ] = lower_connectivity_features[:, 0 : local_window_size - 1]
    feature_matrix[:, upper_connectivity_start:] = upper_connectivity_features[
        :, 0 : local_window_size - 1
    ]

    return feature_matrix


@torch.jit.script
class RandomWalk:
    """
    Representation of a random walk on a simplicial complex.

    Walks are supported on simplices of some fixed rank `k` and they can
    transition over upper and lower connections of rank `k+1` and `k-1`,
    respectively.

    Parameters
    ----------
    simplicial_complex : SimplicialComplex
        The simplicial complex on which the random walk is performed.
    rank : int
        The rank of the simplices in the random walk.
    walk_simplices : torch.Tensor
        Indices of the simplices in the walk.
    connection_direction : torch.Tensor
        Direction of the connection (i.e., upper or lower) for each step in the
        walk.
    connection_simplices : torch.Tensor
        Indices of the simplices connected to the walk simplices.
    window_size : int
        The local window size (i.e., matrix width) for the structural features.
    """

    simplicial_complex: SimplicialComplex
    rank: int
    window_size: int

    __walk_simplices: torch.Tensor
    __connection_simplices: torch.Tensor
    __connection_types: torch.Tensor

    _identity_features: torch.Tensor
    _lower_connectivity_features: torch.Tensor
    _upper_connectivity_features: torch.Tensor

    def __init__(
        self,
        simplicial_complex: SimplicialComplex,
        rank: int,
        walk_simplices: torch.Tensor,
        connection_direction: torch.Tensor,
        connection_simplices: torch.Tensor,
        window_size: int,
    ) -> None:
        """
        Representation of a random walk on a simplicial complex.

        Walks are supported on simplices of some fixed rank `k` and they can
        transition over upper and lower connections of rank `k+1` and `k-1`,
        respectively.

        Parameters
        ----------
        simplicial_complex : SimplicialComplex
            The simplicial complex on which the random walk is performed.
        rank : int
            The rank of the simplices in the random walk.
        walk_simplices : torch.Tensor
            Indices of the simplices in the walk.
        connection_direction : torch.Tensor
            Direction of the connection (i.e., upper or lower) for each step in the
            walk.
        connection_simplices : torch.Tensor
            Indices of the simplices connected to the walk simplices.
        window_size : int
            The local window size (i.e., matrix width) for the structural features.
        """
        self.simplicial_complex = simplicial_complex
        self.rank = rank
        self.window_size = window_size

        self.__walk_simplices = walk_simplices
        self.__connection_simplices = connection_simplices
        self.__connection_types = connection_direction

        # Compute identity and connectivity features for this walk with respect
        # to the maximal window size. This improves efficiency as we don't need
        # to re-compute them for every layer in the neural network.
        (
            self._identity_features,
            self._lower_connectivity_features,
            self._upper_connectivity_features,
        ) = structural_features(
            self.window_size,
            self.__walk_simplices,
            simplicial_complex.lower_adjacency[rank],
            simplicial_complex.upper_adjacency[rank],
        )

    @staticmethod
    def walk_feature_size(
        window_size: int,
        feature_size: int,
        lower_feature_size: int,
        upper_feature_size: int,
    ) -> int:
        """
        Size of the walk feature matrix for given window and feature sizes.

        Parameters
        ----------
        window_size : int
            The local window size for the structural features.
        feature_size : int
            The size of the features on the simplices in the walk.
        lower_feature_size : int
            The size of the features on the lower simplices.
        upper_feature_size : int
            The size of the features on the upper simplices.

        Returns
        -------
        int
            The size of the walk feature matrix.
        """
        return (
            feature_size + lower_feature_size + upper_feature_size + 3 * window_size - 2
        )

    def __len__(self) -> int:
        """
        The length of this random walk.

        Returns
        -------
        int
            The length of this random walk.
        """
        return len(self.__walk_simplices)

    def feature_matrix(
        self,
        data: SimplicialData,
        local_window_size: int,
        lower_feature_size: int,
        upper_feature_size: int,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Compute the feature matrix for this random walk.

        The returned tensors life on the same device as the tensors in `data`.

        Parameters
        ----------
        data : SimplicialData
            Simplicial complex and associated data for this random walk.
        local_window_size : int
            The local window size for the structural features.
        lower_feature_size : int
            The size of the features on the lower simplices.
        upper_feature_size : int
            The size of the features on the upper simplices.

        Returns
        -------
        indices : torch.Tensor
            Indices of the simplices in the walk.
        feature_matrix : torch.Tensor
            The feature matrix describing this random walk.
        """
        self._identity_features = self._identity_features.to(data.device)
        self._lower_connectivity_features = self._lower_connectivity_features.to(
            data.device
        )
        self._upper_connectivity_features = self._upper_connectivity_features.to(
            data.device
        )

        return self.__walk_simplices.to(data.device), feature_matrix(
            self.rank,
            self.__walk_simplices,
            self.__connection_simplices,
            self.__connection_types,
            data,
            self._identity_features,
            self._lower_connectivity_features,
            self._upper_connectivity_features,
            local_window_size,
            lower_feature_size,
            upper_feature_size,
        )


@torch.jit.script
class Walker:
    """
    Random walker for simplicial complexes.

    Parameters
    ----------
    simplicial_complex : SimplicialComplex
        The simplicial complex on which the random walk is performed.
    use_lower_connections : bool
        Whether to use lower connections in the random walk.
    use_upper_connections : bool
        Whether to use upper connections in the random walk.
    max_rank : int
        The maximum rank of the simplices in the random walk.

    Raises
    ------
    ValueError
        If `use_lower_connections` and `use_upper_connections` are both `False`.
    """

    __num_connections: dict[int, torch.Tensor]
    __connection_offsets: dict[int, torch.Tensor]
    __connection_indices: dict[int, torch.Tensor]
    __num_neighbors: dict[int, torch.Tensor]
    __neighbor_offsets: dict[int, torch.Tensor]
    __neighbor_indices: dict[int, torch.Tensor]

    simplicial_complex: SimplicialComplex
    use_lower_connections: bool
    use_upper_connections: bool

    def __init__(
        self,
        simplicial_complex: SimplicialComplex,
        use_lower_connections: bool,
        use_upper_connections: bool,
        max_rank: int,
    ) -> None:
        """
        Random walker for simplicial complexes.

        Parameters
        ----------
        simplicial_complex : SimplicialComplex
            The simplicial complex on which the random walk is performed.
        use_lower_connections : bool
            Whether to use lower connections in the random walk.
        use_upper_connections : bool
            Whether to use upper connections in the random walk.
        max_rank : int
            The maximum rank of the simplices in the random walk.

        Raises
        ------
        ValueError
            If `use_lower_connections` and `use_upper_connections` are both `False`.
        """
        self.simplicial_complex = simplicial_complex
        self.use_lower_connections = use_lower_connections
        self.use_upper_connections = use_upper_connections

        if not use_lower_connections and not use_upper_connections:
            raise ValueError(
                "At least one of use_lower_connections and use_upper_connections must be True."
            )

        self.__num_connections: dict[int, torch.Tensor] = {}
        self.__connection_offsets: dict[int, torch.Tensor] = {}
        self.__connection_indices: dict[int, torch.Tensor] = {}

        self.__num_neighbors: dict[int, torch.Tensor] = {}
        self.__neighbor_offsets: dict[int, torch.Tensor] = {}
        self.__neighbor_indices: dict[int, torch.Tensor] = {}

        for rank in range(max_rank + 1):
            if not self.use_lower_connections or rank == 0:
                sampling_matrix = simplicial_complex.boundary[rank + 1]
            elif not self.use_upper_connections or rank == simplicial_complex.dim:
                sampling_matrix = simplicial_complex.boundary[rank].T
            else:
                sampling_matrix = torch.hstack(
                    (
                        simplicial_complex.boundary[rank].T,
                        simplicial_complex.boundary[rank + 1],
                    )
                )

            connections = [m.nonzero().squeeze(1) for m in sampling_matrix]
            neighbors = [m.nonzero().squeeze(1) for m in sampling_matrix.T]

            self.__num_connections[rank] = torch.tensor(
                [len(m) for m in connections], dtype=torch.int
            )
            self.__connection_offsets[rank] = torch.tensor(
                [0] + [len(m) for m in connections]
            ).cumsum(dim=0)
            self.__connection_indices[rank] = torch.hstack(connections)

            self.__num_neighbors[rank] = torch.tensor(
                [len(m) for m in neighbors], dtype=torch.int
            )
            self.__neighbor_offsets[rank] = torch.tensor(
                [0] + [len(m) for m in neighbors]
            ).cumsum(dim=0)
            self.__neighbor_indices[rank] = torch.hstack(neighbors)

    def random_walk(self, rank: int, start_index: int, steps: int) -> RandomWalk:
        """
        Sample a random walk from the given start simplex.

        Parameters
        ----------
        rank : int
            The rank of the simplices on which to perform the random walk.
        start_index : int
            Index of the simplex from which to start the random walk.
        steps : int
            Number of steps to take for this random walk. The actual number of
            steps may be less than this value if the random walk is stuck and
            ran out of neighbors to choose from.

        Returns
        -------
        RandomWalk
            The sampled random walk.
        """
        return self.random_walks(rank, torch.tensor([start_index]), steps)[0]

    def random_walks(
        self, rank: int, start_indices: torch.Tensor, steps: int
    ) -> list[RandomWalk]:
        """
        Sample random walks from the given start simplices.

        Parameters
        ----------
        rank : int
            The rank of the simplices on which to perform the random walk.
        start_indices : list of ints
            Indices of the simplices from which to start the random walks.
        steps : int
            Number of steps to take for this random walk. The actual number of
            steps may be less than this value if the random walk is stuck and
            ran out of neighbors to choose from.

        Returns
        -------
        list[RandomWalk]
            TODO: Add description.
        """
        # for performance benefits, compute all random values at once
        random_values = torch.randint(
            634734743643574, (len(start_indices), steps * 2), dtype=torch.int64
        )

        (
            walk_simplices,
            connection_direction,
            connection_simplices,
        ) = self._parallel_walks(rank, start_indices, steps, random_values)

        return [
            RandomWalk(
                self.simplicial_complex,
                rank,
                walk_simplices[i],
                connection_direction[i],
                connection_simplices[i],
                window_size=16,
            )
            for i in range(len(start_indices))
        ]

    def _parallel_walks(
        self, rank: int, indices: torch.Tensor, steps: int, random_values: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Perform random walks in parallel on a set of start simplices.

        Parameters
        ----------
        rank : int
            The rank of the simplices on which to perform the random walk.
        indices : torch.Tensor
            Indices of the simplices from which to start the random walks.
        steps : int
            Number of steps to take for this random walk. The actual number of
            steps may be less than this value if the random walk is stuck and
            ran out of neighbors to choose from.
        random_values : torch.Tensor
            Random values used to sample the random walks.

        Returns
        -------
        walk_simplices : torch.Tensor
            Indices of the simplices in the walks.
        connection_direction : torch.Tensor
            Direction of the connection (i.e., upper or lower) for each step in
            the walks.
        connection_simplices : torch.Tensor
            Indices of the simplices connected to the walk simplices.
        """
        if rank < 0 or rank > self.simplicial_complex.dim:
            raise ValueError(
                f"Order must be between 0 and the maximum dimension {self.simplicial_complex.dim}."
            )
        if torch.any(indices >= self.simplicial_complex.shape[rank]):
            raise ValueError("Simplex is not in the simplicial complex.")

        if random_values.shape != (len(indices), steps * 2):
            raise ValueError(
                f"Expected random values of shape {(len(indices), steps * 2)}, got {random_values.shape}."
            )

        num_lower = self.simplicial_complex.shape[rank - 1] if rank > 0 else 0

        walk_simplices = torch.empty((len(indices), steps + 1), dtype=torch.int64)
        connection_direction = torch.zeros((len(indices), steps + 1), dtype=torch.int64)
        connection_simplices = torch.full(
            (len(indices), steps + 1), -1, dtype=torch.int64
        )

        walk_simplices[:, 0] = indices

        non_isolated_simplices = self.__num_connections[rank][indices] > 0

        for i in range(steps):
            connection_indices = self.__connection_indices[rank][
                self.__connection_offsets[rank][indices[non_isolated_simplices]]
                + (
                    random_values[non_isolated_simplices, 2 * i]
                    % self.__num_connections[rank][indices[non_isolated_simplices]]
                )
            ]

            indices[non_isolated_simplices] = self.__neighbor_indices[rank][
                self.__neighbor_offsets[rank][connection_indices]
                + (
                    random_values[non_isolated_simplices, 2 * i + 1]
                    % self.__num_neighbors[rank][connection_indices]
                )
            ]

            walk_simplices[:, i + 1] = indices
            connection_direction[non_isolated_simplices, i + 1] = (
                connection_indices >= num_lower
            ).to(torch.int64) - (connection_indices < num_lower).to(torch.int64)

            connection_indices[connection_indices >= num_lower] -= num_lower
            connection_simplices[non_isolated_simplices, i + 1] = connection_indices

        return walk_simplices, connection_direction, connection_simplices
