import torch
import triton
import triton.language as tl


def dtype_tc2tn(dtype_torch):
    if dtype_torch == torch.int8:
        return tl.int8
    elif dtype_torch == torch.float16:
        return tl.float16
    elif dtype_torch == torch.float8_e5m2:
        return tl.float8e5
    elif dtype_torch == torch.float8_e4m3fn:
        return tl.float8e4nv
    elif dtype_torch == torch.float32:
        return tl.float32
    else:
        raise ValueError(dtype_torch)


def dtype_tn2tc(dtype_triton):
    if dtype_triton == tl.int8:
        return torch.int8
    elif dtype_triton == tl.float16:
        return torch.float16
    elif dtype_triton == tl.float8e5:
        return torch.float8_e5m2
    elif dtype_triton == tl.float8e4nv:
        return torch.float8_e4m3fn
    else:
        raise ValueError(dtype_triton)


def dtype_to_max_torch(dtype):
    dtype_max = None
    if dtype == torch.int8:
        dtype_max = 127.0
    elif dtype == torch.float16:
        dtype_max = 65504.0
    elif dtype == torch.float8_e5m2:
        dtype_max = 57344.0
    elif dtype == torch.float8_e4m3fn:
        dtype_max = 448.0
    else:
        raise ValueError(f"unknown datatype {dtype}")

    return dtype_max


@triton.jit
def dtype_to_max(dtype):
    dtype_max = None
    if dtype == tl.int8:
        dtype_max = 127.0
    elif dtype == tl.float16:
        dtype_max = 65504.0
    elif dtype == tl.float8e5:
        dtype_max = 57344.0
    elif dtype == tl.float8e5b16:
        dtype_max = 28672.0
    elif dtype == tl.float8e4nv:
        dtype_max = 448.0
    elif dtype == tl.float8e4b8:
        dtype_max = 224.0
    elif dtype == tl.float8e4b15:
        dtype_max = 1.75
    else:
        raise ValueError(f"unknown datatype {dtype}")

    return dtype_max


def get_gemm_configs(bs, force_bk_1=False):
    configs_min32 = [
        triton.Config(
            {"BM": 64 // bs, "BN": 256 // bs, "BK": 32 // bs, "GM": 8},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BM": 128 // bs, "BN": 128 // bs, "BK": 32 // bs, "GM": 8},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BM": 128 // bs, "BN": 64 // bs, "BK": 32 // bs, "GM": 8},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BM": 64 // bs, "BN": 128 // bs, "BK": 32 // bs, "GM": 8},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BM": 128 // bs, "BN": 32 // bs, "BK": 32 // bs, "GM": 8},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BM": 64 // bs, "BN": 32 // bs, "BK": 32 // bs, "GM": 8},
            num_stages=5,
            num_warps=2,
        ),
        triton.Config(
            {"BM": 128 // bs, "BN": 32 // bs, "BK": 64 // bs, "GM": 8},
            num_stages=4,
            num_warps=4,
        ),
    ]
    configs_min64 = [
        triton.Config(
            {"BM": 128 // bs, "BN": 256 // bs, "BK": 64 // bs, "GM": 8},
            num_stages=3,
            num_warps=8,
        ),
        triton.Config(
            {"BM": 256 // bs, "BN": 64 // bs, "BK": 128 // bs, "GM": 8},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BM": 128 // bs, "BN": 64 // bs, "BK": 64 // bs, "GM": 8},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BM": 64 // bs, "BN": 128 // bs, "BK": 64 // bs, "GM": 8},
            num_stages=4,
            num_warps=4,
        ),
    ]
    configs_min128 = [
        triton.Config(
            {"BM": 128 // bs, "BN": 256 // bs, "BK": 128 // bs, "GM": 8},
            num_stages=3,
            num_warps=8,
        ),
        triton.Config(
            {"BM": 256 // bs, "BN": 128 // bs, "BK": 128 // bs, "GM": 8},
            num_stages=3,
            num_warps=8,
        ),
        triton.Config(
            {"BM": 64 // bs, "BN": 256 // bs, "BK": 128 // bs, "GM": 8},
            num_stages=4,
            num_warps=4,
        ),
        triton.Config(
            {"BM": 128 // bs, "BN": 128 // bs, "BK": 128 // bs, "GM": 8},
            num_stages=4,
            num_warps=4,
        ),
    ]
    configs = []
    if bs <= 32:
        configs += configs_min32
    if bs <= 64:
        configs += configs_min64
    if bs <= 128:
        configs += configs_min128
    if force_bk_1:
        for c in configs:
            c.kwargs["BK"] = 1
            # c.num_stages += 1
            c.num_warps *= 2
    return configs
