"""Converts a dense tensor to a sparse approximation."""
import dataclasses

import torch


###############################################################################


@dataclasses.dataclass
class CscMatrix:
    """A matrix in sparse compressed column (CSC) format."""

    # Should have two entries.
    shape: torch.Size

    # values.shape = [nnz]
    values: torch.Tensor

    # col_offsets.shape = [rank + 1], dtype in (int32, int64)
    col_offsets: torch.Tensor

    # row_indices.shape = [nnz], dtype in (int32, int64)
    row_indices: torch.Tensor

    def to_sparse_csc_tensor(self) -> torch.sparse_csc_tensor:
        # NOTE: The returned tensor will be the transpose of what I am trying to represent.
        return torch.sparse_csc_tensor(
            ccol_indices=self.col_offsets,
            row_indices=self.row_indices,
            values=self.values,
            size=tuple(reversed(self.shape)),
        )


###############################################################################


_MAX_INT32 = 2**31 - 1


@dataclasses.dataclass
class SingleLrmPefSparsifierCsc:
    """Converts a dense LRM-PEF into a CSC-format sparse approximation."""

    @classmethod
    def create(cls, **kwargs):
        return cls(**kwargs)

    def _requires_output_int64_indices(self, pef: torch.Tensor) -> bool:
        return pef.shape[-1] >= _MAX_INT32

    def sparsify(self, pef: torch.Tensor, nnz: int) -> CscMatrix:
        # pef.shape = [rank, n_parameters]
        with torch.no_grad():
            rank = pef.shape[-2]
            n_parameters = pef.shape[-1]

            # TODO: Change depending on stuff
            indices_dtype = torch.int64 if self._requires_output_int64_indices(pef) else torch.int32

            flat_pef = torch.reshape(pef, [-1])

            # NOTE: I think flat_indices should always have a dtype of int64.
            _, flat_indices = torch.topk(torch.abs(flat_pef), k=nnz, sorted=False)
            flat_indices, _ = torch.sort(flat_indices, descending=False)

            # values = torch.gather(flat_pef, -1, flat_indices)
            values = flat_pef[flat_indices]

            # col_indices = flat_indices // n_parameters
            col_indices = torch.floor_divide(flat_indices, n_parameters)

            col_offsets = torch.stack([
                torch.sum((col_indices == i).type(indices_dtype), dim=-1)
                for i in range(rank)
            ], dim=-1)
            col_offsets = torch.cat([
                torch.zeros_like(col_offsets[:1]),
                col_offsets,
            ], dim=-1)
            col_offsets = torch.cumsum(col_offsets, dim=-1)

            row_indices = flat_indices % n_parameters

            return CscMatrix(
                shape=pef.shape,
                values=values,
                col_offsets=col_offsets.type(indices_dtype),
                row_indices=row_indices.type(indices_dtype),
            )
