import torch
import triton
import triton.language as tl


def torch_merge_output_fwd(
    sq: torch.Tensor,
    rq: torch.Tensor,
    alpha: torch.Tensor,
    gamma: torch.Tensor,
    scale: float,
) -> torch.Tensor:
    o = alpha.exp()[..., None] * sq + gamma[..., None] * rq
    o = (o * scale).to(sq.dtype)
    return o


@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_N": BN, "BLOCK_SIZE_D": BD}, num_warps=nw, num_stages=ns
        )
        for BN in [32, 64]
        for BD in [32, 64]
        for nw in [2, 4]
        for ns in [2, 3]
    ],
    key=["head_dim"],
)
@triton.jit
def _merge_fwd_kernel(
    sq_ptr,
    rq_ptr,
    o_ptr,
    alpha_ptr,
    gamma_ptr,
    scale,
    # shapes
    seq_len,
    num_heads,
    head_dim,
    # strides
    stride_sq_b,
    stride_sq_n,
    stride_sq_h,
    stride_sq_d,
    stride_rq_b,
    stride_rq_n,
    stride_rq_h,
    stride_rq_d,
    stride_o_b,
    stride_o_n,
    stride_o_h,
    stride_o_d,
    stride_a_b,
    stride_a_n,
    stride_a_h,
    stride_g_b,
    stride_g_n,
    stride_g_h,
    # block sizes
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_D: tl.constexpr,
):
    pid_bh, pid_n, pid_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # init_ptrs
    sq_ptrs = tl.make_block_ptr(
        base=sq_ptr + pid_b * stride_sq_b + pid_h * stride_sq_h,
        shape=(seq_len, head_dim),
        strides=(stride_sq_n, stride_sq_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    rq_ptrs = tl.make_block_ptr(
        base=rq_ptr + pid_b * stride_rq_b + pid_h * stride_rq_h,
        shape=(seq_len, head_dim),
        strides=(stride_rq_n, stride_rq_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    alpha_ptrs = tl.make_block_ptr(
        base=alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h,
        shape=(seq_len, 1),
        strides=(stride_a_n, 0),
        offsets=(pid_n * BLOCK_SIZE_N, 0),
        block_shape=(BLOCK_SIZE_N, 1),
        order=(0, 1),
    )
    gamma_ptrs = tl.make_block_ptr(
        base=gamma_ptr + pid_b * stride_g_b + pid_h * stride_g_h,
        shape=(seq_len, 1),
        strides=(stride_g_n, 0),
        offsets=(pid_n * BLOCK_SIZE_N, 0),
        block_shape=(BLOCK_SIZE_N, 1),
        order=(0, 1),
    )
    # load and merge
    sq = tl.load(sq_ptrs, boundary_check=(0, 1), padding_option="zero")
    rq = tl.load(rq_ptrs, boundary_check=(0, 1), padding_option="zero")
    alpha = tl.load(alpha_ptrs, boundary_check=(0, 1), padding_option="zero")
    gamma = tl.load(gamma_ptrs, boundary_check=(0, 1), padding_option="zero")
    o = sq * tl.exp(alpha) + rq * gamma
    o = o * scale
    # save output
    o_ptrs = tl.make_block_ptr(
        base=o_ptr + pid_b * stride_o_b + pid_h * stride_o_h,
        shape=(seq_len, head_dim),
        strides=(stride_o_n, stride_o_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    tl.store(o_ptrs, o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1))


def merge_output_fwd(
    sq: torch.Tensor,
    rq: torch.Tensor,
    alpha: torch.Tensor,
    gamma: torch.Tensor,
    scale: float,
):
    batch_size, seq_len, num_heads, head_dim = sq.shape
    assert rq.shape == sq.shape
    o = torch.empty_like(sq)

    def grid(meta):
        return (
            batch_size * num_heads,
            triton.cdiv(seq_len, meta["BLOCK_SIZE_N"]),
            triton.cdiv(head_dim, meta["BLOCK_SIZE_D"]),
        )

    _merge_fwd_kernel[grid](
        sq,
        rq,
        o,
        alpha,
        gamma,
        scale,
        seq_len,
        num_heads,
        head_dim,
        sq.stride(0),
        sq.stride(1),
        sq.stride(2),
        sq.stride(3),
        rq.stride(0),
        rq.stride(1),
        rq.stride(2),
        rq.stride(3),
        o.stride(0),
        o.stride(1),
        o.stride(2),
        o.stride(3),
        alpha.stride(0),
        alpha.stride(1),
        alpha.stride(2),
        gamma.stride(0),
        gamma.stride(1),
        gamma.stride(2),
    )
    return o


def torch_merge_output_bwd(
    do: torch.Tensor,
    sq: torch.Tensor,
    rq: torch.Tensor,
    alpha: torch.Tensor,
    gamma: torch.Tensor,
    scale: float,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    # --- Gradient with respect to sq (dsq) ---
    # The derivative of the output 'o' with respect to 'sq' is (exp(alpha) * scale).
    # dL/dsq = dL/do * do/dsq
    # The unsqueeze adds the 'head_dim' for broadcasting.
    dsq = do * torch.exp(alpha).unsqueeze(-1) * scale

    # --- Gradient with respect to rq (drq) ---
    # The derivative of 'o' with respect to 'rq' is (gamma * scale).
    # dL/drq = dL/do * do/drq
    drq = do * gamma.unsqueeze(-1) * scale

    # --- Gradient with respect to alpha (dalpha) ---
    # The chain rule gives: dL/dalpha = (dL/do * do/dalpha).
    # The derivative of 'o' w.r.t alpha is (exp(alpha) * sq * scale).
    # Since alpha was broadcasted along the last dimension, we must sum the
    # gradients along that dimension.
    dalpha = (do * sq * scale).sum(dim=-1) * torch.exp(alpha)

    # --- Gradient with respect to gamma (dgamma) ---
    # The derivative of 'o' w.r.t gamma is (rq * scale).
    # Similar to alpha, we sum along the last dimension to account for broadcasting.
    dgamma = (do * rq * scale).sum(dim=-1)

    return dsq, drq, dalpha, dgamma


@triton.heuristics(
    {
        "BLOCK_SIZE_D": lambda args: triton.next_power_of_2(args["head_dim"]),
    }
)
@triton.autotune(
    configs=[
        triton.Config({"BLOCK_SIZE_N": BN}, num_warps=nw, num_stages=ns)
        for BN in [32, 64]
        for nw in [2, 4]
        for ns in [2, 3]
    ],
    key=["head_dim"],
)
@triton.jit
def _merge_bwd_kernel(
    do_ptr,
    sq_ptr,
    rq_ptr,
    alpha_ptr,
    gamma_ptr,
    dsq_ptr,
    drq_ptr,
    dalpha_ptr,
    dgamma_ptr,
    scale,
    # shapes
    seq_len,
    num_heads,
    head_dim,
    # strides
    stride_do_b,
    stride_do_n,
    stride_do_h,
    stride_do_d,
    stride_sq_b,
    stride_sq_n,
    stride_sq_h,
    stride_sq_d,
    stride_rq_b,
    stride_rq_n,
    stride_rq_h,
    stride_rq_d,
    stride_a_b,
    stride_a_n,
    stride_a_h,
    stride_g_b,
    stride_g_n,
    stride_g_h,
    stride_dsq_b,
    stride_dsq_n,
    stride_dsq_h,
    stride_dsq_d,
    stride_drq_b,
    stride_drq_n,
    stride_drq_h,
    stride_drq_d,
    stride_da_b,
    stride_da_n,
    stride_da_h,
    stride_dg_b,
    stride_dg_n,
    stride_dg_h,
    # block sizes
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_D: tl.constexpr,
):
    pid_bh, pid_n, pid_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # init_ptrs
    do_ptrs = tl.make_block_ptr(
        base=do_ptr + pid_b * stride_do_b + pid_h * stride_do_h,
        shape=(seq_len, head_dim),
        strides=(stride_do_n, stride_do_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    sq_ptrs = tl.make_block_ptr(
        base=sq_ptr + pid_b * stride_sq_b + pid_h * stride_sq_h,
        shape=(seq_len, head_dim),
        strides=(stride_sq_n, stride_sq_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    rq_ptrs = tl.make_block_ptr(
        base=rq_ptr + pid_b * stride_rq_b + pid_h * stride_rq_h,
        shape=(seq_len, head_dim),
        strides=(stride_rq_n, stride_rq_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    alpha_ptrs = tl.make_block_ptr(
        base=alpha_ptr + pid_b * stride_a_b + pid_h * stride_a_h,
        shape=(seq_len, 1),
        strides=(stride_a_n, 0),
        offsets=(pid_n * BLOCK_SIZE_N, 0),
        block_shape=(BLOCK_SIZE_N, 1),
        order=(0, 1),
    )
    gamma_ptrs = tl.make_block_ptr(
        base=gamma_ptr + pid_b * stride_g_b + pid_h * stride_g_h,
        shape=(seq_len, 1),
        strides=(stride_g_n, 0),
        offsets=(pid_n * BLOCK_SIZE_N, 0),
        block_shape=(BLOCK_SIZE_N, 1),
        order=(0, 1),
    )
    # load data
    do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero").to(tl.float32)
    sq = tl.load(sq_ptrs, boundary_check=(0, 1), padding_option="zero")
    rq = tl.load(rq_ptrs, boundary_check=(0, 1), padding_option="zero")
    alpha = tl.load(alpha_ptrs, boundary_check=(0, 1), padding_option="zero")
    gamma = tl.load(gamma_ptrs, boundary_check=(0, 1), padding_option="zero")
    # compute gradients
    alpha = tl.exp(alpha)
    dsq = do * alpha * scale
    drq = do * gamma * scale
    dalpha = tl.sum(do * sq, axis=1)[:, None] * alpha * scale
    dgamma = tl.sum(do * rq, axis=1)[:, None] * scale
    # save gradients
    dsq_ptrs = tl.make_block_ptr(
        base=dsq_ptr + pid_b * stride_dsq_b + pid_h * stride_dsq_h,
        shape=(seq_len, head_dim),
        strides=(stride_dsq_n, stride_dsq_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    drq_ptrs = tl.make_block_ptr(
        base=drq_ptr + pid_b * stride_drq_b + pid_h * stride_drq_h,
        shape=(seq_len, head_dim),
        strides=(stride_drq_n, stride_drq_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    dalpha_ptrs = tl.make_block_ptr(
        base=dalpha_ptr + pid_b * stride_da_b + pid_h * stride_da_h,
        shape=(seq_len, 1),
        strides=(stride_da_n, 0),
        offsets=(pid_n * BLOCK_SIZE_N, 0),
        block_shape=(BLOCK_SIZE_N, 1),
        order=(0, 1),
    )
    dgamma_ptrs = tl.make_block_ptr(
        base=dgamma_ptr + pid_b * stride_dg_b + pid_h * stride_dg_h,
        shape=(seq_len, 1),
        strides=(stride_dg_n, 0),
        offsets=(pid_n * BLOCK_SIZE_N, 0),
        block_shape=(BLOCK_SIZE_N, 1),
        order=(0, 1),
    )
    # save gradients
    tl.store(dsq_ptrs, dsq.to(dsq_ptr.dtype.element_ty), boundary_check=(0, 1))
    tl.store(drq_ptrs, drq.to(drq_ptr.dtype.element_ty), boundary_check=(0, 1))
    tl.store(dalpha_ptrs, dalpha.to(dalpha_ptr.dtype.element_ty), boundary_check=(0, 1))
    tl.store(dgamma_ptrs, dgamma.to(dgamma_ptr.dtype.element_ty), boundary_check=(0, 1))


def merge_output_bwd(
    do: torch.Tensor,
    sq: torch.Tensor,
    rq: torch.Tensor,
    alpha: torch.Tensor,
    gamma: torch.Tensor,
    scale: float,
):
    batch_size, seq_len, num_heads, head_dim = sq.shape
    dsq = torch.empty_like(sq)
    drq = torch.empty_like(rq)
    dalpha = torch.empty_like(alpha)
    dgamma = torch.empty_like(gamma)

    def grid(meta):
        return (
            batch_size * num_heads,
            triton.cdiv(seq_len, meta["BLOCK_SIZE_N"]),
            triton.cdiv(head_dim, meta["BLOCK_SIZE_D"]),
        )

    _merge_bwd_kernel[grid](
        do,
        sq,
        rq,
        alpha,
        gamma,
        dsq,
        drq,
        dalpha,
        dgamma,
        scale,
        seq_len,
        num_heads,
        head_dim,
        do.stride(0),
        do.stride(1),
        do.stride(2),
        do.stride(3),
        sq.stride(0),
        sq.stride(1),
        sq.stride(2),
        sq.stride(3),
        rq.stride(0),
        rq.stride(1),
        rq.stride(2),
        rq.stride(3),
        alpha.stride(0),
        alpha.stride(1),
        alpha.stride(2),
        gamma.stride(0),
        gamma.stride(1),
        gamma.stride(2),
        dsq.stride(0),
        dsq.stride(1),
        dsq.stride(2),
        dsq.stride(3),
        drq.stride(0),
        drq.stride(1),
        drq.stride(2),
        drq.stride(3),
        dalpha.stride(0),
        dalpha.stride(1),
        dalpha.stride(2),
        dgamma.stride(0),
        dgamma.stride(1),
        dgamma.stride(2),
    )
    return dsq, drq, dalpha.to(alpha.dtype), dgamma.to(gamma.dtype)
