from typing import List
from collections import Counter
from concurrent.futures import ThreadPoolExecutor

from tqdm import tqdm

from src.evaluator.rubric_evaluator import RubricEvaluator
from src.evaluator.metric import (
    AVAILABLE_METRICS,
    metric_factory,
)
from src.schema import (
    PredictedAnswer,
    Prediction,
    EvaluationReport,
    EvaluationReportGroup,
    Domain,
)


class BenchmarkEvaluator:
    def __init__(
        self,
        rubric_evaluator: RubricEvaluator,
        metric_list: List[str],
    ) -> None:
        self.rubric_evaluator = rubric_evaluator

        assert all(metric in AVAILABLE_METRICS for metric in metric_list), f"Invalid metric in {metric_list}"
        self.metric_dict = {metric_name: metric_factory(metric_name=metric_name) for metric_name in metric_list}

        self.metric_scores = {metric_name: [] for metric_name in metric_list}
        self.medical_metric_scores = {metric_name: [] for metric_name in metric_list}
        self.legal_metric_scores = {metric_name: [] for metric_name in metric_list}
        self.casual_metric_scores = {metric_name: [] for metric_name in metric_list}

        self.eval_predictions = []
        self.medical_eval_predictions = []
        self.legal_eval_predictions = []
        self.casual_eval_predictions = []

    def _reset(self) -> None:
        for metric in self.metric_dict.values():
            metric.reset()
        
        self.eval_predictions = []

    def _update(self, prediction: Prediction) -> None:
        eval_prediction = Prediction(
            qa=prediction.qa,
            pred_answers=[
                PredictedAnswer(
                    answer_id=pred_answer.answer_id,
                    ret_document_ids=pred_answer.ret_document_ids,
                    ret_document_contents=pred_answer.ret_document_contents,
                    ret_document_scores=pred_answer.ret_document_scores,
                    ret_memory_ids=pred_answer.ret_memory_ids,
                    ret_memory_contents=pred_answer.ret_memory_contents,
                    ret_memory_scores=pred_answer.ret_memory_scores,
                    raw_answer=pred_answer.raw_answer,
                    rubric_question_answer_dict=self.rubric_evaluator.eval(
                        raw_answer=pred_answer.raw_answer,
                        answers=prediction.qa.answers
                    ),
                    metadata=pred_answer.metadata,
                )
                for pred_answer in prediction.pred_answers
            ],
        )

        for metric_name, metric in self.metric_dict.items():
            eval_prediction, scores = metric.update(eval_prediction)
            self.metric_scores[metric_name].extend(scores)
            if prediction.qa.domain == Domain.MEDICAL:
                self.medical_metric_scores[metric_name].extend(scores)
            elif prediction.qa.domain == Domain.LEGAL:
                self.legal_metric_scores[metric_name].extend(scores)
            elif prediction.qa.domain == Domain.CASUAL:
                self.casual_metric_scores[metric_name].extend(scores)
            else:
                raise ValueError(f"Unknown domain {prediction.qa.domain} in prediction {prediction.qa.qa_id}")
    
        self.eval_predictions.append(eval_prediction)    
        if prediction.qa.domain == Domain.MEDICAL:
            self.medical_eval_predictions.append(eval_prediction)
        elif prediction.qa.domain == Domain.LEGAL:
            self.legal_eval_predictions.append(eval_prediction)
        elif prediction.qa.domain == Domain.CASUAL:
            self.casual_eval_predictions.append(eval_prediction)
        else:
            raise ValueError(f"Unknown domain {prediction.qa.domain} in prediction {prediction.qa.qa_id}")

    def _compute(self) -> EvaluationReportGroup:
        metric_value_dict = {metric_name: metric.compute() for metric_name, metric in self.metric_dict.items()}
        total_eval_report = EvaluationReport(
            predictions=self.eval_predictions,
            **metric_value_dict,
            inst_failure_mode_counts=dict(Counter([pa.inst_failure_mode for p in self.eval_predictions for pa in p.pred_answers])),
            dist_failure_mode_counts=dict(Counter([p.dist_failure_mode for p in self.eval_predictions])),
        )
        medical_eval_report = EvaluationReport(
            predictions=self.medical_eval_predictions,
            **{metric_name: sum(scores) / len(scores) if scores else 0 for metric_name, scores in self.medical_metric_scores.items()},
            inst_failure_mode_counts=dict(Counter([pa.inst_failure_mode for p in self.medical_eval_predictions for pa in p.pred_answers])),
            dist_failure_mode_counts=dict(Counter([p.dist_failure_mode for p in self.medical_eval_predictions])),
        )
        legal_eval_report = EvaluationReport(
            predictions=self.legal_eval_predictions,
            **{metric_name: sum(scores) / len(scores) if scores else 0 for metric_name, scores in self.legal_metric_scores.items()},
            inst_failure_mode_counts=dict(Counter([pa.inst_failure_mode for p in self.legal_eval_predictions for pa in p.pred_answers])),
            dist_failure_mode_counts=dict(Counter([p.dist_failure_mode for p in self.legal_eval_predictions])),
        )
        casual_eval_report = EvaluationReport(
            predictions=self.casual_eval_predictions,
            **{metric_name: sum(scores) / len(scores) if scores else 0 for metric_name, scores in self.casual_metric_scores.items()},
            inst_failure_mode_counts=dict(Counter([pa.inst_failure_mode for p in self.casual_eval_predictions for pa in p.pred_answers])),
            dist_failure_mode_counts=dict(Counter([p.dist_failure_mode for p in self.casual_eval_predictions])),
        )
        eval_report_group = EvaluationReportGroup(
            total_eval_report=total_eval_report,
            medical_eval_report=medical_eval_report,
            legal_eval_report=legal_eval_report,
            casual_eval_report=casual_eval_report,
        )
        return eval_report_group

    def eval(self, predictions: List[Prediction], num_workers: int = 8) -> EvaluationReportGroup:
        self._reset()
        with ThreadPoolExecutor(max_workers=num_workers) as pool:
            _ = list(
                tqdm(
                    pool.map(self._update, predictions),
                    total=len(predictions),
                )
            )
        eval_report = self._compute()
        return eval_report
