# https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py
from vllm import LLM, SamplingParams
from transformers.generation.logits_process import LogitsProcessor
import torch
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]
class SuppressTokensLogitsProcessorText(LogitsProcessor):
    def __init__(self, start_index, torch_dtype):
        self.start_index = start_index
        self.min = torch.finfo(torch_dtype).min
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        scores[self.start_index:] = self.min
        return scores
logits_processors = [SuppressTokensLogitsProcessorText(start_index=32000, torch_dtype=torch.bfloat16)]
llm = LLM(
    # model='YOUR_ROOT_PATH/model/checkpoint/MLLM/ablation_data_OIv3_full_512_5e-5_2_custom_resume/best_3746_ppl_644.861',
    model='YOUR_ROOT_PATH/model/llama2-1229/Llama-2-7b-hf',
    tokenizer='YOUR_ROOT_PATH/model/checkpoint/MLLM/tokenizer',
    tokenizer_mode="slow",
    trust_remote_code=True,
    tensor_parallel_size=1,
    dtype=torch.bfloat16,
)
# sampling_params = SamplingParams(temperature=0.8, top_p=0.95, logits_processors=logits_processors)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, prompt_logprobs=0, max_tokens=1)
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")