import time
import os
from tqdm import tqdm

from llms.vllms import vLLM
from utils.file import *

class Evaluator:
    def __init__(self, args) -> None:
        self.llm_name = args.llm
        self.task_name = args.task
        self.method = args.method
        if args.evaluate:
            tensor_parallel_size = len(args.gpus.split(','))
            self.llm = vLLM(self.llm_name, tensor_parallel_size)
        self.input_path = f"results/{self.task_name}/{self.llm_name}/{self.method}.json"
        self.output_path = f"preds/{self.task_name}/{self.llm_name}/{self.method}.json"
    
    def inference(self, temp):
        prompts, _ = read_prompt(self.input_path)
        pred_list = []
        if os.path.exists(self.output_path):
            len_input = len(prompts)
            preds = read_list_from_json(self.output_path)
            len_output = len(preds)
            if len_input == len_output:
                return
            elif len_input < len_output:
                raise ValueError("len_input < len_output")
            else:
                pred_list += preds
                prompts = prompts[len_output:]
        response = self.llm.generate(prompts, temp=temp, max_length=1000)
        pred_list += response
        write_list_to_json(self.output_path, pred_list)
    
    def evaluate(self):
        golds = read_list_from_json(self.input_path)
        preds = read_list_from_json(self.output_path)
        exec(f'from utils.evaluate.{self.task_name} import evaluate')
        exec(f'evaluate(preds, golds)')