# SPDX-License-Identifier: Apache-2.0

import pytest
import torch

from vllm import SamplingParams

MODELS = ["distilbert/distilgpt2"]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
def test_logits_processor_force_generate(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
) -> None:
    with vllm_runner(model, dtype=dtype) as vllm_model:
        tokenizer = vllm_model.model.get_tokenizer()
        repeat_times = 2
        enforced_answers = " vLLM"
        vllm_token_ids = tokenizer.encode(enforced_answers,
                                          add_special_tokens=False)
        max_tokens = len(vllm_token_ids) * repeat_times

        def pick_vllm(token_ids, logits):
            token_id = vllm_token_ids[len(token_ids) % len(vllm_token_ids)]
            logits[token_id] = torch.finfo(logits.dtype).max
            return logits

        params_with_logprobs = SamplingParams(
            logits_processors=[pick_vllm],
            prompt_logprobs=3,
            max_tokens=max_tokens,
        )

        # test logits_processors when prompt_logprobs is not None
        vllm_model.model._add_request(
            example_prompts[0],
            params=params_with_logprobs,
        )

        # test prompt_logprobs is not None
        vllm_model.model._add_request(
            example_prompts[1],
            params=SamplingParams(
                prompt_logprobs=3,
                max_tokens=max_tokens,
            ),
        )

        # test grouped requests
        vllm_model.model._add_request(
            example_prompts[2],
            params=SamplingParams(max_tokens=max_tokens),
        )

        outputs = vllm_model.model._run_engine(use_tqdm=False)

        assert outputs[0].outputs[0].text == enforced_answers * repeat_times
