#  
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Classes and functions related to code interpreter.

"""
import io
import os

import multiprocess
from multiprocessing import Pool
import multiprocessing
from pebble import ProcessPool  # pip install Pebble
from typing import Union, Any, Dict, Optional
import re

import pickle
import traceback
import copy
import datetime
from tqdm import tqdm
from concurrent.futures import TimeoutError
from functools import partial
from timeout_decorator import timeout  # pip install timeout-decorator
from contextlib import redirect_stdout

import logging


logger = logging.getLogger(__name__)


class GenericRuntime:
    """ A simple runtime that is initialized with empty global and local variables. Used for evaluation of python code from scratch.
    """
    GLOBAL_DICT = {}
    LOCAL_DICT = None
    HEADERS = []
    def __init__(self):
        self._global_vars = copy.copy(self.GLOBAL_DICT)
        self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None

        for c in self.HEADERS:
            self.exec_code(c)

    def exec_code(self, code_piece: str) -> None:
        if re.search(r'(\s|^)?input\(', code_piece) or re.search(r'(\s|^)?os.system\(', code_piece):
            raise RuntimeError()
        exec(code_piece, self._global_vars)

    def __call__(self, code_piece: str) -> None:
        self.exec_code(code_piece)
        
    def eval_code(self, expr: str) -> Any:
        return eval(expr, self._global_vars)
    
    def inject(self, var_dict: Dict[str, Any]) -> None:
        for k, v in var_dict.items():
            self._global_vars[k] = v
    
    @property
    def answer(self):
        return self._global_vars['answer']

    def add_vars(self, global_vars: dict = None, local_vars: dict = None):
        if isinstance(global_vars, dict):
            self._global_vars.update(global_vars)
        elif global_vars is not None:
            raise TypeError("Expect global_vars to be a dict; got {}".format(type(global_vars)))
        if isinstance(local_vars, dict):
            self._local_vars.update(local_vars)
        elif local_vars is not None:
            raise TypeError("Expect local_vars to be a dict; got {}".format(type(local_vars)))


class PythonExecutor:
    """ A python executor that uses the input runtime or GenericRuntime to evaluate a batch of codes passed in string.
    If several chunks of codes are to be evaluated sequentially, pass it as a list inside the batch list.

    Args:
        runtime: could be a runtime class, or an initialization method that returns an instance of runtime class. The 
            runtime class should have an exec_code method.
        share_runtime_in_batch: bool, whether the runtime should be shared by all codes in the batch.
        answer_symbol: key that corresponds to the answer in global vars.
        answer_expr: given the global vars, the expression to evaluate in order to get the final answer
        get_answer_from_stdout: get the print info as the final answer
        timeout_length: time permitted for each chunk of code
    """
    def __init__(
        self,
        runtime: Optional[Any] = None,
        share_runtime_in_batch: bool = True,
        answer_symbol: Optional[str] = None,
        answer_expr: Optional[str] = None,
        get_answer_from_stdout: bool = False,
        timeout_length: int = 5,
    ) -> None:
        self.share_runtime_in_batch = share_runtime_in_batch
        if runtime is not None:  # lazy check runtime
            assert callable(runtime), TypeError("")
        if self.share_runtime_in_batch:
            self.runtime = runtime() if callable(runtime) else GenericRuntime()
        else:
            self.runtime = runtime if callable(runtime) else GenericRuntime
        self.answer_symbol = answer_symbol
        self.answer_expr = answer_expr
        self.get_answer_from_stdout = get_answer_from_stdout
        self.timeout_length = timeout_length

    def process_generation_to_code(self, gens):
        return [g.split('\n') if isinstance(g, str) else self.process_generation_to_code(g) for g in gens]

    def sequential_execute(
        self,
        code,
        get_answer_from_stdout = None,
        runtime = None,
        answer_symbol = None,
        answer_expr = None,
        timeout_length = 5,
    ):
        if runtime is None:
            runtime = GenericRuntime()
        elif not self.share_runtime_in_batch:
            runtime = runtime()
        
        if isinstance(code, (list, tuple)):
            if len(code) == 0:
                return "", "RuntimeError: no code is detected."
            if isinstance(code[0], str):
                return self.execute(code, get_answer_from_stdout, runtime, answer_symbol, answer_expr, timeout_length)
            elif isinstance(code[0], (list, tuple)):
                result, exec_info = [], []
                for _code in code:
                    _result, _exec_info = self.execute(_code, get_answer_from_stdout, runtime, answer_symbol, answer_expr, timeout_length)
                    result.append(_result)
                    exec_info.append(_exec_info)

                return result, exec_info
        else:
            return "", "TypeError: input code must be a str or List[str]; got {}".format(type(code))

    @staticmethod
    def execute(
        code,
        get_answer_from_stdout = None,
        runtime = None,
        answer_symbol = None,
        answer_expr = None,
        timeout_length = 5,
    ):
        try:
            if get_answer_from_stdout:
                program_io = io.StringIO()
                with redirect_stdout(program_io):
                    timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
                program_io.seek(0)
                # result = program_io.readlines()[-1]
                results = program_io.readlines()  # returns a list
                result = "".join(results)  # in case of multiple line results, return all lines
            elif answer_symbol:
                timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
                result = runtime._global_vars[answer_symbol]
            elif answer_expr:
                timeout(timeout_length)(runtime.exec_code)('\n'.join(code))
                result = timeout(timeout_length)(runtime.eval_code)(answer_expr)
            else:
                timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1]))
                result = timeout(timeout_length)(runtime.eval_code)(code[-1])
            exec_info = "Done"
            str(result)
            pickle.dumps(result) # serialization check
        except:
            result = ''
            exec_info = traceback.format_exc().split('\n')[-2]
        return result, exec_info

    def apply(self, code):
        return self.batch_apply([code])[0]

    def batch_apply(self, batch_code):
        all_code_snippets = self.process_generation_to_code(batch_code)

        timeout_cnt = 0
        all_exec_results = []
        with ProcessPool(max_workers=min(len(all_code_snippets), multiprocessing.cpu_count(), 32)) as pool:
            executor = partial(
                self.sequential_execute,
                get_answer_from_stdout=self.get_answer_from_stdout,
                runtime=self.runtime,
                answer_symbol=self.answer_symbol,
                answer_expr=self.answer_expr,
                timeout_length=self.timeout_length,
            )
            # future = pool.map(executor, all_code_snippets, timeout=self.timeout_length)
            future = pool.map(executor, all_code_snippets, timeout=60)  # let each process controls its own timeout
            iterator = future.result()

            if len(all_code_snippets) > 100:  
                progress_bar = tqdm(total=len(all_code_snippets), desc="Execute")  
            else:  
                progress_bar = None 

            while True:
                try:
                    result = next(iterator)
                    all_exec_results.append(result)
                except StopIteration:
                    break
                except TimeoutError as error:
                    print(error)
                    all_exec_results.append(("", "Timeout Error"))
                    timeout_cnt += 1
                except Exception as error:
                    print(error)
                    exit()
                if progress_bar is not None:
                    progress_bar.update(1) 
            
            if progress_bar is not None:
                progress_bar.close() 

        batch_results = []
        for code, (result, exec_info) in zip(all_code_snippets, all_exec_results):
            batch_results.append((result, exec_info))
        return batch_results

    def __call__(self, batch_code):
        return self.batch_apply(batch_code)


def test_PythonExecutor():
    # test batch apply
    batch_code = [
        "import math\n\nprint(math.sqrt(3))",
        ["import math\n\na = math.sqrt(5)\nprint(a)", "b = a / 5\nprint(b)"],
        "raise RuntimeError('test error')"
    ]
    executor = PythonExecutor(get_answer_from_stdout=True)
    batch_results = executor.batch_apply(batch_code)

    for result in batch_results:
        print(result)


if __name__ == '__main__':
    test_PythonExecutor()