from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import os

tensor_parallel_size = int(os.environ.get("TENSOR_PARALLEL_SIZE", 4))
max_tokens = int(os.environ.get("MAX_TOKENS", 8192))
temperature = float(os.environ.get("TEMPERATURE", 0.0))

llm = None
tokenizer = None

def init_vllm(model_path):
    global llm, tokenizer
    if llm is not None:
        pass
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        llm = LLM(model=model_path, tensor_parallel_size=tensor_parallel_size)
        print("VLLM model initialized.")
    
def get_response_list(messages, enable_thinking=False):
    global llm
    texts = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
        enable_thinking=enable_thinking,
    )
    
    outputs = llm.generate(texts, sampling_params=SamplingParams(
        max_tokens=max_tokens,
        temperature=temperature,
    ))
    
    response_list = [output.outputs[0].text for output in outputs]
    return response_list