import random
from unittest.mock import MagicMock

import pytest
import torch

from vllm.model_executor.utils import set_random_seed
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.top1_proposer import Top1Proposer
from vllm.worker.worker import Worker

from .utils import (assert_logprobs_dict_allclose, create_batch,
                    create_seq_group_metadata_from_prompts, create_worker,
                    patch_execute_model_with_seeds, zero_kv_cache)


@pytest.mark.parametrize('num_steps', list(range(1, 17)))
def test_assert_enough_kv_space(num_steps: int):
    """Test that the multi step worker checks for sufficient space in the KV
    cache. It should throw if it cannot run all the steps.
    """
    block_size = 16
    num_gpu_blocks = 2048 // block_size

    prompts = [
        list(range(block_size * 3)),
        list(range(block_size * 2)),
    ]

    prev_output_tokens = [
        list(range(block_size * 1)),
        list(range(block_size * 2)),
    ]

    final_prompt_lens = [
        len(prompt + output) + num_steps
        for prompt, output in zip(prompts, prev_output_tokens)
    ]

    inputs = create_seq_group_metadata_from_prompts(
        prompts,
        num_gpu_blocks,
        block_size,
        final_prompt_lens,
        continuations=prev_output_tokens)

    assert_enough_kv_space = MultiStepWorker._assert_enough_kv_space  # pylint: disable=protected-access
    worker = MagicMock()
    worker.model_runner.block_size = block_size

    for seq_group_metadata in inputs:
        original_block_tables = seq_group_metadata.block_tables

        # No exception.
        assert_enough_kv_space(worker, inputs, num_steps)

        seq_group_metadata.block_tables = {
            seq_id: []
            for seq_id, physical_blocks in original_block_tables.items()
        }

        # Expect exception.
        with pytest.raises(ValueError,
                           match='times but found insufficient KV space for'):
            assert_enough_kv_space(worker, inputs, num_steps)

        seq_group_metadata.block_tables = original_block_tables


@torch.inference_mode()
def test_same_output_for_single_step():
    """Verify the multi step worker produces the same output as the normal
    worker for num_steps=1.
    """
    seed = 100
    model_name = 'JackFram/llama-68m'

    block_size = 32
    num_gpu_blocks = 2048 // block_size
    multi_step_worker = create_worker(
        MultiStepWorker,
        model_name,
        block_size,
        num_gpu_blocks,
        seed,
    )
    worker = create_worker(
        Worker,
        model_name,
        block_size,
        num_gpu_blocks,
        seed,
    )
    # multi_step_worker.model_runner = worker.model_runner
    # multi_step_worker.cache_engine = worker.cache_engine

    num_steps = 1

    prompts = [
        [1, 2, 3, 4, 5],
        [6, 7, 8, 9, 10],
    ]

    final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]

    multi_step_seq_group = create_seq_group_metadata_from_prompts(
        prompts,
        num_gpu_blocks,
        block_size,
        final_prompt_lens=final_prompt_lens)

    zero_kv_cache(multi_step_worker.cache_engine)
    set_random_seed(seed)
    actual_output, _ = multi_step_worker.sampler_output(
        execute_model_req=ExecuteModelRequest(
            seq_group_metadata_list=multi_step_seq_group),
        sample_len=num_steps)
    assert len(actual_output) == num_steps
    actual_output = actual_output[0]

    single_step_seq_group = create_seq_group_metadata_from_prompts(
        prompts,
        num_gpu_blocks,
        block_size,
        final_prompt_lens=final_prompt_lens)

    zero_kv_cache(worker.cache_engine)
    set_random_seed(seed)
    expected_output = worker.execute_model(
        execute_model_req=ExecuteModelRequest(
            seq_group_metadata_list=single_step_seq_group))[0]

    actual_token_ids = [
        output.samples[0].output_token for output in actual_output
    ]
    actual_logprobs = [output.samples[0].logprobs for output in actual_output]

    expected_token_ids = [
        output.samples[0].output_token for output in expected_output
    ]
    expected_logprobs = [
        output.samples[0].logprobs for output in expected_output
    ]

    assert actual_token_ids == expected_token_ids

    print(f'{actual_logprobs=}')
    print(f'{expected_logprobs=}')
    assert_logprobs_dict_allclose(actual_logprobs, expected_logprobs)


@torch.inference_mode()
def test_same_output_for_multi_step():
    """Verify the multi-step worker produces the same output as the normal
    worker when num_steps > 1. This test runs the multi-step worker once, and
    then runs the worker num_steps times, and compares the output.
    """
    seed = 100
    model_name = 'JackFram/llama-68m'

    block_size = 16
    num_gpu_blocks = 2048 // block_size
    multi_step_worker = create_worker(
        MultiStepWorker,
        model_name,
        block_size,
        num_gpu_blocks,
        seed,
    )

    worker = create_worker(
        Worker,
        model_name,
        block_size,
        num_gpu_blocks,
        seed,
    )

    # Make sure we go over the block boundary.
    num_steps = block_size + 1

    random.seed(seed)
    prompts = [[
        random.randint(0, 1000) for _ in range(random.randint(10, 20))
    ] for _ in range(10)]

    final_prompt_lens = [len(prompt) + num_steps for prompt in prompts]

    rand_seeds = list(random.randint(0, 100) for _ in range(num_steps))
    multi_step_worker.execute_model = patch_execute_model_with_seeds(
        multi_step_worker, rand_seeds)
    worker.execute_model = patch_execute_model_with_seeds(worker, rand_seeds)

    continuations = [[1] for _ in prompts]
    seq_group_metadata_list = create_seq_group_metadata_from_prompts(
        prompts,
        num_gpu_blocks,
        block_size,
        continuations=continuations,
        final_prompt_lens=final_prompt_lens)

    # Run multi-step.
    zero_kv_cache(multi_step_worker.cache_engine)
    set_random_seed(seed)
    multi_step_output, _ = multi_step_worker.sampler_output(
        execute_model_req=ExecuteModelRequest(
            seq_group_metadata_list=seq_group_metadata_list),
        sample_len=num_steps)

    # Run single-step repeatedly.
    zero_kv_cache(worker.cache_engine)
    single_step_output = []
    continuations = [[1] for _ in prompts]
    set_random_seed(seed)

    for _ in multi_step_output:

        seq_group_metadata_list = create_seq_group_metadata_from_prompts(
            prompts,
            num_gpu_blocks,
            block_size,
            continuations=continuations,
            final_prompt_lens=final_prompt_lens)

        single_step_output.extend(
            worker.execute_model(execute_model_req=ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list)))

        # Append output tokens to new sequence data.
        for i, seq_group_output in enumerate(single_step_output[-1]):
            continuations[i].append(seq_group_output.samples[0].output_token)

    # Get token ids and logprobs for comparison.
    multi_step_output_logprobs = [[] for _ in prompts]
    single_step_output_logprobs = [[] for _ in prompts]

    multi_step_output_token_ids = [[] for _ in prompts]
    single_step_output_token_ids = [[] for _ in prompts]
    for i, _ in enumerate(prompts):
        for multi_step, single_step in zip(multi_step_output,
                                           single_step_output):
            multi_step_output_token_ids[i].append(
                multi_step[i].samples[0].output_token)
            single_step_output_token_ids[i].append(
                single_step[i].samples[0].output_token)

            multi_step_output_logprobs[i].append(
                multi_step[i].samples[0].logprobs)
            single_step_output_logprobs[i].append(
                single_step[i].samples[0].logprobs)

    # Print per-sequence token ids
    for i, (multi_step_tokens, single_step_tokens) in enumerate(
            zip(multi_step_output_token_ids, single_step_output_token_ids)):
        print(f'{i=} {multi_step_tokens=}')
        print(f'{i=} {single_step_tokens=}')
        print(f'{i=} equal {multi_step_tokens == single_step_tokens}')

    # Assert token ids are equal.
    for multi_step_tokens, single_step_tokens in zip(
            multi_step_output_token_ids, single_step_output_token_ids):
        assert multi_step_tokens == single_step_tokens

    # Assert logprobs are equal.
    for multi_step_logprobs, single_step_logprobs in zip(
            multi_step_output_logprobs, single_step_output_logprobs):
        assert_logprobs_dict_allclose(multi_step_logprobs,
                                      single_step_logprobs)


@torch.inference_mode()
def test_draft_proposals_full_speculation_len():
    """Verify Top1Proposer correctly handles case where all sequences
    can speculate.
    """
    k = 10
    batch_size = 32
    vocab_size = 32_000
    device = 'cuda:0'

    draft_worker = MagicMock()
    proposer = Top1Proposer(
        worker=draft_worker,
        device=device,
        vocab_size=vocab_size,
        max_proposal_len=2048,
    )
    draft_worker.sampler_output.return_value = [
        SamplerOutput(
            outputs=[],
            sampled_token_probs=torch.rand(batch_size,
                                           vocab_size,
                                           device=device,
                                           dtype=torch.float32),
            logprobs=torch.rand(batch_size,
                                vocab_size,
                                device=device,
                                dtype=torch.float32),
            sampled_token_ids=torch.randint(low=0,
                                            high=vocab_size,
                                            size=(batch_size, ),
                                            device=device,
                                            dtype=torch.long),
        ) for _ in range(k)
    ], True

    seq_group_metadata_list, _, _ = create_batch(batch_size, k)

    proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
        seq_group_metadata_list=seq_group_metadata_list,
        num_lookahead_slots=k), )

    assert torch.is_tensor(proposals.proposal_token_ids)
    assert torch.is_tensor(proposals.proposal_probs)

    assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
    assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])

    assert proposals.proposal_lens.shape == torch.Size([batch_size])
    assert proposals.proposal_lens.tolist() == [k for _ in range(batch_size)]


@torch.inference_mode()
def test_draft_proposals_no_speculations():
    """Verify Top1Proposer correctly handles case where no sequences
    can speculate.
    """
    k = 10
    batch_size = 32
    vocab_size = 32_000
    device = 'cuda:0'
    prompt_len = 10

    draft_worker = MagicMock()
    proposer = Top1Proposer(
        worker=draft_worker,
        device=device,
        vocab_size=vocab_size,
        max_proposal_len=prompt_len + k - 1,
    )

    seq_group_metadata_list, _, _ = create_batch(batch_size,
                                                 k,
                                                 prompt_len=prompt_len)

    proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
        seq_group_metadata_list=seq_group_metadata_list,
        num_lookahead_slots=k), )

    assert torch.is_tensor(proposals.proposal_token_ids)
    assert torch.is_tensor(proposals.proposal_probs)

    assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
    assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])

    assert proposals.proposal_lens.shape == torch.Size([batch_size])
    assert proposals.proposal_lens.tolist() == [0 for _ in range(batch_size)]


@torch.inference_mode()
def test_draft_proposals_mixed_k():
    """Verify Top1Proposer correctly handles case some sequences can
    speculate and some can't.
    """
    k = 10
    batch_size = 32
    vocab_size = 32_000
    device = 'cuda:0'

    small_prompt_len = 5
    long_prompt_len = 10
    prev_output_token_len = 20

    expected_num_proposal_seqs = 6
    expected_num_no_proposal_seqs = batch_size - expected_num_proposal_seqs

    prompt_len = [
        small_prompt_len for _ in range(expected_num_proposal_seqs - 1)
    ] + [long_prompt_len
         for _ in range(expected_num_no_proposal_seqs)] + [small_prompt_len]

    draft_worker = MagicMock()
    proposer = Top1Proposer(
        worker=draft_worker,
        device=device,
        vocab_size=vocab_size,
        max_proposal_len=long_prompt_len + prev_output_token_len + k - 1,
    )

    draft_worker.sampler_output.return_value = [
        SamplerOutput(
            outputs=[],
            sampled_token_probs=torch.rand(expected_num_proposal_seqs,
                                           vocab_size,
                                           device=device,
                                           dtype=torch.float32),
            logprobs=torch.rand(expected_num_proposal_seqs,
                                vocab_size,
                                device=device,
                                dtype=torch.float32),
            sampled_token_ids=torch.randint(
                low=0,
                high=vocab_size,
                size=(expected_num_proposal_seqs, ),
                device=device,
                dtype=torch.long),
        ) for _ in range(k)
    ], True

    seq_group_metadata_list, _, _ = create_batch(
        batch_size,
        k,
        prompt_len=prompt_len,
        prev_output_token_len=prev_output_token_len,
    )

    proposals = proposer.get_proposals(execute_model_req=ExecuteModelRequest(
        seq_group_metadata_list=seq_group_metadata_list,
        num_lookahead_slots=k), )

    assert torch.is_tensor(proposals.proposal_token_ids)
    assert torch.is_tensor(proposals.proposal_probs)

    assert proposals.proposal_token_ids.shape == torch.Size([batch_size, k])
    assert proposals.proposal_probs.shape[:-1] == torch.Size([batch_size, k])

    assert proposals.proposal_lens.shape == torch.Size([batch_size])
    assert proposals.proposal_lens.tolist() == [
        k for _ in range(expected_num_proposal_seqs - 1)
    ] + [0 for _ in range(expected_num_no_proposal_seqs)] + [k]
