import fire

from loguru import logger

from src.generator import generator_factory
from src.evaluator.rubric_evaluator import RubricEvaluator
from src.evaluator.benchmark_evaluator import BenchmarkEvaluator
from src.evaluator.metric import (
    AVAILABLE_GENERATION_METRICS,
    AVAILABLE_METRICS,
    AVAILABLE_RETREIVAL_METRICS,
)
from src.schema import Prediction
from src.utils.json import (
    read_jsonl_file,
    write_json_file,
)
from src.utils.log import set_log_level


def eval(
    response_path: str,
    eval_report_group_path: str,
    eval_mode: str = "both",
    num_workers: int = 8,
) -> None:
    set_log_level()

    assert eval_mode in [
        "both",
        "ret",
        "gen",
    ], f"You can only choose from ['both', 'ret', 'gen'], but got {eval_mode}"

    if eval_mode == "both":
        metric_list = AVAILABLE_METRICS
    elif eval_mode == "ret":
        metric_list = AVAILABLE_RETREIVAL_METRICS
    elif eval_mode == "gen":
        metric_list = AVAILABLE_GENERATION_METRICS

    generator_4o_greedy = generator_factory(generator_name="gpt-4o-2024-08-06", temperature=0.0)
    rubric_evaluator = RubricEvaluator(generator=generator_4o_greedy)
    evaluator = BenchmarkEvaluator(rubric_evaluator=rubric_evaluator, metric_list=metric_list)
    predictions_dict = read_jsonl_file(response_path)
    predictions = [Prediction(**res) for res in predictions_dict]
    eval_report_group = evaluator.eval(predictions=predictions, num_workers=num_workers)
    eval_report_group_dict = eval_report_group.model_dump(mode="json")
    write_json_file(file_path=eval_report_group_path, data=eval_report_group_dict)
    eval_report_group_dict_simple = eval_report_group_dict
    eval_report_group_dict_simple["total_eval_report"].pop("predictions")
    eval_report_group_dict_simple["medical_eval_report"].pop("predictions")
    eval_report_group_dict_simple["legal_eval_report"].pop("predictions")
    eval_report_group_dict_simple["casual_eval_report"].pop("predictions")
    write_json_file(file_path=eval_report_group_path.replace(".json", "_simple.json"), data=eval_report_group_dict_simple)
    logger.success("Done!")


if __name__ == "__main__":
    fire.Fire(eval)
