# flake8: noqa
# type: ignore

"""Core function for HumanEval and APPS evaluation.

Code for evaluating APPS adapted from
    https://github.com/hendrycks/apps/blob/main/eval/test_one_solution.py
Code for HumanEval adapted from
    https://github.com/openai/human-eval

The original files are very messy and would trigger type-check errors.
We want to only minimally modify these to ensure the functionality isn't modified.
We split this out so we can disable type checking for this entire file.
"""

import contextlib
import gc
from enum import Enum
import faulthandler
import io
from io import StringIO
import json
import multiprocessing
import os
import platform
import signal
import sys
import tempfile
from typing import List, Union, Dict, Optional
from unittest.mock import patch, mock_open

import numpy as np
from pyext import RuntimeModule

from helm.common.hierarchical_logger import hlog


# === APPS evaluation below ===
class CodeType(Enum):
    call_based = 0
    standard_input = 1


class TimeoutException(Exception):
    pass


def timeout_handler(signum, frame):
    raise TimeoutException


signal.signal(signal.SIGALRM, timeout_handler)


# lxuechen: The default (global) timeout given in the original APPS codebase is 4. But we found this yielded evaluation
#   runs that are just too slow. We thus make this variable settable in `run_test`.
# timeout = 4  # seconds


# used to capture stdout as a list
# from https://stackoverflow.com/a/16571630/6416660
# alternative use redirect_stdout() from contextlib
class Capturing(list):
    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        # Make closing the StringIO a no-op
        self._stringio.close = lambda x: 1
        return self

    def __exit__(self, *args):
        self.extend(self._stringio.getvalue().splitlines())
        del self._stringio  # free up some memory
        sys.stdout = self._stdout


def run_test(root: str, test: str, timeout: float) -> List[Union[int, bool]]:
    """Run a single test.

    Args:
        root: Path to the problem folder.
        test: Candidate source code solution in string format.
        timeout: Timeout in seconds. Reduce to save time at the cost of accuracy.
            There are two places where timeout can happen:
                1) runtime function creation based on source code string, and
                2) evaluation of a test case.
    Returns:
        A list of numbers/booleans, one for each test case. Interpretation for all types of values:
            -3: didn't find the function for evaluation
            -2: compile error
            -1: runtime error
            True: passed
            False: failed

        The -3 entry is an additional type compared to the original code of APPS authors
            https://github.com/hendrycks/apps/blob/main/eval/testing_util.py
    """

    input_output_fname = os.path.join(root, "input_output.json")
    if os.path.exists(input_output_fname):
        with open(input_output_fname) as f:
            in_outs = json.load(f)
            if in_outs.get("fn_name") is None:
                which_type = CodeType.standard_input  # Standard input
                method_name = None
            else:
                which_type = CodeType.call_based  # Call-based
                method_name = in_outs["fn_name"]

    results = []
    sol = (
        "import sys\nimport time\nimport itertools\nfrom itertools import accumulate, product, permutations, "
        "combinations\nimport collections\nfrom collections import Counter, OrderedDict, deque, defaultdict, "
        "ChainMap\nfrom functools import lru_cache\nimport math\nfrom math import sqrt, sin, cos, tan, ceil, "
        "fabs, floor, gcd, exp, log, log2\nimport fractions\nfrom typing import List, Tuple\nimport numpy as "
        "np\nimport random\nimport heapq\nfrom heapq import *\n"
    )

    if which_type == CodeType.call_based:
        sol += test
        signal.alarm(timeout)
        try:
            tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
            if "class Solution" not in test:
                tmp = tmp_sol
            else:
                tmp = tmp_sol.Solution()
            signal.alarm(0)
        except Exception as e:
            signal.alarm(0)
            hlog(f"type 0 compilation error")
            results.append(-2)
            return results
        signal.alarm(0)

    elif which_type == CodeType.standard_input:
        tmp_test = test.split("\n")

        new_test = []
        for x in tmp_test:
            if (not x.startswith("from ")) and (not x.startswith("import ")):
                new_test.append("\t" + x + "\n")
            else:
                new_test.append(x + "\n")
        tmp_test = new_test

        new_test = ""
        started = False
        for i in tmp_test:
            if i.startswith("\t") and not started:
                new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
                new_test += "def code():\n"
                new_test += i
                started = True
            elif started and ((i.startswith("from ")) or (i.startswith("import "))):
                new_test += "\t" + i
            else:
                new_test += i
        tmp_test = new_test

        sol += tmp_test
        method_name = "code"
        signal.alarm(timeout)
        try:
            tmp_sol = RuntimeModule.from_string("tmp_sol", "", sol)
            tmp = tmp_sol
            signal.alarm(0)
        except Exception as e:
            signal.alarm(0)
            hlog(f"type 1 compilation error")
            results.append(-2)
            return results
        signal.alarm(0)
    try:
        method = getattr(tmp, method_name)  # get_attr second arg must be str
    except:
        signal.alarm(0)
        e = sys.exc_info()
        hlog(f"unable to get function error")
        results.append(-3)
        return results

    for index, inputs in enumerate(in_outs["inputs"]):
        gc.collect()

        # JSON forces dictionaries to have string keys; this undoes this (assuming a singleton list)
        try:
            if isinstance(inputs[0], dict):
                inputs = [{int(k): v for k, v in inputs[0].items()}]
        except:
            pass

        try:
            if isinstance(in_outs["outputs"][index], dict):
                in_outs["outputs"][index] = [{int(k): v for k, v in in_outs["outputs"][index].items()}]
        except:
            pass

        try:
            if isinstance(in_outs["outputs"][index][0], dict):
                in_outs["outputs"][index] = [{int(k): v for k, v in in_outs["outputs"][index][0].items()}]
        except:
            pass

        if which_type == CodeType.call_based:  # Call-based
            signal.alarm(timeout)
            faulthandler.enable()
            try:
                output = method(*inputs)

                # ground truth sequences are not tuples
                if isinstance(output, tuple):
                    output = list(output)

                tmp_result = output == in_outs["outputs"][index]
                if isinstance(in_outs["outputs"][index], list) and in_outs["outputs"][index]:
                    tmp_result = tmp_result or (output == in_outs["outputs"][index][0])

                # ground truth sequences are not tuples
                try:
                    if isinstance(output[0], tuple):
                        tmp_result = tmp_result or ([list(x) for x in output] == in_outs["outputs"][index][0])
                except:
                    pass
                results.append(tmp_result)

                # reset the alarm
                signal.alarm(0)
            except Exception:
                signal.alarm(0)
                faulthandler.disable()
                hlog(f"Standard input runtime error or time limit exceeded error")
                results.append(-1)
                continue
            faulthandler.disable()
            signal.alarm(0)
        elif which_type == CodeType.standard_input:  # Standard input
            faulthandler.enable()
            signal.alarm(timeout)
            passed = False

            if isinstance(inputs, list):
                inputs = "\n".join(inputs)
            if isinstance(in_outs["outputs"][index], list):
                in_outs["outputs"][index] = "\n".join(in_outs["outputs"][index])

            with Capturing() as output:
                try:
                    call_method(method, inputs)
                    # reset the alarm
                    signal.alarm(0)
                    passed = True
                except Exception:
                    # runtime error or took too long
                    signal.alarm(0)
                    hlog(f"Call-based runtime error or time limit exceeded error")
                    results.append(-1)
                signal.alarm(0)

            if not passed:
                continue

            if custom_compare_(output, in_outs["outputs"][index]):
                tmp_result = True
                results.append(tmp_result)
                continue

            # ground truth sequences are expressed as lists not tuples
            if isinstance(output, tuple):
                output = list(output)

            tmp_result = False
            try:
                tmp_result = output == [in_outs["outputs"][index]]
                if isinstance(in_outs["outputs"][index], list):
                    tmp_result = tmp_result or (output == in_outs["outputs"][index])
                    if isinstance(output[0], str):
                        tmp_result = tmp_result or ([e.strip() for e in output] == in_outs["outputs"][index])
            except Exception:
                hlog(f"Failed check1 exception")
                pass

            if isinstance(tmp_result, bool) and tmp_result:
                results.append(tmp_result)
                continue

            # try one more time without \n
            if isinstance(in_outs["outputs"][index], list):
                for tmp_index, i in enumerate(in_outs["outputs"][index]):
                    in_outs["outputs"][index][tmp_index] = i.split("\n")
                    in_outs["outputs"][index][tmp_index] = [
                        x.strip() for x in in_outs["outputs"][index][tmp_index] if x
                    ]
            else:
                in_outs["outputs"][index] = in_outs["outputs"][index].split("\n")
                in_outs["outputs"][index] = list(filter(len, in_outs["outputs"][index]))
                in_outs["outputs"][index] = list(map(lambda x: x.strip(), in_outs["outputs"][index]))

            try:
                tmp_result = output == [in_outs["outputs"][index]]
                if isinstance(in_outs["outputs"][index], list):
                    tmp_result = tmp_result or (output == in_outs["outputs"][index])
            except Exception:
                hlog(f"Failed check2 exception")
                pass

            if isinstance(tmp_result, bool) and tmp_result:
                results.append(tmp_result)
                continue

            # try by converting the output into a split up list too
            if isinstance(output, list):
                output = list(filter(len, output))

            if isinstance(tmp_result, bool) and tmp_result:
                results.append(tmp_result)
                continue

            try:
                tmp_result = output == [in_outs["outputs"][index]]
                if isinstance(in_outs["outputs"][index], list):
                    tmp_result = tmp_result or (output == in_outs["outputs"][index])
            except Exception:
                hlog(f"Failed check3 exception")
                pass

            try:
                output_float = [float(e) for e in output]
                gt_float = [float(e) for e in in_outs["outputs"][index]]
                tmp_result = tmp_result or (
                    (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)
                )
            except:
                pass
            try:
                if isinstance(output[0], list):
                    output_float = [float(e) for e in output[0]]
                    gt_float = [float(e) for e in in_outs["outputs"][index][0]]
                    tmp_result = tmp_result or (
                        (len(output_float) == len(gt_float)) and np.allclose(output_float, gt_float)
                    )
            except:
                pass

            if isinstance(tmp_result, bool) and tmp_result:
                results.append(tmp_result)
                continue

            # try by converting the stuff into split up list
            if isinstance(in_outs["outputs"][index], list):
                for tmp_index, i in enumerate(in_outs["outputs"][index]):
                    in_outs["outputs"][index][tmp_index] = set(i.split())
            else:
                in_outs["outputs"][index] = set(in_outs["outputs"][index].split())

            try:
                tmp_result = output == in_outs["outputs"][index]
            except Exception:
                hlog(f"Failed check4 exception")
                continue

            if isinstance(tmp_result, bool) and tmp_result:
                results.append(tmp_result)
                continue

                # try by converting the output into a split up list too
            if isinstance(output, list):
                for tmp_index, i in enumerate(output):
                    output[tmp_index] = i.split()
                output = list(filter(len, output))
                for tmp_index, i in enumerate(output):
                    output[tmp_index] = set(i)
            else:
                output = output.split()
                output = list(filter(len, output))
                output = set(output)

            try:
                tmp_result = set(frozenset(s) for s in output) == set(frozenset(s) for s in in_outs["outputs"][index])
            except Exception:
                hlog(f"Failed check5 exception")

            # if they are all numbers, round so that similar numbers are treated as identical
            try:
                tmp_result = tmp_result or (
                    set(frozenset(round(float(t), 3) for t in s) for s in output)
                    == set(frozenset(round(float(t), 3) for t in s) for s in in_outs["outputs"][index])
                )
            except Exception:
                hlog(f"Failed check6 exception")

            # By default, tmp_result is False; that's the 'catch' case.
            results.append(tmp_result)

    return results


def custom_compare_(output, ground_truth):
    if isinstance(output, list):
        output_1 = "\n".join(output)
        if stripped_string_compare(output_1, ground_truth):
            return True

    if isinstance(output, list):
        output_2 = [o.lstrip().rstrip() for o in output]
        output_2 = "\n".join(output_2)
        if stripped_string_compare(output_2, ground_truth):
            return True

    return False


def stripped_string_compare(s1, s2):
    s1 = s1.lstrip().rstrip()
    s2 = s2.lstrip().rstrip()
    return s1 == s2


def call_method(method, inputs):
    if isinstance(inputs, list):
        inputs = "\n".join(inputs)

    inputs_line_iterator = iter(inputs.split("\n"))

    @patch("builtins.open", mock_open(read_data=inputs))
    @patch("sys.stdin", StringIO(inputs))
    @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator))
    @patch("sys.stdin.readlines", lambda *args: inputs.split("\n"))
    @patch("sys.stdin.read", lambda *args: inputs)
    def _inner_call_method(_method):
        try:
            return _method()
        except SystemExit as e:
            pass
        finally:
            pass

    return _inner_call_method(method)


# === HumanEval below ===
def check_correctness(problem: Dict, completion: str, timeout: float, completion_id: Optional[int] = None) -> Dict:
    """
    Evaluates the functional correctness of a completion by running the test
    suite provided in the problem.
    :param completion_id: an optional completion ID so we can match
        the results later even if execution finishes asynchronously.
    """

    def unsafe_execute():

        with create_tempdir():

            # These system calls are needed when cleaning up tempdir.
            import os
            import shutil

            rmtree = shutil.rmtree
            rmdir = os.rmdir
            chdir = os.chdir

            # Disable functionalities that can make destructive changes to the test.
            reliability_guard()

            # Construct the check program and run it.
            check_program = (
                problem["prompt"]
                + "    "
                + completion
                + "\n"
                + problem["test"]
                + "\n"
                + f"check({problem['entry_point']})"
            )

            try:
                exec_globals = {}
                with swallow_io():
                    with time_limit(timeout):
                        # WARNING
                        # This program exists to execute untrusted model-generated code. Although
                        # it is highly unlikely that model-generated code will do something overtly
                        # malicious in response to this test suite, model-generated code may act
                        # destructively due to a lack of model capability or alignment.
                        # Users are strongly encouraged to sandbox this evaluation suite so that it
                        # does not perform destructive actions on their host or network. For more
                        # information on how OpenAI sandboxes its code, see the accompanying paper.
                        # Once you have read this disclaimer and taken appropriate precautions,
                        # uncomment the following line and proceed at your own risk:
                        exec(check_program, exec_globals)
                result.append("passed")
            except TimeoutException:
                result.append("timed out")
            except BaseException as e:
                result.append(f"failed: {e}")

            # Needed for cleaning up.
            shutil.rmtree = rmtree
            os.rmdir = rmdir
            os.chdir = chdir

    manager = multiprocessing.Manager()
    result: List[str] = manager.list()

    p = multiprocessing.Process(target=unsafe_execute)
    p.start()
    p.join(timeout=timeout + 1)
    if p.is_alive():
        p.kill()

    if not result:
        result.append("timed out")

    return dict(
        task_id=problem["task_id"],
        passed=result[0] == "passed",
        result=result[0],
        completion_id=completion_id,
    )


def reliability_guard(maximum_memory_bytes: Optional[int] = None):
    """
    This disables various destructive functions and prevents the generated code
    from interfering with the test (e.g. fork bomb, killing other processes,
    removing filesystem files, etc.)
    WARNING
    This function is NOT a security sandbox. Untrusted code, including, model-
    generated code, should not be blindly executed outside of one. See the
    Codex paper for more information about OpenAI's code sandbox, and proceed
    with caution.
    """

    if maximum_memory_bytes is not None:
        import resource

        resource.setrlimit(resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes))
        resource.setrlimit(resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes))
        if not platform.uname().system == "Darwin":
            resource.setrlimit(resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes))

    faulthandler.disable()

    import builtins

    builtins.exit = None
    builtins.quit = None

    import os

    os.environ["OMP_NUM_THREADS"] = "1"

    os.kill = None
    os.system = None
    os.putenv = None
    os.remove = None
    os.removedirs = None
    os.rmdir = None
    os.fchdir = None
    os.setuid = None
    os.fork = None
    os.forkpty = None
    os.killpg = None
    os.rename = None
    os.renames = None
    os.truncate = None
    os.replace = None
    os.unlink = None
    os.fchmod = None
    os.fchown = None
    os.chmod = None
    os.chown = None
    os.chroot = None
    os.fchdir = None
    os.lchflags = None
    os.lchmod = None
    os.lchown = None
    os.getcwd = None
    os.chdir = None

    import shutil

    shutil.rmtree = None
    shutil.move = None
    shutil.chown = None

    import subprocess

    subprocess.Popen = None  # type: ignore

    __builtins__["help"] = None

    import sys

    sys.modules["ipdb"] = None
    sys.modules["joblib"] = None
    sys.modules["resource"] = None
    sys.modules["psutil"] = None
    sys.modules["tkinter"] = None


@contextlib.contextmanager
def time_limit(seconds: float):
    def signal_handler(signum, frame):
        raise TimeoutException("Timed out!")

    signal.setitimer(signal.ITIMER_REAL, seconds)
    signal.signal(signal.SIGALRM, signal_handler)
    try:
        yield
    finally:
        signal.setitimer(signal.ITIMER_REAL, 0)


@contextlib.contextmanager
def swallow_io():
    stream = WriteOnlyStringIO()
    with contextlib.redirect_stdout(stream):
        with contextlib.redirect_stderr(stream):
            with redirect_stdin(stream):
                yield


@contextlib.contextmanager
def create_tempdir():
    with tempfile.TemporaryDirectory() as dirname:
        with chdir(dirname):
            yield dirname


class TimeoutException(Exception):
    pass


class WriteOnlyStringIO(io.StringIO):
    """StringIO that throws an exception when it's read from"""

    def read(self, *args, **kwargs):
        raise IOError

    def readline(self, *args, **kwargs):
        raise IOError

    def readlines(self, *args, **kwargs):
        raise IOError

    def readable(self, *args, **kwargs):
        """Returns True if the IO object can be read."""
        return False


class redirect_stdin(contextlib._RedirectStream):  # type: ignore
    _stream = "stdin"


@contextlib.contextmanager
def chdir(root):
    if root == ".":
        yield
        return
    cwd = os.getcwd()
    os.chdir(root)
    try:
        yield
    except BaseException as exc:
        raise exc
    finally:
        os.chdir(cwd)
