import argparse
from pathlib import Path

from llm_mcts.mcts_algo.mcts_result import MCTSResult
from llm_mcts.mcts_scorer.default import DefaultScorer
from llm_mcts.models.model_builder import build_model
from llm_mcts.tasks.math_vista.task import MathVistaTask
from llm_mcts.visualize_mcts import render_mcts_graph

logging_dir = (Path(__file__) / ".." / ".." / ".." / "logging" / "math_vista").resolve()


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--experiment_name", type=str, required=True)
    parser.add_argument("-i", "--idx", type=int, required=True)
    parser.add_argument("-s", "--split", type=str, default="testmini")
    parser.add_argument("--judge_model", type=str, default="gpt-4o-2024-08-06")
    parser.add_argument("--judge_temperature", type=str, default="0.0")
    return parser.parse_args()


def main() -> None:
    args = parse_args()

    save_dir = logging_dir / args.experiment_name
    assert save_dir.is_dir(), f"Experiment {args.experiment_name} does not exist"

    mcts_result = MCTSResult.load(save_dir / f"mcts_result_{args.idx}.pkl")
    scorer = DefaultScorer()
    judge_model = build_model(
        model_names=args.judge_model,
        model_probs=None,
        temperatures=args.judge_temperature,
        logging_dir=save_dir / f"test_{args.idx}",
    )
    task = MathVistaTask.load_record(
        idx=args.idx, judge_model=judge_model, split=args.split
    )
    render_mcts_graph(
        mcts_result, scorer, task, save_dir / f"mcts_graph_{args.idx}", view=False
    )


if __name__ == "__main__":
    main()
