import torch
import random
import pytest
from unittest.mock import MagicMock

from vllm.spec_decode.multi_step_worker import MultiStepWorker
from vllm.spec_decode.spec_decode_worker import (SpecDecodeWorker,
                                                 split_num_cache_blocks_evenly)
from vllm.spec_decode.interfaces import SpeculativeProposals
from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.layers.rejection_sampler import RejectionSampler
from .utils import (mock_worker, create_batch, ExecuteModelData,
                    create_sampler_output_list)
from vllm.spec_decode.metrics import (SpecDecodeWorkerMetrics,
                                      AsyncMetricsCollector)


@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_correctly_calls_draft_model(k: int, batch_size: int):
    """Verify SpecDecodeWorker calls the draft worker with correct
    inputs. Everything else is mocked out.
    """
    draft_worker = mock_worker(cls=MultiStepWorker)
    target_worker = mock_worker()
    rejection_sampler = MagicMock(spec=RejectionSampler)
    metrics_collector = MagicMock(spec=AsyncMetricsCollector)
    worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
                              metrics_collector)

    exception_secret = 'artifical stop'
    draft_worker.get_spec_proposals.side_effect = ValueError(exception_secret)

    execute_model_data, _, _ = create_batch(batch_size, k)

    with pytest.raises(ValueError, match=exception_secret):
        worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)

    call_args_list = draft_worker.get_spec_proposals.call_args_list
    assert len(call_args_list) == 1

    for args, _ in call_args_list:
        (seq_group_metadata_list, blocks_to_swap_in, blocks_to_swap_out,
         blocks_to_copy, actual_k) = args
        actual_execute_model_data = ExecuteModelData(seq_group_metadata_list,
                                                     blocks_to_swap_in,
                                                     blocks_to_swap_out,
                                                     blocks_to_copy)
        assert actual_execute_model_data == execute_model_data
        assert actual_k == k


@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_correctly_calls_target_model(k: int, batch_size: int):
    """Verify SpecDecodeWorker calls the target model with correct
    inputs. Everything else is mocked out.
    """
    draft_worker = mock_worker(cls=MultiStepWorker)
    target_worker = mock_worker()
    rejection_sampler = MagicMock(spec=RejectionSampler)
    rejection_sampler.token_id_dtype = torch.int64
    metrics_collector = MagicMock(spec=AsyncMetricsCollector)

    draft_worker.device = 'cuda'
    target_worker.device = 'cuda'

    set_random_seed(1)

    worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
                              metrics_collector)
    worker.init_device()

    vocab_size = 32_000

    proposal_token_ids = torch.randint(low=0,
                                       high=vocab_size,
                                       size=(batch_size, k),
                                       dtype=torch.int64,
                                       device='cuda')
    proposal_probs = torch.rand(batch_size,
                                k,
                                vocab_size,
                                dtype=torch.float32,
                                device='cuda')
    proposal_lens = torch.ones(batch_size, dtype=torch.int64,
                               device='cuda') * k

    execute_model_data, prompts, prev_output_tokens = create_batch(
        batch_size, k)

    draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
        proposal_token_ids=proposal_token_ids,
        proposal_probs=proposal_probs,
        proposal_lens=proposal_lens)

    exception_secret = 'artifical stop'
    target_worker.execute_model.side_effect = ValueError(exception_secret)

    with pytest.raises(ValueError, match=exception_secret):
        worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)

    seen_contexts = []

    call_args_list = target_worker.execute_model.call_args_list
    assert len(call_args_list) == 1
    for args, kwargs in call_args_list:
        target_execute_model_data = ExecuteModelData.from_dict(kwargs)

        assert len(target_execute_model_data.seq_group_metadata_list) == (
            k + 1) * batch_size
        for seq_group_metadata in (
                target_execute_model_data.seq_group_metadata_list):
            for seq_data in seq_group_metadata.seq_data.values():
                seen_contexts.append(seq_data.get_token_ids())

    expected_seen_contexts = []

    for prompt, prev_generated, draft_tokens in zip(
            prompts, prev_output_tokens, proposal_token_ids.tolist()):

        for i in range(len(draft_tokens) + 1):
            expected_seen_contexts.append(prompt + prev_generated +
                                          draft_tokens[:i])

    seen_contexts.sort()
    expected_seen_contexts.sort()
    assert expected_seen_contexts == seen_contexts


@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_correctly_calls_rejection_sampler(k: int, batch_size: int):
    """Verify SpecDecodeWorker calls the rejection sampler with
    correct inputs. Everything else is mocked out.
    """
    vocab_size = 32_000

    draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
    target_worker = mock_worker(vocab_size=vocab_size)
    rejection_sampler = MagicMock(spec=RejectionSampler)
    rejection_sampler.token_id_dtype = torch.int64
    metrics_collector = MagicMock(spec=AsyncMetricsCollector)
    draft_worker.device = 'cuda'
    target_worker.device = 'cuda'

    set_random_seed(1)

    worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
                              metrics_collector)
    worker.init_device()

    proposal_token_ids = torch.randint(low=0,
                                       high=vocab_size,
                                       size=(batch_size, k),
                                       dtype=torch.int64,
                                       device='cuda')
    proposal_probs = torch.rand(batch_size,
                                k,
                                vocab_size,
                                dtype=torch.float32,
                                device='cuda')

    proposal_lens = torch.ones(batch_size, dtype=torch.int64,
                               device='cuda') * k

    execute_model_data, _, _ = create_batch(batch_size, k)

    draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
        proposal_token_ids=proposal_token_ids,
        proposal_probs=proposal_probs,
        proposal_lens=proposal_lens)

    target_token_ids = torch.randint(low=0,
                                     high=vocab_size,
                                     size=(1, batch_size * (k + 1)),
                                     dtype=torch.int64,
                                     device='cuda')
    target_token_probs = torch.rand(1,
                                    batch_size * (k + 1),
                                    vocab_size,
                                    dtype=torch.float32,
                                    device='cuda')
    target_output = create_sampler_output_list(target_token_ids,
                                               target_token_probs)

    target_worker.execute_model.return_value = target_output[0]

    exception_secret = 'artifical stop'
    rejection_sampler.side_effect = ValueError(exception_secret)

    with pytest.raises(ValueError, match=exception_secret):
        worker.execute_model(**execute_model_data.to_dict(), num_spec_tokens=k)

    assert len(rejection_sampler.call_args_list) == 1
    args, _ = rejection_sampler.call_args_list[0]
    (actual_proposal_scores, actual_bonus_token_ids, actual_proposal_probs,
     actual_proposal_token_ids) = args

    assert torch.equal(actual_bonus_token_ids,
                       target_token_ids.reshape(batch_size, k + 1)[:, -1:])
    assert torch.equal(
        actual_proposal_scores,
        target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1])
    assert torch.equal(actual_proposal_token_ids, proposal_token_ids)
    assert torch.equal(actual_proposal_probs, proposal_probs)


@pytest.mark.parametrize('k', [1, 2, 6])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_correctly_formats_output(k: int, batch_size: int):
    """Verify SpecDecodeWorker formats sampler output correctly.
    Everything else is mocked out.
    """
    vocab_size = 32_000

    draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
    target_worker = mock_worker(vocab_size=vocab_size)
    rejection_sampler = MagicMock(spec=RejectionSampler)
    rejection_sampler.token_id_dtype = torch.int64
    metrics_collector = MagicMock(spec=AsyncMetricsCollector)
    draft_worker.device = 'cuda'
    target_worker.device = 'cuda'

    set_random_seed(1)

    worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
                              metrics_collector)
    worker.init_device()

    proposal_token_ids = torch.randint(low=0,
                                       high=vocab_size,
                                       size=(batch_size, k),
                                       dtype=torch.int64,
                                       device='cuda')
    proposal_probs = torch.rand(batch_size,
                                k,
                                vocab_size,
                                dtype=torch.float32,
                                device='cuda')

    proposal_lens = torch.ones(batch_size, dtype=torch.int64,
                               device='cuda') * k

    execute_model_data, _, _ = create_batch(batch_size, k)

    draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
        proposal_token_ids=proposal_token_ids,
        proposal_probs=proposal_probs,
        proposal_lens=proposal_lens)

    target_token_ids = torch.randint(low=0,
                                     high=vocab_size,
                                     size=(1, batch_size * (k + 1)),
                                     dtype=torch.int64,
                                     device='cuda')
    target_token_probs = torch.rand(1,
                                    batch_size * (k + 1),
                                    vocab_size,
                                    dtype=torch.float32,
                                    device='cuda')
    target_output = create_sampler_output_list(target_token_ids,
                                               target_token_probs)

    target_worker.execute_model.return_value = target_output[0]

    rejection_sampler_output = torch.randint(low=0,
                                             high=vocab_size,
                                             size=(batch_size, k + 1),
                                             dtype=torch.int64,
                                             device='cuda')
    for i in range(batch_size):
        minimum_accepted_tokens = 1
        rejection_sampler_output[i][
            -random.randint(minimum_accepted_tokens, k + 1):] = -1

    rejection_sampler.return_value = rejection_sampler_output

    output = worker.execute_model(**execute_model_data.to_dict(),
                                  num_spec_tokens=k)

    expected_output = create_sampler_output_list(
        rejection_sampler_output.transpose(0, 1), [None for _ in range(k + 1)])

    seq_ids = [
        next(iter(seq_group_metadata.seq_data.keys()))
        for seq_group_metadata in execute_model_data.seq_group_metadata_list
    ]
    actual_output_by_seq = {seq_id: [] for seq_id in seq_ids}
    expected_output_by_seq = {seq_id: [] for seq_id in seq_ids}

    for step in output:
        for seq_group in step:
            for sample in seq_group.samples:
                seq_id = sample.parent_seq_id
                actual_output_by_seq[seq_id].append(sample)

    for step in expected_output:
        for seq_group in step:
            for sample in seq_group.samples:
                seq_id = sample.parent_seq_id
                expected_output_by_seq[seq_id].append(sample)

    all_seen_seq_ids = set(
        list(actual_output_by_seq.keys()) +
        list(expected_output_by_seq.keys()))
    for seq_id in all_seen_seq_ids:
        actual_by_step = actual_output_by_seq[seq_id]
        expected_by_step = expected_output_by_seq[seq_id]

        for i in range(k + 1):
            if i >= len(actual_by_step):
                assert expected_by_step[i].output_token == -1
                continue
            assert actual_by_step[i].output_token == expected_by_step[
                i].output_token
            assert actual_by_step[i].logprobs == expected_by_step[i].logprobs


@pytest.mark.parametrize('k', [1, 2])
@pytest.mark.parametrize('batch_size', [1])
@pytest.mark.parametrize('returns_metrics', [True, False])
@torch.inference_mode()
def test_collects_metrics(k: int, batch_size: int, returns_metrics: bool):
    """Verify SpecDecodeWorker collects metrics.
    """
    vocab_size = 32_000

    draft_worker = mock_worker(cls=MultiStepWorker, vocab_size=vocab_size)
    target_worker = mock_worker(vocab_size=vocab_size)
    rejection_sampler = MagicMock(spec=RejectionSampler)
    rejection_sampler.token_id_dtype = torch.int64
    metrics_collector = MagicMock(spec=AsyncMetricsCollector)
    draft_worker.device = 'cuda'
    target_worker.device = 'cuda'

    set_random_seed(1)

    worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
                              metrics_collector)
    worker.init_device()

    proposal_token_ids = torch.randint(low=0,
                                       high=vocab_size,
                                       size=(batch_size, k),
                                       dtype=torch.int64,
                                       device='cuda')
    proposal_probs = torch.rand(batch_size,
                                k,
                                vocab_size,
                                dtype=torch.float32,
                                device='cuda')

    proposal_lens = torch.ones(batch_size, dtype=torch.int64,
                               device='cuda') * k

    execute_model_data, _, _ = create_batch(batch_size, k)

    draft_worker.get_spec_proposals.return_value = SpeculativeProposals(
        proposal_token_ids=proposal_token_ids,
        proposal_probs=proposal_probs,
        proposal_lens=proposal_lens)

    target_token_ids = torch.randint(low=0,
                                     high=vocab_size,
                                     size=(1, batch_size * (k + 1)),
                                     dtype=torch.int64,
                                     device='cuda')
    target_token_probs = torch.rand(1,
                                    batch_size * (k + 1),
                                    vocab_size,
                                    dtype=torch.float32,
                                    device='cuda')
    target_output = create_sampler_output_list(target_token_ids,
                                               target_token_probs)

    target_worker.execute_model.return_value = target_output[0]

    rejection_sampler_output = torch.randint(low=0,
                                             high=vocab_size,
                                             size=(batch_size, k + 1),
                                             dtype=torch.int64,
                                             device='cuda')
    for i in range(batch_size):
        minimum_accepted_tokens = 1
        rejection_sampler_output[i][
            -random.randint(minimum_accepted_tokens, k + 1):] = -1

    rejection_sampler.return_value = rejection_sampler_output

    mock_rejsample_metrics = MagicMock(
        spec=SpecDecodeWorkerMetrics) if returns_metrics else None
    metrics_collector.maybe_collect_rejsample_metrics.return_value = (
        mock_rejsample_metrics)

    output = worker.execute_model(**execute_model_data.to_dict(),
                                  num_spec_tokens=k)
    assert output[0].spec_decode_worker_metrics == mock_rejsample_metrics

    call_args_list = (
        metrics_collector.maybe_collect_rejsample_metrics.call_args_list)
    assert len(call_args_list) == 1
    args, kwargs = call_args_list[0]
    assert args[0] == k or kwargs.get('k', -1) == k


@pytest.mark.parametrize('k', [0])
@pytest.mark.parametrize('batch_size', [1, 2, 32])
@torch.inference_mode()
def test_k_equals_zero(k: int, batch_size: int):
    """Verify that the SpecDecodeWorker calls the draft and target workers
    when k is zero. This happens during prefill.
    """
    draft_worker = mock_worker(cls=MultiStepWorker)
    target_worker = mock_worker()
    rejection_sampler = MagicMock(spec=RejectionSampler)
    rejection_sampler.token_id_dtype = torch.int64
    metrics_collector = MagicMock(spec=AsyncMetricsCollector)

    draft_worker.device = 'cuda'
    target_worker.device = 'cuda'

    set_random_seed(1)

    worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
                              metrics_collector)

    execute_model_data, prompts, prev_output_tokens = create_batch(
        batch_size, k, prev_output_token_len=0)

    out = worker.execute_model(**execute_model_data.to_dict(),
                               num_spec_tokens=k)

    assert len(out) == 1, f"expected only one token output when {k=}"
    assert out[0].probs is None, "expect gpu tensor references to be None"
    assert out[
        0].sampled_tokens is None, "expect gpu tensor references to be None"

    draft_worker.execute_model.assert_called_once_with(
        **execute_model_data.to_dict(), return_python_output=False)
    target_worker.execute_model.assert_called_once_with(
        **execute_model_data.to_dict())


@pytest.mark.parametrize('k', [0, 5])
@pytest.mark.parametrize('batch_size', [0])
@torch.inference_mode()
def test_empty_input_batch(k: int, batch_size: int):
    """Verify that the SpecDecodeWorker calls the draft and target workers
    when the input batch is empty. This can happen if the engine communicates
    to the workers information without scheduling a batch.
    """
    draft_worker = mock_worker(cls=MultiStepWorker)
    target_worker = mock_worker()
    rejection_sampler = MagicMock(spec=RejectionSampler)
    rejection_sampler.token_id_dtype = torch.int64
    metrics_collector = MagicMock(spec=AsyncMetricsCollector)

    draft_worker.device = 'cuda'
    target_worker.device = 'cuda'

    set_random_seed(1)

    worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
                              metrics_collector)

    execute_model_data, prompts, prev_output_tokens = create_batch(
        batch_size, k, prev_output_token_len=0)

    out = worker.execute_model(**execute_model_data.to_dict(),
                               num_spec_tokens=k)

    assert len(out) == 1, f"expected only one token output when {k=}"
    assert out[0].probs is None, "expect gpu tensor references to be None"
    assert out[
        0].sampled_tokens is None, "expect gpu tensor references to be None"

    draft_worker.execute_model.assert_called_once_with(
        **execute_model_data.to_dict(), return_python_output=False)
    target_worker.execute_model.assert_called_once_with(
        **execute_model_data.to_dict())


@torch.inference_mode()
def test_init_device():
    """Verify SpecDecodeWorker invokes proposer/scorer worker init_device, as
    well as other GPU initialization.
    """
    draft_worker = mock_worker(cls=MultiStepWorker)
    target_worker = mock_worker()
    rejection_sampler = MagicMock(spec=RejectionSampler)
    rejection_sampler.token_id_dtype = torch.int64
    metrics_collector = MagicMock(spec=AsyncMetricsCollector)

    worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
                              metrics_collector)

    worker.init_device()

    draft_worker.init_device.assert_called_once()

    target_worker.init_device.assert_called_once()

    metrics_collector.init_gpu_tensors.assert_called_once()
    rejection_sampler.init_gpu_tensors.assert_called_once()


@torch.inference_mode()
def test_init_cache_engine():
    """Verify SpecDecodeWorker invokes init_cache_engine on proposer/scorer
    workers.
    """
    draft_worker = mock_worker(cls=MultiStepWorker)
    target_worker = mock_worker()
    rejection_sampler = MagicMock(spec=RejectionSampler)
    rejection_sampler.token_id_dtype = torch.int64
    metrics_collector = MagicMock(spec=AsyncMetricsCollector)

    worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
                              metrics_collector)

    cache_config = MagicMock()

    worker.init_cache_engine(cache_config)

    draft_worker.init_cache_engine.assert_called_once_with(cache_config)
    target_worker.init_cache_engine.assert_called_once_with(cache_config)


@pytest.mark.parametrize('available_gpu_blocks', [1, 1024])
@pytest.mark.parametrize('available_cpu_blocks', [500])
@pytest.mark.parametrize('target_cache_block_size_bytes', [2 * 2 * 4096])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@torch.inference_mode()
def test_profile_num_available_blocks(available_gpu_blocks: int,
                                      available_cpu_blocks: int,
                                      target_cache_block_size_bytes: int,
                                      draft_kv_size_bytes: int):
    """Verify SpecDecodeWorker correctly profiles num available GPU blocks.
    Specifically, it should run profiling in the scorer worker, and then evenly
    split the blocks between proposer and scorer worker.
    """
    draft_worker = mock_worker(cls=MultiStepWorker)
    target_worker = mock_worker()
    rejection_sampler = MagicMock(spec=RejectionSampler)
    rejection_sampler.token_id_dtype = torch.int64
    metrics_collector = MagicMock(spec=AsyncMetricsCollector)

    target_worker.profile_num_available_blocks.return_value = (
        available_gpu_blocks, available_cpu_blocks)
    target_worker.get_cache_block_size_bytes.return_value = (
        target_cache_block_size_bytes)
    draft_worker.get_cache_block_size_bytes.return_value = draft_kv_size_bytes

    worker = SpecDecodeWorker(draft_worker, target_worker, rejection_sampler,
                              metrics_collector)

    # These values do not directly impact the adjusted block size calculation,
    # so they can be fixed.
    gpu_memory_utilization = 0.9
    cpu_swap_space = 100
    block_size = 16

    num_gpu_blocks, num_cpu_blocks = worker.profile_num_available_blocks(
        block_size, gpu_memory_utilization, cpu_swap_space, cache_dtype="auto")

    target_worker.profile_num_available_blocks.assert_called_once_with(
        block_size, gpu_memory_utilization, cpu_swap_space, "auto")
    assert num_cpu_blocks == available_cpu_blocks

    assert num_gpu_blocks == split_num_cache_blocks_evenly(
        target_cache_block_size_bytes, draft_kv_size_bytes,
        available_gpu_blocks)


@pytest.mark.parametrize('available_gpu_blocks',
                         list(range(20)) + [1024, 1024**2])
@pytest.mark.parametrize('target_cache_block_size_bytes',
                         [2 * 2 * 4096, 2 * 2 * 8192])
@pytest.mark.parametrize('draft_kv_size_bytes', [0, 2 * 2 * 768, 2 * 2 * 4096])
@torch.inference_mode()
def test_split_num_cache_blocks_evenly(available_gpu_blocks: int,
                                       target_cache_block_size_bytes: int,
                                       draft_kv_size_bytes: int):
    """Verify split_num_cache_blocks_evenly does not exceed original memory
    allocation in bytes.
    """
    num_blocks = split_num_cache_blocks_evenly(target_cache_block_size_bytes,
                                               draft_kv_size_bytes,
                                               available_gpu_blocks)
    assert (num_blocks * target_cache_block_size_bytes) + (
        num_blocks * draft_kv_size_bytes) <= (available_gpu_blocks *
                                              target_cache_block_size_bytes)
