from concurrent.futures import ThreadPoolExecutor
from typing import List

from tqdm import tqdm

from src.evaluator.metric import (
    AVAILABLE_METRICS,
    metric_factory,
)
from src.schema import (
    EvaluationReport,
    InferenceResult,
)


class BenchmarkEvaluator:
    def __init__(self, metric_list: List[str]) -> None:
        assert all(metric in AVAILABLE_METRICS for metric in metric_list), f"Invalid metric in {metric_list}"

        self.metric_dict = {metric_name: metric_factory(metric_name=metric_name) for metric_name in metric_list}

    def _reset(self) -> None:
        for metric in self.metric_dict.values():
            metric.reset()

    def _update(self, processed_res: InferenceResult) -> None:
        # AnswerSeparationService sometimes produce elements with wrong capitalization
        # FIXME improve AnswerSeparationService to produce the exact elements
        for pred_answer_set in processed_res.pred_answer_sets:
            pred_answer_set.elements = {
                s.split("]")[-1].lower().strip()
                if len(s.split("]")[-1].strip()) != 0 else
                s.replace("[", "").replace("]", "").lower().strip()
                for s in pred_answer_set.elements
            }
        for answer_set in processed_res.qa.answer_sets:
            answer_set.elements = {
                s.split("]")[-1].lower().strip()
                if len(s.split("]")[-1].strip()) != 0 else
                s.replace("[", "").replace("]", "").lower().strip()
                for s in answer_set.elements
            }

        for metric in self.metric_dict.values():
            metric.update(processed_res)

    def _compute(self) -> EvaluationReport:
        eval_report_dict = {metric_name: metric.compute() for metric_name, metric in self.metric_dict.items()}
        eval_report = EvaluationReport(**eval_report_dict)
        return eval_report

    def eval(self, inference_results: List[InferenceResult], num_workers: int = 8) -> EvaluationReport:
        self._reset()
        with ThreadPoolExecutor(max_workers=num_workers) as pool:
            _ = list(
                tqdm(
                    pool.map(self._update, inference_results),
                    total=len(inference_results),
                )
            )
        eval_report = self._compute()
        return eval_report
