import math
import torch
import triton
import triton.language as tl
from .tril_softmax import tril_softmax, naive_tril_softmax


def attn_weight(Q: torch.Tensor, K: torch.Tensor, scale: float, naive_softmax: bool=False):
    w = torch.bmm(Q, K.transpose(1, 2))
    if naive_softmax:
        return naive_tril_softmax(w * scale)
    return tril_softmax(w, scale=scale)


def preattn(K, V):
    BS, NH, T, D = K.size()
    K = K.flatten(0, 1)
    V = V.flatten(0, 1)
    W = attn_weight(K, K, 1 / math.sqrt(D), naive_softmax=True)
    U = torch.empty_like(K)
    deltattn_compileable(U, V, W, 0, 0, T)
    return U.view(BS, NH, T, D)


@torch.library.custom_op("deltattn::iterative_fwd", mutates_args={"U"})
def deltattn_compileable(
    U: torch.Tensor,
    V: torch.Tensor,
    W: torch.Tensor,
    w_start: int,
    u_start: int,
    u_end: int,
) -> None:
    assert(U.ndim == 3)
    BS, T, D = U.size()
    def grid(META):
        return (
            triton.cdiv(D, META['BLOCK_D']),
            BS
        )

    kernel = cumsum_attn_naive_kernel

    kernel[grid](
        U,
        U.stride(0),
        U.stride(1),
        U.stride(2),
        V,
        V.stride(0),
        V.stride(1),
        V.stride(2),
        W,
        W.stride(0),
        W.stride(1),
        W.stride(2),
        w_start,
        u_start,
        u_end,
        D,
        ACC_TYPE=tl.float32,
    )


def _config_deltattn():
    return [
        triton.Config(
            {'BLOCK_D': BD, 'BLOCK_R': BR}, num_stages=ns, num_warps=nw
        )
        for BD in [32]
        for BR in [256]
        for ns in [3]
        for nw in [4]
    ]


@triton.autotune(
    configs=_config_deltattn(),
    key=['D', 'u_end']
)
@triton.jit
def cumsum_attn_naive_kernel(
    U_ptr,
    stride_uh,
    stride_ur,
    stride_uc,
    V_ptr,
    stride_vh,
    stride_vr,
    stride_vc,
    W_ptr,
    stride_wh,
    stride_wr,
    stride_wc,
    w_start,
    u_start,
    u_end,
    D: tl.constexpr,
    BLOCK_D: tl.constexpr, # Assume that there is no mask on D dimension
    BLOCK_R: tl.constexpr,
    ACC_TYPE: tl.constexpr,
):
    pid0 = tl.program_id(axis=0).to(tl.uint64)
    pid1 = tl.program_id(axis=1).to(tl.uint64)
    R_range = tl.arange(0, BLOCK_R).to(tl.uint64)
    D_block = tl.max_contiguous(tl.arange(0, BLOCK_D).to(tl.uint64) + pid0 * BLOCK_D, BLOCK_D)
    for row_i in range(u_start, u_end):
        vi_ptr = V_ptr + D_block * stride_vc + pid1 * stride_uh + row_i * stride_vr
        acc = tl.load(vi_ptr).to(ACC_TYPE)

        for j in range(w_start, row_i, BLOCK_R):
            R_block = R_range + j
            R_mask = R_block < row_i
            wij_ptr = W_ptr + pid1 * stride_wh + row_i * stride_wr + R_block * stride_wc
            w = tl.load(wij_ptr, mask=R_mask, other=0)

            uj_ptr = U_ptr + D_block[None, :] * stride_uc + pid1 * stride_uh + R_block[:, None] * stride_ur
            uj = tl.load(uj_ptr, mask=R_mask[:, None], other=0)

            acc -= tl.sum(w[:, None] * uj, axis=0)

        ui_ptr = U_ptr + D_block * stride_uc + pid1 * stride_uh + row_i * stride_ur
        tl.store(ui_ptr, acc.to(U_ptr.dtype.element_ty))


@torch.library.custom_op("deltattn::cumsum_fwd", mutates_args={"U"})
def cumsum_attn_compileable(
    U: torch.Tensor,
    W: torch.Tensor,
) -> None:
    assert(U.ndim == 3)
    BS, C, D = U.size()
    def grid(META):
        return (
            triton.cdiv(D, META['BLOCK_D']),
            BS
        )
    cumsum_attn_smem_kernel[grid](
        U,
        U.stride(0),
        U.stride(1),
        U.stride(2),
        W,
        W.stride(0),
        W.stride(1),
        W.stride(2),
        C, D,
        ACC_TYPE=tl.float32,
    )


def _config_cumsum_attn():
    return [
        triton.Config(
            {'BLOCK_D': BD}, num_stages=ns, num_warps=nw
        )
        for BD in [16]
        for ns in [4]
        for nw in [4]
    ]


@triton.autotune(
    configs=_config_cumsum_attn(),
    key=['C', 'D']
)
@triton.jit
def cumsum_attn_smem_kernel(
    u_ptr,
    stride_uh,
    stride_ur,
    stride_uc,
    w_ptr,
    stride_wh,
    stride_wr,
    stride_wc,
    C: tl.constexpr,
    D: tl.constexpr,
    BLOCK_D: tl.constexpr, # Assume that there is no mask on D dimension
    ACC_TYPE: tl.constexpr,
):
    pid0 = tl.program_id(axis=0)
    pid1 = tl.program_id(axis=1)

    c_block = tl.arange(0, C)

    v_block_ptr = tl.make_block_ptr(
        base=u_ptr + pid1 * stride_uh,
        shape=(C, D),
        strides=(stride_ur, stride_uc),
        offsets=(0, pid0 * BLOCK_D),
        block_shape=(C, BLOCK_D),
        order=(0, 1),
    )
    u = tl.load(v_block_ptr)

    for row_i in range(0, C):
        w_block_ptr = w_ptr + stride_wh.to(tl.int64) * pid1 + stride_wr.to(tl.int64) * row_i + stride_wc * c_block
        mask_c = c_block < row_i
        w = tl.load(w_block_ptr, mask=mask_c, other=0.)
        up = tl.where(mask_c[:, None], u, 0.)
        ui = u - tl.sum(w[:, None] * up, axis=0)
        mask_i = c_block == row_i
        u = tl.where(mask_i[:, None], ui.to(u_ptr.dtype.element_ty), u)

    tl.store(v_block_ptr, u)
