import os

from llms.vllms import vLLM
from utils_proof.file import *

class LLM:
    def __init__(self, llm_name) -> None:
        self.llm_name = llm_name
        self.llm = vLLM(llm_name)
    
    def generate(self, prompts, temp=1e-5, max_length=1000, max_retry=20):
        response = self.llm.generate(prompts, temp, max_length)
        return response
    
    def inference(self, prompts, output_path=None, temp=1e-5, max_length=1000, batch_size=50):
        pred_list = []
        if os.path.exists(output_path):
            len_input = len(prompts)
            preds = load_json(output_path)
            len_output = len(preds)
            if len_input == len_output:
                return
            elif len_input < len_output:
                raise ValueError
            else:
                pred_list += preds
                prompts = prompts[len_output:]
        
        response = self.llm.generate(prompts, temp=temp, max_length=max_length)
        pred_list += response
        write_json(output_path, pred_list)

        return pred_list