import torch
import triton
import triton.language as tl
from torch.nn import functional as F

@triton.jit
def _single2scatter(
    X_ptr, stride_xm, stride_xk,
    W_ptr, stride_we, stride_wk, stride_wn,
    Y_ptr, stride_ym, stride_yn,
    expert_idxs_ptr,
    FAN_OUT: tl.constexpr,
    K: tl.constexpr, N: tl.constexpr, E: tl.constexpr,
    BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
    ACC_TYPE: tl.constexpr,
):
    pid0 = tl.program_id(axis=0)
    pid1 = tl.program_id(axis=1)

    N_block_id = pid0
    if FAN_OUT == 1:
        in_idx = pid1
    else:
        in_idx = 0
    out_idx = pid1

    K_block = tl.arange(0, BLOCK_K)
    N_block = tl.max_contiguous(tl.multiple_of((N_block_id * BLOCK_N + tl.arange(0, BLOCK_N)) % N, BLOCK_N), BLOCK_N)
    E_idx = tl.load(expert_idxs_ptr + pid1)
    X_blk_ptrs = X_ptr + in_idx * stride_xm + K_block[:, None] * stride_xk
    W_blk_ptrs = W_ptr + E_idx * stride_we + K_block[:, None] * stride_wk + N_block[None, :] * stride_wn
    acc = tl.zeros((1, BLOCK_N), dtype=ACC_TYPE)
    for K_block_id in range(0, tl.cdiv(K, BLOCK_K)):
        x = tl.load(X_blk_ptrs)
        w = tl.load(W_blk_ptrs)
        acc += tl.sum(x * w, axis=0)[None, :]
        X_blk_ptrs += BLOCK_K * stride_xk
        W_blk_ptrs += BLOCK_K * stride_wk
    Y_blk_ptrs = Y_ptr + out_idx * stride_ym + N_block[None, :] * stride_yn
    tl.store(Y_blk_ptrs, acc)

def single2scatter(X, W, expert_idxs):
    E, xdim, ydim = W.size()
    k = expert_idxs.size(1)
    assert X.size(0) == k or X.size(0) == 1
    Y = torch.empty((k, ydim), device=X.device, dtype=X.dtype)
    BLOCK_N = 128
    BLOCK_K = 128
    grid = ydim // BLOCK_N, k
    _single2scatter[grid](
        X, X.stride(0), X.stride(1),
        W, W.stride(0), W.stride(1), W.stride(2),
        Y, Y.stride(0), Y.stride(1),
        expert_idxs,
        FAN_OUT=Y.size(0) // X.size(0),
        K=xdim, N=ydim, E=E,
        BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K,
        ACC_TYPE=tl.float32
    )
    return Y
