import torch
import triton
import triton.language as tl


@triton.jit
def _fused_swiglu_ffn_forward(
    x_ptr,         # [M, K]
    w1_ptr,        # [K, H]
    w3_ptr,        # [K, H]
    w2_ptr,        # [H, D]
    y_ptr,         # [M, D]
    M: tl.constexpr,
    K: tl.constexpr,
    H: tl.constexpr,
    D: tl.constexpr,
    stride_x_m, stride_x_k,
    stride_w1_k, stride_w1_h,
    stride_w3_k, stride_w3_h,
    stride_w2_h, stride_w2_d,
    stride_y_m, stride_y_d,
    BLOCK_M: tl.constexpr,
    BLOCK_K: tl.constexpr,
    BLOCK_H: tl.constexpr,
    BLOCK_D: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_d = tl.program_id(1)

    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D)

    mask_m = offs_m < M
    mask_d = offs_d < D

    acc_y = tl.zeros((BLOCK_M, BLOCK_D), dtype=tl.float32)

    offs_k_init = tl.arange(0, BLOCK_K)
    offs_h_init = tl.arange(0, BLOCK_H)

    for h0 in range(0, H, BLOCK_H):
        hh = h0 + offs_h_init
        mask_h = hh < H
        acc_a = tl.zeros((BLOCK_M, BLOCK_H), dtype=tl.float32)
        acc_b = tl.zeros((BLOCK_M, BLOCK_H), dtype=tl.float32)

        for k0 in range(0, K, BLOCK_K):
            kk = k0 + offs_k_init
            mask_k = kk < K

            x_ptrs = x_ptr + (offs_m[:, None] * stride_x_m + kk[None, :] * stride_x_k)
            w1_ptrs = w1_ptr + (kk[:, None] * stride_w1_k + hh[None, :] * stride_w1_h)
            w3_ptrs = w3_ptr + (kk[:, None] * stride_w3_k + hh[None, :] * stride_w3_h)

            x_tile = tl.load(x_ptrs, mask=(mask_m[:, None] & mask_k[None, :]), other=0.0)
            w1_tile = tl.load(w1_ptrs, mask=(mask_k[:, None] & mask_h[None, :]), other=0.0)
            w3_tile = tl.load(w3_ptrs, mask=(mask_k[:, None] & mask_h[None, :]), other=0.0)

            x_tile = x_tile#.to(tl.bfloat16)
            w1_tile = w1_tile#.to(tl.bfloat16)
            w3_tile = w3_tile#.to(tl.bfloat16)

            acc_a += tl.dot(x_tile, w1_tile, out_dtype=tl.float32)
            acc_b += tl.dot(x_tile, w3_tile, out_dtype=tl.float32)

        # SiLU and gating in fp32
        # silu(x) = x * sigmoid(x)
        silu_a = acc_a * acc_a #tl.sigmoid(acc_a)
        gated = silu_a * acc_b

        w2_ptrs = w2_ptr + (hh[:, None] * stride_w2_h + offs_d[None, :] * stride_w2_d)
        w2_tile = tl.load(w2_ptrs, mask=(mask_h[:, None] & mask_d[None, :]), other=0.0)
        w2_tile = w2_tile#.to(tl.bfloat16)

        acc_y += tl.dot(gated.to(tl.bfloat16), w2_tile, out_dtype=tl.float32)

    # Store result as bf16
    y_tile = acc_y#.to(tl.bfloat16)
    y_ptrs = y_ptr + (offs_m[:, None] * stride_y_m + offs_d[None, :] * stride_y_d)
    tl.store(y_ptrs, y_tile, mask=(mask_m[:, None] & mask_d[None, :]))


def launch_fused_swiglu_ffn_forward(
    x, w1, w3, w2, y,
    BLOCK_M: int = 64,
    BLOCK_K: int = 64,
    BLOCK_H: int = 64,
    BLOCK_D: int = 64,
    num_warps: int = 4,
    num_stages: int = 1,
):
    """
    Launch the fused SwiGLU FFN forward kernel. All tensors should be on the same CUDA device.

    Args:
        x: [M, K] bf16
        w1: [K, H] bf16
        w3: [K, H] bf16
        w2: [H, D] bf16
        y: [M, D] bf16 (output buffer)
    """
    assert x.is_cuda and w1.is_cuda and w3.is_cuda and w2.is_cuda and y.is_cuda
    assert x.dtype == w1.dtype == w3.dtype == w2.dtype == y.dtype
    assert x.dtype == torch.bfloat16

    M, K = x.shape
    K_w1, H = w1.shape
    K_w3, H_w3 = w3.shape
    H_w2, D = w2.shape
    M_y, D_y = y.shape

    assert K_w1 == K and K_w3 == K and H_w3 == H and H_w2 == H and M_y == M and D_y == D

    grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(D, BLOCK_D))

    _fused_swiglu_ffn_forward[
        grid
    ](
        x, w1, w3, w2, y,
        M, K, H, D,
        x.stride(0), x.stride(1),
        w1.stride(0), w1.stride(1),
        w3.stride(0), w3.stride(1),
        w2.stride(0), w2.stride(1),
        y.stride(0), y.stride(1),
        BLOCK_M=BLOCK_M,
        BLOCK_K=BLOCK_K,
        BLOCK_H=BLOCK_H,
        BLOCK_D=BLOCK_D,
        num_warps=num_warps,
        num_stages=num_stages,
    )