import argparse
import json
from pathlib import Path

from llm_mcts.mcts_algo.algo_builder import build_algo
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.mcts_result import MCTSResult
from llm_mcts.mcts_scorer.default import DefaultScorer
from llm_mcts.models.model_builder import build_model
from llm_mcts.prompt_configs import PromptConfig
from llm_mcts.run_mcts import run_mcts
from llm_mcts.tasks.omni_math.omni_judge import MODEL_NAME as OMNI_JUDGE_MODEL_NAME
from llm_mcts.tasks.omni_math.task import OmniMathTask

logging_dir = (Path(__file__) / ".." / ".." / ".." / "logging" / "omni_math").resolve()


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--experiment_name", type=str, required=True)
    parser.add_argument("-i", "--idx", type=int, required=True)
    parser.add_argument(
        "--answer_models",
        type=str,
        default="gpt-4o-2024-08-06",
        help="comma separated list of model names",
    )
    parser.add_argument(
        "--answer_model_probs",
        type=str,
        default=None,
        help="comma separated list of model probabilities",
    )
    parser.add_argument(
        "--answer_temperatures",
        type=str,
        default="0.0",
        help="comma separated list of temperatures",
    )
    parser.add_argument("--initial_prompt_type", type=str, default="omni_math_kou_v1")
    parser.add_argument("--num_simulations", type=int, default=0)
    parser.add_argument("--num_expand_samples", type=int, default=None)
    parser.add_argument("--initial_expand_samples", type=int, default=None)
    parser.add_argument("--mcts_algo", type=str, default="standard")
    parser.add_argument(
        "--judge_model",
        type=str,
        default="gpt-4o-2024-08-06",
        help="model name for the judge (evaluation)",
    )
    parser.add_argument(
        "--judge_temperature",
        type=str,
        default="0.0",
        help="temperature for the judge (evaluation)",
    )
    parser.add_argument("--reward_model_name", type=str, default=None)
    parser.add_argument("--only_reward_model", action="store_true")
    parser.add_argument("--is_sigmoid", action="store_true")
    parser.add_argument("--dataset_name", type=str)
    return parser.parse_args()


def main() -> MCTSResult:
    args = parse_args()

    save_dir = logging_dir / args.experiment_name
    save_dir.mkdir(parents=True, exist_ok=True)

    answer_model = build_model(
        model_names=args.answer_models,
        model_probs=args.answer_model_probs,
        temperatures=args.answer_temperatures,
        logging_dir=save_dir / f"prediction_{args.idx}",
    )

    if args.judge_model != OMNI_JUDGE_MODEL_NAME:
        judge_model = build_model(
            model_names=args.judge_model,
            model_probs=None,
            temperatures=args.judge_temperature,
            logging_dir=save_dir / f"judge_{args.idx}",
        )
    else:
        judge_model = args.judge_model

    task = OmniMathTask.load_record(
        idx=args.idx,
        judge_model=judge_model,
        dataset_name=args.dataset_name,
        reward_model_name=args.reward_model_name,
        only_reward_model=args.only_reward_model,
        is_sigmoid=args.is_sigmoid,
    )

    scorer = DefaultScorer()

    mcts_config = MCTSConfig(
        num_simulations=args.num_simulations,
        num_expand_samples=args.num_expand_samples,
        initial_expand_samples=args.initial_expand_samples,
        actions=("answer",),
    )

    prompt_config = PromptConfig(
        is_o1=False, initial_prompt_type=args.initial_prompt_type, with_image=True
    )

    mcts_algo = build_algo(args.mcts_algo, config=mcts_config)

    mcts_result = run_mcts(
        task=task,
        model=answer_model,
        scorer=scorer,
        mcts_config=mcts_config,
        prompt_config=prompt_config,
        mcts_algo=mcts_algo,
    )

    total_price = 0
    for fname in answer_model.logging_dir.glob("*.txt"):
        try:
            total_price += json.load(open(fname, "r"))["price"]["total"]
        except Exception:
            continue

    (answer_model.logging_dir / "cost_dollar.txt").write_text(str(total_price))

    mcts_result.save(save_dir / f"mcts_result_{args.idx}.pkl")

    return mcts_result


if __name__ == "__main__":
    main()
