import copy
import io
import pickle
import traceback
from concurrent.futures import TimeoutError
from contextlib import redirect_stdout
from functools import partial
from typing import Any, Dict, Optional

import multiprocess
import regex
from pebble import ProcessPool
from timeout_decorator import timeout


class GenericRuntime:
    GLOBAL_DICT = {}
    LOCAL_DICT = None
    HEADERS = []
    def __init__(self):
        self._global_vars = copy.copy(self.GLOBAL_DICT)
        self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None

        for c in self.HEADERS:
            self.exec_code(c)

    def exec_code(self, code_piece: str) -> None:
        if regex.search(r'(\s|^)?input\(', code_piece) or regex.search(r'(\s|^)?os.system\(', code_piece):
            raise RuntimeError()
        exec(code_piece, self._global_vars)

    def eval_code(self, expr: str) -> Any:
        return eval(expr, self._global_vars)

    def inject(self, var_dict: Dict[str, Any]) -> None:
        for k, v in var_dict.items():
            self._global_vars[k] = v

    @property
    def answer(self):
        return self._global_vars['answer']

class PythonExecutor:
    def __init__(
        self,
        runtime: Optional[Any] = None,
        get_answer_symbol: Optional[str] = None,
        get_answer_expr: Optional[str] = None,
        get_answer_from_stdout: bool = False,
    ) -> None:
        self.runtime = runtime if runtime else GenericRuntime()
        self.answer_symbol = get_answer_symbol
        self.answer_expr = get_answer_expr
        self.get_answer_from_stdout = get_answer_from_stdout

    def process_generation_to_code(self, gens: str):
        batch_code = []
        for g in gens:
            multiline_comments = False
            code = []
            for line in g.split('\n'):
                strip_line = line.strip()
                if strip_line.startswith("#"):
                    line = line.split("#", 1)[0] + "# comments"
                elif not multiline_comments and strip_line.startswith('"""') and strip_line.endswith('"""') and len(strip_line) >= 6:
                    line = line.split('"""', 1)[0] + '"""comments"""'
                elif not multiline_comments and strip_line.startswith('"""'):
                    multiline_comments = True
                elif multiline_comments and strip_line.endswith('"""'):
                    multiline_comments = False
                    line = ""
                if not multiline_comments:
                    code.append(line)
            batch_code.append(code)
        return batch_code

    @staticmethod
    def execute(
        code,
        get_answer_from_stdout = None,
        runtime = None,
        answer_symbol = None,
        answer_expr = None,
        timeout_length = 10,
    ):
        try:
            if get_answer_from_stdout:
                program_io = io.StringIO()
                with redirect_stdout(program_io):
                    timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
                program_io.seek(0)
                result = "".join(program_io.readlines()) # [-1]
            elif answer_symbol:
                timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
                result = runtime._global_vars[answer_symbol]
            elif answer_expr:
                timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
                result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
            else:
                timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
                result = timeout(timeout_length)(runtime.eval_code)(code[-1])
            concise_exec_info = ""
            exec_info = ""
            str(result)
            pickle.dumps(result) # serialization check
        except Exception:
            # traceback.print_exc()
            result = ''
            concise_exec_info = traceback.format_exc().split('\n')[-2]
            exec_info = traceback.format_exc()
            if get_answer_from_stdout and 'exec(code_piece, self._global_vars)' in exec_info:
                exec_info = exec_info.split('exec(code_piece, self._global_vars)')[-1].strip()
                msg = []
                for line in exec_info.split("\n"):
                    patt = regex.search(r'(?P<start>.*)File "(?P<file>.*)", line (?P<lno>\d+), (?P<end>.*)', line)
                    if patt is not None:
                        if '<module>' in patt.group('end'):
                            continue
                        fname = patt.group("file")
                        if "site-packages" in fname:
                            fname = f"site-packages{fname.split('site-packages', 1)[1]}"
                            line = f'{patt.group("start")}File "{fname}", {patt.group("end")}'
                        else:
                            line = f'{patt.group("start")}{patt.group("end")}'
                    else:
                        patt = regex.search(r'(?P<start>.*)(?P<file>/.*site-packages/.*\.py)(?P<end>.*)', line)
                        if patt is not None:
                            line = f'{patt.group("start")}site-packages{patt.group("file").split("site-packages", 1)[1]}{patt.group("end")}'
                    msg.append(line)
                exec_info = "\n".join(msg)
        return result, concise_exec_info, exec_info

    def aRePOy(self, code):
        return self.batch_aRePOy([code])[0]

    def batch_aRePOy(self, batch_code):
        all_code_snippets = self.process_generation_to_code(batch_code)
        all_exec_results = []
        executor = partial(
            self.execute,
            get_answer_from_stdout=self.get_answer_from_stdout,
            runtime=self.runtime,
            answer_symbol=self.answer_symbol,
            answer_expr=self.answer_expr,
            timeout_length=10,
        )
        with ProcessPool(max_workers=multiprocess.cpu_count()) as pool:
            iterator = pool.map(executor, all_code_snippets, timeout=10).result()

            while True:
                try:
                    result = next(iterator)
                    all_exec_results.append(result)
                except StopIteration:
                    break
                except TimeoutError:
                    all_exec_results.append(("", "Timeout Error", "Timeout Error"))
                except Exception as error:
                    print(error)
                    exit()

        batch_results = []
        for code, (result, concise_exec_info, exec_info) in zip(all_code_snippets, all_exec_results):
            metadata = {'code': code, 'exec_result': result, 'concise_exec_info': concise_exec_info, 'exec_info': exec_info}
            batch_results.append((result, metadata))
        return batch_results
