import argparse
import json
import pickle
import re
from pathlib import Path
from typing import List

from llm_mcts.mcts_algo.mcts_result import MCTSResult
from llm_mcts.node_ranker.simple_ranker import SimpleRanker
from llm_mcts.tasks.live_code_bench_code_generation.task import (
    LiveCodeBenchCodeGenerationTask,
)

logging_dir = (
    Path(__file__) / ".." / ".." / ".." / "logging" / "live_code_bench_code_generation"
).resolve()


def extract_python_code(text: str) -> str:
    """
    Extracts a `python code` block from the given text. If no python code block is found, returns an empty string.
    """
    if "```python" in text:
        pattern = r"```python(.*?)```"
    elif "```" in text:
        pattern = r"```(.*?)```"
    else:
        return ""
    match = re.search(pattern, text, re.DOTALL)
    return match.group(1).strip() if match else ""


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument("-e", "--experiment_name", type=str, required=True)
    parser.add_argument("-i", "--idx", type=int, required=True)
    parser.add_argument("--release_version", type=str, default="release_v4")
    return parser.parse_args()


def main() -> None:
    args = parse_args()

    save_dir = logging_dir / args.experiment_name
    assert save_dir.is_dir(), f"Experiment {args.experiment_name} does not exist"

    save_result_json_dir = save_dir / "each_result"
    save_result_json_dir.mkdir(parents=True, exist_ok=True)

    # Load task
    task = LiveCodeBenchCodeGenerationTask.load_record(
        idx=args.idx, release_version=args.release_version
    )
    instance_id = (
        f"{args.idx}_{task.problem.question_title.replace(' ', '_').replace('/', '_')}"
    )

    # Load mcts result
    mcts_path = save_dir / f"mcts_result_{args.idx}.pkl"
    assert mcts_path.exists()
    mcts_result = MCTSResult.load(mcts_path)

    # Get top submission with valid python code
    top_node_list = SimpleRanker().top_k_predictions(mcts_result, k=None)
    for node in top_node_list:
        generated_completion = node.completion
        if generated_completion is None:
            continue
        result_code = extract_python_code(generated_completion.generation)
        if result_code != "":
            break

    # get test results
    if node.eval_results is not None:
        public_score = sum(
            eval_result.get_score() for eval_result in node.eval_results
        ) / len(node.eval_results)
    else:
        public_score = 0
    private_eval_results, private_score = task.evaluate_on_test(generated_completion)

    with open(save_result_json_dir / f"{instance_id}.json", "w") as f:
        json.dump(
            {
                "instance_id": instance_id,
                "model_name_or_path": args.experiment_name,
                "generated_code": result_code,
                "public_tests": {
                    "score": public_score,
                    "predictions": [
                        {
                            "answer": eval_result.answer,
                            "groundtruth": eval_result.groundtruth,
                        }
                        for eval_result in (node.eval_results or [])
                    ],
                },
                "private_tests": {
                    "score": (
                        private_score / len(private_eval_results)
                        if private_eval_results
                        else 0
                    ),
                    "predictions": [
                        {
                            "answer": eval_result.answer,
                            "groundtruth": eval_result.groundtruth,
                        }
                        for eval_result in private_eval_results
                    ],
                },
            },
            f,
        )

    print(f"Saved to {save_result_json_dir / f'{instance_id}.json'}")

    save_private_test_results_pkl_dir = save_dir / "private_test_results"
    save_private_test_results_pkl_dir.mkdir(parents=True, exist_ok=True)

    with open(save_private_test_results_pkl_dir / f"{instance_id}.pkl", "wb") as f:
        pickle.dump(private_eval_results, f)


if __name__ == "__main__":
    main()
