import fire
from loguru import logger

from src.batch_processor import batch_processor_factory
from src.schema import Prediction
from src.utils.json import (
    read_jsonl_file,
    write_jsonl_file,
)
from src.utils.log import set_log_level


def gen_parse_res(
    retrieval_path: str,
    batch_path: str,
    response_path: str,
    generator_name: str,
    api_type: str,  # [s, p, b]
) -> None:
    set_log_level()

    retrieval_data = read_jsonl_file(retrieval_path)
    predictions = [Prediction(**data) for data in retrieval_data]
    batch_processor = batch_processor_factory(
        generator_name=generator_name,
        api_type=api_type,
    )
    gen_results = batch_processor.parse_response(
        predictions=predictions,
        batch_path=batch_path,
    )
    gen_results_dict = [gen_result.model_dump(mode="json") for gen_result in gen_results]
    write_jsonl_file(file_path=response_path, data=gen_results_dict)
    logger.success(f"Done!")


if __name__ == "__main__":
    fire.Fire(gen_parse_res)
