import argparse
import json
from pathlib import Path

from llm_mcts.mcts_algo.algo_builder import build_algo
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.mcts_result import MCTSResult
from llm_mcts.mcts_scorer.default import DefaultScorer
from llm_mcts.models.model_builder import build_model
from llm_mcts.prompt_configs import PromptConfig
from llm_mcts.run_mcts import run_mcts
from llm_mcts.tasks.arc.task import ARCProblem

logging_dir = (Path(__file__) / ".." / ".." / ".." / "logging" / "arc").resolve()


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--experiment_name", type=str, required=True)
    parser.add_argument("-i", "--idx", type=str, required=True)
    parser.add_argument("--split", type=str, default="evaluation")
    parser.add_argument(
        "-t",
        "--temperatures",
        type=str,
        default="0.0",
        help="comma separated list of temperatures",
    )
    parser.add_argument(
        "--answer_models",
        type=str,
        default="gpt-4o-2024-08-06",
        help="comma separated list of model names",
    )
    parser.add_argument(
        "--answer_model_probs",
        type=str,
        default=None,
        help="comma separated list of model probabilities",
    )
    parser.add_argument("--initial_prompt_type", type=str, default="baseline_v1_single_turn")
    parser.add_argument("--num_simulations", type=int, default=1)
    parser.add_argument("--num_expand_samples", type=int, default=None)
    parser.add_argument("--initial_expand_samples", type=int, default=None)
    parser.add_argument("--mcts_algo", type=str, default="standard")
    parser.add_argument("--priors", type=str, default=None)
    parser.add_argument("--multimodel_strategy", type=str, default="stack")

    return parser.parse_args()


def main() -> MCTSResult:
    args = parse_args()

    save_dir = logging_dir / args.experiment_name
    save_dir.mkdir(exist_ok=True)

    answer_model = build_model(
        model_names=args.answer_models,
        model_probs=args.answer_model_probs,
        temperatures=args.temperatures,
        logging_dir=save_dir / f"prediction_{args.idx}",
    )

    arc_problem_path = Path(f"ARC-AGI/data/{args.split}/{args.idx}.json")
    task = ARCProblem.load_file(arc_problem_path)

    scorer = DefaultScorer()

    mcts_config = MCTSConfig(
        num_simulations=args.num_simulations,
        num_expand_samples=args.num_expand_samples,
        initial_expand_samples=args.initial_expand_samples,
        actions=("transform",),
    )

    prompt_config = PromptConfig(
        is_o1=False, initial_prompt_type=args.initial_prompt_type, with_image=True
    )

    kwargs = dict()
    kwargs["priors"] = args.priors
    kwargs["answer_models"] = args.answer_models
    kwargs["multimodel_strategy"] = args.multimodel_strategy

    mcts_algo = build_algo(algo_name=args.mcts_algo, config=mcts_config, **kwargs)

    mcts_result = run_mcts(
        task=task,
        model=answer_model,
        scorer=scorer,
        mcts_config=mcts_config,
        prompt_config=prompt_config,
        mcts_algo=mcts_algo,
    )

    total_price = 0
    for fname in answer_model.logging_dir.glob("*.txt"):
        try:
            total_price += json.load(open(fname, "r"))["price"]["total"]
        except Exception:
            continue

    (answer_model.logging_dir / "cost_dollar.txt").write_text(str(total_price))

    mcts_result.save(save_dir / f"mcts_result_{args.idx}.pkl")

    return mcts_result


if __name__ == "__main__":
    main()
