from loguru import logger
from vllm import LLM, SamplingParams


def generate_completions(
    model_dir: str,
    prompts: list[str],
    max_completion_length=256,
    stop_words: list[str] = [],
    temperature: float = 1.0,
):
    logger.info(f"Number of prompts: {len(prompts)}")
    logger.info(f"Example prompt: {prompts[0]}")

    llm = LLM(model_dir, gpu_memory_utilization=0.5)
    sampling_params = SamplingParams(
        max_tokens=max_completion_length, temperature=temperature, stop=stop_words
    )
    results = llm.generate(prompts, sampling_params)
    completions = [res.outputs[0].token_ids for res in results]
    return completions
