# SPDX-License-Identifier: Apache-2.0

from unittest.mock import MagicMock

import pytest  # noqa

from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.scheduler import Scheduler
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, SequenceGroup

from .utils import create_dummy_prompt


def get_sequence_groups(scheduler_output):
    return [s.seq_group for s in scheduler_output.scheduled_seq_groups]


def append_new_token(seq_group: SequenceGroup, token_id: int):
    for seq in seq_group.get_seqs():
        seq.append_token_id(token_id, {token_id: Logprob(token_id)})


def schedule_and_update_computed_tokens(scheduler):
    metas, out, _ = scheduler.schedule()
    for s, meta in zip(out.scheduled_seq_groups, metas):
        s.seq_group.update_num_computed_tokens(meta.token_chunk_size)
    return metas, out


def test_simple():
    """Verify basic scheduling works."""
    block_size = 4
    num_seq_group = 4
    max_model_len = 16
    max_num_batched_tokens = 64
    scheduler_config = SchedulerConfig("generate",
                                       max_num_batched_tokens,
                                       num_seq_group,
                                       max_model_len,
                                       enable_chunked_prefill=True)
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)
    running: list[SequenceGroup] = []

    # Add seq groups to scheduler.
    for i in range(num_seq_group):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=block_size,
                                           block_size=block_size)
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    # Schedule seq groups prompts.
    num_tokens = block_size * num_seq_group
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert set(get_sequence_groups(out)) == set(running)
    assert out.num_batched_tokens == num_tokens
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group
    for s in running:
        append_new_token(s, 1)

    # Schedule seq groups generation.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert set(get_sequence_groups(out)) == set(running)
    assert out.num_batched_tokens == num_seq_group
    assert (not out.blocks_to_copy and not out.blocks_to_swap_in
            and not out.blocks_to_swap_out)
    assert len(seq_group_meta) == num_seq_group


def test_chunk():
    """Verify prefills are chunked properly."""
    block_size = 4
    max_seqs = 60
    max_model_len = 80
    max_num_batched_tokens = 64
    scheduler_config = SchedulerConfig(
        "generate",
        max_num_batched_tokens,
        max_seqs,
        max_model_len,
        enable_chunked_prefill=True,
    )
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 32
    cache_config.num_gpu_blocks = 32
    scheduler = Scheduler(scheduler_config, cache_config, None)
    running: list[SequenceGroup] = []

    # Add seq groups to scheduler.
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    # Verify the second request is chunked.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    print()
    assert set(get_sequence_groups(out)) == set(running)
    assert seq_group_meta[0].token_chunk_size == 60
    # Verify it is chunked.
    assert seq_group_meta[1].token_chunk_size == 4
    assert out.num_prefill_groups == 2
    assert out.num_batched_tokens == 64
    # Only the first seq group has a new token appended.
    append_new_token(running[0], 1)

    # One chunked prefill, and one decoding.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert set(get_sequence_groups(out)) == set(running)
    # The first one is prefill. Scheduler guarantees ordering.
    assert seq_group_meta[0].token_chunk_size == 56
    # The second one is a chunked prefill.
    assert seq_group_meta[1].token_chunk_size == 1
    assert out.num_prefill_groups == 1
    assert out.num_batched_tokens == 57


def test_concurrent_chunking():
    """Verify prefills are chunked properly when 
    --max-num-partial-prefills is > 1"""
    block_size = 4
    max_seqs = 60
    max_model_len = 2000
    max_num_batched_tokens = 64
    scheduler_config = SchedulerConfig(
        "generate",
        max_num_batched_tokens,
        max_seqs,
        max_model_len,
        enable_chunked_prefill=True,
        max_num_partial_prefills=2,  # Up to 2 partial prefills at a time
    )
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 32
    cache_config.num_gpu_blocks = 32
    scheduler = Scheduler(scheduler_config, cache_config, None)
    running: list[SequenceGroup] = []

    # Add seq groups to scheduler.
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    # Verify both requests are chunked with half of max_num_batched_tokens each
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert set(get_sequence_groups(out)) == set(running)
    assert seq_group_meta[0].token_chunk_size == 32
    assert seq_group_meta[1].token_chunk_size == 32
    assert out.num_prefill_groups == 2
    assert out.num_batched_tokens == 64

    # After one iteration, both should have 60 - 32 = 28 tokens left to prefill
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert set(get_sequence_groups(out)) == set(running)
    assert seq_group_meta[0].token_chunk_size == 28
    assert seq_group_meta[1].token_chunk_size == 28
    assert out.num_prefill_groups == 2
    assert out.num_batched_tokens == 56


def test_concurrent_chunking_large_requests():
    """Verify large prefill requests are run one at a time"""
    block_size = 4
    max_seqs = 60
    max_model_len = 2000
    max_num_batched_tokens = 64
    scheduler_config = SchedulerConfig(
        "generate",
        max_num_batched_tokens,
        max_seqs,
        max_model_len,
        enable_chunked_prefill=True,
        max_num_partial_prefills=2,  # Up to 2 partial prefills at a time
    )
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 3200  # large KV cache size for large requests
    cache_config.num_gpu_blocks = 3200
    scheduler = Scheduler(scheduler_config, cache_config, None)

    # Add seq groups to scheduler.
    for i in range(2):
        _, seq_group = create_dummy_prompt(
            str(i),
            prompt_length=1200,  # Very large prompt
            block_size=block_size)
        scheduler.add_seq_group(seq_group)

    # Verify only a single request is chunked, and it gets all 64 tokens
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert len(get_sequence_groups(out)) == 1
    assert seq_group_meta[0].token_chunk_size == 64
    assert out.num_prefill_groups == 1
    assert out.num_batched_tokens == 64


def test_short_prompts_jump_long_prompts_in_queue():
    """Verify large prefill requests are punted behind smaller ones if 
    another large prefill request is already running"""
    block_size = 4
    max_seqs = 60
    max_model_len = 2000
    max_num_batched_tokens = 64
    scheduler_config = SchedulerConfig(
        "generate",
        max_num_batched_tokens,
        max_seqs,
        max_model_len,
        enable_chunked_prefill=True,
        max_num_partial_prefills=2,  # Up to 2 partial prefills at a time
    )
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 3200  # large KV cache size for large requests
    cache_config.num_gpu_blocks = 3200
    scheduler = Scheduler(scheduler_config, cache_config, None)
    long_seqs: list[SequenceGroup] = []
    short_seqs: list[SequenceGroup] = []

    # Add 2 large seq groups to scheduler.
    for i in range(2):
        _, seq_group = create_dummy_prompt(
            str(i),
            prompt_length=1200,  # Very large prompt
            block_size=block_size)
        scheduler.add_seq_group(seq_group)
        long_seqs.append(seq_group)
        assert seq_group.is_prefill()

    # Add 2 small seq groups behind them
    for i in range(2):
        _, seq_group = create_dummy_prompt(
            str(i + 2),
            prompt_length=40,  # Very small prompt
            block_size=block_size)
        scheduler.add_seq_group(seq_group)
        short_seqs.append(seq_group)
        assert seq_group.is_prefill()

    # Verify one large req and 1 small req chunked
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert seq_group_meta[0].token_chunk_size == 32  # large req gets 32 tokens
    assert seq_group_meta[1].token_chunk_size == 32  # small req gets 32 tokens

    # all 4 are prefilling
    assert long_seqs[0].is_prefill()
    assert long_seqs[1].is_prefill()
    assert short_seqs[0].is_prefill()
    assert short_seqs[1].is_prefill()
    # First short and first long sequences have been scheduled
    assert long_seqs[0].first_seq.get_num_computed_tokens() == 32
    assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
    assert short_seqs[0].first_seq.get_num_computed_tokens() == 32
    assert short_seqs[1].first_seq.get_num_computed_tokens() == 0

    assert out.num_prefill_groups == 2
    assert out.num_batched_tokens == 64

    # in the second iteration,
    # the first small request had only 8 tokens left
    # so it went to decode
    # The other small req is scheduled
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    # the new small req got 64 - (32+8) tokens
    assert seq_group_meta[0].token_chunk_size == 24
    assert seq_group_meta[1].token_chunk_size == 32  # large req still got 32
    # the other small request had only 8 tokens left
    assert seq_group_meta[2].token_chunk_size == 8  # 40-32

    # The first small request got to decode now
    assert long_seqs[0].is_prefill()
    assert long_seqs[1].is_prefill()
    assert not short_seqs[0].is_prefill()
    assert short_seqs[1].is_prefill()
    # Both small requests have started in front of the second long request
    assert long_seqs[0].first_seq.get_num_computed_tokens() == 64
    assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
    assert short_seqs[0].first_seq.get_num_computed_tokens() == 40
    assert short_seqs[1].first_seq.get_num_computed_tokens() == 24

    assert out.num_prefill_groups == 3
    assert out.num_batched_tokens == 64
    # the first small seq group has a new token appended.
    append_new_token(short_seqs[0], 1)

    # in the third iteration,
    # the first small request is already decoding
    # the second small request only has 16 tokens left and will enter decoding
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert seq_group_meta[0].token_chunk_size == 32  # large still got 32
    # small req finished prefilling 40-24=16 tokens
    assert seq_group_meta[1].token_chunk_size == 16
    assert seq_group_meta[2].token_chunk_size == 1  # decode
    assert out.num_prefill_groups == 2
    assert out.num_batched_tokens == 49  # (32+16+1 decode)

    # both small requests have now reached decode
    assert long_seqs[0].is_prefill()
    assert long_seqs[1].is_prefill()
    assert not short_seqs[0].is_prefill()
    assert not short_seqs[1].is_prefill()
    assert long_seqs[0].first_seq.get_num_computed_tokens() == 96
    assert long_seqs[1].first_seq.get_num_computed_tokens() == 0
    assert short_seqs[0].first_seq.get_num_computed_tokens() == 41
    assert short_seqs[1].first_seq.get_num_computed_tokens() == 40

    # both the small seq groups have a new token appended
    append_new_token(short_seqs[0], 1)
    append_new_token(short_seqs[1], 1)

    # in the fourth iteration, both small requests are decoding
    # so large request gets all the budget
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)

    # large req gets 62 tokens (minus 2 for decode)
    assert seq_group_meta[0].token_chunk_size == 62
    assert seq_group_meta[1].token_chunk_size == 1  # decode
    assert seq_group_meta[2].token_chunk_size == 1  # decode
    assert out.num_prefill_groups == 1
    assert out.num_batched_tokens == 64

    assert long_seqs[0].first_seq.get_num_computed_tokens() == 158

    # assert long_seqs[0].is_prefill()
    # assert long_seqs[1].is_prefill()
    # assert not short_seqs[0].is_prefill()
    # assert not short_seqs[1].is_prefill()

    # # both the small seq groups have a new token appended
    # append_new_token(short_seqs[0], 1)
    # append_new_token(short_seqs[1], 1)

    # # in the fifth iteration, large request gets all the budget
    # # while both small requests are decoding
    # seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    # assert seq_group_meta[0].token_chunk_size == 62
    # assert seq_group_meta[1].token_chunk_size == 1  # decode
    # assert seq_group_meta[2].token_chunk_size == 1  # decode
    # assert out.num_prefill_groups == 1
    # assert out.num_batched_tokens == 64


def test_complex():
    block_size = 4
    max_seqs = 60
    max_model_len = 80
    max_num_batched_tokens = 64
    scheduler_config = SchedulerConfig(
        "generate",
        max_num_batched_tokens,
        max_seqs,
        max_model_len,
        enable_chunked_prefill=True,
    )
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 64
    cache_config.num_gpu_blocks = 64
    scheduler = Scheduler(scheduler_config, cache_config, None)
    running: list[SequenceGroup] = []

    # Add seq groups to scheduler.
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)
        assert seq_group.is_prefill()

    # Verify the second request is chunked.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)

    assert set(get_sequence_groups(out)) == set(running)
    assert seq_group_meta[0].token_chunk_size == 60
    # Verify it is chunked.
    assert seq_group_meta[1].token_chunk_size == 4
    assert not running[0].is_prefill()
    assert running[1].is_prefill()
    assert out.num_prefill_groups == 2
    assert out.num_batched_tokens == 64
    # Only the first seq group has a new token appended.
    append_new_token(running[0], 1)

    # Add 2 more requests.
    for i in range(2, 4):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=60,
                                           block_size=block_size)
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    # Decoding & chunked prefill & first chunk of 3rd request is scheduled.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert len(get_sequence_groups(out)) == 3
    # The first one is the first chunked prefill.
    assert seq_group_meta[0].token_chunk_size == 7
    # The second one is the second new chunked prefill.
    assert seq_group_meta[1].token_chunk_size == 56
    # The last one is decode.
    assert seq_group_meta[2].token_chunk_size == 1
    # Two of them are in chunked prefill.
    assert out.num_prefill_groups == 2
    assert out.num_batched_tokens == 64
    # The first 2 requests are now in decodine phase.
    append_new_token(running[0], 1)
    assert not running[0].is_prefill()
    append_new_token(running[1], 1)
    assert not running[1].is_prefill()
    # The third request is still in prefill stage.
    assert running[2].is_prefill()


def test_maximal_decoding():
    """Verify decoding requests are prioritized."""
    block_size = 4
    max_seqs = 2
    max_model_len = 8
    max_num_batched_tokens = 2
    scheduler_config = SchedulerConfig(
        "generate",
        max_num_batched_tokens,
        max_seqs,
        max_model_len,
        enable_chunked_prefill=True,
    )
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 8
    cache_config.num_gpu_blocks = 8
    scheduler = Scheduler(scheduler_config, cache_config, None)
    running: list[SequenceGroup] = []

    # Add seq groups to scheduler.
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=2,
                                           block_size=block_size)
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)
        assert seq_group.is_prefill()

    # The first prefill is scheduled.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert len(get_sequence_groups(out)) == 1
    assert seq_group_meta[0].token_chunk_size == 2
    assert not running[0].is_prefill()
    assert running[1].is_prefill()
    assert out.num_prefill_groups == 1
    assert out.num_batched_tokens == 2
    # Only the first seq group has a new token appended.
    append_new_token(running[0], 1)

    # Create one more seq_group.
    _, seq_group = create_dummy_prompt("3",
                                       prompt_length=2,
                                       block_size=block_size)
    scheduler.add_seq_group(seq_group)
    running.append(seq_group)
    assert seq_group.is_prefill()
    # The first decoding + second chunk is scheduled.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert len(get_sequence_groups(out)) == 2
    assert seq_group_meta[0].token_chunk_size == 1
    assert seq_group_meta[1].token_chunk_size == 1
    assert not running[0].is_prefill()
    assert running[1].is_prefill()
    assert running[2].is_prefill()
    assert out.num_prefill_groups == 1
    assert out.num_batched_tokens == 2
    append_new_token(running[0], 1)

    # Decoding + running prefill is prioritized.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert len(get_sequence_groups(out)) == 2
    assert seq_group_meta[0].token_chunk_size == 1
    assert seq_group_meta[1].token_chunk_size == 1
    assert not running[0].is_prefill()
    assert not running[1].is_prefill()
    assert out.num_prefill_groups == 1
    assert out.num_batched_tokens == 2
    append_new_token(running[0], 1)
    append_new_token(running[1], 1)

    # Only decoding is prioritized.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert len(get_sequence_groups(out)) == 2
    assert seq_group_meta[0].token_chunk_size == 1
    assert seq_group_meta[1].token_chunk_size == 1
    assert not running[0].is_prefill()
    assert not running[1].is_prefill()
    assert out.num_prefill_groups == 0
    assert out.num_batched_tokens == 2
    append_new_token(running[0], 1)
    append_new_token(running[1], 1)

    # After aborting the decoding request, the fcfs new prefill is prioritized.
    scheduler.abort_seq_group(running[0].request_id)
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert len(get_sequence_groups(out)) == 2
    assert seq_group_meta[0].token_chunk_size == 1
    assert seq_group_meta[1].token_chunk_size == 1
    assert not running[1].is_prefill()
    assert running[2].is_prefill()
    assert out.num_prefill_groups == 1
    assert out.num_batched_tokens == 2


def test_prompt_limit():
    """Verify max_num_batched_tokens < max_model_len is possible."""
    block_size = 4
    max_seqs = 32
    max_model_len = 64
    max_num_batched_tokens = 32
    scheduler_config = SchedulerConfig(
        "generate",
        max_num_batched_tokens,
        max_seqs,
        max_model_len,
        enable_chunked_prefill=True,
    )
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 16
    cache_config.num_gpu_blocks = 16
    scheduler = Scheduler(scheduler_config, cache_config, None)
    running: list[SequenceGroup] = []

    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=48,
                                       block_size=block_size)
    scheduler.add_seq_group(seq_group)
    running.append(seq_group)
    assert seq_group.is_prefill()

    # The prompt length > max_num_batched_tokens should be still scheduled.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert len(get_sequence_groups(out)) == 1
    assert seq_group_meta[0].token_chunk_size == 32
    assert running[0].is_prefill()
    assert out.num_prefill_groups == 1
    assert out.num_batched_tokens == 32


def test_prompt_limit_exceed():
    block_size = 4
    max_seqs = 64
    max_model_len = 32
    max_num_batched_tokens = 64
    scheduler_config = SchedulerConfig("generate",
                                       max_num_batched_tokens,
                                       max_seqs,
                                       max_model_len,
                                       enable_chunked_prefill=True)
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 16
    cache_config.num_gpu_blocks = 16
    scheduler = Scheduler(scheduler_config, cache_config, None)
    running: list[SequenceGroup] = []
    _, seq_group = create_dummy_prompt("2",
                                       prompt_length=48,
                                       block_size=block_size)
    scheduler.add_seq_group(seq_group)
    running.append(seq_group)
    assert seq_group.is_prefill()
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.ignored_seq_groups) == 1
    assert out.ignored_seq_groups[0] == seq_group


def test_chunked_prefill_preempt():
    """Verify preempt works with chunked prefill requests"""
    block_size = 4
    max_seqs = 30
    max_model_len = 200
    max_num_batched_tokens = 30
    scheduler_config = SchedulerConfig(
        "generate",
        max_num_batched_tokens,
        max_seqs,
        max_model_len,
        enable_chunked_prefill=True,
    )
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 16
    cache_config.num_gpu_blocks = 16
    scheduler = Scheduler(scheduler_config, cache_config, None)

    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=60,
                                       block_size=block_size)
    scheduler.add_seq_group(seq_group)
    _, out = schedule_and_update_computed_tokens(scheduler)
    # The request is chunked.
    # prefill scheduled now.
    assert len(out.scheduled_seq_groups) == 1
    assert out.num_prefill_groups == 1
    assert seq_group.is_prefill()
    assert out.num_batched_tokens == max_num_batched_tokens

    # The request should be preempted.
    scheduler.block_manager.can_append_slots = MagicMock()

    def cannot_append_second_group1(seq_group, num_lookahead_slots):
        return seq_group.request_id != "1"

    scheduler.block_manager.can_append_slots.side_effect = (
        cannot_append_second_group1)

    # The running prefill is now preempted.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 0
    assert out.num_batched_tokens == 0
    assert out.blocks_to_swap_out == []
    assert out.blocks_to_swap_in == []

    # Make sure we can reschedule preempted request.
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 1
    assert out.num_prefill_groups == 1
    assert seq_group.is_prefill()
    assert out.num_batched_tokens == max_num_batched_tokens
    assert seq_group.get_num_uncomputed_tokens() == 30

    # We should be able to run prefill twice as it is chunked.
    def cannot_append_second_group2(seq_group, num_lookahead_slots):
        return True

    scheduler.block_manager.can_append_slots.side_effect = (
        cannot_append_second_group2)
    _, out = schedule_and_update_computed_tokens(scheduler)
    assert len(out.scheduled_seq_groups) == 1
    assert out.num_prefill_groups == 1
    assert not seq_group.is_prefill()
    assert out.num_batched_tokens == max_num_batched_tokens


@pytest.mark.parametrize("num_scheduler_steps", [1, 5])
def test_chunked_prefill_spec_prefill(num_scheduler_steps):
    """Verify that the num_lookahead_slots is set appropriately for an all"""
    """prefill batch depending on whether multi-step scheduling is enabled"""
    """or not"""
    block_size = 4
    max_seqs = 30
    max_model_len = 200
    max_num_batched_tokens = 30
    num_lookahead_slots = 4
    scheduler_config = SchedulerConfig(
        "generate",
        max_num_batched_tokens,
        max_seqs,
        max_model_len,
        enable_chunked_prefill=True,
        num_lookahead_slots=num_lookahead_slots,
        num_scheduler_steps=num_scheduler_steps,
    )
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 16
    cache_config.num_gpu_blocks = 16
    scheduler = Scheduler(scheduler_config, cache_config, None)

    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=30,
                                       block_size=block_size)
    scheduler.add_seq_group(seq_group)
    _, out = schedule_and_update_computed_tokens(scheduler)
    # The request is chunked.
    # prefill scheduled now.
    assert len(out.scheduled_seq_groups) == 1
    assert out.num_prefill_groups == 1
    assert out.num_batched_tokens == max_num_batched_tokens
    print(out.num_lookahead_slots)
    assert out.num_lookahead_slots == (0 if (num_scheduler_steps == 1) else
                                       num_lookahead_slots)


def test_chunked_prefill_max_seqs():
    block_size = 4
    max_seqs = 2
    max_model_len = 80
    max_num_batched_tokens = 64
    scheduler_config = SchedulerConfig(
        "generate",
        max_num_batched_tokens,
        max_seqs,
        max_model_len,
        enable_chunked_prefill=True,
    )
    cache_config = CacheConfig(block_size, 1.0, 1, "auto")
    cache_config.num_cpu_blocks = 128
    cache_config.num_gpu_blocks = 128
    scheduler = Scheduler(scheduler_config, cache_config, None)
    running: list[SequenceGroup] = []

    _, seq_group = create_dummy_prompt("1",
                                       prompt_length=65,
                                       block_size=block_size)
    scheduler.add_seq_group(seq_group)
    running.append(seq_group)
    # The first prefill is chunked.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert seq_group_meta[0].token_chunk_size == max_num_batched_tokens
    assert len(get_sequence_groups(out)) == 1

    # Add new requests.
    for i in range(4):
        _, seq_group = create_dummy_prompt(str(i),
                                           prompt_length=65,
                                           block_size=block_size)
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    # Make sure only 2 requests are scheduled.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert out.num_batched_tokens == max_num_batched_tokens
    assert len(get_sequence_groups(out)) == 2
    assert not running[0].is_prefill()
    assert running[1].is_prefill()
    append_new_token(running[0], 1)

    # Although we have enough token budget, we can only schedule max_seqs.
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert seq_group_meta[0].token_chunk_size == 2
    assert seq_group_meta[1].token_chunk_size == 1
    assert out.num_batched_tokens == 3
    assert len(get_sequence_groups(out)) == max_seqs
    assert not running[0].is_prefill()
    assert not running[1].is_prefill()


def test_prefix_caching():
    """Verify allocating full blocks when prefix caching is enabled."""
    block_size = 4
    max_seqs = 10
    max_model_len = 80
    max_num_batched_tokens = 64
    scheduler_config = SchedulerConfig(
        "generate",
        max_num_batched_tokens,
        max_seqs,
        max_model_len,
        enable_chunked_prefill=True,
    )
    cache_config = CacheConfig(block_size,
                               1.0,
                               1,
                               "auto",
                               enable_prefix_caching=True)
    cache_config.num_cpu_blocks = 0
    cache_config.num_gpu_blocks = 32
    scheduler = Scheduler(scheduler_config, cache_config, None)
    running: list[SequenceGroup] = []

    # Add seq groups to scheduler.
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           block_size=block_size,
                                           prompt_length=50)
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert set(get_sequence_groups(out)) == set(running)
    assert seq_group_meta[0].token_chunk_size == 50
    # Verify it is chunked. Note that although the budget is 64-50=14,
    # we only allocate full blocks for prefix caching, so only 4*(14//4)=12
    # tokens are allocated.
    assert seq_group_meta[1].token_chunk_size == 12
    assert out.num_prefill_groups == 2
    assert out.num_batched_tokens == 62


def test_prefix_caching_with_concurrent_partial_prefills():
    """Verify allocating full blocks when prefix caching is enabled with 
    --max-num-partial-prefills > 1."""
    block_size = 4
    max_seqs = 10
    max_model_len = 8000
    max_num_batched_tokens = 60  # With two slots, each slot will get 30 tokens
    scheduler_config = SchedulerConfig("generate",
                                       max_num_batched_tokens,
                                       max_seqs,
                                       max_model_len,
                                       enable_chunked_prefill=True,
                                       max_num_partial_prefills=2)
    cache_config = CacheConfig(block_size,
                               1.0,
                               1,
                               "auto",
                               enable_prefix_caching=True)
    cache_config.num_cpu_blocks = 0
    cache_config.num_gpu_blocks = 32
    scheduler = Scheduler(scheduler_config, cache_config, None)
    running: list[SequenceGroup] = []

    # Add seq groups to scheduler.
    for i in range(2):
        _, seq_group = create_dummy_prompt(str(i),
                                           block_size=block_size,
                                           prompt_length=50)
        scheduler.add_seq_group(seq_group)
        running.append(seq_group)

    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert set(get_sequence_groups(out)) == set(running)
    # To partially prefill both sequences, both can chunk up to 30 tokens
    # But the next lowest multiple of the block size (4) is 28
    assert seq_group_meta[0].token_chunk_size == 28
    assert seq_group_meta[1].token_chunk_size == 28
    assert out.num_prefill_groups == 2
    assert out.num_batched_tokens == 56

    # On the next iteration, both sequences should finish prefill
    seq_group_meta, out = schedule_and_update_computed_tokens(scheduler)
    assert set(get_sequence_groups(out)) == set(running)
    # Both sequences have 50 - 28 = 22 tokens left to prefill.
    # This is not a multiple of the block size, but we don't care since we don't
    # cache the final partial block of prefix sequences
    assert seq_group_meta[0].token_chunk_size == 22
    assert seq_group_meta[1].token_chunk_size == 22
    assert out.num_prefill_groups == 2
    assert out.num_batched_tokens == 44


@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("max_num_partial_prefills", [2, 4, 8])
def test_chunked_prefill_with_actual_engine(model: str,
                                            max_num_partial_prefills: int):
    """Make sure the model can actually sample with concurrent 
    partial prefills
    """

    prompt = "hello" * 40

    engine_args = EngineArgs(
        model=model,
        max_num_partial_prefills=max_num_partial_prefills,
        max_num_batched_tokens=40,
        max_num_seqs=8,
        enable_chunked_prefill=True,
        gpu_memory_utilization=0.8,
    )

    engine = LLMEngine.from_engine_args(engine_args)
    sampling_params = SamplingParams(temperature=0)

    for req_num in range(max_num_partial_prefills):
        engine.add_request(f"{req_num}", prompt, sampling_params)
    # first step
    request_outputs = engine.step()
    # means all are prefilling
    assert len(request_outputs) == 0
    assert len(engine.scheduler[0].running) == max_num_partial_prefills
