import triton
import triton.language as tl
import torch
from triton.tools.tensor_descriptor import TensorDescriptor
from typing import NamedTuple, Optional, Dict


def matmul_tma_no_T_n_hook(named_args):
    
    EPILOGUE_SUBTILE = named_args.get("EPILOGUE_SUBTILE", False)

    T_m = named_args['T_m']
    T_n = 256
    T_k = named_args['T_k']

    named_args['a_td'].block_shape = [T_m, T_k]
    named_args['b_td'].block_shape = [T_k, T_n]
    named_args['c_td'].block_shape = [T_m, T_n]

    if EPILOGUE_SUBTILE:
        named_args['c_td'].block_shape = [T_m, T_n // 2]


def matmul_proton_no_T_n_launch_metadata(
        grid: tuple, metadata: NamedTuple, args: dict,
) -> dict:
    ret = {}
    M, N, K = args["M"], args["N"], args["K"]
    T_m, T_n, T_k = args["T_m"], 256, args["T_k"]
    num_warps = metadata.num_warps
    num_stages = metadata.num_stages
    grid_size = grid[0]
    clusters = metadata.num_ctas
    shared_memory = metadata.shared

    name_suffix = (
        f"_<tile:{T_m}m{T_n}n{T_k}k>_<warps:{num_warps}>_<stages:{num_stages}>"
    )

    if "warp_specialize" in args:
        ws = args["warp_specialize"]
        name_suffix += f"_<ws:{ws}>"
    if "epilogue_subtile" in args:
        ep = args["epilogue_subtile"]
        name_suffix += f"_<ep_subtile:{ep}>"
    ret["name"] = f"{metadata.name} [M={M}, N={N}, K={K}]{name_suffix}"
    bytes_per_elem = 2  
    ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K
    ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)
    ret["grid_size"] = grid_size
    ret["clusters"] = clusters
    ret["shared_memory"] = shared_memory
    return ret

@triton.jit(launch_metadata=matmul_proton_no_T_n_launch_metadata)
def matmul_dense_to_bwell_only_indices_tma_opt_kernel(
    a_td,  
    b_td,  
    c_td,  
    idxs_ptr,  
    nnzs_ptr,  
    M,
    N,
    K,
    
    T_m: tl.constexpr,
    T_n: tl.constexpr,
    T_k: tl.constexpr,
    num_blocks_n: tl.constexpr,
    
    G_m: tl.constexpr,
    epilogue_subtile: tl.constexpr,
    warp_specialize: tl.constexpr,
):
    tl.assume(N % T_n == 0)
    tl.assume(N / T_n == num_blocks_n)

    pid = tl.program_id(0)
    num_blocks_m = tl.cdiv(M, T_m)

    elements_per_group = num_blocks_n * G_m
    group_id = pid // elements_per_group
    group_iter = pid % elements_per_group

    starting_grouped_m = group_id * G_m
    current_group_size = min(num_blocks_m - starting_grouped_m, G_m)

    m = starting_grouped_m + (group_iter % current_group_size)
    n = group_iter // current_group_size

    start_index_m = m * T_m
    start_index_n = n * T_n

    c_block_accumulator = tl.zeros((T_m, T_n), dtype=tl.float32)
    for start_index_k in tl.range(0, K, T_k, warp_specialize=warp_specialize):
        a_block = a_td.load([start_index_m, start_index_k])
        b_block = b_td.load([start_index_k, start_index_n])
        c_block_accumulator = tl.dot(a_block, b_block, acc=c_block_accumulator)

    T_m_range = tl.arange(0, T_m)[:, None]

    if epilogue_subtile:
        tl.static_assert(T_n % 2 == 0)
        T_n_half: tl.constexpr = T_n // 2

        
        c_block_accumulator = tl.reshape(c_block_accumulator, (T_m, 2, T_n_half))
        c_block_accumulator = tl.permute(c_block_accumulator, (0, 2, 1))
        c_block_accumulator_0, c_block_accumulator_1 = tl.split(c_block_accumulator)

        
        c_block_0 = c_block_accumulator_0.to(tl.bfloat16)
        is_positive_0 = c_block_accumulator_0 > 0
        c_td.store(
            [start_index_m, start_index_n],
            value=tl.where(is_positive_0, c_block_0, 0),
        )

        
        c_block_1 = c_block_accumulator_1.to(tl.bfloat16)
        is_positive_1 = c_block_accumulator_1 > 0
        c_td.store(
            [start_index_m, start_index_n + T_n_half],
            value=tl.where(is_positive_1, c_block_1, 0),
        )

        
        T_n_range_half = tl.arange(0, T_n_half)[None, :]
        is_positive_int_0 = is_positive_0.to(tl.int32)
        n_offsets_0 = tl.cumsum(is_positive_int_0, axis=1) - is_positive_int_0

        idxs_block_ptr_0 = idxs_ptr + (start_index_m + T_m_range) * N + (
            start_index_n + n_offsets_0
        )
        tl.store(idxs_block_ptr_0, T_n_range_half.to(tl.uint8),
                 mask=is_positive_0)

        nnz_0 = tl.sum(is_positive_int_0, keep_dims=True, axis=1)

        
        is_positive_int_1 = is_positive_1.to(tl.int32)
        n_offsets_1 = tl.cumsum(is_positive_int_1, axis=1) - is_positive_int_1

        idxs_block_ptr_1 = idxs_ptr + (start_index_m + T_m_range) * N + (
            start_index_n + nnz_0 + n_offsets_1
        )
        tl.store(
            idxs_block_ptr_1,
            (T_n_range_half + T_n_half).to(tl.uint8),
            mask=is_positive_1,
        )

        nnz_1 = tl.sum(is_positive_int_1, keep_dims=True, axis=1)

        
        nnzs_block_ptr = nnzs_ptr + (start_index_m + T_m_range) * num_blocks_n + n
        tl.store(nnzs_block_ptr, (nnz_0 + nnz_1).to(tl.uint8))

    else:
        c_block = c_block_accumulator.to(tl.bfloat16)
        is_positive = c_block_accumulator > 0
        c_td.store([start_index_m, start_index_n],
                   value=tl.where(is_positive, c_block, 0))

        T_n_range = tl.arange(0, T_n)[None, :]
        is_positive_int = is_positive.to(tl.int32)

        n_offsets = tl.cumsum(is_positive_int, axis=1) - is_positive_int
        idxs_block_ptr = idxs_ptr + (start_index_m + T_m_range) * N + (
            start_index_n + n_offsets
        )
        tl.store(idxs_block_ptr, T_n_range.to(tl.uint8), mask=is_positive)

        nnzs_block_ptr = nnzs_ptr + (start_index_m + T_m_range) * num_blocks_n + n
        tl.store(nnzs_block_ptr, tl.sum(is_positive_int, keep_dims=True, axis=1)
                 .to(tl.uint8))


def matmul_dense_to_bwell_only_indices_tma_opt(a, b, warp_specialize: bool = False):
    
    assert a.shape[1] == b.shape[0], "Inner dimensions must match for tma nn ops"
    assert a.is_contiguous(), "Check memory in a is contiguous"
    M, K = a.shape
    N = b.shape[1]
    
    c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
    if warp_specialize:
        print(
            "Warning: warp specialize enabled for non persistent TMA matmul"
            " is this running on blackwell?")

    
    T_m = 128
    T_k = 64
    T_n = 256
    assert N % T_n == 0, "N must be multiple of 256 for this bwell nn ops"
    num_blocks_n = N // T_n
    a_td = TensorDescriptor.from_tensor(a, block_shape=[T_m, T_k])
    b_td = TensorDescriptor.from_tensor(b, block_shape=[T_k, T_n])
    c_td = TensorDescriptor.from_tensor(c, block_shape=[T_m, T_n//2])

    idxs = torch.empty((M, N), dtype=torch.uint8, device=a.device)
    nnzs = torch.empty((M, num_blocks_n), dtype=torch.uint8, device=a.device)

    def grid(meta):
        return (triton.cdiv(M, meta['T_m']) * num_blocks_n,)

    matmul_dense_to_bwell_only_indices_tma_opt_kernel[grid](
        a_td=a_td,
        b_td=b_td,
        c_td=c_td,
        idxs_ptr=idxs,
        nnzs_ptr=nnzs,
        M=M, N=N, K=K, T_n=T_n, num_blocks_n=num_blocks_n,
        warp_specialize=warp_specialize,
        T_m=128,
        T_k=64,
        G_m=8,
        num_warps=8,
        num_stages=4,
        epilogue_subtile=True,
    )

    return c, idxs, nnzs

@triton.jit(launch_metadata=matmul_proton_no_T_n_launch_metadata)
def matmul_dense_to_bwell_only_indices_tma_kernel(
    a_td,  
    b_td,  
    c_td,  
    idxs_ptr,  
    nnzs_ptr,  
    M,
    N,
    K,
    
    T_m: tl.constexpr,
    T_n: tl.constexpr,
    T_k: tl.constexpr,
    num_blocks_n: tl.constexpr,
    
    G_m: tl.constexpr,
    warp_specialize: tl.constexpr,
):
    tl.assume(N % T_n == 0)
    tl.assume(N / T_n == num_blocks_n)
    tl.assume(T_n % 2 == 0)

    pid = tl.program_id(0)
    num_blocks_m = tl.cdiv(M, T_m)

    elements_per_group = num_blocks_n * G_m
    group_id = pid // elements_per_group
    group_iter = pid % elements_per_group

    starting_grouped_m = group_id * G_m
    current_group_size = min(num_blocks_m - starting_grouped_m, G_m)

    m = starting_grouped_m + (group_iter % current_group_size)
    n = group_iter // current_group_size

    start_index_m = m * T_m
    start_index_n = n * T_n

    c_block_accumulator = tl.zeros((T_m, T_n), dtype=tl.float32)
    for start_index_k in tl.range(0, K, T_k, warp_specialize=warp_specialize):
        a_block = a_td.load([start_index_m, start_index_k])
        b_block = b_td.load([start_index_k, start_index_n])
        c_block_accumulator = tl.dot(a_block, b_block, acc=c_block_accumulator)

    T_n_half: tl.constexpr = T_n // 2

    
    c_block_accumulator = tl.reshape(c_block_accumulator, (T_m, 2, T_n_half))
    c_block_accumulator = tl.permute(c_block_accumulator, (0, 2, 1))
    c_block_accumulator_0, c_block_accumulator_1 = tl.split(c_block_accumulator)

    
    c_block_0 = c_block_accumulator_0.to(tl.bfloat16)
    is_positive_0 = c_block_accumulator_0 > 0
    c_td.store(
        [start_index_m, start_index_n],
        value=tl.where(is_positive_0, c_block_0, 0),
    )

    
    c_block_1 = c_block_accumulator_1.to(tl.bfloat16)
    is_positive_1 = c_block_accumulator_1 > 0
    c_td.store(
        [start_index_m, start_index_n + T_n_half],
        value=tl.where(is_positive_1, c_block_1, 0),
    )

    
    T_m_range = tl.arange(0, T_m)[:, None]
    T_n_range_half = tl.arange(0, T_n_half)[None, :]
    is_positive_int_0 = is_positive_0.to(tl.int32)
    n_offsets_0 = tl.cumsum(is_positive_int_0, axis=1) - is_positive_int_0

    idxs_block_ptr_0 = idxs_ptr + (start_index_m + T_m_range) * N + (
        start_index_n + n_offsets_0
    )
    tl.store(idxs_block_ptr_0, T_n_range_half.to(tl.uint8),
                mask=is_positive_0)

    nnz_0 = tl.sum(is_positive_int_0, keep_dims=True, axis=1)

    
    is_positive_int_1 = is_positive_1.to(tl.int32)
    n_offsets_1 = tl.cumsum(is_positive_int_1, axis=1) - is_positive_int_1

    idxs_block_ptr_1 = idxs_ptr + (start_index_m + T_m_range) * N + (
        start_index_n + nnz_0 + n_offsets_1
    )
    tl.store(
        idxs_block_ptr_1,
        (T_n_range_half + T_n_half).to(tl.uint8),
        mask=is_positive_1,
    )

    nnz_1 = tl.sum(is_positive_int_1, keep_dims=True, axis=1)

    
    nnzs_block_ptr = nnzs_ptr + (start_index_m + T_m_range) * num_blocks_n + n
    tl.store(nnzs_block_ptr, (nnz_0 + nnz_1).to(tl.uint8))


def matmul_dense_to_bwell_only_indices_tma(a, b, warp_specialize: bool = False):
    
    assert a.shape[1] == b.shape[0], "Inner dimensions must match for tma nn ops"
    assert a.is_contiguous(), "Check memory in a is contiguous"
    M, K = a.shape
    N = b.shape[1]
    
    c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)
    if warp_specialize:
        print(
            "Warning: warp specialize enabled for non persistent TMA matmul"
            " is this running on blackwell?")

    
    T_m = 128
    T_k = 64
    T_n = 256
    assert N % T_n == 0, "N must be multiple of 256 for this bwell nn ops"
    num_blocks_n = N // T_n
    a_td = TensorDescriptor.from_tensor(a, block_shape=[T_m, T_k])
    b_td = TensorDescriptor.from_tensor(b, block_shape=[T_k, T_n])
    c_td = TensorDescriptor.from_tensor(c, block_shape=[T_m, T_n//2])

    idxs = torch.empty((M, N), dtype=torch.uint8, device=a.device)
    nnzs = torch.empty((M, num_blocks_n), dtype=torch.uint8, device=a.device)

    def grid(meta):
        return (triton.cdiv(M, meta['T_m']) * num_blocks_n,)

    matmul_dense_to_bwell_only_indices_tma_kernel[grid](
        a_td=a_td,
        b_td=b_td,
        c_td=c_td,
        idxs_ptr=idxs,
        nnzs_ptr=nnzs,
        M=M, N=N, K=K, T_n=T_n, num_blocks_n=num_blocks_n,
        warp_specialize=warp_specialize,
        T_m=128,
        T_k=64,
        G_m=8,
        num_warps=8,
        num_stages=4,
    )

    return c, idxs, nnzs

@triton.jit(launch_metadata=matmul_proton_no_T_n_launch_metadata)
def matmul_dense_to_bwell_only_indices_tma_alt_kernel(
    a_td,
    b_td,
    c_td,
    idxs_td,      
    nnzs_ptr,     
    M,
    N,
    K,
    T_m: tl.constexpr,
    T_n: tl.constexpr,
    T_k: tl.constexpr,
    num_blocks_n: tl.constexpr,
    G_m: tl.constexpr,
    warp_specialize: tl.constexpr,
):
    tl.assume(N % T_n == 0)
    tl.assume(N / T_n == num_blocks_n)
    tl.assume(T_n % 2 == 0)

    pid = tl.program_id(0)
    num_blocks_m = tl.cdiv(M, T_m)

    elements_per_group = num_blocks_n * G_m
    group_id = pid // elements_per_group
    group_iter = pid % elements_per_group

    starting_grouped_m = group_id * G_m
    current_group_size = min(num_blocks_m - starting_grouped_m, G_m)

    m = starting_grouped_m + (group_iter % current_group_size)
    n = group_iter // current_group_size

    start_index_m = m * T_m
    start_index_n = n * T_n

    c_acc = tl.zeros((T_m, T_n), dtype=tl.float32)
    for start_index_k in tl.range(0, K, T_k, warp_specialize=warp_specialize):
        a_block = a_td.load([start_index_m, start_index_k])
        b_block = b_td.load([start_index_k, start_index_n])
        c_acc = tl.dot(a_block, b_block, acc=c_acc)

    
    is_pos = c_acc > 0

    
    T_n_half: tl.constexpr = T_n // 2
    c_acc_rs = tl.reshape(c_acc, (T_m, 2, T_n_half))
    c_acc_rs = tl.permute(c_acc_rs, (0, 2, 1))
    c_acc_0, c_acc_1 = tl.split(c_acc_rs)

    c_block_0 = c_acc_0.to(tl.bfloat16)
    is_pos_0 = c_acc_0 > 0
    c_td.store(
        [start_index_m, start_index_n],
        value=tl.where(is_pos_0, c_block_0, 0),
    )

    c_block_1 = c_acc_1.to(tl.bfloat16)
    is_pos_1 = c_acc_1 > 0
    c_td.store(
        [start_index_m, start_index_n + T_n_half],
        value=tl.where(is_pos_1, c_block_1, 0),
    )

    
    cols = tl.arange(0, T_n)[None, :].to(tl.int32)
    cols = tl.broadcast_to(cols, (T_m, T_n))

    
    
    key = tl.where(is_pos, cols, T_n)
    key_sorted = tl.sort(key, dim=1)

    idxs_tile = key_sorted.to(tl.uint8)
    idxs_td.store([start_index_m, start_index_n], idxs_tile)

    
    tmr = tl.arange(0, T_m)[:, None]
    nnz = tl.sum(is_pos.to(tl.int32), axis=1, keep_dims=True)
    nnzs_block_ptr = nnzs_ptr + (start_index_m + tmr) * num_blocks_n + n
    tl.store(nnzs_block_ptr, nnz.to(tl.uint8))



def matmul_dense_to_bwell_only_indices_tma_alt(a, b, warp_specialize=False):
    assert a.shape[1] == b.shape[0]
    assert a.is_contiguous()
    M, K = a.shape
    N = b.shape[1]

    c = torch.empty((M, N), device=a.device, dtype=torch.bfloat16)

    T_m = 128
    T_k = 64
    T_n = 256
    assert N % T_n == 0
    num_blocks_n = N // T_n

    a_td = TensorDescriptor.from_tensor(a, block_shape=[T_m, T_k])
    b_td = TensorDescriptor.from_tensor(b, block_shape=[T_k, T_n])
    c_td = TensorDescriptor.from_tensor(c, block_shape=[T_m, T_n // 2])

    idxs = torch.empty((M, N), dtype=torch.uint8, device=a.device)
    idxs_td = TensorDescriptor.from_tensor(idxs, block_shape=[T_m, T_n])

    nnzs = torch.empty((M, num_blocks_n), dtype=torch.uint8, device=a.device)

    def grid(meta):
        return (triton.cdiv(M, meta["T_m"]) * num_blocks_n,)

    matmul_dense_to_bwell_only_indices_tma_alt_kernel[grid](
        a_td=a_td,
        b_td=b_td,
        c_td=c_td,
        idxs_td=idxs_td,
        nnzs_ptr=nnzs,
        M=M,
        N=N,
        K=K,
        T_n=T_n,
        num_blocks_n=num_blocks_n,
        warp_specialize=warp_specialize,
        T_m=T_m,
        T_k=T_k,
        G_m=8,
        num_warps=8,
        num_stages=4,
    )

    return c, idxs, nnzs



def get_block256_compress_autotune_config():
    configs = []
    warps_options = [1, 2, 4, 8]
    stages_options = [1, 2, 3]
    for num_warps in warps_options:
        for num_stages in stages_options:
            configs.append(
                triton.Config(
                    {},
                    num_stages=num_stages,
                    num_warps=num_warps,
                )
            )
    return configs


@triton.autotune(
    configs=get_block256_compress_autotune_config(),
    key=["M", "N"],
)
@triton.jit
def compress_nonzeros_by_block_kernel(
    input_ptr,
    values_ptr,
    idxs_ptr,
    nnzs_ptr,
    M,
    N,
    input_stride_m,
    input_stride_n,
    values_stride_m,
    values_stride_n,
    idxs_stride_m,
    idxs_stride_n,
    nnzs_stride_m,
    nnzs_stride_n,
    T_n: tl.constexpr,
):
    pid = tl.program_id(0)
    num_blocks_n = tl.cdiv(N, T_n)

    row = pid // num_blocks_n
    block_n = pid - row * num_blocks_n
    start_n = block_n * T_n

    col_offsets = start_n + tl.arange(0, T_n)

    input_row_ptr = input_ptr + row * input_stride_m
    values_row_ptr = values_ptr + row * values_stride_m
    idxs_row_ptr = idxs_ptr + row * idxs_stride_m

    col_mask = col_offsets < N
    row_values = tl.load(
        input_row_ptr + col_offsets * input_stride_n,
        mask=col_mask,
        other=0,
    )

    is_positive = row_values > 0
    is_positive_int = is_positive.to(tl.int32)

    write_offsets = tl.cumsum(is_positive_int, axis=0) - is_positive_int
    store_offsets = start_n + write_offsets

    tl.store(
        values_row_ptr + store_offsets * values_stride_n,
        row_values,
        mask=is_positive,
    )

    tl.store(
        idxs_row_ptr + store_offsets * idxs_stride_n,
        tl.arange(0, T_n).to(tl.uint8),
        mask=is_positive,
    )

    nnz = tl.sum(is_positive_int, axis=0)
    nnzs_store_ptr = (
        nnzs_ptr
        + row * nnzs_stride_m
        + block_n * nnzs_stride_n
    )

    tl.store(nnzs_store_ptr, nnz.to(tl.uint8))


def compress_nonzeros_by_block(c):
    
    
    assert c.ndim == 2
    assert c.is_cuda

    M, N = c.shape
    T_n = 256
    num_blocks_n = triton.cdiv(N, T_n)

    values = torch.empty((M, N), device=c.device, dtype=c.dtype)
    idxs = torch.empty((M, N), device=c.device, dtype=torch.uint8)
    nnzs = torch.empty(
        (M, num_blocks_n),
        device=c.device,
        dtype=torch.uint8,
    )

    grid = (M * num_blocks_n,)

    compress_nonzeros_by_block_kernel[grid](
        input_ptr=c,
        values_ptr=values,
        idxs_ptr=idxs,
        nnzs_ptr=nnzs,
        M=M,
        N=N,
        input_stride_m=c.stride(0),
        input_stride_n=c.stride(1),
        values_stride_m=values.stride(0),
        values_stride_n=values.stride(1),
        idxs_stride_m=idxs.stride(0),
        idxs_stride_n=idxs.stride(1),
        nnzs_stride_m=nnzs.stride(0),
        nnzs_stride_n=nnzs.stride(1),
        T_n=T_n,
    )

    return values, idxs, nnzs


@triton.jit()  
def matmul_bwell_to_dense_only_indices_kernel(
    
    a_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    b_ptr,
    c_ptr,
    
    M,
    N: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    a_m,
    a_k,
    
    indices_stride_m,
    indices_stride_k,
    b_stride_k,
    b_stride_n,
    c_stride_m,
    c_stride_n,
):
    row_idx = tl.program_id(0)
    output_accumulator = tl.zeros((N,), dtype=tl.float32)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + block_idx * block_size + idx * indices_stride_k
            ).to(tl.int64) + block_start
            value = tl.load(
                a_ptr + row_idx * a_m + index * a_k,)

            
            b_row = tl.load(
                b_ptr + index * b_stride_k + tl.arange(0, N) * b_stride_n)

            output_accumulator += (value * b_row).to(tl.float32)
    c_row_ptr = c_ptr + row_idx * c_stride_m + tl.arange(0, N) * c_stride_n
    tl.store(c_row_ptr, output_accumulator.to(tl.bfloat16))


def matmul_bwell_to_dense_only_indices(a, indices_by_block, nonzeros_by_block, b):
    
    assert a.is_contiguous(), 'values must be contiguous'
    assert indices_by_block.is_contiguous(), 'indices must be contiguous'
    assert b.is_contiguous(), 'b must be contiguous'

    M, K = a.shape
    K, N = b.shape
    M, K_blocks = nonzeros_by_block.shape

    assert K % K_blocks == 0, "N must be divisible by N_blocks"
    block_size = K // K_blocks

    c = torch.empty(
        (M, N),
        device=a.device,
        dtype=torch.bfloat16,
    )

    grid = (M,)

    matmul_bwell_to_dense_only_indices_kernel[grid](
        a_ptr=a,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        b_ptr=b,
        c_ptr=c,
        M=M,
        N=N,
        K_blocks=K_blocks,
        block_size=block_size,
        a_m=a.stride(0),
        a_k=a.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        b_stride_k=b.stride(0),
        b_stride_n=b.stride(1),
        c_stride_m=c.stride(0),
        c_stride_n=c.stride(1),
        num_stages=1,
        num_warps=1,
    )
    return c


@triton.jit()  
def matmul_bwell_to_dense_only_indices_alt_kernel(
    
    a_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    b_ptr,
    c_ptr,
    
    M,
    N: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    a_m,
    a_k,
    
    indices_stride_m,
    indices_stride_k,
    b_stride_k,
    b_stride_n,
    c_stride_m,
    c_stride_n,
):
    row_idx = tl.program_id(0)
    output_accumulator = tl.zeros((N,), dtype=tl.float32)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + block_idx * block_size + idx * indices_stride_k
            ).to(tl.int64) + block_start
            value = tl.load(
                a_ptr + row_idx * a_m + index * a_k,)

            
            b_row = tl.load(
                b_ptr + index * b_stride_k + tl.arange(0, N) * b_stride_n)

            output_accumulator += (value * b_row).to(tl.float32)
    c_row_ptr = c_ptr + row_idx * c_stride_m + tl.arange(0, N) * c_stride_n
    tl.store(c_row_ptr, output_accumulator.to(tl.bfloat16))


def matmul_bwell_to_dense_only_indices_alt(a, indices_by_block, nonzeros_by_block, b):
    
    assert a.is_contiguous(), 'values must be contiguous'
    assert indices_by_block.is_contiguous(), 'indices must be contiguous'
    assert b.is_contiguous(), 'b must be contiguous'

    M, K = a.shape
    K, N = b.shape
    M, K_blocks = nonzeros_by_block.shape

    assert K % K_blocks == 0, "N must be divisible by N_blocks"
    block_size = K // K_blocks

    c = torch.empty(
        (M, N),
        device=a.device,
        dtype=torch.bfloat16,
    )

    grid = (M,)

    matmul_bwell_to_dense_only_indices_alt_kernel[grid](
        a_ptr=a,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        b_ptr=b,
        c_ptr=c,
        M=M,
        N=N,
        K_blocks=K_blocks,
        block_size=block_size,
        a_m=a.stride(0),
        a_k=a.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        b_stride_k=b.stride(0),
        b_stride_n=b.stride(1),
        c_stride_m=c.stride(0),
        c_stride_n=c.stride(1),
        num_stages=1,
        num_warps=1,
    )
    return c

def bwell_mlp_v0(x, b1, b2):
    
    B, D = x.shape
    D, D_inn = b1.shape
    num_blocks_n = D_inn // 256
    T_m = 128
    T_k = 64
    T_n = 256

    c1 = torch.empty((B, D_inn), device=x.device, dtype=torch.bfloat16)
    a_td = TensorDescriptor.from_tensor(x, block_shape=[T_m, T_k])
    b_td = TensorDescriptor.from_tensor(b1, block_shape=[T_k, T_n])
    c_td = TensorDescriptor.from_tensor(c1, block_shape=[T_m, 128])

    idxs = torch.empty((B, D_inn), dtype=torch.uint8, device=x.device)
    nnzs = torch.empty((B, num_blocks_n), dtype=torch.uint8, device=x.device)

    matmul_dense_to_bwell_only_indices_tma_opt_kernel[(
            B // 128 * num_blocks_n,)](
        a_td=a_td,
        b_td=b_td,
        c_td=c_td,
        idxs_ptr=idxs,
        nnzs_ptr=nnzs,
        M=B, N=D_inn, K=D, T_n=T_n, num_blocks_n=num_blocks_n,
        warp_specialize=False,
        T_m=128,
        T_k=64,
        G_m=8,
        num_warps=8,
        num_stages=4,
        epilogue_subtile=True,
    )
    
    c2 = torch.empty((B, D), device=x.device, dtype=torch.bfloat16)
    
    matmul_bwell_to_dense_only_indices_kernel[(B,)](
        a_ptr=c1,
        indices_by_block_ptr=idxs,
        nonzeros_by_block_ptr=nnzs,
        b_ptr=b2,
        c_ptr=c2,
        M=B,
        N=D,
        K_blocks=num_blocks_n,
        block_size=T_n,
        a_m=c1.stride(0),
        a_k=c1.stride(1),
        indices_stride_m=idxs.stride(0),
        indices_stride_k=idxs.stride(1),
        b_stride_k=b2.stride(0),
        b_stride_n=b2.stride(1),
        c_stride_m=c2.stride(0),
        c_stride_n=c2.stride(1),
        num_stages=1,
        num_warps=1,
    )
    return c2