import argparse, os

from context_manager import TaskEnvContextManager
from multiprocessing import Pool, cpu_count
from typing import Dict
from instances_check import setup_testbed
from utils import (
    extract_minimal_patch,
    get_instances,
    split_instances,
    DotDict
)


def evaluate_predictions(data: Dict):
    """
    Sets up task environment context manager. Each prediction is then
    evaluated within the context manager.

    Args:
        data: Dict containing task instances and other data
            task_instances: List of [task instance, prediction] pairs to evalute
            + setup_testbed args
    """
    data_dict = DotDict(data)
    for task_instance in data_dict.task_instances:
        with TaskEnvContextManager(
            task_instance,
            data_dict.testbed,
            data_dict.venv,
            data_dict.log_dir,
            data_dict.conda_path,
            verbose=data_dict.verbose,
            timeout=data_dict.timeout,
            is_eval=True,
        ) as tcm:
            if tcm.log_file_exists:
                print(
                    f"Skipping {task_instance['instance_id']}, log file ({tcm.log_file}) already exists"
                )
                continue

            # Attempt to set up environment with task instance
            if not tcm.reset_task_env(task_instance):
                continue

            # Attempt to apply prediction
            patch_type = "pred_try"
            if not tcm.apply_patch(task_instance["prediction"], patch_type=patch_type):
                task_instance["prediction"] = extract_minimal_patch(task_instance["prediction"])
                patch_type = "pred_minimal_try"
                if not tcm.apply_patch(task_instance["prediction"], patch_type=patch_type):
                    continue
            tcm.apply_patch(task_instance["prediction"], patch_type=patch_type, revert=True)
            patch_type = patch_type.replace("_try", "")

            # Run installation + testing script
            if (
                not tcm.run_install_task(task_instance)
                or not tcm.apply_patch(task_instance["test_patch"], patch_type="test")
                or not tcm.apply_patch(task_instance["prediction"], patch_type=patch_type)
                or not tcm.run_tests_task(task_instance)
            ):
                continue


def main(args):
    """
    Splits predictions into multiple groups if num_workers > 1. Each group is
    then evaluated in parallel.
    """
    if args.num_workers is None:
        args.num_workers = cpu_count()

    predictions = get_instances(args.predictions_path)

    # Remove predictions that have already been evaluated
    predictions_filtered = []
    for p in predictions:
        path_log = os.path.join(args.log_dir, f"{p['instance_id']}.{p['model']}.eval.log")
        if not os.path.exists(path_log):
            predictions_filtered.append(p)
    if len(predictions_filtered) == 0:
        print("All predictions have already been evaluated")
        return
    else:
        print(f"# predictions to evaluate: {len(predictions_filtered)} ({len(predictions) - len(predictions_filtered)} already evaluated)")
        predictions = predictions_filtered

    predictions_groups = split_instances(predictions, args.num_workers)
    setup_refs = get_instances(args.setup_refs_path) if args.setup_refs_path else None

    data_groups = [
        {
            "task_instances": g,
            "func": evaluate_predictions,
            "setup_refs": setup_refs,
            **vars(args),
        }
        for g in predictions_groups
    ]

    # TODO: Remove this?
    for group in data_groups:
        del group["predictions_path"]

    if args.num_workers == 1:
        setup_testbed(data_groups[0])
        return

    pool = Pool(processes=args.num_workers)
    pool.map(setup_testbed, data_groups)
    pool.close()
    pool.join()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "predictions_path", type=str, help="Path to predictions instances file"
    )  # (model, instance_id, prediction)
    parser.add_argument("log_dir", type=str, help="Path to log directory")
    parser.add_argument(
        "--path_conda",
        type=str,
        help="(Optional) Path to miniconda3 or anaconda installation",
    )
    parser.add_argument(
        "--testbed", type=str, help="(Optional) Path to testbed directory"
    )
    parser.add_argument(
        "--temp_dir",
        type=str,
        help="(Optional) Path to temporary directory for storing virtual envs",
    )
    parser.add_argument(
        "--timeout",
        type=int,
        default=None,
        help="(Optional) Timeout (seconds) for testing script execution",
    )
    parser.add_argument(
        "--verbose", action="store_true", help="(Optional) Verbose mode"
    )
    parser.add_argument(
        "--num_workers", type=int, default=None, help="(Optional) Number of workers"
    )
    parser.add_argument(
        "--setup_refs_path",
        type=str,
        default=None,
        help="(Optional) Path to setup reference instances file",
    )
    args = parser.parse_args()
    main(args)
