import argparse, os, re
import json

from constants import (
    APPLY_PATCH_FAIL,
    KEY_INSTANCE_ID,
    KEY_MODEL,
    KEY_PREDICTION,
)
from context_manager import TaskEnvContextManager
from engine_validation import setup_testbed
from multiprocessing import Pool, cpu_count
from tqdm.auto import tqdm
from utils import (
    extract_minimal_patch,
    get_instances,
    split_instances,
    DotDict, extract_fuzzy_patch, extract_custom_patches
)


def overwrite_ablation(tcm: TaskEnvContextManager, task_instance: dict):
    """
    Code for running ablation experiment to compare generating full files vs patches

    Args:
        tcm: TaskEnvContextManager
        task_instance: Dict containing task instance
    """
    # if full output is none, write to log and skip altogether
    if 'full_output' not in task_instance:
        print(f'[{task_instance[KEY_INSTANCE_ID]}] No `full_output` field, skipping')
        with open(tcm.log_file, 'a') as f_log:
            f_log.write(f'{APPLY_PATCH_FAIL}; No `full_output` field\n')
        return
    if task_instance['full_output'] is None:
        print(f'[{task_instance[KEY_INSTANCE_ID]}] `full_output` is None, skipping')
        with open(tcm.log_file, 'a') as f_log:
            f_log.write(f'{APPLY_PATCH_FAIL}; `full_output` is None\n')
        return

    # Attempt to set up environment with task + apply test patch
    if not tcm.reset_task_env(task_instance):
        return
    
    filename_pat = re.compile(r'\[start of ([\w\.\-\/]+)\]\n(.+?)\n\[end of \1\]', re.DOTALL)
    # Run installation
    if (
        not tcm.run_install_task(task_instance)
        or not tcm.apply_patch(task_instance["test_patch"], patch_type="test")
    ):
        return
    
    # overwrite files
    for filename, contents in filename_pat.findall(task_instance['full_output']):
        correct_filename = './' + filename.lstrip('/')
        correct_filename = os.path.abspath(correct_filename)
        if not correct_filename.startswith(os.getcwd()):
            print(f"[{task_instance[KEY_INSTANCE_ID]}] Generation attempted to create file outside of working directory")
            return

        # if os.path.exists(correct_filename):
        if not os.path.exists(correct_filename):
            folder = '/'.join(correct_filename.split('/')[:-1])
            if not os.path.exists(folder):
                os.makedirs(folder)
        with open(correct_filename, 'w') as f:
            f.write(contents)
            with open(tcm.log_file, 'a') as f_log:
                f_log.write(f'Overwrote {correct_filename}\n')
    
    # run testing script
    if not tcm.run_tests_task(task_instance):
        return
    
    return


def evaluate_predictions(data: dict):
    """
    Sets up task environment context manager. Each prediction is then
    evaluated within the context manager.

    Args:
        data: Dict containing task instances and other data
            task_instances: List of [task instance, prediction] pairs to evalute
            + setup_testbed args
    """
    data_dict = DotDict(data)
    for task_instance in tqdm(
        data_dict.task_instances,
        disable=data_dict.verbose,
        desc=f"Evaluating predictions for {data_dict.log_dir}"
    ):
        with TaskEnvContextManager(
            task_instance,
            data_dict.testbed,
            data_dict.venv,
            data_dict.log_dir,
            data_dict.conda_path,
            verbose=data_dict.verbose,
            timeout=data_dict.timeout,
            is_eval=True,
            log_suffix=data_dict.log_suffix,
        ) as tcm:
            # Attempt to set up environment with task instance
            if not tcm.reset_task_env(task_instance):
                continue

            model_prediction_raw = task_instance["full_output"]
            successful = False
            if data_dict.vanilla_patch:
                # Attempt to apply prediction
                patch_type = "pred_try"
                applied_patch = model_prediction_raw
                if applied_patch is None:
                    continue
                successful = tcm.apply_patch(applied_patch, patch_type=patch_type)
                if not successful:
                    applied_patch = extract_minimal_patch(model_prediction_raw)
                    patch_type = "pred_minimal_try"
                    successful = tcm.apply_patch(applied_patch, patch_type=patch_type)
                if successful:
                    tcm.apply_patch(applied_patch, patch_type=patch_type, revert=True)
            if data_dict.fuzzy_patch:
                if not successful:
                    # Attempt to apply custom patches
                    patch_type = "fuzzy_try"
                    fuzzy_patches = extract_fuzzy_patch(model_prediction_raw)
                    successful = tcm.apply_fuzzy_patches(fuzzy_patches)
                    applied_patch = tcm.extract_git_diff()
                    tcm.apply_patch(applied_patch, patch_type=patch_type, revert=True)
            if data_dict.custom_patch:
                if not successful:
                    # Attempt to apply custom patches
                    patch_type = "custom_try"
                    custom_patches = extract_custom_patches(model_prediction_raw)
                    successful = tcm.apply_custom_patches(custom_patches, patch_type=patch_type)
                    applied_patch = tcm.extract_git_diff()
                    tcm.apply_patch(applied_patch, patch_type=patch_type, revert=True)
            if not successful:
                continue
            with open(tcm.log_file, 'a') as f_log:
                f_log.write(f'Patch extracted and applied\n')
                f_log.write(json.dumps({"patch": applied_patch, "type": patch_type}) + '\n')

            patch_type = patch_type.replace("_try", "")
            # Run installation + testing script
            # run test: no added tests + no golden patch
            if not tcm.run_install_task(task_instance) or not tcm.run_tests_task(task_instance) or not tcm.apply_patch(applied_patch, patch_type=patch_type):
                with open(tcm.log_file, 'a') as f_log:
                    f_log.write(f'Installation or testing script failed or test patch could not be applied before run\n')
                continue
            # run test: added tests + no golden patch
            test_pass_pre = tcm.run_tests_task(task_instance)
            if not tcm.apply_patch(task_instance["patch"], patch_type="test"):
                with open(tcm.log_file, 'a') as f_log:
                    f_log.write(f'Golden patch could not be applied\n')
                continue
            # run test: added tests + golden patch
            test_pass_post = tcm.run_tests_task(task_instance)
            with open(tcm.log_file, 'a') as f_log:
                f_log.write(f'Tests passed before/after golden patch: {test_pass_pre};{test_pass_post}\n')
            if not tcm.apply_patch(applied_patch, patch_type=patch_type, revert=True):
                with open(tcm.log_file, 'a') as f_log:
                    f_log.write(f'Test patch could not be reverted\n')
                continue
            # run test: no added tests + golden patch
            tcm.run_tests_task(task_instance)


def main(args):
    """
    Splits predictions into multiple groups if num_workers > 1. Each group is
    then evaluated in parallel.
    """
    if args.num_workers is None:
        args.num_workers = cpu_count()

    predictions = get_instances(args.predictions_path)

    # Remove predictions that have already been evaluated
    if args.skip_existing:
        predictions_filtered = []
        for p in predictions:
            log_file_name = f"{p[KEY_INSTANCE_ID]}.{p[KEY_MODEL]}.eval.log"
            if args.log_suffix is not None:
                log_file_name = f"{p[KEY_INSTANCE_ID]}.{p[KEY_MODEL]}.{args.log_suffix}.eval.log"
            path_log = os.path.join(args.log_dir, log_file_name)
            if not os.path.exists(path_log):
                predictions_filtered.append(p)
        if len(predictions_filtered) == 0:
            return
        else:
            predictions = predictions_filtered

    predictions_groups = split_instances(predictions, args.num_workers)

    data_groups = [
        {
            "task_instances": g,
            "func": evaluate_predictions,
            **vars(args),
        }
        for g in predictions_groups
    ]

    if args.num_workers == 1:
        setup_testbed(data_groups[0])
        return

    pool = Pool(processes=args.num_workers)
    pool.map(setup_testbed, data_groups)
    pool.close()
    pool.join()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--predictions_path", type=str, help="Path to predictions instances file", required=True)
    parser.add_argument("--log_dir", type=str, help="Path to log directory", required=True)
    parser.add_argument("--conda_link", type=str, default=None, help="(Optional) URL to conda installation to use")
    parser.add_argument("--log_suffix", type=str, default=None, help="(Optional) Suffix to append to log file names")
    parser.add_argument("--num_workers", type=int, default=1, help="(Optional) Number of workers")
    parser.add_argument("--path_conda", type=str, help="(Optional) Path to miniconda3 or anaconda installation")
    parser.add_argument("--skip_existing", action="store_true", help="(Optional) Skip existing logs")
    parser.add_argument("--testbed", type=str, help="(Optional) Path to testbed directory")
    parser.add_argument("--temp_dir", type=str, help="(Optional) Path to temporary directory for storing virtual envs")
    parser.add_argument("--timeout", type=int, default=None, help="(Optional) Timeout (seconds) for testing script execution")
    parser.add_argument("--verbose", action="store_true", help="(Optional) Verbose mode")
    args = parser.parse_args()
    main(args)
