"""
Module for data classes related to simplicial complexes.

Notes
-----
All functions and classes in this module are JIT compiled using TorchScript for
additional performance benefits.
"""
import torch


@torch.jit.script
class SimplicialComplex:
    """
    Data class for a simplicial complex.

    The simplicial complex is represented by its boundary matrices.

    Parameters
    ----------
    boundary_matrices : dict[int, torch.Tensor]
        The boundary matrices of the simplicial complex.
    """

    def __init__(self, boundary_matrices: dict[int, torch.Tensor]) -> None:
        """
        Data class for a simplicial complex.

        The simplicial complex is represented by its boundary matrices.

        Parameters
        ----------
        boundary_matrices : dict[int, torch.Tensor]
            The boundary matrices of the simplicial complex.
        """
        self.boundary = boundary_matrices
        self.shape = [self.boundary[1].shape[0]] + [
            b.shape[1] for b in self.boundary.values()
        ]

        # pre-compute lower and upper adjacency matrices
        self.lower_adjacency = []
        self.upper_adjacency = []
        for rank in range(0, len(self.boundary) + 1):
            self.lower_adjacency.append(self._lower_adjacency(rank))
            self.upper_adjacency.append(self._upper_adjacency(rank))

    @property
    def dim(self) -> int:
        """
        Rank of the largest simplex in the simplicial complex.

        Returns
        -------
        int
            The rank of the largest simplex in the simplicial complex.
        """
        return len(self.boundary)

    def down_laplacian_matrix(self, rank: int) -> torch.Tensor:
        """
        Return the down laplacian matrix of given rank.

        Parameters
        ----------
        rank : int
            The rank of the down laplacian matrix.

        Returns
        -------
        torch.Tensor, shape = [n, n]
            The down laplacian matrix of rank `rank`.
        """
        if rank == 0:
            return torch.zeros(
                (self.boundary[1].size(0), self.boundary[1].size(0)),
                dtype=torch.float,
            )

        return self.boundary[rank].transpose(0, 1) @ self.boundary[rank]

    def up_laplacian_matrix(self, rank: int) -> torch.Tensor:
        """
        Return the up laplacian matrix of given rank.

        Parameters
        ----------
        rank : int
            The rank of the up laplacian matrix.

        Returns
        -------
        torch.Tensor, shape = [n, n]
            The up laplacian matrix of rank `rank`.
        """
        if rank == len(self.boundary):
            return torch.zeros(
                (self.boundary[rank].size(1), self.boundary[rank].size(1)),
                dtype=torch.float,
            )

        return self.boundary[rank + 1] @ self.boundary[rank + 1].transpose(0, 1)

    def _lower_adjacency(self, rank: int) -> torch.Tensor:
        """
        Compute the lower adjacency matrix of given rank.

        Parameters
        ----------
        rank : int
            The rank of the lower adjacency matrix.

        Returns
        -------
        torch.Tensor, shape = [n, n]
            The lower adjacency matrix of rank `rank`.
        """
        if rank == 0:
            return torch.zeros(
                (self.boundary[1].size(0), self.boundary[1].size(0)),
                dtype=torch.float,
            )

        lower_adjacency = self.boundary[rank].transpose(0, 1) @ self.boundary[rank]
        lower_adjacency[lower_adjacency > 0] = 1
        return lower_adjacency

    def _upper_adjacency(self, rank: int) -> torch.Tensor:
        """
        Compute the upper adjacency matrix of given rank.

        Parameters
        ----------
        rank : int
            The rank of the upper adjacency matrix.

        Returns
        -------
        torch.Tensor, shape = [n, n]
            The upper adjacency matrix of rank `rank`.
        """
        if rank == len(self.boundary):
            return torch.zeros(
                (self.boundary[rank].size(1), self.boundary[rank].size(1)),
                dtype=torch.float,
            )

        upper_adjacency = self.boundary[rank + 1] @ self.boundary[rank + 1].transpose(
            0, 1
        )
        upper_adjacency[upper_adjacency > 0] = 1
        return upper_adjacency


@torch.jit.script
class SimplicialData:
    """
    Container for simplicial complexes with associated cochains on the simplices.

    Parameters
    ----------
    simplicial_complex : SimplicialComplex
        The simplicial complex.
    dtype : torch.dtype
        The dtype of the data supported on the simplicial complexes.
    device : torch.device
        Device on which to store the data.
    """

    def __init__(
        self,
        simplicial_complex: SimplicialComplex,
        dtype: torch.dtype,
        device: torch.device,
    ):
        """
        Container for simplicial complexes with associated cochains on the simplices.

        Parameters
        ----------
        simplicial_complex : SimplicialComplex
            The simplicial complex.
        dtype : torch.dtype
            The dtype of the data supported on the simplicial complexes.
        device : torch.device
            Device on which to store the data.
        """
        self.domain = simplicial_complex
        self.dtype = dtype
        self.device = device

        self._cochains: dict[int, torch.Tensor] = {}
        self._aux_tensors: dict[int, torch.Tensor] = {}
        self.shape = [0 for _ in range(self.domain.dim + 1)]

    def aux_tensor(self, rank: int) -> torch.Tensor:
        """
        Return the auxillary tensor of given rank.

        Parameters
        ----------
        rank : int
            The rank of the auxillary tensor.

        Returns
        -------
        torch.Tensor, shape = [n, f]
            The auxillary tensor of rank `rank`.

        Raises
        ------
        RuntimeError
            If the simplicial data has no auxillary tensor of given rank.
        """
        if rank not in self._aux_tensors:
            raise RuntimeError(f"No aux tensor for rank {rank}.")
        return self._aux_tensors[rank]

    def has_aux_tensor(self, rank: int) -> bool:
        """
        Return whether the simplicial data has an auxillary tensor of given rank.

        Parameters
        ----------
        rank : int
            The rank of the auxillary tensor.

        Returns
        -------
        bool
            Whether the simplicial data has an auxillary tensor of given rank.
        """
        return rank in self._aux_tensors

    def set_aux_tensor(self, rank: int, tensor: torch.Tensor) -> None:
        """
        Set the auxillary tensor of given rank.

        Parameters
        ----------
        rank : int
            The rank of the auxillary tensor.
        tensor : torch.Tensor, shape = [n, f]
            The auxillary tensor of rank `rank` to set.

        Raises
        ------
        ValueError
            If the rank is larger than the dimension of the simplicial complex.
        """
        if self.device != tensor.device:
            raise ValueError(
                f"Input tensor must be on the same device as the other tensors, got {tensor.device} and expected {self.device}."
            )

        self._aux_tensors[rank] = tensor

    def __getitem__(self, rank: int) -> torch.Tensor:
        """
        Return the cochain of given rank.

        Parameters
        ----------
        rank : int
            The rank of the cochain.

        Returns
        -------
        torch.Tensor, shape = [n, f]
            The cochain of rank `rank`.
        """
        if rank in self._cochains:
            return self._cochains[rank]
        return torch.zeros(
            (self.domain.shape[rank], 0), dtype=self.dtype, device=self.device
        )

    def __setitem__(self, rank: int, tensor: torch.Tensor) -> None:
        """
        Set the cochain of given rank.

        Parameters
        ----------
        rank : int
            The rank of the cochain.
        tensor : torch.Tensor, shape = [n, f]
            The cochain of rank `rank` to set.

        Raises
        ------
        ValueError
            If the rank is larger than the dimension of the simplicial complex.
        ValueError
            If the first dimension of the tensor does not match the size of the
            simplicial complex.
        ValueError
            If the tensor is not 2D or 3D.
        ValueError
            If the dtype or device of the tensor does not match the dtype of the other tensors.
        """
        if rank > self.domain.dim:
            raise ValueError(
                f"Rank must be less than or equal to the dimension of the simplicial complex, got {rank} and expected at most {self.domain.dim}."
            )
        elif tensor.size(0) != self.domain.shape[rank]:
            raise ValueError(
                f"The tensor's first dimension for rank {rank} must have size {self.domain.shape[rank]}, got {tensor.size(0)}."
            )

        if len(tensor.shape) != 2 and len(tensor.shape) != 3:
            raise ValueError(f"Input must be a 2D or 3D tensor, got {tensor.shape}.")

        if self.dtype != tensor.dtype:
            raise ValueError(
                f"Input tensor must be of the same dtype as the other tensors, got {tensor.dtype} and expected {self.dtype}."
            )
        if self.device != tensor.device:
            raise ValueError(
                f"Input tensor must be on the same device as the other tensors, got {tensor.device} and expected {self.device}."
            )

        self._cochains[rank] = tensor
        self.shape[rank] = tensor.size(1)

    def to(self, to) -> "SimplicialData":
        """
        Change the device or dtype of the simplicial data in-place.

        Parameters
        ----------
        to : torch.device or torch.dtype
            The device or dtype to which to move the data.

        Returns
        -------
        SimplicialData
            Self.
        """
        if isinstance(to, torch.device):
            self.device = to
            for key, aux_tensor in self._aux_tensors.items():
                self._aux_tensors[key] = aux_tensor.to(to)

        for key, cochain in self._cochains.items():
            self._cochains[key] = cochain.to(to)

        return self
