import torch
import triton
import triton.language as tl
@triton.jit
def unpack_dequant_kernel(
    value, #[batch, head_num, seq, head_dim // 4] * int8
    indices, #[batch, token_num] * int32
    value_scale, #[batch, head_num, seq, head_dim // group_size] * bf16
    value_mn, #[batch, head_num, seq, head_dim // 32] * bf16
    dequant_value, # [batch, head_num, seq, head_dim] * bf16
    BSZ: tl.constexpr,
    HEAD_NUM: tl.constexpr,
    SEQ_LEN: tl.constexpr,
    TOKEN_NUM: tl.constexpr,
    HEAD_DIM: tl.constexpr,
    QUANT_HEAD_SIZE: tl.constexpr,
    QUANT_PARAM_SIZE: tl.constexpr,
    SEQ_BLOCK: tl.constexpr,
    SEQ_BLOCK_NUM: tl.constexpr,
):
    """
        unpack(int8) + deqaunt, every 2-bits store an element, e.g. every thread handle 4 elements
    """
    pid = tl.program_id(axis=0)
    bsz_idx = pid // (HEAD_NUM * SEQ_BLOCK_NUM)
    head_num_idx = pid // SEQ_BLOCK_NUM % HEAD_NUM
    seq_block_idx = pid % SEQ_BLOCK_NUM
    indices_offset = bsz_idx * TOKEN_NUM + seq_block_idx * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)
    token_indices = tl.load(indices + indices_offset, mask=seq_block_idx * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)<TOKEN_NUM, other=SEQ_LEN)
    seq_mask = (token_indices < SEQ_LEN)[:, None]
    value_offset = bsz_idx * (HEAD_NUM * QUANT_HEAD_SIZE * SEQ_LEN) + head_num_idx * (SEQ_LEN * QUANT_HEAD_SIZE) + token_indices[:, None] * QUANT_HEAD_SIZE + tl.arange(0, QUANT_HEAD_SIZE)[None, :]
    value_block = tl.load(value + value_offset, mask=seq_mask, other=0)
    quant_param_offset = bsz_idx * (HEAD_NUM * QUANT_PARAM_SIZE * SEQ_LEN) + head_num_idx * (SEQ_LEN * QUANT_PARAM_SIZE) + token_indices[:, None] * QUANT_HEAD_SIZE + tl.arange(0, QUANT_HEAD_SIZE)[None, :] // (QUANT_HEAD_SIZE // QUANT_PARAM_SIZE) 
    scale = tl.load(value_scale + quant_param_offset, mask=seq_mask, other=0.0)
    mn = tl.load(value_mn + quant_param_offset, mask=seq_mask, other=0.0)
    bit_mask = 0xFF
    dequant_value_offset = bsz_idx * (HEAD_NUM * HEAD_DIM * TOKEN_NUM) + head_num_idx * (TOKEN_NUM * HEAD_DIM) + seq_block_idx * SEQ_BLOCK * HEAD_DIM + tl.arange(0, SEQ_BLOCK)[:, None] * HEAD_DIM + tl.arange(0, QUANT_HEAD_SIZE)[None, :] * 4
    result_mask = ((seq_block_idx * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)) < TOKEN_NUM)[:, None]
    for i in range(4):
        # unpack
        unpack_value = (value_block & bit_mask) >> (2 * (3 - i))
        unpack_value_bf16 = unpack_value.to(dtype=tl.bfloat16)
        result = unpack_value_bf16 * scale + mn
        # store
        tl.store(dequant_value + dequant_value_offset, result, mask=result_mask)
        # shift mask
        bit_mask = bit_mask >> 2
        #updage offset
        dequant_value_offset = dequant_value_offset + 1
def unpack_dequant(
    value,
    indices,
    value_scale,
    value_mn,
    head_dim,
    dequant_value,
    offset
):
    bsz = value.shape[0]
    head_num = value.shape[1]
    seq = value.shape[2]
    token_num = indices.shape[1]
    # reinterpreter_cast to int8
    value_int8 = value.view(dtype=torch.int8)
    quant_head_size = value_int8.shape[-1]
    quant_param_size = value_scale.shape[-1]
    #  = torch.empty([bsz, head_num, token_num, head_dim], dtype=torch.float16, device="cuda:0")
    dequant_value = dequant_value.narrow(2, offset, token_num)
    QUANT_HEAD_SIZE = head_dim // 4
    SEQ_BLOCK = 1024 // QUANT_HEAD_SIZE
    SEQ_BLOCK_NUM = (token_num + SEQ_BLOCK - 1) // SEQ_BLOCK
    TOTAL_BLOCK = SEQ_BLOCK_NUM * head_num * bsz
    unpack_dequant_kernel[(TOTAL_BLOCK,)](
        value_int8,
        indices,
        value_scale,
        value_mn,
        dequant_value,
        bsz,
        head_num,
        seq,
        token_num,
        head_dim,
        quant_head_size,
        quant_param_size,
        SEQ_BLOCK,
        SEQ_BLOCK_NUM
    )
    return dequant_value