import torch
import triton

from ....triton.chunkwise.limit_chunk import mlstm_chunkwise__parallel_bw_dQKV_kernel
from ....triton.kernel_param_heuristics import get_head_dim_block_size
from ....utils.kernels import is_power_of_2
from ...utils import torch2triton_dtype


def mlstm_chunkwise__parallel_bw_dQKV(
    matQ: torch.Tensor,  # (B, NH, S, DHQK)
    matK: torch.Tensor,  # (B, NH, S, DHQK)
    matV: torch.Tensor,  # (B, NH, S, DHHV)
    vecB: torch.Tensor,  # (B, NH, NC, L)
    vecI: torch.Tensor,  # (B, NH, NC, L)
    vecM_combine: torch.Tensor,  # (B, NH, S) = (B, NH, NC * L)
    scaM_inter: torch.Tensor,  # (B, NH, NC+1)
    matC_states: torch.Tensor,  # (B, NH, NC * DHQK, DHHV)
    matDeltaH: torch.Tensor,  # (B, NH, S, DHHV)
    vecN_out: torch.Tensor,  # (B, NH, S)
    matDeltaC_states: torch.Tensor,  # (B, NH, NC * DHQK, DHHV)
    qk_scale: float = None,
    CHUNK_SIZE: int = 64,
    NUM_CHUNKS: int = 1,
    EPS: float = 1e-6,
) -> tuple[
    torch.Tensor, torch.Tensor, torch.Tensor
]:  # matDeltaQ (B,NH,S,DHQK), matDeltaK (B,NH,S,DHQK), matDeltaV (B,NH,S,DHHV)
    B, NH, S, DHQK, DHHV = *matQ.shape, matV.shape[-1]
    NC = NUM_CHUNKS
    L = CHUNK_SIZE
    _dtype, _device = matQ.dtype, matQ.device

    assert is_power_of_2(L), "Chunk size must be a power of 2."

    if qk_scale is None:
        qk_scale = DHQK**-0.5

    siz_b_DHQK = get_head_dim_block_size(
        head_dim=DHQK, min_block_size=32 if _dtype == torch.float32 else 64
    )
    siz_b_DHHV = get_head_dim_block_size(
        head_dim=DHHV, min_block_size=32 if _dtype == torch.float32 else 64
    )

    num_b_DHQK = triton.cdiv(DHQK, siz_b_DHQK)

    num_stages = 1
    num_warps = 4 if siz_b_DHQK == 64 else 2

    matDeltaQ = torch.zeros((B, NH, S, DHQK), dtype=_dtype, device=_device)
    matDeltaK = torch.zeros((B, NH, S, DHQK), dtype=_dtype, device=_device)
    # each b_DHQK thread block computes the contribution of its siz_b_DHQK block of matDeltaC
    # we need to sum them up to get the final result (we do this outside the kernel)
    matDeltaV = torch.zeros((num_b_DHQK, B, NH, S, DHHV), dtype=_dtype, device=_device)

    grid = (num_b_DHQK, NC, B * NH)
    # print(f"grid: {grid}")
    mlstm_chunkwise__parallel_bw_dQKV_kernel[grid](
        matQ=matQ,
        matK=matK,
        matV=matV,
        vecB=vecB,
        vecI=vecI,
        vecM_combine=vecM_combine,
        scaM_inter=scaM_inter,
        matC_states=matC_states,
        matDeltaH=matDeltaH,
        vecN_out=vecN_out,
        matDeltaC_states=matDeltaC_states,
        matDeltaQ=matDeltaQ,
        matDeltaK=matDeltaK,
        matDeltaV=matDeltaV,
        qk_scale=qk_scale,
        str_matQK_B_NH=matQ.stride(1),
        str_matQK_S=matQ.stride(2),
        str_matQK_DHQK=matQ.stride(3),
        str_matDV_num_b_DHQK=matDeltaV.stride(0),
        str_matHV_B_NH=matV.stride(1),
        str_matHV_S=matV.stride(2),
        str_matHV_DHHV=matV.stride(3),
        str_vecBI_B_NH=vecI.stride(1),
        str_vecBI_NC=vecI.stride(2),
        str_vecBI_L=vecI.stride(3),
        str_vecM_combine_B_NH=vecM_combine.stride(1),
        str_vecM_combine_S=vecM_combine.stride(2),
        str_scaM_inter_B_NH=scaM_inter.stride(1),
        str_scaM_inter_NC=scaM_inter.stride(2),
        str_matC_states_B_NH=matC_states.stride(1),
        str_matC_states_NCDHQK=matC_states.stride(2),
        str_matC_states_DHHV=matC_states.stride(3),
        str_vecN_out_B_NH=vecN_out.stride(1),
        str_vecN_out_S=vecN_out.stride(2),
        str_matDeltaC_states_B_NH=matDeltaC_states.stride(1),
        str_matDeltaC_states_NCDHQK=matDeltaC_states.stride(2),
        str_matDeltaC_states_DHHV=matDeltaC_states.stride(3),
        B=B,
        NH=NH,
        S=S,
        DHQK=DHQK,
        DHHV=DHHV,
        NC=NC,
        L=L,
        siz_b_DHQK=siz_b_DHQK,
        siz_b_DHHV=siz_b_DHHV,
        DTYPE=torch2triton_dtype(_dtype),
        EPS=EPS,
        num_stages=num_stages,
        num_warps=num_warps,
    )
    # sum up the contributions of each siz_b_DHQK block
    matDeltaV = matDeltaV.sum(dim=0)  # (B, NH, S, DHHV)

    return matDeltaQ, matDeltaK, matDeltaV
