import os
import torch
import pdb

from vllm import LLM, SamplingParams

class vLLM:
    def __init__(self, llm_name):
        tensor_parallel_size = torch.cuda.device_count()
        self.llm = LLM(f"../../models/{llm_name}", tensor_parallel_size=tensor_parallel_size, max_model_len=43152)
        self.llm_name = llm_name
    
    def generate(self, prompts, temp, max_length):
        sampling_params = SamplingParams(temperature=temp, max_tokens=max_length)
        outputs = self.llm.generate(prompts, sampling_params)
        output_texts = [output.outputs[0].text for output in outputs]
        return output_texts
    
if __name__ == '__main__':
    
    os.environ['CUDA_VISIBLE_DEVICES'] = '1'
    
    llm = vLLM('codeqwen1.5-7b-chat')
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]
    response = llm.generate(prompts, 1e-5, 50)
    print(response)