import argparse
from pathlib import Path

from llm_mcts.file_logging import get_problem_info
from llm_mcts.mcts_algo.mcts_result import MCTSResult
from llm_mcts.node_ranker.regularized_ranker import RegularizedRanker

from llm_mcts.node_ranker.simple_ranker import SimpleRanker
from llm_mcts.tasks.arc.task import ARCProblem
from llm_mcts.mcts_scorer.default import DefaultScorer

# https://arcprize.org/guide#scoring-methodology
MAX_SUBMISSIONS = 2


def submit(
    mcts_path: Path, problem_path: Path, transduction_result_path: Path | None = None
) -> bool:
    assert mcts_path.exists()
    mcts_result = MCTSResult.load(mcts_path)
    arc_problem_path = Path(problem_path)
    task = ARCProblem.load_file(arc_problem_path)
    scorer = DefaultScorer()

    submissions = RegularizedRanker(scorer, task).top_k_predictions(
        mcts_result,
        k=MAX_SUBMISSIONS,
    )
    for submission in submissions:
        if submission.score == 1.0:
            return True
    return False


def get_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("--mcts-path", required=False, type=Path)
    parser.add_argument("--result-path", required=False, type=Path)
    parser.add_argument("--problem_path", required=False, type=str, default=None)
    parser.add_argument(
        "--transduction-result-path", required=False, type=Path, default=None
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()

    mcts_path = args.mcts_path
    result_path = args.result_path
    problem_path = args.problem_path

    is_correct = submit(
        mcts_path=mcts_path,
        problem_path=problem_path,
        transduction_result_path=args.transduction_result_path,
    )
    if result_path.exists():
        result_path.unlink()

    result_path.write_text("1" if is_correct else "0")
