import argparse
import json
from pathlib import Path

from llm_mcts.mcts_algo.mcts_result import MCTSResult
from llm_mcts.models.model_builder import build_model
from llm_mcts.node_ranker.simple_ranker import SimpleRanker
from llm_mcts.tasks.omni_math.omni_judge import MODEL_NAME as OMNI_JUDGE_MODEL_NAME
from llm_mcts.tasks.omni_math.task import OmniMathTask

logging_dir = (Path(__file__) / ".." / ".." / ".." / "logging" / "omni_math").resolve()


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--experiment_name", type=str, required=True)
    parser.add_argument("-i", "--idx", type=int, required=True)
    parser.add_argument(
        "--judge_model",
        type=str,
        default="gpt-4o-2024-08-06",
        help="model name for the answer extraction",
    )
    parser.add_argument(
        "--judge_temperature",
        type=str,
        default="0.0",
        help="temperature for the answer extraction",
    )
    parser.add_argument("--reward_model_name", type=str, default=None)
    parser.add_argument("--only_reward_model", action="store_true")
    parser.add_argument("--is_sigmoid", action="store_true")
    parser.add_argument("--dataset_name", type=str)
    return parser.parse_args()


def main() -> None:
    args = parse_args()

    save_dir = logging_dir / args.experiment_name
    assert save_dir.is_dir(), f"Experiment {args.experiment_name} does not exist"

    save_result_json_dir = save_dir / "each_result"
    save_result_json_dir.mkdir(parents=True, exist_ok=True)

    # Load task
    if args.judge_model != OMNI_JUDGE_MODEL_NAME:
        judge_model = build_model(
            model_names=args.judge_model,
            model_probs=None,
            temperatures=args.judge_temperature,
            logging_dir=save_dir / f"extraction_{args.idx}",
        )
    else:
        judge_model = args.judge_model

    task = OmniMathTask.load_record(
        idx=args.idx,
        judge_model=judge_model,
        dataset_name=args.dataset_name,
        reward_model_name=args.reward_model_name,
        only_reward_model=args.only_reward_model,
        is_sigmoid=args.is_sigmoid,
    )
    instance_id = args.idx

    # Load mcts result
    mcts_path = save_dir / f"mcts_result_{args.idx}.pkl"
    assert mcts_path.is_file(), f"MCTS result {mcts_path} does not exist"
    mcts_result = MCTSResult.load(mcts_path)

    # Get top submission
    top_node_list = SimpleRanker().top_k_predictions(mcts_result)
    for node in top_node_list:
        generated_completion = node.completion
        if generated_completion is None:
            continue
        extracted_result = task.llm_judge.extract_answer(
            generated_completion.generation
        )
        _, test_score = task.llm_judge.check_test_score(
            task.problem, generated_completion.generation
        )
        if extracted_result != "":
            break

    with open(save_result_json_dir / f"{instance_id}.json", "w") as f:
        json.dump(
            {
                "instance_id": instance_id,
                "model_name_or_path": args.experiment_name,
                "answer_raw": generated_completion.generation,
                "answer_extracted": extracted_result,
                "llm_judge_score": (
                    node.eval_results[0].score if node.eval_results is not None else 0
                ),
                "test_score": (test_score),  # Get the test result here
            },
            f,
        )
    print(f"Saved to {save_result_json_dir / f'{instance_id}.json'}")


if __name__ == "__main__":
    main()
