# SPDX-License-Identifier: Apache-2.0
"""Compare the with and without prefix caching."""

from typing import Optional

import pytest
import torch

from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import cdiv, sha256
from vllm.v1.core.block_pool import BlockPool
from vllm.v1.core.kv_cache_manager import KVCacheManager, Request
from vllm.v1.core.kv_cache_utils import (BlockHashType, KVCacheBlock,
                                         hash_block_tokens)
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
                                        KVCacheGroupSpec)


def make_request(request_id,
                 prompt_token_ids,
                 mm_positions=None,
                 mm_hashes=None,
                 prompt_logprobs: Optional[int] = None):
    if mm_positions is None:
        multi_modal_inputs = None
    else:
        multi_modal_inputs = [MultiModalKwargs({})] * len(mm_positions)

    return Request(
        request_id=request_id,
        prompt=None,
        prompt_token_ids=prompt_token_ids,
        multi_modal_inputs=multi_modal_inputs,
        multi_modal_hashes=mm_hashes,
        multi_modal_placeholders=mm_positions,
        sampling_params=SamplingParams(max_tokens=17,
                                       prompt_logprobs=prompt_logprobs),
        eos_token_id=100,
        arrival_time=0,
        lora_request=None,
    )


def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
    return KVCacheConfig(
        num_blocks=num_blocks,
        tensors={},
        kv_cache_groups=[
            KVCacheGroupSpec(['layer'],
                             FullAttentionSpec(block_size, 1, 1, torch.float32,
                                               False))
        ],
    )


@pytest.mark.parametrize("hash_algo", ["sha256", "hash"])
def test_prefill(hash_algo):
    manager = KVCacheManager(
        make_kv_cache_config(16, 11),
        max_model_len=8192,
        enable_caching=True,
        caching_hash_algo=hash_algo,
        num_preallocate_tokens=16,
    )

    # choose the hash function according to the parameter
    hash_fn = sha256 if hash_algo == "sha256" else hash

    # Complete 3 blocks (48 tokens)
    common_token_ids = [i for i in range(3) for _ in range(16)]

    # Fully cache miss
    # Incomplete 1 block (7 tokens)
    unique_token_ids = [3] * 7
    all_token_ids = common_token_ids + unique_token_ids
    req0 = make_request("0", all_token_ids)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
    assert len(manager.req_to_block_hashes[req0.request_id]) == 3
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req0, 55, computed_blocks)
    assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]

    # Check full block metadata
    parent_block_hash = None
    for block_id in (1, 2, 3):
        block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
        block_hash = hash_block_tokens(hash_fn, parent_block_hash,
                                       block_tokens)
        assert manager.block_pool.blocks[block_id].block_hash == block_hash
        assert manager.block_pool.blocks[block_id].ref_cnt == 1
        parent_block_hash = block_hash.hash_value

    # Check partial/preallocated block metadata
    for block_id in (4, 5):
        assert manager.block_pool.blocks[block_id].block_hash is None
        assert manager.block_pool.blocks[block_id].ref_cnt == 1

    # Cache hit in the common prefix when the original block is still in use.
    # Incomplete 1 block (5 tokens)
    unique_token_ids = [3] * 5
    req1 = make_request("1", common_token_ids + unique_token_ids)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
    assert len(manager.req_to_block_hashes[req1.request_id]) == 3
    assert [b.block_id for b in computed_blocks] == [1, 2, 3]
    assert num_computed_tokens == 3 * 16
    num_new_tokens = 53 - 3 * 16
    blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
    assert [b.block_id for b in blocks] == [6, 7]
    for block in computed_blocks:
        assert block.ref_cnt == 2

    # At this point, we should have 3 free blocks left.
    assert manager.block_pool.free_block_queue.num_free_blocks == 3

    manager.free(req0)
    manager.free(req1)

    # All blocks should be available.
    assert manager.block_pool.free_block_queue.num_free_blocks == 10
    # The order should be
    # [unallocated (8, 9, 10)]
    # [unique_req0 (5, 4)]
    # [unique_req1 (7, 6)]
    # [common (3, 2, 1)]
    assert [
        b.block_id
        for b in manager.block_pool.free_block_queue.get_all_free_blocks()
    ] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]

    # Cache hit in the common prefix when the original block is already free.
    # Incomplete 1 block (6 tokens)
    unique_token_ids = [3] * 6
    req2 = make_request("2", common_token_ids + unique_token_ids)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
    assert len(manager.req_to_block_hashes[req2.request_id]) == 3
    assert [b.block_id for b in computed_blocks] == [1, 2, 3]
    assert num_computed_tokens == 3 * 16
    num_new_tokens = 53 - 3 * 16
    blocks = manager.allocate_slots(req2, num_new_tokens, computed_blocks)
    assert [b.block_id for b in blocks] == [8, 9]

    # Although we only have 5 free blocks, we have 8 blocks in
    # the free block queue due to lazy removal.
    assert manager.block_pool.free_block_queue.num_free_blocks == 5
    assert all([
        b.ref_cnt == 0
        for b in manager.block_pool.free_block_queue.get_all_free_blocks()
    ])
    assert len([
        b for b in manager.block_pool.free_block_queue.get_all_free_blocks()
    ]) == 5

    manager.free(req2)

    # Cache miss and eviction.
    req3 = make_request("3", [99] * (16 * 9))
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks)
    # This block ID order also checks the eviction order.
    assert [b.block_id for b in blocks] == [10, 5, 4, 7, 6, 9, 8, 3, 2, 1]
    assert manager.block_pool.free_block_queue.num_free_blocks == 0
    assert manager.block_pool.free_block_queue.free_list_head is None
    assert manager.block_pool.free_block_queue.free_list_tail is None


def test_prefill_plp():
    '''Test prefill with APC and some prompt logprobs (plp) requests.

    1. Schedule plp request and validate APC block allocation
    2. Schedule non-plp request and validate blocks
    3. Schedule plp request; no hit should occur; validate blocks
    '''
    manager = KVCacheManager(
        make_kv_cache_config(16, 11),
        max_model_len=8192,
        enable_caching=True,
        num_preallocate_tokens=16,
    )
    # the default hash function is hash
    hash_fn = hash

    # Complete 3 blocks (48 tokens)
    common_token_ids = [i for i in range(3) for _ in range(16)]

    # Request #0 is a prompt logprobs request
    # Fully cache miss
    # Incomplete 1 block (7 tokens)
    unique_token_ids = [3] * 7
    all_token_ids = common_token_ids + unique_token_ids
    req0 = make_request("0", all_token_ids, prompt_logprobs=5)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
    assert len(manager.req_to_block_hashes[req0.request_id]) == 3
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req0, 55, computed_blocks)
    assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
    req0_block_hashes = [b.block_hash for b in blocks]

    # Check full block metadata
    parent_block_hash = None
    for block_id in (1, 2, 3):
        block_tokens = tuple(all_token_ids[(block_id - 1) * 16:block_id * 16])
        block_hash = hash_block_tokens(hash_fn, parent_block_hash,
                                       block_tokens)
        assert manager.block_pool.blocks[block_id].block_hash == block_hash
        assert manager.block_pool.blocks[block_id].ref_cnt == 1
        parent_block_hash = block_hash.hash_value

    # Check partial/preallocated block metadata
    for block_id in (4, 5):
        assert manager.block_pool.blocks[block_id].block_hash is None
        assert manager.block_pool.blocks[block_id].ref_cnt == 1

    # Request #1 is a non-prompt-logprobs request:
    # Cache hit in the common prefix when the original block is still in use.
    # Incomplete 1 block (5 tokens)
    unique_token_ids = [3] * 5
    req1 = make_request("1", common_token_ids + unique_token_ids)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
    assert len(manager.req_to_block_hashes[req1.request_id]) == 3
    assert [b.block_id for b in computed_blocks] == [1, 2, 3]
    assert num_computed_tokens == 3 * 16
    num_new_tokens = 53 - 3 * 16
    blocks = manager.allocate_slots(req1, num_new_tokens, computed_blocks)
    assert [b.block_id for b in blocks] == [6, 7]
    for block in computed_blocks:
        assert block.ref_cnt == 2

    # At this point, we should have 3 free blocks left.
    assert manager.block_pool.free_block_queue.num_free_blocks == 3

    manager.free(req0)
    manager.free(req1)

    # All blocks should be available.
    assert manager.block_pool.free_block_queue.num_free_blocks == 10
    # The order should be
    # [unallocated (8, 9, 10)]
    # [unique_req0 (5, 4)]
    # [unique_req1 (7, 6)]
    # [common (3, 2, 1)]
    assert [
        b.block_id
        for b in manager.block_pool.free_block_queue.get_all_free_blocks()
    ] == [8, 9, 10, 5, 4, 7, 6, 3, 2, 1]

    # Request #2 is a prompt-logprobs request:
    # NO cache hit in the common prefix; duplicates request #0 cached blocks
    unique_token_ids = [3] * 6
    req2 = make_request("2",
                        common_token_ids + unique_token_ids,
                        prompt_logprobs=5)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
    assert len(manager.req_to_block_hashes[req2.request_id]) == 3
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req2, 55, computed_blocks)
    block_ids = [b.block_id for b in blocks]
    # Duplicate cached blocks have different ids but same hashes vs request #0
    assert [b.block_hash for b in blocks] == req0_block_hashes
    assert block_ids != [1, 2, 3, 4, 5]

    # Request #2 block hashes are valid since request #0 hashes are.
    # Check block reference counts.
    for block_id in block_ids:
        assert manager.block_pool.blocks[block_id].ref_cnt == 1

    manager.free(req2)


def test_decode():
    manager = KVCacheManager(
        make_kv_cache_config(16, 11),
        max_model_len=8192,
        enable_caching=True,
        num_preallocate_tokens=16,
    )

    # Complete 3 blocks (48 tokens)
    common_token_ids = [i for i in range(3) for _ in range(16)]

    # Fully cache miss
    # Incomplete 1 block (7 tokens)
    unique_token_ids = [3] * 7
    req0 = make_request("0", common_token_ids + unique_token_ids)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req0, 55, computed_blocks)
    assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]

    # Append slots without allocating a new block.
    req0.num_computed_tokens = 55
    for _ in range(4):
        req0.append_output_token_ids(8)
    new_blocks = manager.allocate_slots(req0, 4)
    assert new_blocks is not None and len(new_blocks) == 0
    assert manager.req_to_blocks[req0.request_id][-2].block_hash is None

    # Append slots without allocating a new block, but start using the
    # preallocated block.
    req0.num_computed_tokens = 59
    # 6 tokens to fill the previous block, and 10 tokens to fill
    # the preallocated block.
    for _ in range(5 + 10):
        req0.append_output_token_ids(7)
    new_blocks = manager.allocate_slots(req0, 15)
    assert new_blocks is not None and len(new_blocks) == 0
    assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None

    # Append slots with allocating a new block.
    req0.num_computed_tokens = 74
    # 6 tokens to fill the previous block, and 10 tokens to fill
    # the preallocated block.
    for _ in range(6 + 11):
        req0.append_output_token_ids(12)
    new_blocks = manager.allocate_slots(req0, 17)
    # Plus one preallocated block.
    assert new_blocks is not None and len(new_blocks) == 2


def test_evict():
    manager = KVCacheManager(
        make_kv_cache_config(16, 11),
        max_model_len=8192,
        enable_caching=True,
        num_preallocate_tokens=16,
    )

    last_token_id = 5 * 16 + 7
    req0 = make_request("0", list(range(last_token_id)))
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req0, 5 * 16 + 7, computed_blocks)
    assert len(blocks) == 7  # 5 full + 1 partial + 1 preallocated

    # 3 blocks.
    req1 = make_request("1", list(range(last_token_id,
                                        last_token_id + 3 * 16)))
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req1, 3 * 16, computed_blocks)
    assert len(blocks) == 3  # 3 full blocks
    last_token_id += 3 * 16

    assert manager.block_pool.free_block_queue.num_free_blocks == 0

    manager.free(req0)
    manager.free(req1)
    assert manager.block_pool.free_block_queue.num_free_blocks == 10
    assert [
        b.block_id
        for b in manager.block_pool.free_block_queue.get_all_free_blocks()
    ] == [7, 6, 5, 4, 3, 2, 1, 10, 9, 8]

    # Touch the first 2 blocks.
    req2 = make_request("2", list(range(2 * 16 + 3)))
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
    assert [b.block_id for b in computed_blocks] == [1, 2]
    assert num_computed_tokens == 2 * 16
    blocks = manager.allocate_slots(req2, 3, computed_blocks)
    assert [b.block_id for b in blocks] == [7, 6]
    assert manager.block_pool.free_block_queue.num_free_blocks == 6


def test_hash_block_correct_reuse():
    """
    This tests when a previously cached block is reused as a new block,
    its hash metadata should be correctly reset.
    """
    block_size = 16
    manager = KVCacheManager(
        make_kv_cache_config(16, 2),
        max_model_len=8192,
        enable_caching=True,
        num_preallocate_tokens=0,
    )

    # Allocate 1 block and cache it.
    num_tokens = block_size * 1
    req = make_request("0", list(range(num_tokens)))
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req, num_tokens, computed_blocks)
    assert len(blocks) == 1

    # Deallocate the block.
    manager.free(req)

    # Allocate a new block that's not full, make sure hash info on the
    # block is cleared.
    req = make_request("1", list(range(num_tokens - 1)))
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks)
    assert len(blocks) == 1

    assert manager.block_pool.blocks[blocks[0].block_id].block_hash is None


def test_computed_blocks_not_evicted():
    """
    Test that the computed blocks are not evicted when getting new blocks
    for a request if there are any other free blocks.
    """
    block_size = 16
    manager = KVCacheManager(
        make_kv_cache_config(block_size, 3),
        max_model_len=8192,
        enable_caching=True,
        num_preallocate_tokens=0,
    )

    # Allocate a block and cache it.
    num_tokens = block_size * 1
    req0 = make_request("0", list(range(num_tokens)))
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req0, num_tokens, computed_blocks)
    assert len(blocks) == 1
    assert blocks[0].block_id == 1

    # Allocate another block.
    req1 = make_request("1", list(range(num_tokens, num_tokens * 2)))
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req1, num_tokens, computed_blocks)
    assert len(blocks) == 1
    assert blocks[0].block_id == 2

    # Free the blocks.
    manager.free(req0)
    manager.free(req1)

    # Now if we have a cache hit on the first block, we should evict the second
    # cached block rather than the first one.
    req2 = make_request("2", list(range(num_tokens * 2)))
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
    assert len(computed_blocks) == 1
    assert computed_blocks[0].block_id == 1
    assert num_computed_tokens == block_size

    blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens,
                                    computed_blocks)
    assert len(blocks) == 1
    assert blocks[0].block_id == 2


def test_basic_prefix_caching_disabled():
    """
    This tests that the prefix caching is disabled.
    """
    block_size = 4
    manager = KVCacheManager(
        make_kv_cache_config(block_size, 5),
        max_model_len=8192,
        enable_caching=False,
        num_preallocate_tokens=0,
    )

    req1 = make_request("1", list(range(10)))  # 2 blocks and some more

    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req1, 10, computed_blocks)
    assert len(blocks) == 3

    # Free the blocks.
    manager.free(req1)

    # No caching.
    req2 = make_request("2", list(range(16)))  # shared prefix
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req2, 16, computed_blocks)
    assert len(blocks) == 4

    # New requests should not have any blocks.
    req3 = make_request("3", list(range(4)))
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
    assert not computed_blocks
    assert num_computed_tokens == 0
    blocks = manager.allocate_slots(req3, 4, computed_blocks)
    assert not blocks


@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8)))
@pytest.mark.parametrize("block_size", [4])
def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int):
    """
    This tests that the preallocated blocks are correctly added.
    """
    manager = KVCacheManager(
        make_kv_cache_config(block_size, 11),
        max_model_len=8192,
        enable_caching=True,
        num_preallocate_tokens=num_preallocate_tokens,
    )
    num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size)

    req = make_request("0", list(range(block_size * 30)))
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req)
    assert not computed_blocks
    assert num_computed_tokens == 0
    # Just ask for 1 block.
    blocks = manager.allocate_slots(req, block_size, computed_blocks)
    req.num_computed_tokens = block_size
    assert len(blocks) == 1 + num_preallocated_blocks

    # Assume all computed, only when num_preallocate_tokens > 0, we need to
    # consume the previously preallocated blocks.
    if num_preallocated_blocks > 0:
        manager.allocate_slots(req, block_size * (len(blocks) - 1))
        req.num_computed_tokens = block_size * len(blocks)

    # Append 1 block.
    blocks = manager.allocate_slots(req, block_size)
    assert len(blocks) == 1 + num_preallocated_blocks


@pytest.mark.parametrize("hash_fn", [sha256, hash])
def test_cache_blocks(hash_fn):
    """
    This is a unit test that tests the correctness of the _cache_full_blocks
    function of KVCacheManager.
    """
    block_size = 4
    block_pool = BlockPool(
        num_gpu_blocks=5,
        enable_caching=True,
    )
    # Req:
    #  Block 0: [0, 1, 2, 3]
    #  Block 1: [4, 5, 6, 7]
    #  Block 2: [8, 9, 10, 11]
    #  Block 3: [12, 13]
    req = make_request("0", list(range(14)))

    # Test that blocks are cached correctly for 2 full blocks from the start.
    blocks = [KVCacheBlock(block_id=i) for i in range(2)]
    block_hashes: list[BlockHashType] = []

    block_pool.cache_full_blocks(
        request=req,
        blocks=blocks,
        block_hashes=block_hashes,
        num_cached_blocks=0,
        num_full_blocks=2,
        block_size=block_size,
        hash_fn=hash_fn,
    )

    assert len(block_pool.cached_block_hash_to_block) == 2
    assert all([block.block_hash is not None for block in blocks])

    # Test that blocks that don't start from the beginning are cached correctly.
    blocks += [KVCacheBlock(block_id=2)]
    block_pool.cache_full_blocks(
        request=req,
        blocks=blocks,
        block_hashes=block_hashes,
        num_cached_blocks=2,
        num_full_blocks=3,
        block_size=block_size,
        hash_fn=hash_fn,
    )
    assert len(block_pool.cached_block_hash_to_block) == 3
    assert blocks[0].block_hash is not None


def test_mm_prefix_caching():
    """
    This tests that the multi-modal prefix caching is correct.
    """
    manager = KVCacheManager(
        make_kv_cache_config(16, 11),
        max_model_len=8192,
        enable_caching=True,
        num_preallocate_tokens=16,
    )

    # Common prompt tokens (T is text tokens and P is image placeholder tokens)
    # [T,...,T, P0,...,P0], [P0,...,P0,T,...,T,P1,...,P1], [P1,...,P1]
    common_token_ids = list(range(10)) + [-1] * 6
    common_token_ids += [-1] * 4 + list(range(10, 20)) + [-1] * 2
    common_token_ids += [-1] * 16

    common_mm_positions = [
        PlaceholderRange(offset=11, length=10),
        PlaceholderRange(offset=30, length=18),
    ]
    common_mm_hashes = ["aaa", "bbb"]

    # A unique image plus some text tokens.
    unique_token_ids = [-1] * 7 + [100] * 4
    all_token_ids = common_token_ids + unique_token_ids
    mm_positions = common_mm_positions + [
        PlaceholderRange(offset=48, length=7)
    ]
    mm_hashes = common_mm_hashes + ["ccc"]
    req0 = make_request("0",
                        all_token_ids,
                        mm_positions=mm_positions,
                        mm_hashes=mm_hashes)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)

    # Completed block should have hashes with extra keys.
    assert not computed_blocks
    assert num_computed_tokens == 0
    block_hashes = manager.req_to_block_hashes[req0.request_id]
    assert len(block_hashes) == 3
    assert block_hashes[0].extra_keys == ("aaa", )
    assert block_hashes[1].extra_keys == ("aaa", "bbb")
    assert block_hashes[2].extra_keys == ("bbb", )

    blocks = manager.allocate_slots(req0, 59, computed_blocks)
    assert [b.block_id for b in blocks] == [1, 2, 3, 4, 5]
    req0.num_computed_tokens = 59

    # Append slots without allocating a new block.
    for _ in range(5):
        req0.append_output_token_ids(8)
    new_blocks = manager.allocate_slots(req0, 5)
    assert new_blocks is not None and len(new_blocks) == 0

    # The just completed block should have hashes with extra keys.
    assert len(block_hashes) == 4
    assert block_hashes[3].extra_keys == ("ccc", )

    # Cache hit.
    unique_token_ids = [-1] * 7 + [200] * 5
    all_token_ids = common_token_ids + unique_token_ids
    mm_positions = common_mm_positions + [
        PlaceholderRange(offset=48, length=7)
    ]
    mm_hashes = common_mm_hashes + ["ccc"]
    req1 = make_request("1",
                        all_token_ids,
                        mm_positions=mm_positions,
                        mm_hashes=mm_hashes)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
    assert len(computed_blocks) == 3
    assert num_computed_tokens == 3 * 16


def test_prefill_not_enough_free_blocks_with_computed_blocks():
    """
    This is a unit test that tests the correctness of the allocate_slots
    when there is not enough free blocks. Specifically, when a request
    has computed blocks but cannot be allocated due to not enough free blocks,
    the computed blocks should not be touched.
    """
    block_size = 16
    manager = KVCacheManager(
        make_kv_cache_config(block_size, 11),
        max_model_len=8192,
        enable_caching=True,
        num_preallocate_tokens=0,
    )
    # Complete 3 blocks (48 tokens)
    # | Common-0 | Common-1 | Common-2 | ... |
    common_token_ids = [i for i in range(3) for _ in range(16)]
    req0 = make_request("0", common_token_ids)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req0)
    assert not computed_blocks
    assert num_computed_tokens == 0
    manager.allocate_slots(req0, 48, computed_blocks)
    block_part0 = manager.req_to_blocks[req0.request_id]

    # | Common-0 | Common-1 | Common-2 | Req1-3 | Req1-4 | Req1-5 | ... |
    req1 = make_request("1", common_token_ids * 2)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req1)
    assert computed_blocks == block_part0
    assert num_computed_tokens == 3 * 16
    manager.allocate_slots(req1, 48, computed_blocks)
    block_part1 = manager.req_to_blocks[req1.request_id]
    # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
    # | Req1-5(F)| ... |
    manager.free(req1)
    assert {block.ref_cnt for block in block_part1[:3]} == {1}
    assert {block.ref_cnt for block in block_part1[3:]} == {0}

    # | Common-0 | Common-1 | Common-2 | Req1-3 (F) | Req1-4 (F) |
    # | Req1-5(F)| Req2-0   | Req2-1   | ... |
    req2 = make_request("2", [7] * block_size * 2)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req2)
    assert not computed_blocks
    assert num_computed_tokens == 0
    manager.allocate_slots(req2, block_size * 2, computed_blocks)

    # Req3 is Req2 + 3 new blocks, so the first 6 blocks are computed,
    # but it cannot be allocated due to insufficient free blocks (2).
    # In this case, the ref_cnt of the computed blocks should not be changed.
    assert manager.block_pool.free_block_queue.num_free_blocks == 5
    req3 = make_request("3", common_token_ids * 3)
    computed_blocks, num_computed_tokens = manager.get_computed_blocks(req3)
    assert computed_blocks == block_part1
    assert num_computed_tokens == 6 * 16
    # Req3 cannot be allocated.
    assert manager.allocate_slots(req3, 48, computed_blocks) is None
    # Block 0-2 are used by Req 1.
    assert {block.ref_cnt for block in block_part1[:3]} == {1}
    # Block 3-5 are free.
    assert {block.ref_cnt for block in block_part1[3:]} == {0}


def test_reset_prefix_cache():
    manager = KVCacheManager(
        make_kv_cache_config(16, 11),
        max_model_len=8192,
        enable_caching=True,
        num_preallocate_tokens=0,
    )

    full_block_token_ids = [i for i in range(3) for _ in range(16)]
    unique_token_ids = [3] * 7
    all_token_ids = full_block_token_ids + unique_token_ids
    req0 = make_request("0", all_token_ids)
    blocks = manager.allocate_slots(req0, 55)
    assert [b.block_id for b in blocks] == [1, 2, 3, 4]

    unique_token_ids = [4] * 7
    all_token_ids = full_block_token_ids + unique_token_ids
    req1 = make_request("1", all_token_ids)
    computed_blocks, _ = manager.get_computed_blocks(req1)
    assert len(manager.req_to_block_hashes[req1.request_id]) == 3
    assert len(computed_blocks) == 3
    blocks = manager.allocate_slots(req1, 7, computed_blocks)
    assert [b.block_id for b in blocks] == [5]

    # Failed to reset prefix cache because some blocks are not freed yet.
    assert not manager.reset_prefix_cache()
    assert manager.block_pool.cached_block_hash_to_block

    # Free the blocks.
    manager.free(req0)
    manager.free(req1)

    assert manager.reset_prefix_cache()
    assert not manager.block_pool.cached_block_hash_to_block
    assert all([blk.block_hash is None for blk in manager.block_pool.blocks])
