import triton


def calc_dims(q_len, k_len, block_size_q, block_size_k, q_offset):
    assert q_offset + q_len <= k_len, "Query length + offset must be less than or equal to key length"

    # The index of the first query block in the entire sequence
    q_block_offset = q_offset // block_size_q

    # The number of padding tokens before the first query token in the first query block
    q_start_padding = q_offset - q_block_offset * block_size_q

    # The number of query blocks
    q_blocks = triton.cdiv(q_len + q_offset, block_size_q) - q_block_offset

    # The number of padding tokens after the last query token in the last query block
    q_end_padding = q_blocks * block_size_q - (q_start_padding + q_len)

    # The number of key blocks, including padding to match the query blocks
    k_blocks = triton.cdiv(k_len, block_size_k)
    k_blocks = max(k_blocks, triton.cdiv((q_block_offset + q_blocks) * block_size_q, block_size_k))

    return q_block_offset, q_blocks, q_start_padding, q_end_padding, k_blocks
