import torch
import torch.nn.functional as F
import triton
import triton.language as tl

NUM_WARPS_AUTOTUNE = [1, 2, 4, 8, 16, 32]

triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps) for num_warps in NUM_WARPS_AUTOTUNE
    ],
    key=["D"],
)


@triton.jit
def l2norm_add_fwd_kernel1(
    x,
    x_add,
    y,
    tgt_scale,
    rstd,
    eps,
    D,
    BD: tl.constexpr,
):
    i_t = tl.program_id(0)
    x += i_t * D
    y += i_t * D
    x_add += i_t * D
    # Compute mean and variance
    cols = tl.arange(0, BD)
    mask = cols < D

    tgt_scale_f = tl.load(tgt_scale + i_t).to(tl.float32)

    b_x = tl.load(x + cols, mask=mask, other=0.0)
    orig_dtype = b_x.dtype
    b_x = b_x.to(tl.float32)
    b_x_add = tl.load(x_add + cols, mask=mask, other=0.0).to(tl.float32)
    b_x = b_x + b_x_add
    b_rstd = 1 / tl.sqrt(tl.sum(b_x * b_x) + eps)
    b_y = b_x * b_rstd * tgt_scale_f
    tl.store(
        y + cols, b_y.to(orig_dtype), mask=mask
    )  # save the output to original dtype
    tl.store(rstd + i_t, b_rstd)  # this is float32


@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps) for num_warps in NUM_WARPS_AUTOTUNE
    ],
    key=["D"],
)
@triton.jit
def l2norm_bwd_kernel1(
    y,
    rstd,
    tgt_scale,
    dy,
    dx,
    dtgt_scale,
    eps,
    D,
    BD: tl.constexpr,
):
    i_t = tl.program_id(0)
    y += i_t * D
    dx += i_t * D
    dy += i_t * D

    cols = tl.arange(0, BD)
    mask = cols < D

    b_rstd = tl.load(rstd + i_t).to(tl.float32)
    b_tgt_scale = tl.load(tgt_scale + i_t).to(tl.float32)

    # b_y is the  x / x.norm()
    b_y = tl.load(y + cols, mask=mask, other=0.0)
    original_dtype = b_y.dtype
    b_y = b_y.to(tl.float32) / b_tgt_scale
    b_dy = tl.load(dy + cols, mask=mask, other=0.0).to(tl.float32)
    b_dx = (b_dy * b_rstd - tl.sum(b_dy * b_y) * b_y * b_rstd) * b_tgt_scale

    dtgt_scale_f = tl.sum(b_dy * b_y)

    tl.store(dtgt_scale + i_t, dtgt_scale_f)
    tl.store(dx + cols, b_dx.to(original_dtype), mask=mask)


def l2_norm_add_fwd(
    x: torch.Tensor,  # [B, D1, D2]
    x_add: torch.Tensor,  # [B, D1, D2]
    tgt_scale: torch.Tensor,  # [B, D1, 1]
    eps: float = 1e-5,
):
    y = torch.empty_like(x)
    rstd = torch.empty_like(tgt_scale, dtype=torch.float32)

    B, nD, D = x.shape

    block_size_d = triton.next_power_of_2(D)

    N = B * nD
    grid = (N,)
    l2norm_add_fwd_kernel1[grid](x, x_add, y, tgt_scale, rstd, eps, D, block_size_d)
    return y, rstd


def l2_norm_bwd(
    dy: torch.Tensor,  # [B, D1, D2]
    y: torch.Tensor,  # [B, D1, D2]
    tgt_scale: torch.Tensor,  # [B, D1, 1]
    rstd: torch.Tensor,  # [B, D1]
    eps: float = 1e-5,
):
    B, nD, D = y.shape

    block_size_d = triton.next_power_of_2(D)

    dx = torch.empty_like(dy)
    dtgt_scale = torch.empty_like(tgt_scale)

    N = B * nD
    grid = (N,)
    l2norm_bwd_kernel1[grid](
        y, rstd, tgt_scale, dy, dx, dtgt_scale, eps, D, block_size_d
    )
    return dx, dtgt_scale


class L2NormAddFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, x_add, tgt_scale, eps=1e-5):
        ctx.eps = eps
        ctx.dtype = x.dtype

        y, rstd = l2_norm_add_fwd(x, x_add, tgt_scale, eps)
        ctx.save_for_backward(y, rstd, tgt_scale)
        return y

    @staticmethod
    def backward(ctx, dy):
        y, rstd, tgt_scale = ctx.saved_tensors
        dx, dtgt_scale = l2_norm_bwd(dy, y, tgt_scale, rstd, ctx.eps)

        return dx, dx, dtgt_scale, None


l2_norm_add_fused = L2NormAddFunction.apply


@triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps) for num_warps in NUM_WARPS_AUTOTUNE
    ],
    key=["D"],
)
@triton.jit
def swiglu_bwd_kernel_simple_v2(
    dh_ptr,  # [b, l, d]
    y0_y2_ptr,  # [b, l, d * 2]
    lr0_ptr,  # [b, l, 1]
    lr1_ptr,  # [b, l, 1]
    lr2_ptr,
    out0_out2_ptr,  # [b, l, d * 2]
    out_hidden_lr1_ptr,
    block_size_d: tl.constexpr,
    D: tl.constexpr,  # n_columns
):
    pid = tl.program_id(0)

    offsets_d = tl.arange(0, block_size_d)
    mask_d = offsets_d < D

    dh = tl.load(dh_ptr + pid * D + offsets_d, mask=mask_d, other=0.0)
    orig_dtype = dh.dtype
    dhf = dh.to(tl.float32)
    y0f = tl.load(y0_y2_ptr + pid * D * 2 + offsets_d, mask=mask_d, other=0.0).to(
        tl.float32
    )
    y2f = tl.load(y0_y2_ptr + pid * D * 2 + D + offsets_d, mask=mask_d, other=0.0).to(
        tl.float32
    )

    lr0f = tl.load(lr0_ptr + pid).to(tl.float32)  # scalar per (b,l)
    lr1f = tl.load(lr1_ptr + pid).to(tl.float32)  # scalar per (b,l)
    lr2f = tl.load(lr2_ptr + pid).to(tl.float32)  # scalar per (b,l)

    # Compute in fp32 for stability; cast back at the end

    sig = 1.0 / (1.0 + tl.exp(-y0f))
    sig_x0_times_dh = sig * dhf

    # out2: lr2 * dh * silu(w0_x) = lr2 * (sig * dh * y0)
    out2 = lr2f * sig_x0_times_dh * y0f

    # out0: lr0 * (dh * w2_x) * sigmoid(w0_x) * (1 + w0_x * (1 - sigmoid(w0_x)))
    out0 = lr0f * y2f * sig_x0_times_dh * (1.0 + y0f * (1.0 - sig))

    out_hidden_lr1 = lr1f * y2f * sig * y0f

    # Store (cast back to input dtype of dh)
    tl.store(out0_out2_ptr + pid * D * 2 + offsets_d, out0.to(orig_dtype), mask=mask_d)
    tl.store(
        out0_out2_ptr + pid * D * 2 + D + offsets_d, out2.to(orig_dtype), mask=mask_d
    )
    tl.store(
        out_hidden_lr1_ptr + pid * D + offsets_d,
        out_hidden_lr1.to(orig_dtype),
        mask=mask_d,
    )


def swiglu_bwd_fused_simple_v2(dh, y0_y2, lr0, lr1, lr2):
    """
    Args:
        dh:   [B, L, D]
        y0_y2: [B, L, D * 2]
        lr0:  [B, L, 1]
        lr1:  [B, L, 1]
        lr2:  [B, L, 1]
    Returns:
        out0, out2: [B, L, D]
    """
    assert dh.is_contiguous(), "dh must be contiguous"
    assert y0_y2.is_contiguous()
    assert lr0.is_contiguous() and lr2.is_contiguous()

    B, L, D = dh.shape
    out0_out2 = dh.new_empty((B, L, D * 2))
    out_hidden_lr1 = torch.empty_like(dh)

    N = B * L
    block_size_d = triton.next_power_of_2(D)

    swiglu_bwd_kernel_simple_v2[(N,)](
        dh,
        y0_y2,
        lr0,
        lr1,
        lr2,
        out0_out2,
        out_hidden_lr1,
        block_size_d=block_size_d,
        D=D,
    )

    return out0_out2, out_hidden_lr1


triton.autotune(
    configs=[
        triton.Config({}, num_warps=num_warps) for num_warps in NUM_WARPS_AUTOTUNE
    ],
    key=["D"],
)


@triton.jit
def swiglu_bwd_bwd_kernel(
    dh_ptr,  # [b, l, d]
    x0_x2_ptr,  # [b, l, d * 2]
    lr0_ptr,  # [b, l, 1]
    lr1_ptr,  # [b, l, 1]
    lr2_ptr,
    grad_dx0_dx2_ptr,  # [b, l, d]
    grad_hidden_lr1_ptr,  # [b, l, d]
    grad_dh_ptr,  # computed gradients as part of the output of this kernel
    grad_x0_x2_ptr,
    grad_lr0_ptr,
    grad_lr1_ptr,
    grad_lr2_ptr,
    block_size_d: tl.constexpr,
    D: tl.constexpr,  # n_columns
):
    """
    In previous fwd pass:
    dx0 = lr0 * dh * x2 * sigma * (1 + x0 * (1 - sigma))
    dx2 = lr2 * dh * silu(x0)
    hidden_lr1 = lr1 * x2 * silu(x0)

    In this backward pass:
    grad_dh = grad_dx0 * lr0 * x2 * sigma * (1 + x0 * (1 - sigma)) + grad_dx2 * lr2 * silu(x0)

    grad_x2 = grad_dx0 * lr0 * dh * sigma * (1 + x0 * (1 - sigma)) + grad_hidden_lr1 * lr1 * sigma * x0
    # for grad_x0, a little bit tricky,
    - grad_sigma = grad_dx0 * lr0 * dh * x2 * (1 + x0 - 2 sigma * x0)
    - grad_x0_naive  = grad_dx2 * lr2 * dh * sigma * (1 + x0 * (1 - sigma)) +  grad_dx0 * lr0 * dh * x2 * sigma * (1 - sigma) + grad_hidden_lr1 * lr1 * x2 * dsilu_x0_multiplier
    grad_x0 = grad_x0_naive + grad_sigma * sigma * (1 - sigma)

    # then sum of the last dimension (the d dimension!)
    grad_lr0 = grad_dx0 * dh * x2 * sigma * (1 + x0 * (1 - sigma)) # need to sum over all the d of the same l
    grad_lr2 = grad_dx2 * dh * silu(x0)
    grad_lr1 = grad_hidden_lr1 * x2 * sigma * x0

    """
    pid = tl.program_id(0)
    offsets_d = tl.arange(0, block_size_d)
    mask_d = offsets_d < D

    dh = tl.load(dh_ptr + pid * D + offsets_d, mask=mask_d, other=0.0)
    dh_orig_dtype = dh.dtype
    dh = dh.to(tl.float32)
    x0 = tl.load(x0_x2_ptr + pid * D * 2 + offsets_d, mask=mask_d, other=0.0).to(
        tl.float32
    )
    x2 = tl.load(x0_x2_ptr + pid * D * 2 + D + offsets_d, mask=mask_d, other=0.0).to(
        tl.float32
    )
    lr0 = tl.load(lr0_ptr + pid).to(tl.float32)
    lr1 = tl.load(lr1_ptr + pid).to(tl.float32)
    lr2 = tl.load(lr2_ptr + pid).to(tl.float32)
    grad_dx0 = tl.load(
        grad_dx0_dx2_ptr + pid * D * 2 + offsets_d, mask=mask_d, other=0.0
    ).to(tl.float32)
    grad_dx2 = tl.load(
        grad_dx0_dx2_ptr + pid * D * 2 + D + offsets_d, mask=mask_d, other=0.0
    ).to(tl.float32)
    grad_hidden_lr1 = tl.load(
        grad_hidden_lr1_ptr + pid * D + offsets_d, mask=mask_d, other=0.0
    ).to(tl.float32)

    sigma = 1.0 / (1.0 + tl.math.exp(-x0))
    silu_bp_multiplier = sigma * (1 + x0 * (1 - sigma))
    silu_x0 = x0 * sigma

    grad_dh = grad_dx0 * lr0 * x2 * silu_bp_multiplier + grad_dx2 * lr2 * silu_x0
    tl.store(grad_dh_ptr + pid * D + offsets_d, grad_dh.to(dh_orig_dtype), mask=mask_d)

    grad_x2 = grad_dx0 * lr0 * dh * silu_bp_multiplier + grad_hidden_lr1 * lr1 * silu_x0
    tl.store(
        grad_x0_x2_ptr + pid * D * 2 + D + offsets_d,
        grad_x2.to(dh_orig_dtype),
        mask=mask_d,
    )

    grad_sigma = grad_dx0 * lr0 * dh * x2 * (1 + x0 - 2 * sigma * x0)
    grad_x0_naive = (
        grad_dx2 * lr2 * dh + grad_hidden_lr1 * lr1 * x2
    ) * silu_bp_multiplier + grad_dx0 * lr0 * dh * x2 * sigma * (1 - sigma)
    grad_x0 = grad_x0_naive + grad_sigma * sigma * (1 - sigma)
    grad_lr0 = tl.sum(grad_dx0 * dh * x2 * silu_bp_multiplier)
    grad_lr1 = tl.sum(grad_hidden_lr1 * x2 * silu_x0)
    grad_lr2 = tl.sum(grad_dx2 * dh * silu_x0)

    # write back
    tl.store(
        grad_x0_x2_ptr + pid * D * 2 + offsets_d, grad_x0.to(dh_orig_dtype), mask=mask_d
    )
    tl.store(grad_lr0_ptr + pid, grad_lr0.to(dh_orig_dtype))
    tl.store(grad_lr1_ptr + pid, grad_lr1.to(dh_orig_dtype))
    tl.store(grad_lr2_ptr + pid, grad_lr2.to(dh_orig_dtype))


def swiglu_bwd_bwd_fused(
    dh: torch.Tensor,  # [b, l, d]
    x0_x2: torch.Tensor,  # [b, l, d * 2]
    lr0: torch.Tensor,
    lr1: torch.Tensor,  # [b, l, 1]
    lr2: torch.Tensor,
    grad_dx0_dx2: torch.Tensor,  # [b, l, d * 2]
    grad_hidden_lr1: torch.Tensor,  # [b, l, d]
):
    B, L, H = dh.shape
    grad_dh = torch.empty_like(dh)
    grad_x0_x2 = torch.empty_like(x0_x2)
    grad_lr0 = torch.empty_like(lr0)
    grad_lr1 = torch.empty_like(lr1)
    grad_lr2 = torch.empty_like(lr2)

    B, L, D = dh.shape
    N = B * L
    block_size_d = triton.next_power_of_2(D)

    swiglu_bwd_bwd_kernel[(N,)](
        dh,
        x0_x2,
        lr0,
        lr1,
        lr2,
        grad_dx0_dx2,
        grad_hidden_lr1,
        grad_dh,  # output begins
        grad_x0_x2,
        grad_lr0,
        grad_lr1,
        grad_lr2,
        block_size_d=block_size_d,
        D=D,
    )

    return grad_dh, grad_x0_x2, grad_lr0, grad_lr1, grad_lr2


class FusedSwigluBwdInTTTFunction(torch.autograd.Function):
    @staticmethod
    @torch.compile
    def forward(ctx, dh, y0_y2, lr0, lr1, lr2, k, v):
        """
        dh: [b, l, d]
        w0_x: [b, l, d]
        w2_x: [b, l, d]
        lr0: [b, l, 1]
        lr1: [b, l, 1]
        lr2: [b, l, 1]
        k: [b, l, d]
        v: [b, l, d]
        """
        out0_out2, out_hidden_lr1 = swiglu_bwd_fused_simple_v2(dh, y0_y2, lr0, lr1, lr2)

        # [b, hd * 2, l] @ [b, l, dk] -> [b, dh * 2, dk]
        dw0_dw2 = torch.bmm(out0_out2.transpose(1, 2), k)
        # out0_out2 = torch.concat([out0, out2], dim=-1).transpose(1, 2)  # [b, 2 * d, l]
        # dw0_dw2 = torch.bmm(out0_out2, k)
        # dw0, dw2 = dw0_dw2.chunk(2, dim=1)

        # [b, d, l] @ [b, l, d] -> [b, d, d]
        dw1 = torch.bmm(v.transpose(1, 2), out_hidden_lr1)  # [b, d, d]

        ctx.save_for_backward(dh, y0_y2, lr0, lr1, lr2, k, v)
        return dw0_dw2, dw1

    @staticmethod
    @torch.compile
    def backward(ctx, grad_dw0_dw2, grad_dw1):
        """
        Args:
            grad_dw0_dw2: [b, dh * 2, dk]
            grad_dw1: [b, dv, dh]
        """
        dh, y0_y2, lr0, lr1, lr2, k, v = ctx.saved_tensors

        out0_out2, out_hidden_lr1 = swiglu_bwd_fused_simple_v2(dh, y0_y2, lr0, lr1, lr2)

        # how to group the four matmuls below?

        # grad k is like grouped gemm, can we fuse these two matmul?
        # [b, l, dh * 2] @ [b, dh * 2, dk] -> [b, l, dk]
        grad_k = torch.bmm(out0_out2, grad_dw0_dw2)

        # can we fuse these two matmul?
        # [b, l, dh] @ [b, dh, dv] -> [b, l, dv]
        grad_v = torch.bmm(out_hidden_lr1, grad_dw1.transpose(1, 2))
        # [b, l, dv] @ [b, dv, dh] -> [b, l, dh]
        grad_out_hidden_lr1 = torch.bmm(v, grad_dw1)

        # can we fuse below tow mamtul?
        # [b, l, dk] @ [b, dk, dh * 2] -> [b, l, dh * 2]
        grad_out0_out2 = torch.bmm(k, grad_dw0_dw2.transpose(1, 2))

        grad_dh, grad_x0_x2, grad_lr0, grad_lr1, grad_lr2 = swiglu_bwd_bwd_fused(
            dh, y0_y2, lr0, lr1, lr2, grad_out0_out2, grad_out_hidden_lr1
        )

        return (
            grad_dh,
            grad_x0_x2,
            grad_lr0,
            grad_lr1,
            grad_lr2,
            grad_k,
            grad_v,
        )


fused_swiglu_bwd = FusedSwigluBwdInTTTFunction.apply
