import ast
import builtins
import contextlib
import io
import json
import logging
import os
import sys
from concurrent.futures import TimeoutError
from copy import deepcopy
from io import StringIO

import multiprocess
from pebble import ProcessPool
from tqdm.auto import tqdm

from .test_results import Solution

logger = logging.getLogger(__name__)


def handler(signum, frame):
    raise Exception("Timeout")


@contextlib.contextmanager
def stdoutIO(stdout=None):
    old = sys.stdout
    if stdout is None:
        stdout = StringIO()
    sys.stdout = stdout
    yield stdout
    sys.stdout = old


@contextlib.contextmanager
def stdinIO(input_str):
    buf = io.StringIO(input_str)

    def mock_input(prompt=None):
        try:
            return buf.readline().rstrip("\n")
        except Exception as e:
            logger.debug(f"Error in mock_input: {e}")

    original_input = builtins.input
    builtins.input = mock_input
    yield
    builtins.input = original_input


def preprocess_input(input_data):
    if isinstance(input_data, list):
        return "\n".join(input_data)
    else:
        return input_data


def worker(code, input_data=""):
    """
    worker for executing python code
    """
    input_data = preprocess_input(input_data)

    with stdinIO(input_data), stdoutIO() as s:
        try:
            exec(code, {"__name__": "__main__"})
            res = s.getvalue().strip()
        except SystemExit as e:
            res = s.getvalue().strip()

    return res


def execute_code(code, test_case_input):
    with ProcessPool(1) as p:
        return p.schedule(worker, (code, test_case_input), timeout=2).result()


def is_float(text):
    try:
        text = float(text)
        return True
    except:
        return False


def check(input, actual, expected, custom_evaluator=None):
    try:
        actual = actual.strip()
        expected = expected.strip()
    except:
        return False

    if actual == expected:
        return True

    # normalize list
    # overly long strings can introduce segmenation fault in ast.literal_eval
    if expected.startswith("[") and expected.endswith("]") and len(expected) < 1e6:
        try:
            expected = ast.literal_eval(expected)
            expected = "\n".join(expected).strip()
        except:
            expected = expected

    if actual.startswith("[") and actual.endswith("]") and len(actual) < 1e6:
        try:
            actual = ast.literal_eval(actual)
            actual = "\n".join(actual).strip()
        except:
            pass

    if custom_evaluator is not None:
        flag = custom_evaluator(input, actual, expected)
    else:
        # list special judge
        try:
            if "\n" in actual.strip() and "\n" in expected.strip():
                actual = [i.strip() for i in actual.strip().split("\n")]
                expected = [i.strip() for i in expected.strip().split("\n")]
                flag = True
                for i_pred, i_answer in zip(actual, expected):
                    if i_pred != i_answer:
                        if not (
                            is_float(i_pred)
                            and is_float(i_answer)
                            and abs(float(i_pred) - float(i_answer)) < 1e-4
                        ):
                            flag = False
                            break
                return flag
        except:
            pass

        try:  # float special judge
            actual = float(actual)
            expected = float(expected)
            flag1 = abs(actual - expected) < (max(abs(expected), abs(actual)) * 1e-4)
            flag2 = abs(actual - expected) < 1e-4
            flag = flag1 | flag2
            return flag
        except:
            pass

        flag = actual == expected
    return flag


def run_solutions(solutions, num_threads=8):
    results = {}
    # skip already executed solutions
    filtered_solutions = []
    for solution in solutions:
        assert isinstance(solution, Solution)
        if solution.already_processed:
            results[solution.question_id] = solution.to_json()
        else:
            filtered_solutions.append(solution)
    solutions = filtered_solutions

    with ProcessPool(num_threads) as p:
        # set timeout to 2 seconds
        async_results = []
        for solution in solutions:
            for case in solution.test_cases:
                async_results.append(
                    p.schedule(worker, (solution.solution, case["input"]), timeout=2)
                )

        preds = []
        for result in tqdm(async_results, disable=len(async_results) == 0):
            pred = "ERROR"
            try:
                pred = result.result()
            except multiprocess.TimeoutError:
                pred = "ERROR Execution time of the program exceeds 2 second, terminate execution."
            except Exception as e:
                error = e
                pred = f"ERROR {error}"
            preds.append(pred)

    pred_idx = 0
    with ProcessPool(num_threads) as p:
        async_results = []
        for solution in solutions:
            cases = deepcopy(solution.test_cases)
            for case_id, test_case in enumerate(cases):
                input_data = test_case["input"]
                if "output" in test_case:
                    test_case["output"] = str(test_case["output"]).strip()
                    test_case["pred"] = preds[pred_idx]
                else:
                    test_case["pred"] = preds[pred_idx]
                    test_case["output"] = "[PLACEHOLDER]"
                async_results.append(
                    p.schedule(
                        check,
                        (
                            input_data,
                            preds[pred_idx],
                            test_case["output"],
                            None,
                        ),
                        timeout=2,
                    )
                )
                pred_idx += 1

        flags = []
        for result in tqdm(async_results, disable=len(async_results) == 0):
            flag = "[error] default pred"
            try:
                flag = result.result()
            except multiprocess.TimeoutError:
                flag = False
            except Exception as e:
                flag = False
            flags.append(flag)

    pred_idx = 0
    for solution in solutions:
        cases = deepcopy(solution.test_cases)
        for case_id, test_case in enumerate(cases):
            input_data = test_case["input"]
            test_case["expected"] = str(test_case["output"]).strip()
            test_case["output"] = preds[pred_idx]
            test_case["correct"] = flags[pred_idx]
            pred_idx += 1
        overall_flag = all([i["correct"] for i in cases])
        solution.test_cases = cases
        solution.correct = overall_flag
        solution.already_processed = True
        results[solution.question_id] = solution.to_json()

    return results


if __name__ == "__main__":
    print(check(None, "50000000000000.4949989318847656", "50000000000000.5", None))
    res = run_solutions(
        [
            (
                None,
                f"print({idx})",
                [{"input": "", "output": str(idx)} for _ in range(5)],
            )
            for idx in range(10)
        ]
    )
    print(res)
