from typing import Optional, Tuple
import torch
import triton
import triton.language as tl


def torch_clip_fwd(v, sk, min_val=-1, max_val=1):
    return torch.clamp(v - sk, min=min_val, max=max_val)


def torch_clip_bwd(do, v, sk, min_val=-1, max_val=1):
    o = v - sk
    mask = (o >= min_val) & (o <= max_val)
    dv = do * mask
    dsk = -dv
    return dv, dsk


@triton.heuristics(
    {
        "RETURN_MASK": lambda args: args["mask_ptr"] is not None,
    }
)
@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_N": BN, "BLOCK_SIZE_D": BD}, num_warps=nw, num_stages=ns
        )
        for BN in [32, 64]
        for BD in [64, 128]
        for nw in [2, 4, 8]
        for ns in [2, 3]
    ],
    key=["RETURN_MASK", "head_dim"],
)
@triton.jit
def _clip_kernel_fwd(
    v_ptr,
    sk_ptr,
    o_ptr,
    mask_ptr,
    min_val,
    max_val,
    seq_len,
    num_heads,
    head_dim,
    # stride
    stride_v_b,
    stride_v_n,
    stride_v_h,
    stride_v_d,
    stride_sk_b,
    stride_sk_n,
    stride_sk_h,
    stride_sk_d,
    stride_o_b,
    stride_o_n,
    stride_o_h,
    stride_o_d,
    stride_m_b,
    stride_m_n,
    stride_m_h,
    stride_m_d,
    # block sizes
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_D: tl.constexpr,
    RETURN_MASK: tl.constexpr,
):
    pid_bh, pid_n, pid_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # move ptr to the start of this batch
    v_ptr = v_ptr + pid_b * stride_v_b + pid_h * stride_v_h
    sk_ptr = sk_ptr + pid_b * stride_sk_b + pid_h * stride_sk_h
    o_ptr = o_ptr + pid_b * stride_o_b + pid_h * stride_o_h
    if RETURN_MASK:
        mask_ptr = mask_ptr + pid_b * stride_m_b + pid_h * stride_m_h
    # ptrs
    v_ptrs = tl.make_block_ptr(
        base=v_ptr,
        shape=(seq_len, head_dim),
        strides=(stride_v_n, stride_v_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    sk_ptrs = tl.make_block_ptr(
        base=sk_ptr,
        shape=(seq_len, head_dim),
        strides=(stride_sk_n, stride_sk_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    o_ptrs = tl.make_block_ptr(
        base=o_ptr,
        shape=(seq_len, head_dim),
        strides=(stride_o_n, stride_o_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    if RETURN_MASK:
        mask_ptrs = tl.make_block_ptr(
            base=mask_ptr,
            shape=(seq_len, head_dim),
            strides=(stride_m_n, stride_m_d),
            offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
            block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
            order=(1, 0),
        )
    # load and compute
    v = tl.load(v_ptrs)
    sk = tl.load(sk_ptrs)
    o = v - sk
    mask_min = tl.where(o >= min_val, True, False)
    mask_max = tl.where(o <= max_val, True, False)
    o = tl.where(mask_min, tl.where(mask_max, o, max_val), min_val)
    # store
    tl.store(o_ptrs, o.to(o_ptrs.dtype.element_ty), boundary_check=(0, 1))
    if RETURN_MASK:
        tl.store(
            mask_ptrs,
            (mask_min & mask_max).to(mask_ptrs.dtype.element_ty),
            boundary_check=(0, 1),
        )


def clip_fwd(
    v: torch.Tensor,
    sk: torch.Tensor,
    min_val: float = -1,
    max_val: float = 1,
    return_mask: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    batch_size, seq_len, num_heads, head_dim = v.shape
    assert v.shape == sk.shape
    # no clip
    if min_val == float("-inf") and max_val == float("inf"):
        return v - sk, None
    o = torch.empty_like(v)
    if return_mask:
        mask = torch.empty_like(v, dtype=torch.bool)
    else:
        mask = None

    def grid(meta):
        return (
            batch_size * num_heads,
            triton.cdiv(seq_len, meta["BLOCK_SIZE_N"]),
            triton.cdiv(head_dim, meta["BLOCK_SIZE_D"]),
        )

    _clip_kernel_fwd[grid](
        v,
        sk,
        o,
        mask,
        min_val,
        max_val,
        seq_len,
        num_heads,
        head_dim,
        v.stride(0),
        v.stride(1),
        v.stride(2),
        v.stride(3),
        sk.stride(0),
        sk.stride(1),
        sk.stride(2),
        sk.stride(3),
        o.stride(0),
        o.stride(1),
        o.stride(2),
        o.stride(3),
        mask.stride(0) if mask is not None else 0,
        mask.stride(1) if mask is not None else 0,
        mask.stride(2) if mask is not None else 0,
        mask.stride(3) if mask is not None else 0,
    )
    return o, mask


@triton.autotune(
    configs=[
        triton.Config(
            {"BLOCK_SIZE_N": BN, "BLOCK_SIZE_D": BD}, num_warps=nw, num_stages=ns
        )
        for BN in [32, 64]
        for BD in [64, 128]
        for nw in [2, 4, 8]
        for ns in [2, 3]
    ],
    key=["head_dim"],
)
@triton.jit
def _clip_kernel_bwd(
    do_ptr,
    mask_ptr,
    dv_ptr,
    dsk_ptr,
    seq_len,
    num_heads,
    head_dim,
    # stride
    stride_do_b,
    stride_do_n,
    stride_do_h,
    stride_do_d,
    stride_m_b,
    stride_m_n,
    stride_m_h,
    stride_m_d,
    stride_dv_b,
    stride_dv_n,
    stride_dv_h,
    stride_dv_d,
    stride_dsk_b,
    stride_dsk_n,
    stride_dsk_h,
    stride_dsk_d,
    # block sizes
    BLOCK_SIZE_N: tl.constexpr,
    BLOCK_SIZE_D: tl.constexpr,
):
    pid_bh, pid_n, pid_d = tl.program_id(0), tl.program_id(1), tl.program_id(2)
    pid_b, pid_h = pid_bh // num_heads, pid_bh % num_heads
    # move ptr to the start of this batch
    do_ptr = do_ptr + pid_b * stride_do_b + pid_h * stride_do_h
    mask_ptr = mask_ptr + pid_b * stride_m_b + pid_h * stride_m_h
    dv_ptr = dv_ptr + pid_b * stride_dv_b + pid_h * stride_dv_h
    dsk_ptr = dsk_ptr + pid_b * stride_dsk_b + pid_h * stride_dsk_h
    # ptrs
    do_ptrs = tl.make_block_ptr(
        base=do_ptr,
        shape=(seq_len, head_dim),
        strides=(stride_do_n, stride_do_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    mask_ptrs = tl.make_block_ptr(
        base=mask_ptr,
        shape=(seq_len, head_dim),
        strides=(stride_m_n, stride_m_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    # load and compute
    do = tl.load(do_ptrs)
    mask = tl.load(mask_ptrs).to(tl.int1)
    # compute
    dv = tl.where(mask, do, 0)
    dsk = -dv
    # store
    dv_ptrs = tl.make_block_ptr(
        base=dv_ptr,
        shape=(seq_len, head_dim),
        strides=(stride_dv_n, stride_dv_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    dsk_ptrs = tl.make_block_ptr(
        base=dsk_ptr,
        shape=(seq_len, head_dim),
        strides=(stride_dsk_n, stride_dsk_d),
        offsets=(pid_n * BLOCK_SIZE_N, pid_d * BLOCK_SIZE_D),
        block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_D),
        order=(1, 0),
    )
    tl.store(dv_ptrs, dv.to(dv_ptrs.dtype.element_ty), boundary_check=(0, 1))
    tl.store(dsk_ptrs, dsk.to(dsk_ptrs.dtype.element_ty), boundary_check=(0, 1))


def clip_bwd(
    do: torch.Tensor, mask: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
    batch_size, seq_len, num_heads, head_dim = do.shape
    if mask is None:
        return do, -do
    assert do.shape == mask.shape
    dv = torch.empty_like(do)
    dsk = torch.empty_like(do)

    def grid(meta):
        return (
            batch_size * num_heads,
            triton.cdiv(seq_len, meta["BLOCK_SIZE_N"]),
            triton.cdiv(head_dim, meta["BLOCK_SIZE_D"]),
        )

    _clip_kernel_bwd[grid](
        do,
        mask,
        dv,
        dsk,
        seq_len,
        num_heads,
        head_dim,
        do.stride(0),
        do.stride(1),
        do.stride(2),
        do.stride(3),
        mask.stride(0),
        mask.stride(1),
        mask.stride(2),
        mask.stride(3),
        dv.stride(0),
        dv.stride(1),
        dv.stride(2),
        dv.stride(3),
        dsk.stride(0),
        dsk.stride(1),
        dsk.stride(2),
        dsk.stride(3),
    )
    return dv, dsk
