from dataloader.bench_dataloader import BenchmarkDataLoader
from prompt.omni_prompt import extract_omni_answer
from .evaluator import BaseEvaluator


class AudioEvaluator(BaseEvaluator):

    def __init__(self, args):
        super().__init__(args)
        assert self.modality == "audio" or self.modality == "omni"

        self.bench_loader = BenchmarkDataLoader(self.args)
        self.lock = threading.Lock()

    def json_handler(self, results):

        with self.lock:
            with open(self.output_file, "a", encoding="utf-8") as fw:
                for r in results:
                    fw.write(json.dumps(r, ensure_ascii=False) + "\n")
    
    def eval(self, router):

        benchmark_dataset_iter = self.bench_loader.load_iter(self.benchmark, self.modality)

        router.add_handler(self.eval_hander)
        router.add_handler(self.json_handler)

        start_time = time.time()
        router.init_vllm()
        router.run(benchmark_dataset_iter)
        end_time = time.time()
        elapsed = end_time - start_time
        
        print(f"⏱️ AudioEvaluator router.run() Spend Time: {elapsed:.2f} 秒")
        self.save()