import triton
import triton.language as tl
import torch
from triton.tools.tensor_descriptor import TensorDescriptor
from typing import NamedTuple, Optional, Dict


def matmul_proton_no_T_n_launch_metadata(
        grid: tuple, metadata: NamedTuple, args: dict,
) -> dict:
    ret = {}
    M, N, K = args["M"], args["N"], args["K"]
    T_m, T_n, T_k = args["T_m"], 256, args["T_k"]
    num_warps = metadata.num_warps
    num_stages = metadata.num_stages
    grid_size = grid[0]
    clusters = metadata.num_ctas
    shared_memory = metadata.shared

    name_suffix = (
        f"_<tile:{T_m}m{T_n}n{T_k}k>_<warps:{num_warps}>_<stages:{num_stages}>"
    )

    if "warp_specialize" in args:
        ws = args["warp_specialize"]
        name_suffix += f"_<ws:{ws}>"
    if "epilogue_subtile" in args:
        ep = args["epilogue_subtile"]
        name_suffix += f"_<ep_subtile:{ep}>"
    ret["name"] = f"{metadata.name} [M={M}, N={N}, K={K}]{name_suffix}"
    bytes_per_elem = 2  
    ret[f"flops{bytes_per_elem * 8}"] = 2. * M * N * K
    ret["bytes"] = bytes_per_elem * (M * K + N * K + M * N)
    ret["grid_size"] = grid_size
    ret["clusters"] = clusters
    ret["shared_memory"] = shared_memory
    return ret


@triton.jit()  
def gated_up_projection_kernel(
    
    
    a_ptr,
    
    gate_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    
    b_ptr,
    c_ptr,
    
    M,
    N: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    a_m,
    a_k,
    
    indices_stride_m,
    indices_stride_k,
    b_stride_k,
    b_stride_n,
    c_stride_m,
    c_stride_n,
):
    row_idx = tl.program_id(0)
    output_accumulator = tl.zeros((N,), dtype=tl.float32)
    input_sample = tl.load(a_ptr + row_idx * a_m + tl.arange(0, a_k) * a_k)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + block_idx * block_size + idx * indices_stride_k
            ).to(tl.int64) + block_start
            value = tl.load(
                gate_ptr + row_idx * a_m + index * a_k,)

            
            b_row = tl.load(
                b_ptr + index * b_stride_k + tl.arange(0, N) * b_stride_n)

            output_accumulator += (value * b_row).to(tl.float32)
    c_row_ptr = c_ptr + row_idx * c_stride_m + tl.arange(0, N) * c_stride_n
    tl.store(c_row_ptr, output_accumulator.to(tl.bfloat16))


def gated_up_projection(a, indices_by_block, nonzeros_by_block, b):
    
    assert a.is_contiguous(), 'values must be contiguous'
    assert indices_by_block.is_contiguous(), 'indices must be contiguous'
    assert b.is_contiguous(), 'b must be contiguous'

    M, K = a.shape
    K, N = b.shape
    M, K_blocks = nonzeros_by_block.shape

    assert K % K_blocks == 0, "N must be divisible by N_blocks"
    block_size = K // K_blocks
    grid = (M,)

    gated_up_projection_kernel[grid](
        a_ptr=a,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        b_ptr=b,
        c_ptr=c,
        M=M,
        N=N,
        K_blocks=K_blocks,
        block_size=block_size,
        a_m=a.stride(0),
        a_k=a.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        b_stride_k=b.stride(0),
        b_stride_n=b.stride(1),
        c_stride_m=c.stride(0),
        c_stride_n=c.stride(1),
        num_stages=1,
        num_warps=1,
    )
    return c



def gated_downprojection_proton_launch_metadata(
    grid: tuple, metadata: NamedTuple, args: dict
) -> dict:
    ret: dict[str, object] = {}

    M = int(args["M"])
    N = int(args["N"])
    K_blocks = int(args["K_blocks"])
    block_size = int(args["block_size"])

    num_warps = int(getattr(metadata, "num_warps", 0))
    num_stages = int(getattr(metadata, "num_stages", 0))
    clusters = int(getattr(metadata, "num_ctas", 1))
    shared_memory = int(getattr(metadata, "shared", 0))

    grid_size = int(grid[0]) if len(grid) > 0 else 0

    ret["name"] = (
        f"{metadata.name} [M={M}, N={N}, K_blocks={K_blocks}, "
        f"block_size={block_size}]_<warps:{num_warps}>_<stages:{num_stages}>"
    )
    bytes_per_elem = 2  

    
    
    
    k = K_blocks * block_size
    bytes_per_row = bytes_per_elem * (N + k * (2 * N) + N)
    ret["bytes"] = M * bytes_per_row

    
    
    
    ret[f"flops{bytes_per_elem * 8}"] = float(M) * float(k) * float(4 * N)

    ret["grid_size"] = grid_size
    ret["clusters"] = clusters
    ret["shared_memory"] = shared_memory
    return ret



















@triton.jit(launch_metadata=gated_downprojection_proton_launch_metadata)
def gated_downprojection_bwell_kernel(
    
    sparse_gate_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    
    input_ptr,
    
    up_weight_ptr,
    
    down_weight_ptr,
    
    output_ptr,
    
    M: tl.constexpr,
    N: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    sparse_gate_stride_m: tl.constexpr,
    sparse_gate_stride_k: tl.constexpr,
    
    indices_stride_m: tl.constexpr,
    indices_stride_k: tl.constexpr,
    input_stride_m: tl.constexpr,
    input_stride_n: tl.constexpr,
    
    up_weight_stride_k: tl.constexpr,
    up_weight_stride_n: tl.constexpr,
    down_weight_stride_k: tl.constexpr,
    down_weight_stride_n: tl.constexpr,
    output_stride_m: tl.constexpr,
    output_stride_n: tl.constexpr,
):
    row_idx = tl.program_id(0)
    input_row_ptr = (
        input_ptr + row_idx * input_stride_m +
        tl.arange(0, N) * input_stride_n
    )
    input_row = tl.load(input_row_ptr)
    output_accumulator = tl.zeros((N,), dtype=tl.float32)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + (block_idx * block_size + idx) * indices_stride_k
            ).to(tl.int64) + block_start

            
            
            up_weight_row = tl.load(
                up_weight_ptr + index * up_weight_stride_k +
                tl.arange(0, N) * up_weight_stride_n
            )

            
            feature = tl.sum((input_row * up_weight_row).to(tl.float32)).to(
                tl.bfloat16)

            gate = tl.load(
                sparse_gate_ptr + row_idx * sparse_gate_stride_m +
                index * sparse_gate_stride_k
            ) * feature

            
            down_weight_row = tl.load(
                down_weight_ptr + index * down_weight_stride_k +
                tl.arange(0, N) * down_weight_stride_n
            )

            output_accumulator += (
                gate * down_weight_row).to(tl.float32)
            
    output_row_ptr = (
        output_ptr + row_idx * output_stride_m +
        tl.arange(0, N) * output_stride_n
    )
    tl.store(output_row_ptr, output_accumulator.to(tl.bfloat16))


def gated_downprojection_bwell(
    sparse_gate,
    indices_by_block,
    nonzeros_by_block,
    input,
    up_weight,
    down_weight,
):
    
    
    
    

    M, K = sparse_gate.shape
    K, N = down_weight.shape
    M, K_blocks = nonzeros_by_block.shape

    
    
    
    
    
    block_size = 256

    output = torch.empty(
        (M, N),
        device=sparse_gate.device,
        dtype=torch.bfloat16,
    )

    grid = (M,)

    gated_downprojection_bwell_kernel[grid](
        sparse_gate_ptr=sparse_gate,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        input_ptr=input,
        up_weight_ptr=up_weight,
        down_weight_ptr=down_weight,
        output_ptr=output,
        M=M,
        N=N,
        K_blocks=K_blocks,
        block_size=block_size,
        sparse_gate_stride_m=sparse_gate.stride(0),
        sparse_gate_stride_k=sparse_gate.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        input_stride_m=input.stride(0),
        input_stride_n=input.stride(1),
        up_weight_stride_k=up_weight.stride(0),
        up_weight_stride_n=up_weight.stride(1),
        down_weight_stride_k=down_weight.stride(0),
        down_weight_stride_n=down_weight.stride(1),
        output_stride_m=output.stride(0),
        output_stride_n=output.stride(1),
        num_stages=1,
        num_warps=1,
    )
    return output

@triton.jit(launch_metadata=gated_downprojection_proton_launch_metadata)
def gated_downprojection_bwell_ord_kernel(
    
    sparse_gate_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    
    input_ptr,
    
    up_weight_ptr,
    
    down_weight_ptr,
    
    output_ptr,
    
    priority_ptr,
    
    M: tl.constexpr,
    N: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    sparse_gate_stride_m: tl.constexpr,
    sparse_gate_stride_k: tl.constexpr,
    
    indices_stride_m: tl.constexpr,
    indices_stride_k: tl.constexpr,
    input_stride_m: tl.constexpr,
    input_stride_n: tl.constexpr,
    
    up_weight_stride_k: tl.constexpr,
    up_weight_stride_n: tl.constexpr,
    down_weight_stride_k: tl.constexpr,
    down_weight_stride_n: tl.constexpr,
    output_stride_m: tl.constexpr,
    output_stride_n: tl.constexpr,
):
    pid = tl.program_id(0)
    row_idx = tl.load(priority_ptr + pid).to(tl.int32)
    input_row_ptr = (
        input_ptr + row_idx * input_stride_m +
        tl.arange(0, N) * input_stride_n
    )
    input_row = tl.load(input_row_ptr)
    output_accumulator = tl.zeros((N,), dtype=tl.float32)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + (block_idx * block_size + idx) * indices_stride_k
            ).to(tl.int64) + block_start

            
            
            up_weight_row = tl.load(
                up_weight_ptr + index * up_weight_stride_k +
                tl.arange(0, N) * up_weight_stride_n
            )

            
            feature = tl.sum((input_row * up_weight_row).to(tl.float32)).to(
                tl.bfloat16)

            gate = tl.load(
                sparse_gate_ptr + row_idx * sparse_gate_stride_m +
                index * sparse_gate_stride_k
            ) * feature

            
            down_weight_row = tl.load(
                down_weight_ptr + index * down_weight_stride_k +
                tl.arange(0, N) * down_weight_stride_n
            )

            output_accumulator += (
                gate * down_weight_row).to(tl.float32)
            
    output_row_ptr = (
        output_ptr + row_idx * output_stride_m +
        tl.arange(0, N) * output_stride_n
    )
    tl.store(output_row_ptr, output_accumulator.to(tl.bfloat16))

def gated_downprojection_bwell_ord(
    sparse_gate,
    indices_by_block,
    nonzeros_by_block,
    priority,
    input,
    up_weight,
    down_weight,
):

    M, K = sparse_gate.shape
    K, N = down_weight.shape
    M, K_blocks = nonzeros_by_block.shape

    block_size = 256

    output = torch.empty(
        (M, N),
        device=sparse_gate.device,
        dtype=torch.bfloat16,
    )

    grid = (M,)

    gated_downprojection_bwell_ord_kernel[grid](
        sparse_gate_ptr=sparse_gate,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        priority_ptr=priority,
        input_ptr=input,
        up_weight_ptr=up_weight,
        down_weight_ptr=down_weight,
        output_ptr=output,
        M=M,
        N=N,
        K_blocks=K_blocks,
        block_size=block_size,
        sparse_gate_stride_m=sparse_gate.stride(0),
        sparse_gate_stride_k=sparse_gate.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        input_stride_m=input.stride(0),
        input_stride_n=input.stride(1),
        up_weight_stride_k=up_weight.stride(0),
        up_weight_stride_n=up_weight.stride(1),
        down_weight_stride_k=down_weight.stride(0),
        down_weight_stride_n=down_weight.stride(1),
        output_stride_m=output.stride(0),
        output_stride_n=output.stride(1),
        num_stages=1,
        num_warps=1,
        
    )
    return output


@triton.jit(launch_metadata=gated_downprojection_proton_launch_metadata)
def gated_downprojection_bwell_ord_subtile_kernel(
    
    sparse_gate_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    
    input_ptr,
    
    up_weight_ptr,
    
    down_weight_ptr,
    
    output_ptr,
    
    priority_ptr,
    
    M: tl.constexpr,
    N: tl.constexpr,
    half_N: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    sparse_gate_stride_m: tl.constexpr,
    sparse_gate_stride_k: tl.constexpr,
    
    indices_stride_m: tl.constexpr,
    indices_stride_k: tl.constexpr,
    input_stride_m: tl.constexpr,
    input_stride_n: tl.constexpr,
    
    up_weight_stride_k: tl.constexpr,
    up_weight_stride_n: tl.constexpr,
    down_weight_stride_k: tl.constexpr,
    down_weight_stride_n: tl.constexpr,
    output_stride_m: tl.constexpr,
    output_stride_n: tl.constexpr,
):
    pid = tl.program_id(0)
    row_idx = tl.load(priority_ptr + pid).to(tl.int32)
    input_row_ptr = (
        input_ptr + row_idx * input_stride_m +
        tl.arange(0, half_N)
    )
    input_row_0 = tl.load(input_row_ptr)
    input_row_1 = tl.load(input_row_ptr + half_N)
    
    acc0 = tl.zeros((half_N,), dtype=tl.float32)
    acc1 = tl.zeros((half_N,), dtype=tl.float32)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + (block_idx * block_size + idx) * indices_stride_k
            ).to(tl.int64) + block_start

            
            
            up_weight_row_0 = tl.load(
                up_weight_ptr + index * up_weight_stride_k +
                tl.arange(0, half_N)
            )

            
            feature = tl.sum((input_row_0 * up_weight_row_0).to(tl.float32)).to(
                tl.bfloat16)
            
            up_weight_row_1 = tl.load(
                up_weight_ptr + index * up_weight_stride_k +
                tl.arange(half_N, N)
            )

            
            feature *= tl.sum((input_row_1 * up_weight_row_1).to(tl.float32)).to(
                tl.bfloat16)

            gate = tl.load(
                sparse_gate_ptr + row_idx * sparse_gate_stride_m +
                index * sparse_gate_stride_k
            ) * feature

            
            down_weight_row_0 = tl.load(
                down_weight_ptr + index * down_weight_stride_k +
                tl.arange(0, half_N) * down_weight_stride_n
            )

            acc0 += (
                gate * down_weight_row_0).to(tl.float32)
            
            down_weight_row_1 = tl.load(
                down_weight_ptr + index * down_weight_stride_k +
                tl.arange(half_N, N) * down_weight_stride_n
            )

            acc1 += (
                gate * down_weight_row_1).to(tl.float32)
    
    
    
    
    c0 = acc0.to(tl.bfloat16)
    output_row_ptr = (
        output_ptr + row_idx * output_stride_m +
        tl.arange(0, half_N)
    )
    tl.store(output_row_ptr, c0)
    c1 = acc1.to(tl.bfloat16)
    tl.store(output_row_ptr + half_N, c1)


def gated_downprojection_bwell_ord_subtile(
    sparse_gate,
    indices_by_block,
    nonzeros_by_block,
    priority,
    input,
    up_weight,
    down_weight,
):

    M, K = sparse_gate.shape
    K, N = down_weight.shape
    M, K_blocks = nonzeros_by_block.shape

    block_size = 256

    output = torch.empty(
        (M, N),
        device=sparse_gate.device,
        dtype=torch.bfloat16,
    )

    grid = (M,)

    gated_downprojection_bwell_ord_subtile_kernel[grid](
        sparse_gate_ptr=sparse_gate,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        priority_ptr=priority,
        input_ptr=input,
        up_weight_ptr=up_weight,
        down_weight_ptr=down_weight,
        output_ptr=output,
        M=M,
        N=N,
        half_N=N//2,
        K_blocks=K_blocks,
        block_size=block_size,
        sparse_gate_stride_m=sparse_gate.stride(0),
        sparse_gate_stride_k=sparse_gate.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        input_stride_m=input.stride(0),
        input_stride_n=input.stride(1),
        up_weight_stride_k=up_weight.stride(0),
        up_weight_stride_n=up_weight.stride(1),
        down_weight_stride_k=down_weight.stride(0),
        down_weight_stride_n=down_weight.stride(1),
        output_stride_m=output.stride(0),
        output_stride_n=output.stride(1),
        num_stages=1,
        num_warps=1,
    )
    return output




@triton.jit(launch_metadata=gated_downprojection_proton_launch_metadata)
def gated_upprojection_bwell_kernel(
    
    sparse_gate_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    
    input_ptr,
    
    up_weight_ptr,
    
    
    output_ptr,
    
    M,
    N: tl.constexpr,
    K: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    sparse_gate_stride_m: tl.constexpr,
    sparse_gate_stride_k: tl.constexpr,
    
    indices_stride_m: tl.constexpr,
    indices_stride_k: tl.constexpr,
    input_stride_m: tl.constexpr,
    input_stride_n: tl.constexpr,
    
    up_weight_stride_k: tl.constexpr,
    up_weight_stride_n: tl.constexpr,
):
    row_idx = tl.program_id(0)
    input_row_ptr = (
        input_ptr + row_idx * input_stride_m +
        tl.arange(0, N) * input_stride_n
    )
    input_row = tl.load(input_row_ptr)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + (block_idx * block_size + idx) * indices_stride_k
            ).to(tl.int64) + block_start

            
            
            up_weight_row = tl.load(
                up_weight_ptr + index * up_weight_stride_k +
                tl.arange(0, N) * up_weight_stride_n
            )

            
            feature = tl.sum((input_row * up_weight_row).to(tl.float32))

            gate = tl.load(
                sparse_gate_ptr + row_idx * sparse_gate_stride_m +
                index * sparse_gate_stride_k
            )
            
            tl.store(
                output_ptr + row_idx * K + index, value=gate * feature
            )


def gated_upprojection_bwell(
    sparse_gate,
    indices_by_block,
    nonzeros_by_block,
    input,
    up_weight,
):
    
    assert sparse_gate.is_contiguous(), 'gate must be contiguous'
    assert indices_by_block.is_contiguous(), 'indices must be contiguous'

    M, K = sparse_gate.shape
    M, K_blocks = nonzeros_by_block.shape
    K, N = up_weight.shape
    assert K % K_blocks == 0, "Inner dimension K must be divisible by K_blocks"
    block_size = K // K_blocks

    output = torch.empty(
        (M, K),
        device=sparse_gate.device,
        dtype=torch.bfloat16,
    )

    grid = (M,)

    gated_upprojection_bwell_kernel[grid](
        sparse_gate_ptr=sparse_gate,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        input_ptr=input,
        up_weight_ptr=up_weight,
        output_ptr=output,
        M=M,
        N=N,
        K=K,
        K_blocks=K_blocks,
        block_size=block_size,
        sparse_gate_stride_m=sparse_gate.stride(0),
        sparse_gate_stride_k=sparse_gate.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        input_stride_m=input.stride(0),
        input_stride_n=input.stride(1),
        up_weight_stride_k=up_weight.stride(0),
        up_weight_stride_n=up_weight.stride(1),
        
        
        num_stages=2,
        num_warps=1,
    )
    return output


@triton.jit(launch_metadata=gated_downprojection_proton_launch_metadata)
def gated_up_ip_bwell_kernel(
    
    sparse_gate_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    
    input_ptr,
    
    up_weight_ptr,
    
    M,
    N: tl.constexpr,
    K: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    sparse_gate_stride_m: tl.constexpr,
    sparse_gate_stride_k: tl.constexpr,
    
    indices_stride_m: tl.constexpr,
    indices_stride_k: tl.constexpr,
    input_stride_m: tl.constexpr,
    input_stride_n: tl.constexpr,
    
    up_weight_stride_k: tl.constexpr,
    up_weight_stride_n: tl.constexpr,
):
    row_idx = tl.program_id(0)
    input_row_ptr = (
        input_ptr + row_idx * input_stride_m +
        tl.arange(0, N) * input_stride_n
    )
    input_row = tl.load(input_row_ptr)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + (block_idx * block_size + idx) * indices_stride_k
            ).to(tl.int64) + block_start

            
            
            up_weight_row = tl.load(
                up_weight_ptr + index * up_weight_stride_k +
                tl.arange(0, N) * up_weight_stride_n
            )

            
            feature = tl.sum((input_row * up_weight_row).to(tl.float32))

            gate = tl.load(
                sparse_gate_ptr + row_idx * sparse_gate_stride_m +
                index * sparse_gate_stride_k
            )
            
            tl.store(
                sparse_gate_ptr + row_idx * sparse_gate_stride_m +
                index * sparse_gate_stride_k, value=gate * feature
            )


def gated_up_ip_bwell(
    sparse_gate,
    indices_by_block,
    nonzeros_by_block,
    input,
    up_weight,
):
    
    assert sparse_gate.is_contiguous(), 'gate must be contiguous'
    assert indices_by_block.is_contiguous(), 'indices must be contiguous'

    M, K = sparse_gate.shape
    M, K_blocks = nonzeros_by_block.shape
    K, N = up_weight.shape
    assert K % K_blocks == 0, "Inner dimension K must be divisible by K_blocks"
    block_size = K // K_blocks

    grid = (M,)

    gated_up_ip_bwell_kernel[grid](
        sparse_gate_ptr=sparse_gate,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        input_ptr=input,
        up_weight_ptr=up_weight,
        M=M,
        N=N,
        K=K,
        K_blocks=K_blocks,
        block_size=block_size,
        sparse_gate_stride_m=sparse_gate.stride(0),
        sparse_gate_stride_k=sparse_gate.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        input_stride_m=input.stride(0),
        input_stride_n=input.stride(1),
        up_weight_stride_k=up_weight.stride(0),
        up_weight_stride_n=up_weight.stride(1),
        
        
        num_stages=2,
        num_warps=1,
    )
    return sparse_gate


@triton.jit(launch_metadata=gated_downprojection_proton_launch_metadata)
def gated_downprojection_bwell_split_kernel(
    
    sparse_gate_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    
    input_ptr,
    
    up_weight_ptr,
    
    down_weight_ptr,
    
    output_ptr,
    
    M,
    N: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    split_size: tl.constexpr,
    sparse_gate_stride_m: tl.constexpr,
    sparse_gate_stride_k: tl.constexpr,
    
    indices_stride_m: tl.constexpr,
    indices_stride_k: tl.constexpr,
    input_stride_m: tl.constexpr,
    input_stride_n: tl.constexpr,
    
    up_weight_stride_k: tl.constexpr,
    up_weight_stride_n: tl.constexpr,
    down_weight_stride_k: tl.constexpr,
    down_weight_stride_n: tl.constexpr,
    output_stride_m: tl.constexpr,
    output_stride_n: tl.constexpr,
):
    row_idx = tl.program_id(0)
    split_start = tl.program_id(1) * split_size
    input_row_ptr = (
        input_ptr + row_idx * input_stride_m +
        tl.arange(0, N) * input_stride_n
    )
    input_row = tl.load(input_row_ptr)
    output_accumulator = tl.zeros((split_size,), dtype=tl.float32)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + (block_idx * block_size + idx) * indices_stride_k
            ).to(tl.int64) + block_start

            
            
            up_weight_row = tl.load(
                up_weight_ptr + index * up_weight_stride_k +
                tl.arange(0, N) * up_weight_stride_n
            )

            
            feature = tl.sum((input_row * up_weight_row).to(tl.float32))

            gate = tl.load(
                sparse_gate_ptr + row_idx * sparse_gate_stride_m +
                index * sparse_gate_stride_k
            )

            
            down_weight_row = tl.load(
                down_weight_ptr + index * down_weight_stride_k +
                (split_start + tl.arange(0, split_size)) * down_weight_stride_n
            )

            output_accumulator += (
                gate * feature * down_weight_row).to(tl.float32)
            
    output_row_ptr = (
        output_ptr + row_idx * output_stride_m +
        (split_start + tl.arange(0, split_size)) * output_stride_n
    )
    tl.store(output_row_ptr, output_accumulator.to(tl.bfloat16))


def gated_downprojection_bwell_split(
    sparse_gate,
    indices_by_block,
    nonzeros_by_block,
    input,
    up_weight,
    down_weight,
    split_size=1024
):
    
    assert sparse_gate.is_contiguous(), 'gate must be contiguous'
    assert indices_by_block.is_contiguous(), 'indices must be contiguous'
    assert down_weight.is_contiguous(), 'down_weight must be contiguous'

    M, K = sparse_gate.shape
    K, N = down_weight.shape
    M, K_blocks = nonzeros_by_block.shape

    assert M == input.shape[0], "Mismatched input shape"
    assert N == input.shape[1], "Mismatched input shape"
    assert K == up_weight.shape[0], "Mismatched transposed up weight shape"
    assert N == up_weight.shape[1], "Mismatched transposed up weight shape"
    assert K % K_blocks == 0, "Inner dimension K must be divisible by K_blocks"
    block_size = K // K_blocks
    n_splits = N // split_size
    assert N % split_size == 0, "N must be divisible by split_size"

    output = torch.empty(
        (M, N),
        device=sparse_gate.device,
        dtype=torch.bfloat16,
    )

    grid = (M, n_splits)

    gated_downprojection_bwell_split_kernel[grid](
        sparse_gate_ptr=sparse_gate,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        input_ptr=input,
        up_weight_ptr=up_weight,
        down_weight_ptr=down_weight,
        output_ptr=output,
        M=M,
        N=N,
        K_blocks=K_blocks,
        block_size=block_size,
        split_size=split_size,
        sparse_gate_stride_m=sparse_gate.stride(0),
        sparse_gate_stride_k=sparse_gate.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        input_stride_m=input.stride(0),
        input_stride_n=input.stride(1),
        up_weight_stride_k=up_weight.stride(0),
        up_weight_stride_n=up_weight.stride(1),
        down_weight_stride_k=down_weight.stride(0),
        down_weight_stride_n=down_weight.stride(1),
        output_stride_m=output.stride(0),
        output_stride_n=output.stride(1),
        num_stages=2,
        num_warps=1,
    )
    return output







@triton.jit(launch_metadata=gated_downprojection_proton_launch_metadata)
def gated_downprojection_bwell_os_kernel(
    
    sparse_gate_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    
    input_ptr,
    
    up_weight_ptr,
    
    down_weight_ptr,
    
    output_ptr,
    
    M,
    N: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    sparse_gate_stride_m,
    sparse_gate_stride_k,
    
    indices_stride_m,
    indices_stride_k,
    input_stride_m,
    input_stride_n,
    
    up_weight_stride_k,
    up_weight_stride_n,
    down_weight_stride_k,
    down_weight_stride_n,
    output_stride_m,
    output_stride_n,
):
    row_idx = tl.program_id(0)
    input_row_ptr = (
        input_ptr + row_idx * input_stride_m +
        tl.arange(0, N) * input_stride_n
    )
    input_row = tl.load(input_row_ptr).to(tl.float32)
    output_accumulator = tl.zeros((N,), dtype=tl.float32)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + (block_idx * block_size + idx) * indices_stride_k
            ).to(tl.int64) + block_start

            
            
            up_weight_row = tl.load(
                up_weight_ptr + index * up_weight_stride_k +
                tl.arange(0, N) * up_weight_stride_n
            ).to(tl.float32)

            
            feature = tl.sum((input_row * up_weight_row).to(tl.float32))

            gate = tl.load(
                sparse_gate_ptr + row_idx * sparse_gate_stride_m +
                index * sparse_gate_stride_k
            ).to(tl.float32)

            
            down_weight_row = tl.load(
                down_weight_ptr + index * down_weight_stride_k +
                tl.arange(0, N) * down_weight_stride_n
            ).to(tl.float32)

            output_accumulator += (
                gate * feature * down_weight_row).to(tl.float32)
            
    output_row_ptr = (
        output_ptr + row_idx * output_stride_m +
        tl.arange(0, N) * output_stride_n
    )
    tl.store(output_row_ptr, output_accumulator.to(tl.bfloat16))


def gated_downprojection_bwell_os(
    sparse_gate,
    indices_by_block,
    nonzeros_by_block,
    input,
    up_weight,
    down_weight,
):
    
    assert sparse_gate.is_contiguous(), 'gate must be contiguous'
    assert indices_by_block.is_contiguous(), 'indices must be contiguous'
    assert down_weight.is_contiguous(), 'down_weight must be contiguous'

    M, K = sparse_gate.shape
    K, N = down_weight.shape
    M, K_blocks = nonzeros_by_block.shape

    assert M == input.shape[0], "Mismatched input shape"
    assert N == input.shape[1], "Mismatched input shape"
    assert K == up_weight.shape[0], "Mismatched transposed up weight shape"
    assert N == up_weight.shape[1], "Mismatched transposed up weight shape"
    assert K % K_blocks == 0, "Inner dimension K must be divisible by K_blocks"
    block_size = K // K_blocks

    output = torch.empty(
        (M, N),
        device=sparse_gate.device,
        dtype=torch.bfloat16,
    )

    grid = (M,)

    gated_downprojection_bwell_os_kernel[grid](
        sparse_gate_ptr=sparse_gate,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        input_ptr=input,
        up_weight_ptr=up_weight,
        down_weight_ptr=down_weight,
        output_ptr=output,
        M=M,
        N=N,
        K_blocks=K_blocks,
        block_size=block_size,
        sparse_gate_stride_m=sparse_gate.stride(0),
        sparse_gate_stride_k=sparse_gate.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        input_stride_m=input.stride(0),
        input_stride_n=input.stride(1),
        up_weight_stride_k=up_weight.stride(0),
        up_weight_stride_n=up_weight.stride(1),
        down_weight_stride_k=down_weight.stride(0),
        down_weight_stride_n=down_weight.stride(1),
        output_stride_m=output.stride(0),
        output_stride_n=output.stride(1),
        num_stages=1,
        num_warps=1,
    )
    return output



@triton.jit()  
def gated_downprojection_dummy_bwell_kernel(
    
    sparse_gate_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    
    input_ptr,
    
    up_weight_ptr,
    
    down_weight_ptr,
    
    output_ptr,
    
    M,
    N: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    sparse_gate_m,
    sparse_gate_k,
    
    indices_stride_m,
    indices_stride_k,
    input_m,
    input_n,
    
    up_weight_stride_k,
    up_weight_stride_n,
    down_weight_stride_k,
    down_weight_stride_n,
    output_stride_m,
    output_stride_n,
):
    row_idx = tl.program_id(0)
    output_accumulator = tl.zeros((N,), dtype=tl.float32)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + (block_idx * block_size + idx) * indices_stride_k
            ).to(tl.int64) + block_start

            gate = tl.load(
                sparse_gate_ptr + row_idx * sparse_gate_m +
                index * sparse_gate_k
            )

            
            down_weight_row = tl.load(
                down_weight_ptr + index * down_weight_stride_k +
                tl.arange(0, N) * down_weight_stride_n
            )

            output_accumulator += (gate * down_weight_row).to(tl.float32)
    output_row_ptr = (
        output_ptr + row_idx * output_stride_m +
        tl.arange(0, N) * output_stride_n
    )
    tl.store(output_row_ptr, output_accumulator.to(tl.bfloat16))


def gated_downprojection_dummy_bwell(
    sparse_gate,
    indices_by_block,
    nonzeros_by_block,
    input,
    up_weight,
    down_weight,
):
    
    assert sparse_gate.is_contiguous(), 'gate must be contiguous'
    assert indices_by_block.is_contiguous(), 'indices must be contiguous'
    assert down_weight.is_contiguous(), 'down_weight must be contiguous'

    M, K = sparse_gate.shape
    K, N = down_weight.shape
    M, K_blocks = nonzeros_by_block.shape

    assert M == input.shape[0], "Mismatched input shape"
    assert N == input.shape[1], "Mismatched input shape"
    assert K == up_weight.shape[0], "Mismatched transposed up weight shape"
    assert N == up_weight.shape[1], "Mismatched transposed up weight shape"
    assert K % K_blocks == 0, "Inner dimension K must be divisible by K_blocks"
    block_size = K // K_blocks

    output = torch.empty(
        (M, N),
        device=sparse_gate.device,
        dtype=torch.bfloat16,
    )

    grid = (M,)

    gated_downprojection_dummy_bwell_kernel[grid](
        sparse_gate_ptr=sparse_gate,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        input_ptr=input,
        up_weight_ptr=up_weight,
        down_weight_ptr=down_weight,
        output_ptr=output,
        M=M,
        N=N,
        K_blocks=K_blocks,
        block_size=block_size,
        sparse_gate_m=sparse_gate.stride(0),
        sparse_gate_k=sparse_gate.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        input_m=input.stride(0),
        input_n=input.stride(1),
        up_weight_stride_k=up_weight.stride(0),
        up_weight_stride_n=up_weight.stride(1),
        down_weight_stride_k=down_weight.stride(0),
        down_weight_stride_n=down_weight.stride(1),
        output_stride_m=output.stride(0),
        output_stride_n=output.stride(1),
        num_stages=1,
        num_warps=1,
    )
    return output


@triton.jit(launch_metadata=gated_downprojection_proton_launch_metadata)
def gated_downprojection_bwell_np2_kernel(
    
    sparse_gate_ptr,
    
    indices_by_block_ptr,
    
    nonzeros_by_block_ptr,
    
    input_ptr,
    
    up_weight_ptr,
    
    down_weight_ptr,
    
    output_ptr,
    
    M: tl.constexpr,
    N: tl.constexpr,
    load_N: tl.constexpr,
    
    K_blocks: tl.constexpr,
    block_size: tl.constexpr,
    sparse_gate_stride_m: tl.constexpr,
    sparse_gate_stride_k: tl.constexpr,
    
    indices_stride_m: tl.constexpr,
    indices_stride_k: tl.constexpr,
    input_stride_m: tl.constexpr,
    input_stride_n: tl.constexpr,
    
    up_weight_stride_k: tl.constexpr,
    up_weight_stride_n: tl.constexpr,
    down_weight_stride_k: tl.constexpr,
    down_weight_stride_n: tl.constexpr,
    output_stride_m: tl.constexpr,
    output_stride_n: tl.constexpr,
):
    row_idx = tl.program_id(0)
    mask = tl.arange(0, load_N) < N
    input_row_ptr = (
        input_ptr + row_idx * input_stride_m +
        tl.arange(0, load_N) * input_stride_n
    )
    input_row = tl.load(input_row_ptr, mask=mask)
    output_accumulator = tl.zeros((load_N,), dtype=tl.float32)
    for block_idx in range(K_blocks):
        num_nonzeros = tl.load(
            nonzeros_by_block_ptr + row_idx * K_blocks + block_idx,
        ).to(tl.int32)
        block_start = block_idx * block_size
        for idx in range(num_nonzeros):
            index = tl.load(
                indices_by_block_ptr + row_idx * indices_stride_m
                + (block_idx * block_size + idx) * indices_stride_k
            ).to(tl.int64) + block_start

            
            
            up_weight_row = tl.load(
                up_weight_ptr + index * up_weight_stride_k +
                tl.arange(0, load_N) * up_weight_stride_n,
                mask=mask,
            )

            
            feature = tl.sum((input_row * up_weight_row).to(tl.float32)).to(
                tl.bfloat16)

            gate = tl.load(
                sparse_gate_ptr + row_idx * sparse_gate_stride_m +
                index * sparse_gate_stride_k
            ) * feature

            
            down_weight_row = tl.load(
                down_weight_ptr + index * down_weight_stride_k +
                tl.arange(0, load_N) * down_weight_stride_n,
                mask=mask,
            )

            output_accumulator += (
                gate * down_weight_row).to(tl.float32)
            
    output_row_ptr = (
        output_ptr + row_idx * output_stride_m +
        tl.arange(0, load_N) * output_stride_n
    )
    tl.store(output_row_ptr, output_accumulator.to(tl.bfloat16), mask=mask)


def gated_downprojection_bwell_np2(
    sparse_gate,
    indices_by_block,
    nonzeros_by_block,
    input,
    up_weight,
    down_weight,
):
    M, K = sparse_gate.shape
    K, N = down_weight.shape
    M, K_blocks = nonzeros_by_block.shape

    assert M == input.shape[0], "Mismatched input shape"
    assert N == input.shape[1], "Mismatched input shape"
    assert K == up_weight.shape[0], "Mismatched transposed up weight shape"
    assert N == up_weight.shape[1], "Mismatched transposed up weight shape"
    assert K % K_blocks == 0, "Inner dimension K must be divisible by K_blocks"
    block_size = 256

    output = torch.empty(
        (M, N),
        device=sparse_gate.device,
        dtype=torch.bfloat16,
    )

    grid = (M,)

    gated_downprojection_bwell_np2_kernel[grid](
        sparse_gate_ptr=sparse_gate,
        indices_by_block_ptr=indices_by_block,
        nonzeros_by_block_ptr=nonzeros_by_block,
        input_ptr=input,
        up_weight_ptr=up_weight,
        down_weight_ptr=down_weight,
        output_ptr=output,
        M=M,
        N=N,
        load_N=triton.next_power_of_2(N),
        K_blocks=K_blocks,
        block_size=block_size,
        sparse_gate_stride_m=sparse_gate.stride(0),
        sparse_gate_stride_k=sparse_gate.stride(1),
        indices_stride_m=indices_by_block.stride(0),
        indices_stride_k=indices_by_block.stride(1),
        input_stride_m=input.stride(0),
        input_stride_n=input.stride(1),
        up_weight_stride_k=up_weight.stride(0),
        up_weight_stride_n=up_weight.stride(1),
        down_weight_stride_k=down_weight.stride(0),
        down_weight_stride_n=down_weight.stride(1),
        output_stride_m=output.stride(0),
        output_stride_n=output.stride(1),
        num_stages=1,
        num_warps=1,
    )
    return output


