"""
Entry point of Softmax kernel
"""
from typing import Tuple
import torch
from dsp.softmax import long_seq_softmax_bf16, long_seq_softmax

def softmax(input_matrix: torch.Tensor, scaler: float) -> torch.Tensor:
    r""" A customized Softmax kernel.

    Args:
        input_matrix: input matrix to be normalized with softmax
        scaler: a scaler to be applied before the softmax
    
    Returns:
        output_matrix: matrix normalized by softmax
    
    Example:
        >>> import torch
        >>> import dspattn
        >>> import numpy as np
        >>> input_matrix = torch.randn(size=(8, 4096, 2048), dtype=torch.bfloat16, device='cuda')
        >>> output_matrix = dspattn.softmax(input_matrix, 1./np.sqrt(2048))
    """

    #########################
    # Check input data type #
    #########################

    if input_matrix.dtype != torch.float32 and input_matrix.dtype != torch.bfloat16:
        raise ValueError("the input_matrix should be in torch.float32 or torch.bfloat16 (got {})".format(input_matrix.dtype))
    

    if not input_matrix.is_cuda:
        raise ValueError("the input_matrix should be on GPU (got CPU)")
    
    ################################
    # launch the extended function #
    ################################

    if input_matrix.dtype == torch.float32:
        output_matrix = long_seq_softmax(input_matrix, -1, scaler)
    else:
        output_matrix = long_seq_softmax_bf16(input_matrix, -1, scaler)
    
    return output_matrix
    