import fire

from loguru import logger

from src.evaluator.benchmark_evaluator import BenchmarkEvaluator
from src.evaluator.metric import (
    AVAILABLE_GENERATION_METRICS,
    AVAILABLE_METRICS,
    AVAILABLE_RETREIVAL_METRICS,
)
from src.schema import InferenceResult
from src.utils.json import (
    read_json_file,
    write_json_file,
)
from src.utils.log import set_log_level


def eval(
    inference_results_path: str,
    evaluation_report_path: str,
    infer_mode: str = "both",
    num_workers: int = 8,
) -> None:
    set_log_level()

    assert infer_mode in [
        "both",
        "ret",
        "gen",
    ], f"You can only choose from ['both', 'ret', 'gen'], but got {infer_mode}"

    if infer_mode == "both":
        metric_list = AVAILABLE_METRICS
    elif infer_mode == "ret":
        metric_list = AVAILABLE_RETREIVAL_METRICS
    elif infer_mode == "gen":
        metric_list = AVAILABLE_GENERATION_METRICS

    evaluator = BenchmarkEvaluator(metric_list=metric_list)

    inference_results_dict = read_json_file(inference_results_path)
    inference_results = [InferenceResult(**res) for res in inference_results_dict]

    eval_report = evaluator.eval(inference_results=inference_results, num_workers=num_workers)
    eval_report_dict = eval_report.model_dump(mode="json")
    logger.info(f"[Evaluation Report]\n{eval_report_dict}")
    write_json_file(file_path=evaluation_report_path, data=eval_report_dict)


if __name__ == "__main__":
    fire.Fire(eval)
