# Original Code: https://github.com/LiveCodeBench/LiveCodeBench/blob/f05cda286956b0a976df08afe2e2a323358d32d1/lcb_runner/evaluation/testing_util.py
# We slightly modified the execution method by generating a Python file and executing it.
import ast
import faulthandler
import json
import os
import platform
import signal
import subprocess
import sys
import tempfile
from enum import Enum
from io import StringIO
from typing import Any, Dict, List, Optional, Tuple

import numpy as np


def truncatefn(s, length=300):
    assert isinstance(s, str)
    if len(s) <= length:
        return s
    return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]


class CODE_TYPE(Enum):
    call_based = 0
    standard_input = 1


class TimeoutException(Exception):
    pass


def timeout_handler(signum, frame):
    print("alarm went off")
    raise TimeoutException


signal.signal(signal.SIGALRM, timeout_handler)


class Capturing(list):
    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        # Make closing the StringIO a no-op
        self._stringio.close = lambda x: 1
        return self

    def __exit__(self, *args):
        self.append(self._stringio.getvalue())
        del self._stringio
        sys.stdout = self._stdout


class reliability_context:
    """
    Context manager to apply reliability_guard inside its scope and revert after exiting.
    """

    def __init__(self, maximum_memory_bytes=None):
        self.maximum_memory_bytes = maximum_memory_bytes
        self.original_state = {}

    def __enter__(self):
        self.apply_reliability_guard(self.maximum_memory_bytes)

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.restore_reliability_guard()

    def apply_reliability_guard(self, maximum_memory_bytes=None):
        import builtins
        import os
        import shutil
        import sys

        self.original_state["exit"] = (
            builtins.exit if hasattr(builtins, "exit") else None
        )
        self.original_state["quit"] = (
            builtins.quit if hasattr(builtins, "quit") else None
        )

        # self.original_state["os_kill"] = getattr(os, "kill", None)
        self.original_state["os_system"] = getattr(os, "system", None)
        self.original_state["os_putenv"] = getattr(os, "putenv", None)
        self.original_state["os_remove"] = getattr(os, "remove", None)
        self.original_state["os_removedirs"] = getattr(os, "removedirs", None)
        self.original_state["os_rmdir"] = getattr(os, "rmdir", None)
        self.original_state["os_fchdir"] = getattr(os, "fchdir", None)
        self.original_state["os_setuid"] = getattr(os, "setuid", None)
        self.original_state["os_fork"] = getattr(os, "fork", None)
        self.original_state["os_forkpty"] = getattr(os, "forkpty", None)
        self.original_state["os_killpg"] = getattr(os, "killpg", None)
        self.original_state["os_rename"] = getattr(os, "rename", None)
        self.original_state["os_renames"] = getattr(os, "renames", None)
        self.original_state["os_truncate"] = getattr(os, "truncate", None)
        self.original_state["os_replace"] = getattr(os, "replace", None)
        self.original_state["os_unlink"] = getattr(os, "unlink", None)
        self.original_state["os_fchmod"] = getattr(os, "fchmod", None)
        self.original_state["os_fchown"] = getattr(os, "fchown", None)
        self.original_state["os_chmod"] = getattr(os, "chmod", None)
        self.original_state["os_chown"] = getattr(os, "chown", None)
        self.original_state["os_chroot"] = getattr(os, "chroot", None)
        self.original_state["os_lchflags"] = getattr(os, "lchflags", None)
        self.original_state["os_lchmod"] = getattr(os, "lchmod", None)
        self.original_state["os_lchown"] = getattr(os, "lchown", None)
        self.original_state["os_getcwd"] = getattr(os, "getcwd", None)
        self.original_state["os_chdir"] = getattr(os, "chdir", None)

        self.original_state["shutil_rmtree"] = getattr(shutil, "rmtree", None)
        self.original_state["shutil_move"] = getattr(shutil, "move", None)
        self.original_state["shutil_chown"] = getattr(shutil, "chown", None)

        # subprocess.Popenは利用したいので上書きしない
        # self.original_state["subprocess_Popen"] = getattr(subprocess, "Popen", None)

        self.original_state["help"] = __builtins__.get("help", None)

        self.original_state["ipdb"] = sys.modules.get("ipdb", None)
        self.original_state["joblib"] = sys.modules.get("joblib", None)
        self.original_state["resource"] = sys.modules.get("resource", None)
        self.original_state["psutil"] = sys.modules.get("psutil", None)
        self.original_state["tkinter"] = sys.modules.get("tkinter", None)

        # Apply reliability guard changes
        faulthandler.disable()

        builtins.exit = None
        builtins.quit = None

        os.environ["OMP_NUM_THREADS"] = "1"

        # os.kill = None
        os.system = None
        os.putenv = None
        os.remove = None
        os.removedirs = None
        os.rmdir = None
        os.fchdir = None
        os.setuid = None
        os.fork = None
        os.forkpty = None
        os.killpg = None
        os.rename = None
        os.renames = None
        os.truncate = None
        os.replace = None
        os.unlink = None
        os.fchmod = None
        os.fchown = None
        os.chmod = None
        os.chown = None
        os.chroot = None
        os.fchdir = None
        os.lchflags = None
        os.lchmod = None
        os.lchown = None
        os.getcwd = None
        os.chdir = None

        shutil.rmtree = None
        shutil.move = None
        shutil.chown = None

        __builtins__["help"] = None

        sys.modules["ipdb"] = None
        sys.modules["joblib"] = None
        sys.modules["resource"] = None
        sys.modules["psutil"] = None
        sys.modules["tkinter"] = None

        if maximum_memory_bytes is not None:
            try:
                import resource

                resource.setrlimit(
                    resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
                )
                resource.setrlimit(
                    resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
                )
                if not platform.uname().system == "Darwin":
                    resource.setrlimit(
                        resource.RLIMIT_STACK,
                        (maximum_memory_bytes, maximum_memory_bytes),
                    )
            except:
                pass

    def restore_reliability_guard(self):
        import builtins
        import os
        import shutil
        import sys

        # revert changes
        builtins.exit = self.original_state["exit"]
        builtins.quit = self.original_state["quit"]

        # os.kill = self.original_state["os_kill"]
        os.system = self.original_state["os_system"]
        os.putenv = self.original_state["os_putenv"]
        os.remove = self.original_state["os_remove"]
        os.removedirs = self.original_state["os_removedirs"]
        os.rmdir = self.original_state["os_rmdir"]
        os.fchdir = self.original_state["os_fchdir"]
        os.setuid = self.original_state["os_setuid"]
        os.fork = self.original_state["os_fork"]
        os.forkpty = self.original_state["os_forkpty"]
        os.killpg = self.original_state["os_killpg"]
        os.rename = self.original_state["os_rename"]
        os.renames = self.original_state["os_renames"]
        os.truncate = self.original_state["os_truncate"]
        os.replace = self.original_state["os_replace"]
        os.unlink = self.original_state["os_unlink"]
        os.fchmod = self.original_state["os_fchmod"]
        os.fchown = self.original_state["os_fchown"]
        os.chmod = self.original_state["os_chmod"]
        os.chown = self.original_state["os_chown"]
        os.chroot = self.original_state["os_chroot"]
        os.lchflags = self.original_state["os_lchflags"]
        os.lchmod = self.original_state["os_lchmod"]
        os.lchown = self.original_state["os_lchown"]
        os.getcwd = self.original_state["os_getcwd"]
        os.chdir = self.original_state["os_chdir"]

        shutil.rmtree = self.original_state["shutil_rmtree"]
        shutil.move = self.original_state["shutil_move"]
        shutil.chown = self.original_state["shutil_chown"]

        if "help" in __builtins__:
            __builtins__["help"] = self.original_state["help"]
        else:
            if self.original_state["help"] is not None:
                __builtins__["help"] = self.original_state["help"]

        sys.modules["ipdb"] = self.original_state["ipdb"]
        sys.modules["joblib"] = self.original_state["joblib"]
        sys.modules["resource"] = self.original_state["resource"]
        sys.modules["psutil"] = self.original_state["psutil"]
        sys.modules["tkinter"] = self.original_state["tkinter"]


def run_test(
    sample: Dict[str, Any],
    test: Optional[str] = None,
    debug: bool = False,
    timeout: int = 6,
) -> Tuple[List[str], List[str], List[int], Dict[str, Any]]:
    with reliability_context():
        in_outs = parse_input_output(sample, debug=debug)

        if in_outs is not None:
            expected_outputs = [json.dumps(o) for o in in_outs["outputs"]]
            generated_outputs = ["" for _ in in_outs["outputs"]]
        else:
            expected_outputs = [""]
            generated_outputs = [""]

        if in_outs is None:
            return (
                expected_outputs,
                generated_outputs,
                [],
                {"error": "Failed to parse input_output"},
            )

        code_type, method_name = determine_code_type(in_outs)

        if test is None:
            return (
                expected_outputs,
                generated_outputs,
                [],
                {"error": "no test code provided"},
            )

        base_sol = build_base_solution_imports()
        if debug:
            print(
                f"Base imports prepared. Code type: {code_type}, method: {method_name}"
            )

        sol_code, compile_error = compile_solution(
            base_sol, test, code_type, debug, timeout
        )
        if compile_error:
            return expected_outputs, generated_outputs, [-2], compile_error

        has_solution_class = ("class Solution" in test) and (
            code_type == CODE_TYPE.call_based
        )

        results = []
        err_infos = []
        all_expected_outputs = []
        all_generated_outputs = []

        for index, _ in enumerate(in_outs["inputs"]):
            res, err_info, gen_output, exp_output = run_single_test_case(
                code_type=code_type,
                method=method_name,
                in_outs=in_outs,
                index=index,
                debug=debug,
                timeout=timeout,
                sol_code=sol_code,
                has_solution_class=has_solution_class,
            )

            if exp_output is not None:
                all_expected_outputs.append(exp_output)
            else:
                all_expected_outputs.append(json.dumps(in_outs["outputs"][index]))

            if gen_output is not None:
                all_generated_outputs.append(gen_output)
            else:
                all_generated_outputs.append("")

            if res is not True:
                results.append(0)
                err_infos.append(err_info)
            else:
                results.append(1)
                err_infos.append({})

        return all_expected_outputs, all_generated_outputs, results, err_infos


def parse_input_output(sample: Dict[str, Any], debug: bool = False) -> Dict[str, Any]:
    try:
        in_outs = json.loads(sample["input_output"])
    except ValueError:
        in_outs = None
    if debug and in_outs:
        print(f"Parsed in_outs with fn_name={in_outs.get('fn_name')}")
    return in_outs


def determine_code_type(in_outs: Dict[str, Any]) -> Tuple[CODE_TYPE, Optional[str]]:
    if in_outs.get("fn_name") is None:
        return CODE_TYPE.standard_input, None
    else:
        return CODE_TYPE.call_based, in_outs["fn_name"]


def build_base_solution_imports() -> str:
    return (
        "from string import *\nfrom re import *\nfrom datetime import *\n"
        "from collections import *\nfrom heapq import *\nfrom bisect import *\n"
        "from copy import *\nfrom math import *\nfrom random import *\n"
        "from statistics import *\nfrom itertools import *\nfrom functools import *\n"
        "from operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\n"
        "from builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\n"
        "import collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\n"
        "import statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\n"
        "sys.setrecursionlimit(6*10**5)\n"
    )


def run_code_in_subprocess(code_str: str, input_str: str = "", timeout: int = 6):
    with tempfile.NamedTemporaryFile(
        dir="/dev/shm", delete=True, suffix=".py", mode="w"
    ) as f:
        f.write(code_str)
        f.flush()
        temp_file = f.name

        p = subprocess.Popen(
            ["python3", temp_file],
            stdin=subprocess.PIPE,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )
        try:
            output, err = p.communicate(input_str, timeout=timeout)
        except subprocess.TimeoutExpired:
            p.kill()
            return None, {
                "error": "TimeoutExpired",
                "error_code": -3,
                "error_message": "Time Limit Exceeded",
            }

        if p.returncode != 0:
            return None, {
                "error": err,
                "error_code": -4,
                "error_message": "Runtime Error",
            }

    return output, {}


def compile_solution(
    base_sol: str, test: str, code_type: CODE_TYPE, debug: bool, timeout: int
) -> Tuple[Optional[str], Dict[str, Any]]:
    sol = base_sol
    if code_type == CODE_TYPE.call_based:
        sol += test
    else:
        sol += transform_standard_input_code(test, debug)

    return sol, {}


def transform_standard_input_code(test: str, debug: bool = False) -> str:
    try:
        astree = ast.parse(test)
        last_block = astree.body[-1]
        if isinstance(last_block, ast.If):
            condition = last_block.test
            if (
                ast.unparse(condition).strip() == "__name__ == '__main__'"
                or ast.unparse(condition).strip() == '__name__ == "__main__"'
            ):
                test = (
                    ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body)
                )
    except Exception as e:
        if debug:
            print(f"Error processing if __name__ == '__main__': block: {e}")
        pass

    tmp_test = test.split("\n")

    new_test = []
    # Add indentation to lines that don't start with from/import
    for x in tmp_test:
        if (not x.startswith("from ")) and (not x.startswith("import ")):
            new_test.append("\t" + x + "\n")
        else:
            new_test.append(x + "\n")
    tmp_test = new_test

    new_test = ""
    started = False

    # Wrap non-import lines in a def code(): block
    for i in tmp_test:
        if i.startswith("\t") and not started:
            new_test += "stdin = sys.stdin\nstdout = sys.stdout\n"
            new_test += "def code():\n"
            new_test += i
            started = True
        elif started and ((i.startswith("from ")) or (i.startswith("import "))):
            # Additional imports after code start, indent them
            new_test += "\t" + i
        else:
            new_test += i

    return new_test


def run_single_test_case(
    code_type: CODE_TYPE,
    method: Any,
    in_outs: Dict[str, Any],
    index: int,
    debug: bool,
    timeout: int,
    sol_code: str,
    has_solution_class: bool,
) -> Tuple[Any, Dict[str, Any], str, str]:
    raw_inputs = in_outs["inputs"][index]
    raw_outputs = in_outs["outputs"][index]
    raw_inputs_for_log, raw_outputs_for_log = prepare_logging_text(
        raw_inputs, raw_outputs
    )

    if code_type == CODE_TYPE.call_based:
        return run_call_based_test_case(
            method,
            in_outs,
            index,
            debug,
            timeout,
            raw_inputs_for_log,
            raw_outputs_for_log,
            sol_code,
            has_solution_class,
        )
    else:
        return run_standard_input_test_case(
            in_outs,
            index,
            debug,
            timeout,
            raw_inputs_for_log,
            raw_outputs_for_log,
            sol_code,
        )


def run_call_based_test_case(
    method_name: str,
    in_outs: Dict[str, Any],
    index: int,
    debug: bool,
    timeout: int,
    raw_inputs: str,
    raw_outputs: str,
    sol_code: str,
    has_solution_class: bool,
) -> Tuple[Any, Dict[str, Any], str, str]:
    inputs = parse_json_inputs(in_outs["inputs"][index], debug)
    expected = in_outs["outputs"][index]

    # Using repr since we're directly inserting Python objects into code
    inputs_repr = repr(inputs)

    call_code = "from __future__ import print_function\n"
    call_code += sol_code + "\n"
    if has_solution_class:
        call_code += "solution = Solution()\n"
        call_code += f"res = solution.{method_name}(*{inputs_repr})\n"
    else:
        call_code += f"res = {method_name}(*{inputs_repr})\n"
    call_code += "import json\nprint(json.dumps(res))\n"

    output, err_info = run_code_in_subprocess(call_code, timeout=timeout)
    if err_info:
        err_info.update({"inputs": raw_inputs, "expected": raw_outputs})
        return -1, err_info, None, None

    generated_output_str = output.strip()

    if not compare_call_based_output(generated_output_str, expected):
        return (
            False,
            {
                "output": truncatefn(generated_output_str, 200),
                "expected": expected,
                "inputs": raw_inputs,
                "error_code": -2,
                "error_message": "Wrong Answer",
            },
            generated_output_str,
            expected,
        )

    return True, {}, generated_output_str, expected


def run_standard_input_test_case(
    in_outs: Dict[str, Any],
    index: int,
    debug: bool,
    timeout: int,
    raw_inputs: str,
    raw_outputs: str,
    sol_code: str,
) -> Tuple[Any, Dict[str, Any], str, str]:
    inputs = in_outs["inputs"][index]
    expected = in_outs["outputs"][index]

    if isinstance(inputs, list):
        inputs_str = "\n".join(inputs)
    else:
        inputs_str = inputs

    if isinstance(expected, list):
        expected_str = "\n".join(expected)
    else:
        expected_str = expected

    # For standard input, execute code in subprocess to call code()
    call_code = sol_code + "\nif __name__=='__main__':\n    code()\n"

    output, err_info = run_code_in_subprocess(
        call_code, input_str=inputs_str, timeout=timeout
    )
    if err_info:
        err_info.update({"inputs": raw_inputs, "expected": raw_outputs})
        return -1, err_info, None, None

    actual_output = output.splitlines()

    generated_output_str = "\n".join(actual_output)
    expected_output_str = expected_str

    if not compare_standard_input_output(actual_output, expected):
        return (
            False,
            {
                "output": truncatefn(output, 200),
                "expected": raw_outputs,
                "inputs": raw_inputs,
                "error_code": -2,
                "error_message": "Wrong Answer",
            },
            generated_output_str,
            expected_output_str,
        )

    return True, {}, generated_output_str, expected_output_str


def parse_json_inputs(inputs_str: str, debug: bool = False) -> List[Any]:
    inputs = [json.loads(line) for line in inputs_str.split("\n")]
    try:
        if isinstance(inputs[0], dict):
            inputs = [{int(k): v for k, v in inputs[0].items()}]
    except:
        pass
    return inputs


def compare_call_based_output(output: Any, expected: Any) -> bool:
    if output == expected:
        return True

    # Additional checks if needed
    return False


def compare_standard_input_output(output: List[str], expected: Any) -> bool:
    if custom_compare_(output, expected):
        return True

    if isinstance(expected, str):
        expected = expected.split("\n")
        expected = [x.strip() for x in expected if x.strip()]

    if isinstance(expected, list):
        expected = [x.strip() for x in expected if x.strip()]

    output_stripped = [x.strip() for x in output if x.strip()]

    if output_stripped == expected:
        return True

    if try_numeric_approx_check(output_stripped, expected):
        return True

    return False


def custom_compare_(output, ground_truth):
    if isinstance(output, list):
        output_1 = "\n".join(output)
        if stripped_string_compare(output_1, ground_truth):
            return True

    if isinstance(output, list):
        output_2 = [o.strip() for o in output]
        output_2 = "\n".join(output_2)
        if stripped_string_compare(output_2, ground_truth):
            return True

    return False


def stripped_string_compare(s1, s2):
    if not isinstance(s1, str) or not isinstance(s2, str):
        return False
    s1 = s1.strip()
    s2 = s2.strip()
    return s1 == s2


def try_numeric_approx_check(output: List[str], expected: List[str]) -> bool:
    try:
        output_float = [float(e) for e in output]
        expected_float = [float(e) for e in expected]
        if len(output_float) == len(expected_float) and np.allclose(
            output_float, expected_float
        ):
            return True
    except:
        pass
    return False


def prepare_logging_text(raw_inputs: str, raw_outputs: str) -> Tuple[str, str]:
    if "\n" in raw_inputs:
        line_count = raw_inputs.count("\n") + 1
        truncate_line_size = 300 // line_count
        raw_inputs_for_log = "\n".join(
            [
                truncatefn(line, truncate_line_size)
                for line in raw_inputs.strip().split("\n")
            ]
        )
    else:
        raw_inputs_for_log = truncatefn(raw_inputs)

    raw_outputs_for_log = truncatefn(raw_outputs, 200)
    return raw_inputs_for_log, raw_outputs_for_log


def classify_runtime_error(e: Exception) -> Tuple[int, str]:
    if "timeoutexception" in repr(e).lower():
        return -3, "Time Limit Exceeded"
    else:
        return -4, "Runtime Error"
