import json
import logging
import os
from pprint import pformat

import hydra
from dotenv import load_dotenv
from loguru import logger
from omegaconf import DictConfig

from hallucinations.config import LllmJudgeConfig
from hallucinations.data.factory import get_dataset
from hallucinations.evaluation.eval_with_api import LlmJudgeEvaluator
from hallucinations.utils import resolve_config

load_dotenv(override=True)
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "api_key_missing")
RATE_LIMIT = os.getenv("RATE_LIMIT", None)

logging.getLogger("openai").setLevel(logging.WARNING)
logging.getLogger("httpx").setLevel(logging.WARNING)


@hydra.main(version_base="1.3", config_path="../../config", config_name="llm_as_judge")
def main(cfg: DictConfig) -> None:
    config = LllmJudgeConfig(**resolve_config(cfg))
    logger.info(f"Config: {pformat(config.model_dump())}")

    config.evaluation_config_file.parent.mkdir(parents=True, exist_ok=True)
    dataset = get_dataset(config.dataset, split=config.dataset.test_split_name)

    with config.answers_file.open("r") as file:
        answers = json.load(file)

    answers = [
        {
            "question": ds_item["question"],
            "prediction": ans_item["prediction"],
            "gold": ans_item["gold"],
        }
        for ds_item, ans_item in zip(dataset, answers, strict=True)
    ]

    eval_func = LlmJudgeEvaluator(
        config=config,
        openai_api_key="<none>",
        batch_size=config.llm_api.batch_size,
        rate_limit=int(RATE_LIMIT) if RATE_LIMIT is not None else None,
    )

    costs = eval_func.estimate_cost(answers)
    print(f"Total cost: {costs['total_cost']:_.3f} USD")


if __name__ == "__main__":
    main()
