import argparse
from pathlib import Path

from llm_mcts.mcts_algo.algo_builder import build_algo
from llm_mcts.mcts_algo.mcts_config import MCTSConfig
from llm_mcts.mcts_algo.mcts_result import MCTSResult
from llm_mcts.mcts_scorer.default import DefaultScorer
from llm_mcts.models.model_builder import build_model
from llm_mcts.prompt_configs import PromptConfig
from llm_mcts.run_mcts import run_mcts
from llm_mcts.tasks.swe_bench.task import SWEBenchTask


logging_dir = (Path(__file__) / ".." / ".." / ".." / "logging" / "swe_bench").resolve()


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--experiment_name", type=str, required=True)
    parser.add_argument("-i", "--idx", type=int, required=True)
    parser.add_argument(
        "--dataset_name", type=str, default="princeton-nlp/SWE-bench_Lite_bm25_13K"
    )
    parser.add_argument("--split", type=str, default="dev")
    parser.add_argument(
        "-t",
        "--temperatures",
        type=str,
        default="0.0",
        help="comma separated list of temperatures",
    )
    parser.add_argument(
        "--answer_models",
        type=str,
        default="gpt-4o-2024-08-06",
        help="comma separated list of model names",
    )
    parser.add_argument(
        "--answer_model_probs",
        type=str,
        default=None,
        help="comma separated list of model probabilities",
    )
    parser.add_argument("--initial_prompt_type", type=str, default="swe_bench_baseline")
    parser.add_argument("--num_simulations", type=int, default=1)
    parser.add_argument("--num_expand_samples", type=int, default=None)
    parser.add_argument("--initial_expand_samples", type=int, default=None)
    parser.add_argument("--judge_model", type=str, default="gpt-4o-2024-08-06")
    parser.add_argument("--judge_temperature", type=str, default="0.0")
    parser.add_argument("--mcts_algo", type=str, default="standard")

    return parser.parse_args()


def main() -> MCTSResult:
    args = parse_args()

    save_dir = logging_dir / args.experiment_name
    save_dir.mkdir(exist_ok=True)

    answer_model = build_model(
        model_names=args.answer_models,
        model_probs=args.answer_model_probs,
        temperatures=args.temperatures,
        logging_dir=save_dir / f"prediction_{args.idx}",
    )

    judge_model = build_model(
        model_names=args.judge_model,
        model_probs=None,
        temperatures=args.judge_temperature,
        logging_dir=save_dir / f"judge_{args.idx}",
    )

    task = SWEBenchTask.load_record(
        idx=args.idx,
        judge_model=judge_model,
        dataset_name=args.dataset_name,
        split=args.split,
    )

    scorer = DefaultScorer()

    mcts_config = MCTSConfig(
        num_simulations=args.num_simulations,
        num_expand_samples=args.num_expand_samples,
        initial_expand_samples=args.initial_expand_samples,
        actions=("answer",),
    )

    prompt_config = PromptConfig(
        is_o1=False, initial_prompt_type=args.initial_prompt_type, with_image=True
    )

    mcts_algo = build_algo(args.mcts_algo, config=mcts_config)

    mcts_result = run_mcts(
        task=task,
        model=answer_model,
        scorer=scorer,
        mcts_config=mcts_config,
        prompt_config=prompt_config,
        mcts_algo=mcts_algo,
    )

    mcts_result.save(save_dir / f"mcts_result_{args.idx}.pkl")

    return mcts_result


if __name__ == "__main__":
    main()
