from utils import extract_omni_answer
from abc import ABC, abstractmethod
from datetime import datetime
import json
import os


class BaseEvaluator(ABC):

    def __init__(self, args):
        self.args = args

        self.benchmark = self.args.benchmark
        
        self.output_dir = self.args.output_dir
        if not os.path.exists(self.output_dir):
            os.makedirs(self.output_dir, exist_ok=True)
        time_str = datetime.now().strftime("%Y%m%d")
        self.output_file = os.path.join(self.output_dir, f"{self.benchmark}_{time_str}_v2.jsonl")
        
        self.eval_metrics = {}

    def json_handler(self, results):
        
        with self.lock:
            with open(self.output_file, "a", encoding="utf-8") as fw:
                for data in results:
                    ##################unified test ##########################
                    if os.environ.get("TEST_UNIFIED", "False").lower() in ["1", "true", "yes"]:
                        fw.write(json.dumps(data.dumps())+'\n')
                        continue
                    ######################################################
                    try:
                        fw.write(data.criteria_and_judge_dumps() + "\n")
                    except:
                        continue
                        # fw.write(data.criteria_and_judge_dumps_simple() + "\n")

    def eval_hander(self, results):
        
        with self.lock:
            for r in results:
                try:
                    suffix = r.paired_data.suffix
                    _id = r.paired_data.id
                    
                    if suffix not in self.eval_metrics:
                        self.eval_metrics[suffix] = {
                            "correct": 0,
                            "error": 0,
                            "count": 0,
                            "parse_error": 0
                        }
                    
                    if suffix not in self.eval_metrics[suffix]:
                        self.eval_metrics[suffix][_id] = {                        
                            "correct": 0,
                            "error": 0,
                            "count": 0,
                            "parse_error": 0
                        }

                    correct, pred = extract_omni_answer(r.llm_response[0], r.answer)
                    # hack for unified test ##############################
                    if os.environ.get("TEST_UNIFIED", "False").lower() in ["1", "true", "yes"]:
                        print(f"use unified extract answer func!!!")
                        def unified_extract_func(response: str, answer: int):
                            correct = None
                            pred = "dummy set for unified test"
                            if answer == 0:
                                tag = "Answer 1 is better"
                            elif answer == 1:
                                tag == "Answer 2 is better"
                            else: raise ValueError(f"answer is not 0 or 1")
                            if tag in response:
                                correct = True
                            else:
                                correct = False
                            return correct, pred
                        
                        correct, pred = unified_extract_func(response= r.llm_response[0], answer= r.answer)
                    ################################################################################
                    self.eval_metrics[suffix]['count'] += 1
                    self.eval_metrics[suffix][_id]["count"] += 1
                    if correct:
                        self.eval_metrics[suffix]['correct'] += 1
                        self.eval_metrics[suffix][_id]["correct"] += 1
                    else:
                        self.eval_metrics[suffix]['error'] += 1
                        self.eval_metrics[suffix][_id]["error"] += 1
                    
                    if pred is None:
                        self.eval_metrics[suffix]['parse_error'] += 1
                        self.eval_metrics[suffix][_id]["parse_error"] += 1
                except:
                    continue
    def save(self):
        
        new_eval_metrics = {}
        for suffix, values in self.eval_metrics.items():

            correct = 0
            count = 0
            for _suffix, v in values.items():
                if isinstance(v, dict):
                    if v['correct'] > (self.sampling_n / 2):
                        correct += 1
                    count += 1
            acc = correct / count
            
            new_eval_metrics[suffix] = {}
            new_eval_metrics[suffix]["accury"] = acc
            new_eval_metrics[suffix]["correct"] = correct
            new_eval_metrics[suffix]["count"] = count
            
            new_eval_metrics[suffix]["error"] = self.eval_metrics[suffix]['error']
            new_eval_metrics[suffix]["parse_error"] = self.eval_metrics[suffix]['parse_error']
            print(f"🎯 The {suffix}  Accuracy = {acc}")
            
        with open(self.output_file.replace("jsonl", "metrics"), "a", encoding="utf-8") as fw:
            fw.write(json.dumps(new_eval_metrics, ensure_ascii=False) + "\n")
        print(f"📊 The Final Evaluation Results: {new_eval_metrics}")
