"""Compare the with and without prefix caching.

Run `pytest tests/prefix_caching/test_prefix_caching.py`.
"""
from typing import List

import pytest

from tests.kernels.utils import override_backend_env_variable
from vllm.block import PhysicalTokenBlock
from vllm.core.block_manager_v1 import CachedBlockAllocator
from vllm.utils import Device

from ..models.utils import check_outputs_equal

MODELS = [
    "facebook/opt-125m",
]


@pytest.mark.parametrize("block_size", [16])
@pytest.mark.parametrize("num_blocks", [16])
def test_block_allocator(
    block_size: int,
    num_blocks: int,
):
    block_hash = 1
    block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)

    # Allocate two PysicalTokenBlocks with the same hash and check
    # that they are the same PhysicalTokenBlock
    first_block = block_allocator.allocate(block_hash, 0)
    second_block = block_allocator.allocate(block_hash, 0)
    assert (first_block == second_block)
    assert (second_block.ref_count == 2)

    # Check metric: 1 hit of 2 queries
    assert block_allocator.get_prefix_cache_hit_rate() == 0.5

    # Free the first_block and confirm that the ref_count is correctly
    # decremented on the second block
    block_allocator.free(first_block)
    assert (second_block.ref_count == 1)

    # Free the second block
    block_allocator.free(second_block)

    # Reallocate the first block and confirm that, even after the block
    # had its ref_count go to 0, we still get the same block back
    first_block = block_allocator.allocate(block_hash, 0)
    assert (first_block == second_block)
    assert (first_block.block_hash == block_hash)

    # Allocate one more time to get 3/4 hit rate for easy checking
    block_allocator.allocate(block_hash, 0)
    assert block_allocator.get_prefix_cache_hit_rate() == 0.75


@pytest.mark.parametrize("num_blocks", [16])
def test_eviction(num_blocks: int, ):
    block_size = 16
    block_allocator = CachedBlockAllocator(Device.CPU, block_size, num_blocks)
    blocks: List[PhysicalTokenBlock] = []

    for i in range(num_blocks):
        # use i as the block_hash
        blocks.append(block_allocator.allocate(i, 0))

    #Free all blocks
    for block in blocks:
        block_allocator.free(block)

    # Allocate a new block and confirm that it's the first block freed.
    # I.E The Least Recently Used block
    new_block_hash = block_size
    new_block = block_allocator.allocate(new_block_hash, 0)
    assert (new_block == blocks[0])
    assert (new_block.block_hash == new_block_hash)

    # Reallocate the second in blocks to remove it from the free list
    realloc_block_hash = 1
    realloc_block = block_allocator.allocate(realloc_block_hash, 0)
    assert (realloc_block == blocks[realloc_block_hash])
    assert (realloc_block.block_hash == realloc_block_hash)

    # Allocate a new block and confirm that it's not the realloc_block,
    # since the realloc_block shouldn't be in the free list
    new_block_hash = block_size + 1
    new_block = block_allocator.allocate(new_block_hash, 0)
    assert (realloc_block != new_block)
    assert (new_block.block_hash == new_block_hash)
    assert (new_block.block_number == 2)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"])
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("cached_position", [0, 1])
@pytest.mark.parametrize("use_v2_block_manager", [False, True])
def test_mixed_requests(
    hf_runner,
    vllm_runner,
    example_prompts,
    model: str,
    backend: str,
    dtype: str,
    max_tokens: int,
    cached_position: int,
    use_v2_block_manager: bool,
    monkeypatch,
) -> None:
    """
    Test the case when some sequences have the prefix cache hit
    and the others don't. The cached position determines where 
    the sequence is at among the batch of prefills.
    """
    override_backend_env_variable(monkeypatch, backend)

    with hf_runner(model, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

    cached_prompt = example_prompts[cached_position]
    with vllm_runner(
            model,
            dtype=dtype,
            enable_prefix_caching=True,
            use_v2_block_manager=use_v2_block_manager,
    ) as vllm_model:
        # Run the first prompt so the cache is populated
        vllm_outputs = vllm_model.generate_greedy([cached_prompt], max_tokens)

        # Run all the promopts
        vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

    check_outputs_equal(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
