import os

from vllm import LLM, SamplingParams

class vLLM:
    def __init__(self, llm_name, tensor_parallel_size):
        self.llm = LLM(f"../../models/{llm_name}", tensor_parallel_size=tensor_parallel_size, max_model_len=43152)
        self.llm_name = llm_name
    
    def generate(self, prompts, temp, max_length):
        sampling_params = SamplingParams(temperature=temp, max_tokens=max_length)
        outputs = self.llm.generate(prompts, sampling_params)
        output_texts = [output.outputs[0].text for output in outputs]
        return output_texts

if __name__ == '__main__':
    
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    
    llm = vLLM('llama-3.1-8b-instruct')
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    response = llm.generate(prompts, 1e-5, 50)
    print(response)