"""
Entrypoint of Sparse Utility. It provides function that convert dense 
matrices to sparse with the 50% structured sparsity on Ampere
"""
import torch
from dsp.meta import bdense2sparse, block_ell, meta_ell
from typing import Tuple
import numpy as np


def dense2sparse(dense_matrix: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    r""" A customized kernel that prunes a dense matrix to sparse. The output matches the
    format used in CUTLASS SpMM with sparse tensor core.
    With ``float32``  tensors, it will select the larger one from each 1x2 vector in the output.
    With ``bfloat16`` tensors, it will setect the larger two from each 1x4 vector in the output.

    Args:
        dense_matrix: a (B, M, N) or (M, N) dense tensor to be pruned.
    
    Returns:
        nonzeros      : :math:`D` of size `(B, M, N/2)` or `(M, N/2)` 
        metadata      : :math:`C` of size `(B, M, N/2)` or `(M, N/Q)`, `Q`=8 for `float32` and `Q`=16 for `bfloat16`
        uncompressed: : :math:`D` of size `(B, M, N)` or `(M, N)`, the uncompressed output sparse matrix
    
    Example:
        >>> import torch
        >>> import dspattn
        >>> dense_matrix = torch.randn(size=(8, 4096, 4096), dtype=torch.bfloat16, device='cuda')
        >>> nonzeros, metadata = dspattn.dense2sparse(dense_matrix)
    """

    #########################
    # Check input dimension #
    #########################

    if dense_matrix.dim() != 2 and dense_matrix.dim() != 3:
        raise ValueError("expected 2D or 3D dense_matrix (got {}D input)".format(dense_matrix.dim()))
    
    #########################
    # Check input data type #
    #########################

    if dense_matrix.dtype != torch.float32 and dense_matrix.dtype != torch.bfloat16:
        raise ValueError("the dense_matrix should be in torch.float32 or torch.bfloat16 (got {})".format(dense_matrix.dtype))
    
    if not dense_matrix.is_cuda:
        raise ValueError("the dense_matrix should be on GPU (got CPU)")
    
    ################################
    # launch the extended function #
    ################################

    nonzeros, uncompressed_matrix, metadata, metadata_reorder = bdense2sparse(dense_matrix)

    return nonzeros, metadata_reorder, uncompressed_matrix


def static_random_mask(batch: int, m: int, n: int, nnz: int) -> torch.Tensor:
    r""" Generate a random mask for Blocked-ELL format indices.

    Args:
        batch: the batch size
        m: :math:`#row/block_size`, 
        n: :math:`#col/block_size`,
        nnz: number of nonzero blocks in each row
    
    Returns:
        indices: the nonzero indices of Blocked-ELL format
    """
    indices = []
    rng = np.random.default_rng()
    if batch == 1:
        for i in range(m):
            ind = np.sort(rng.choice(n, nnz, replace=False))
            indices.append(ind)
    else:
        for b in range(batch):
            b_indices = []
            for i in range(m):
                ind = np.sort(rng.choice(n, nnz, replace=False))
                b_indices.append(ind)
            indices.append(b_indices)
    
    indices = torch.tensor(indices, dtype=torch.int32, device='cuda')
    return indices


def block_ell_prune(input_data: torch.Tensor, indices: torch.Tensor, block_size_n: int) -> Tuple[torch.Tensor, torch.Tensor]:
    r""" Prune a dense matrix to block-ell format

    Args:
        input_data: A tensor of size (m, n)
        indices: A 2D tensor of indices
        block_size_n: the block size along the column dimension.
    
    Returns:
        output_data: A tensor that contains only the nonzero values
        uncompressed: A tensor of size (m, n), the pruned values are set to 0
    """

    output_data, uncompressed = block_ell(input_data, indices, block_size_n)
    return output_data, uncompressed


def meta_ell_prune(input_data: torch.Tensor, indices: torch.Tensor, block_size_n: int) -> Tuple[torch.Tensor, torch.Tensor]:
    r""" Prune a dense metadata to block-ell format

    Args:
        input_data: A tensor of size (m, n/8) (float) or (m, n/16) (bfloat16)
        indices: A 2D tensor of indices
        block_size_n: the block size along the column dimension.
    
    Returns:
        output_data: A tensor that contains only the nonzero values
        uncompressed: A tensor of size (m, n), the pruned values are set to 0
    """
    output_data, uncompressed = meta_ell(input_data, indices, block_size_n)
    return output_data, uncompressed