import re

from .log_parsers import MAP_REPO_TO_PARSER, TestStatus


# Evaluation Log Constants
APPLY_PATCH_FAIL = ">>>>> Patch Apply Failed"
APPLY_PATCH_PASS = ">>>>> Applied Patch"
INSTALL_FAIL = ">>>>> Init Failed"
INSTALL_PASS = ">>>>> Init Succeeded"
RESET_FAILED = ">>>>> Reset Failed"
TESTS_TIMEOUT = ">>>>> Tests Timed Out"
TESTS_ERROR = ">>>>> Tests Errored"
TEST_SCRIPT = "Test Script:"

# Result Categories
FAIL_TO_PASS = "FAIL_TO_PASS"
FAIL_TO_FAIL = "FAIL_TO_FAIL"
PASS_TO_PASS = "PASS_TO_PASS"
PASS_TO_FAIL = "PASS_TO_FAIL"
FAIL_TO_NULL = "FAIL_TO_NULL"
PASS_TO_NULL = "PASS_TO_NULL"
NULL_TO_FAIL = "NULL_TO_FAIL"
NULL_TO_PASS = "NULL_TO_PASS"


def get_diffs(sm_1: dict, sm_2: dict) -> dict:
    """
    Get differences between two test status maps

    Args:
        sm_1 (dict): test case to test status mapping
        sm_2 (dict): test case to test status mapping
    Returns:
        dict: test case to test status mapping
    """
    set1 = set(sm_1.items())
    set2 = set(sm_2.items())
    diffs = set1 ^ set2

    diff_map = {}
    for diff in diffs:
        if diff[0] not in diff_map:
            diff_map[diff[0]] = []
        diff_map[diff[0]].append(diff[1])
    return diff_map


def get_logs_eval(log_fp: str) -> (dict, bool):
    """
    Retrieve evaluation results for a task instance from its corresponding log file

    Args:
        log_fp (str): path to log file
    Returns:
        bool: whether the patch applied successfully
        dict: status map
    """
    repo = get_repo_from_lp(log_fp)
    log_parser = MAP_REPO_TO_PARSER[repo]

    with open(log_fp) as f:
        content = f.read()
        if any([x in content for x in [APPLY_PATCH_FAIL, RESET_FAILED, TESTS_ERROR, TESTS_TIMEOUT, "Failed to reset task environment"]]) or APPLY_PATCH_PASS not in content:
            # Eval patch was not applied successfully
            return {}, False

        # Get status map of evaluation results
        content = content.split(f"{APPLY_PATCH_PASS} (pred)")[-1]
        return log_parser(content), True


def get_logs_gold(log_fp: str) -> (str, str, str, str):
    """
    Retrieve pre-patch, post-patch test logs from a validation log file

    Args:
        log_fp (str): path to log file
    Returns:
        str: nothing, add-testcase, add-testcase-patch, add-patch test logs
    """
    with open(log_fp) as f:
        content = f.read()
    if len(re.findall(TEST_SCRIPT, content)) != 4:
        return None, None, None, None
    logs = re.split(f"{TEST_SCRIPT}|Coverage Script: ", content)
    if len(logs) < 8:
        print(log_fp)
        return None, None, None, None
    return logs[1], logs[3], logs[5], logs[7]


get_file_name_from_lp = lambda x: x.rsplit("/", 1)[-1]


get_id_from_lp = lambda x: get_file_name_from_lp(x).split(".")[0]


get_repo_from_lp = lambda x: get_id_from_lp(x).rsplit("-", 1)[0].replace("__", "/")


def log_path_to_sms(log_fp: str, log_parser) -> (list, bool):
    """
    Wrapper for getting log data from log_parser file

    Args:
        log_fp (str): path to log file
        log_parser (function): function to parse log file
    Returns:
        list: list of status maps
        bool: whether or not log file was parsed properly
    """
    logs = get_logs_gold(log_fp)
    if any([x is None for x in logs]):
        # Skip if either one of test patch apply + patch apply failed
        return None, False

    try:
        sms = [log_parser(log) for log in logs]
    except Exception as e:
        # Skip if log file was not parsed properly
        print(f"Error parsing log {log_fp}: {e}")
        sms = None

    if any([x is None for x in sms]):
        # Skip if test patch or patch statuses are none
        return None, False

    return sms, True


test_passed = lambda case, sm: case in sm and sm[case] == TestStatus.PASSED.value

test_failed = lambda case, sm: any(
    [sm.get(case) == status for status in [TestStatus.FAILED.value, TestStatus.ERROR.value]]
)

test_missing = lambda case, sm: case not in sm
