import torch

import quartet_c


def cvt_bf16_e2m1(
        x: torch.Tensor,
        out: torch.Tensor = None,
) -> torch.Tensor:
    """
    bf16 to fp4_e2m1
    :param x: (..., 8X), bf16, input tensor
    :param out: (..., 4X), e2m1x2
    """

    if out is None:
        out = torch.empty(*x.shape[:-1], x.size(-1) // 2, dtype=torch.float4_e2m1fn_x2, device=x.device)
    assert x.dtype == torch.bfloat16 and out.dtype == torch.float4_e2m1fn_x2
    assert x.is_contiguous() and out.is_contiguous()
    quartet_c.cvt_bf16_e2m1(out, x)
    return out


def cvt_e2m1_bf16(
        x: torch.Tensor,
        out: torch.Tensor = None,
) -> torch.Tensor:
    """
    fp4_e2m1 to bf16
    :param x: (..., 4X), e2m1x2, input tensor
    :param out: (..., 8X), bf16
    """

    if out is None:
        out = torch.empty(*x.shape[:-1], x.size(-1) * 2, dtype=torch.bfloat16, device=x.device)
    assert x.dtype == torch.float4_e2m1fn_x2 and out.dtype == torch.bfloat16
    assert x.is_contiguous() and out.is_contiguous()
    quartet_c.cvt_e2m1_bf16(out, x)
    return out


def forward_bf16(
        x: torch.Tensor,
        h: torch.Tensor,
        xh_e2m1: torch.Tensor = None,
        xh_e8m0: torch.Tensor = None,
        clip_mask: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Forward
    :param x: (..., 32X), bf16, input tensor
    :param h: (32, 32), bf16, Hadamard matrix
    :param xh_e2m1: (..., 16X), e2m1x2
    :param xh_e8m0: (..., X), e8m0
    :param clip_mask: (..., 4X), uint8 = bool x 8
    """

    if xh_e2m1 is None:
        xh_e2m1 = torch.empty(*x.shape[:-1], x.size(-1) // 2, dtype=torch.float4_e2m1fn_x2, device=h.device)
    if xh_e8m0 is None:
        xh_e8m0 = torch.empty(*x.shape[:-1], x.size(-1) // 32, dtype=torch.float8_e8m0fnu, device=h.device)
    if clip_mask is None:
        clip_mask = torch.empty(*x.shape[:-1], x.size(-1) // 8, dtype=torch.uint8, device=h.device)

    assert x.dtype == h.dtype == torch.bfloat16 and xh_e2m1.dtype == torch.float4_e2m1fn_x2 and xh_e8m0.dtype == torch.float8_e8m0fnu and clip_mask.dtype == torch.uint8
    assert x.is_contiguous() and h.is_contiguous() and xh_e2m1.is_contiguous() and xh_e8m0.is_contiguous() and clip_mask.is_contiguous()

    quartet_c.forward_bf16(x, h, xh_e2m1, xh_e8m0, clip_mask)

    return xh_e2m1, xh_e8m0, clip_mask


def backward_bf16(
        x: torch.Tensor,
        h: torch.Tensor,
        xh_e2m1: torch.Tensor = None,
        xh_e8m0: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Backward
    :param x: (..., 32X), bf16, input tensor
    :param h: (32, 32), bf16, Hadamard matrix
    :param xh_e2m1: (..., 16X), e2m1x2
    :param xh_e8m0: (..., X), e8m0
    """

    if xh_e2m1 is None:
        xh_e2m1 = torch.empty(*x.shape[:-1], x.size(-1) // 2, dtype=torch.float4_e2m1fn_x2, device=h.device)
    if xh_e8m0 is None:
        xh_e8m0 = torch.empty(*x.shape[:-1], x.size(-1) // 32, dtype=torch.float8_e8m0fnu, device=h.device)

    assert x.dtype == h.dtype == torch.bfloat16 and xh_e2m1.dtype == torch.float4_e2m1fn_x2 and xh_e8m0.dtype == torch.float8_e8m0fnu
    assert x.is_contiguous() and h.is_contiguous() and xh_e2m1.is_contiguous() and xh_e8m0.is_contiguous()

    quartet_c.backward_bf16(x, h, xh_e2m1, xh_e8m0)

    return xh_e2m1, xh_e8m0


def backward_t_bf16(
        x: torch.Tensor,
        h: torch.Tensor,
        xh_e2m1: torch.Tensor = None,
        xh_e8m0: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Backward (Transposed)
    :param x: (..., 32X, M), bf16, input tensor
    :param h: (32, 32), bf16, Hadamard matrix
    :param xh_e2m1: (..., M, 16X), e2m1x2
    :param xh_e8m0: (..., M, X), e8m0
    """

    if xh_e2m1 is None:
        xh_e2m1 = torch.empty(*x.shape[:-2], x.size(-1), x.size(-2) // 2, dtype=torch.float4_e2m1fn_x2, device=h.device)
    if xh_e8m0 is None:
        xh_e8m0 = torch.empty(*x.shape[:-2], x.size(-1), x.size(-2) // 32, dtype=torch.float8_e8m0fnu, device=h.device)

    assert x.dtype == h.dtype == torch.bfloat16 and xh_e2m1.dtype == torch.float4_e2m1fn_x2 and xh_e8m0.dtype == torch.float8_e8m0fnu
    assert x.is_contiguous() and h.is_contiguous() and xh_e2m1.is_contiguous() and xh_e8m0.is_contiguous()

    quartet_c.backward_t_bf16(x, h, xh_e2m1, xh_e8m0)

    return xh_e2m1, xh_e8m0


def backward_qt_bf16(
        x_e2m1: torch.Tensor,
        x_e8m0: torch.Tensor,
        h: torch.Tensor,
        alpha: float,
        xh_e2m1: torch.Tensor = None,
        xh_e8m0: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Backward (Quantized Transposed)
    :param x_e2m1: (..., 32X, 16Y), e2m1x2, input tensor
    :param x_e8m0: (..., 32X, Y), e8m0, input tensor
    :param h: (32, 32), bf16, Hadamard matrix
    :param alpha: float, quantization scaling factor
    :param xh_e2m1: (..., 32Y, 16X), e2m1x2
    :param xh_e8m0: (..., 32Y, X), e8m0
    """

    if xh_e2m1 is None:
        xh_e2m1 = torch.empty(*x_e2m1.shape[:-2], x_e2m1.size(-1) * 2, x_e2m1.size(-2) // 2, dtype=torch.float4_e2m1fn_x2, device=h.device)
    if xh_e8m0 is None:
        xh_e8m0 = torch.empty(*x_e8m0.shape[:-2], x_e8m0.size(-1) * 32, x_e8m0.size(-2) // 32, dtype=torch.float8_e8m0fnu, device=h.device)

    # assert h.dtype == torch.bfloat16 and x_e2m1.dtype == xh_e2m1.dtype == torch.float4_e2m1fn_x2 and x_e8m0.dtype == xh_e8m0.dtype == torch.float8_e8m0fnu
    assert x_e2m1.is_contiguous() and x_e8m0.is_contiguous() and h.is_contiguous() and xh_e2m1.is_contiguous() and xh_e8m0.is_contiguous()

    quartet_c.backward_qt_bf16(x_e2m1, x_e8m0, h, alpha, xh_e2m1, xh_e8m0)

    return xh_e2m1, xh_e8m0
