import torch
import qutlass._CUDA

__all__ = [
           "matmul_mxf4_bf16_tn",
           "matmul_mxf8_bf16_tn",
           "fusedQuantize",
           "fusedQuantize_bwd"
           ]

def matmul_mxf4_bf16_tn(a: torch.Tensor,
                        b: torch.Tensor,
                        block_scale_a: torch.Tensor,
                        block_scale_b: torch.Tensor,
                        alpha: torch.float32) -> torch.Tensor:
    return qutlass._CUDA.matmul_mxf4_bf16_tn(a, b, block_scale_a, block_scale_b, alpha)

def matmul_mxf8_bf16_tn(a: torch.Tensor,
                        b: torch.Tensor,
                        block_scale_a: torch.Tensor,
                        block_scale_b: torch.Tensor,
                        alpha: torch.float32) -> torch.Tensor:
    return qutlass._CUDA.matmul_mxf8_bf16_tn(a, b, block_scale_a, block_scale_b, alpha)

def fusedQuantize(a: torch.Tensor,
                  b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

    xh_e2m1   = torch.empty(*a.shape[:-1], a.size(-1) // 2,  dtype=torch.uint8,          device=a.device)
    xh_e8m0   = torch.empty(*a.shape[:-1], a.size(-1) // 32, dtype=torch.float8_e8m0fnu, device=a.device)
    clip_mask = torch.empty(*a.shape[:-1], a.size(-1) // 8,  dtype=torch.uint8,          device=a.device)

    return qutlass._CUDA.fusedQuantize(a, b, xh_e2m1, xh_e8m0, clip_mask)

def fusedQuantize_bwd(a: torch.Tensor,
                      b: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:

    xh_e2m1 = torch.empty(*a.shape[:-1], a.size(-1) // 2,  dtype=torch.uint8,            device=a.device)
    xh_e8m0 = torch.empty(*a.shape[:-1], a.size(-1) // 32, dtype=torch.float8_e8m0fnu,   device=a.device)

    return qutlass._CUDA.fusedQuantize_bwd(a, b, xh_e2m1, xh_e8m0)