from itertools import cycle

import pytest

from vllm import SamplingParams

from .conftest import get_token_ids_from_llm_generator


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
        # Use a small model for a fast test.
        "model": "facebook/opt-125m",

        # skip cuda graph creation for fast test.
        "enforce_eager": True,

        # Allow only 5 sequences of ~1024 tokens in worst case.
        "block_size": 16,
        "num_gpu_blocks_override": 5 * (64 + 1),
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
    "use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
    "use_v2_block_manager": True,
    "preemption_mode": "swap"
}, {
    "use_v2_block_manager": True,
    "preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator,
                                               test_llm_generator, batch_size):
    """Verify block manager v2 produces same outputs as block manager v1, even
    when there is preemption.

    This constructs two LLM, each with limited number of GPU blocks. The limit
    is decided such that as the sequences in the batch grow, sequences must be
    preempted and removed from cache.

    If the output token ids are equivalent, then we have confidence that the KV
    cache is not corrupted in the v2 block manager.

    NOTE: We want a significant number of generated tokens so that any incorrect
    KV mapping has time to build up error.
    """
    output_len = 1024
    temperature = 0.0

    # We want to ensure equality even with preemption.
    # We force the total block size to be 1 + cdiv(output_len, block_size)
    # so that only one sequence can fit at a time (once the sequences grow).

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    sampling_params = SamplingParams(
        max_tokens=output_len,
        ignore_eos=True,
        temperature=temperature,
    )

    print('Getting token ids from block manager v1')
    baseline_token_ids = get_token_ids_from_llm_generator(
        baseline_llm_generator, prompts, sampling_params)

    print('Getting token ids from block manager v2')
    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
                                                      prompts, sampling_params)

    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
                                                    test_token_ids):
        assert expected_token_ids == actual_token_ids

    assert baseline_token_ids == test_token_ids


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
        # Use a small model for a fast test.
        "model": "facebook/opt-125m",

        # skip cuda graph creation for fast test.
        "enforce_eager": True,

        # Use a large block size to trigger more copy-on-writes.
        "block_size": 32,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
    "use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
    "use_v2_block_manager": True,
    "preemption_mode": "swap"
}, {
    "use_v2_block_manager": True,
    "preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator,
                                        test_llm_generator, batch_size):
    """Verify beam search equality with block manager v1 and v2.

    This requires copy-on-writes; if the v1 and v2 output is the same, then
    we have some confidence cow is working.
    """
    output_len = 128
    temperature = 0.0

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    sampling_params = SamplingParams(
        max_tokens=output_len,
        ignore_eos=True,
        temperature=temperature,
        use_beam_search=True,
        best_of=2,
    )

    print('Getting token ids from block manager v1')
    baseline_token_ids = get_token_ids_from_llm_generator(
        baseline_llm_generator, prompts, sampling_params)

    print('Getting token ids from block manager v2')
    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
                                                      prompts, sampling_params)

    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
                                                    test_token_ids):
        assert expected_token_ids == actual_token_ids

    assert baseline_token_ids == test_token_ids


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
        # Use a small model for a fast test.
        "model": "facebook/opt-125m",

        # Our prompts will generate 128 tokens; since the prompts themselves are
        # small, we don't need much KV space beyond 128.
        "max_model_len": 160,

        # skip cuda graph creation for fast test.
        "enforce_eager": True,

        # Lookahead scheduling only supported in v2 block manager.
        "use_v2_block_manager": True,
    }])
@pytest.mark.parametrize(
    "per_test_common_llm_kwargs",
    [
        {
            "block_size": 16,

            # Allow only 2 sequences of ~128 tokens in worst case.
            # Note 8 = 128/block_size
            "num_gpu_blocks_override": 2 * (8 + 1),
        },
        {
            "block_size": 8,

            # Allow only 2 sequences of ~128 tokens in worst case.
            # Note 16 = 128/block_size
            "num_gpu_blocks_override": 2 * (16 + 2),
        }
    ])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
    "num_lookahead_slots": 0,
}])
@pytest.mark.parametrize(
    "test_llm_kwargs",
    [
        {
            # We run one test with block_size < lookahead_slots, one test with
            # block_size > lookahead_slots
            "num_lookahead_slots": 10,
            "preemption_mode": "swap",
        },
        {
            "num_lookahead_slots": 10,
            "preemption_mode": "recompute",
        }
    ])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_lookahead_greedy_equality_with_preemption(baseline_llm_generator,
                                                   test_llm_generator,
                                                   batch_size):
    """Verify vLLM produces the same output with greedy sampling, when lookahead
    scheduling is used vs. not.

    Lookahead scheduling is not expected to modify the output, as it simply
    allocates empty slots ahead of the known token ids in a sliding fashion.

    This test constrains the total number of blocks to force preemption. It also
    varies the block size so that the lookahead size is less than and greater
    than the block size.
    """
    output_len = 128
    temperature = 0.0

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    sampling_params = SamplingParams(
        max_tokens=output_len,
        ignore_eos=True,
        temperature=temperature,
    )

    print('Getting token ids without lookahead scheduling')
    baseline_token_ids = get_token_ids_from_llm_generator(
        baseline_llm_generator, prompts, sampling_params)

    print('Getting token ids with lookahead scheduling')
    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
                                                      prompts, sampling_params)

    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
                                                    test_token_ids):
        assert expected_token_ids == actual_token_ids

    assert baseline_token_ids == test_token_ids


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [
        {
            # Use a small model for a fast test.
            "model": "facebook/opt-125m",

            # skip cuda graph creation for fast test.
            "enforce_eager": True,
            "enable_chunked_prefill": True,
        },
    ])
@pytest.mark.parametrize("per_test_common_llm_kwargs",
                         [{
                             "block_size": 8,
                             "max_num_batched_tokens": 2,
                             "max_num_seqs": 2,
                         }, {
                             "block_size": 8,
                             "max_num_batched_tokens": 3,
                             "max_num_seqs": 2,
                         }, {
                             "block_size": 8,
                             "max_num_batched_tokens": 256,
                             "max_num_seqs": 10,
                         }])
@pytest.mark.parametrize("baseline_llm_kwargs", [
    {
        "use_v2_block_manager": False,
    },
])
@pytest.mark.parametrize("test_llm_kwargs", [
    {
        "use_v2_block_manager": True,
        "num_lookahead_slots": 0,
    },
    {
        "use_v2_block_manager": True,
        "num_lookahead_slots": 5,
    },
])
@pytest.mark.parametrize("batch_size", [4])
@pytest.mark.parametrize("seed", [1])
def test_chunked_prefill_block_manager_v2(baseline_llm_generator,
                                          test_llm_generator, batch_size):
    """Verify that chunked prefill works with BlockManagerV2, with and without
    lookahead scheduling.
    """
    output_len = 32
    temperature = 0.0

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        ("1 + " * 50) + " 1 = ",  # Longer prompt.
        "The capital of France is",
        "The future of AI is",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    sampling_params = SamplingParams(
        max_tokens=output_len,
        ignore_eos=True,
        temperature=temperature,
    )

    print('Getting token ids with BlockManagerV1')
    baseline_token_ids = get_token_ids_from_llm_generator(
        baseline_llm_generator, prompts, sampling_params)

    print('Getting token ids with BlockManagerV2')
    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
                                                      prompts, sampling_params)

    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
                                                    test_token_ids):
        assert expected_token_ids == actual_token_ids

    assert baseline_token_ids == test_token_ids


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
        # Use a small model for a fast test.
        "model": "facebook/opt-125m",

        # skip cuda graph creation for fast test.
        "enforce_eager": True,

        # Allow only 5 sequences of ~1024 tokens in worst case.
        "block_size": 16,
        "num_gpu_blocks_override": 5 * (64 + 1),

        # Enable prefill cache
        "enable_prefix_caching": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
    "use_v2_block_manager": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
    "use_v2_block_manager": True,
    "preemption_mode": "swap"
}, {
    "use_v2_block_manager": True,
    "preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_v1_v2_greedy_equality_prefix_caching_enabled_with_preemption(
        baseline_llm_generator, test_llm_generator, batch_size):
    """Verify block manager v2 produces same outputs as block manager v1, even
    when there is preemption.

    This constructs two LLM, each with limited number of GPU blocks. The limit
    is decided such that as the sequences in the batch grow, sequences must be
    preempted and removed from cache.

    If the output token ids are equivalent, then we have confidence that the KV
    cache is not corrupted in the v2 block manager.

    NOTE: We want a significant number of generated tokens so that any incorrect
    KV mapping has time to build up error.
    """
    output_len = 1024
    temperature = 0.0

    # We want to ensure equality even with preemption.
    # We force the total block size to be 1 + cdiv(output_len, block_size)
    # so that only one sequence can fit at a time (once the sequences grow).

    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    sampling_params = SamplingParams(
        max_tokens=output_len,
        ignore_eos=True,
        temperature=temperature,
    )

    print('Getting token ids from block manager v1')
    baseline_token_ids = get_token_ids_from_llm_generator(
        baseline_llm_generator, prompts, sampling_params)

    print('Getting token ids from block manager v2')
    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
                                                      prompts, sampling_params)

    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
                                                    test_token_ids):
        assert expected_token_ids == actual_token_ids

    assert baseline_token_ids == test_token_ids


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
        # Use a small model for a fast test.
        "model": "facebook/opt-125m",

        # skip cuda graph creation for fast test.
        "enforce_eager": True,

        # Allow only 5 sequences of ~1024 tokens in worst case.
        "block_size": 16,
        "num_gpu_blocks_override": 5 * (64 + 1),

        # Test APC in v2 block
        "use_v2_block_manager": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
    "enable_prefix_caching": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
    "enable_prefix_caching": True,
    "preemption_mode": "swap"
}, {
    "enable_prefix_caching": True,
    "preemption_mode": "recompute"
}])
@pytest.mark.parametrize("batch_size", [10])
@pytest.mark.parametrize("seed", [1])
def test_auto_prefix_caching_with_preemption(baseline_llm_generator,
                                             test_llm_generator, batch_size):
    """Verify block manager v2 with auto prefix caching enabled produces same
    outputs as auto prefix caching disabled, even when there is preemption.

    This constructs two LLM, each with limited number of GPU blocks. The limit
    is decided such that as the sequences in the batch grow, sequences must be
    preempted and removed from cache.

    If the output token ids are equivalent, then we have confidence that auto
    prefix caching itself at least don't cause result error.
    """
    output_len = 1024
    temperature = 0.0

    # We want to ensure equality even with preemption.
    # We force the total block size to be 1 + cdiv(output_len, block_size)
    # so that only one sequence can fit at a time (once the sequences grow).
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))]

    sampling_params = SamplingParams(
        max_tokens=output_len,
        ignore_eos=True,
        temperature=temperature,
    )

    print('Getting token ids with APC disabled')
    baseline_token_ids = get_token_ids_from_llm_generator(
        baseline_llm_generator, prompts, sampling_params)

    print('Getting token ids with APC enabled')
    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
                                                      prompts, sampling_params)

    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
                                                    test_token_ids):
        assert expected_token_ids == actual_token_ids

    assert baseline_token_ids == test_token_ids


@pytest.mark.parametrize(
    "common_llm_kwargs",
    [{
        # Use a small model for a fast test.
        "model": "facebook/opt-125m",

        # skip cuda graph creation for fast test.
        "enforce_eager": True,

        # we keep the blocks small, so that hit eviction quickly
        "max_model_len": 48,
        "block_size": 16,
        "num_gpu_blocks_override": 3,

        # Test APC in v2 block
        "use_v2_block_manager": True,
    }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{
    "enable_prefix_caching": False
}])
@pytest.mark.parametrize("test_llm_kwargs", [{
    "enable_prefix_caching": True,
}])
@pytest.mark.parametrize("seed", [1])
def test_auto_prefix_caching_after_evition_start(baseline_llm_generator,
                                                 test_llm_generator):
    """Verify block manager v2 with auto prefix caching could works normal
    even when eviction started.
    With APC enabled, all blocks are held by native block at the beginning.
    Then blocks are managed by evictor instead. If cache hit at the evitor's
    block, then it could be reused, or we need to recompute its kv cache.
    """
    output_len = 10
    temperature = 0.0

    prompts = [
        "You are a helpful assistant. Please answer truthfully and write "
        "out your thinking step by step to be sure you get the right answer. "
        "If you make a mistake, attempt to correct it. who are you?",
        "You are a helpful assistant. Please answer truthfully and write out "
        "your thinking step by step to be sure you get the right answer. You "
        "are helpful and harmless and you follow ethical guidelines. "
        "who are you?"
    ]

    sampling_params = SamplingParams(
        max_tokens=output_len,
        ignore_eos=True,
        temperature=temperature,
    )

    print('Getting token ids with APC disabled')
    baseline_token_ids = get_token_ids_from_llm_generator(
        baseline_llm_generator, prompts, sampling_params)

    print('Getting token ids with APC enabled')
    test_token_ids = get_token_ids_from_llm_generator(test_llm_generator,
                                                      prompts, sampling_params)

    for expected_token_ids, actual_token_ids in zip(baseline_token_ids,
                                                    test_token_ids):
        assert expected_token_ids == actual_token_ids

    assert baseline_token_ids == test_token_ids
