# codes adapted from https://github.com/deepseek-ai/DeepSeek-Prover-V1.5.git
# all copyright to https://github.com/deepseek-ai/DeepSeek-Prover-V1.5.git
import os
import time
import signal
import psutil
import json
import ctypes
import resource
import tempfile
import traceback
import threading
import subprocess
import multiprocessing as mp
from pprint import pprint

import numpy as np

from ..lean.ast_parser import lean4_parser
from ..workers import ProcessScheduler
from ..utils import AttrDict

HOME_DIR = os.path.expanduser('~')
DEFAULT_LAKE_PATH = f'{HOME_DIR}/.elan/bin/lake'

DEFAULT_LEAN_WORKSPACE = '/userhomes/minsu/symr/mathlib4'


def verify_lean4_file(code, lake_path=DEFAULT_LAKE_PATH, lean_workspace=DEFAULT_LEAN_WORKSPACE, last_env=None,
                      verbose=False, timeout=300, allTactics=True, ast=False, premises=False, tactics=False):
    command = dict(cmd=code, allTactics=allTactics, ast=ast, tactics=tactics, premises=premises)
    if last_env is not None:
        command.update(env=last_env)
    message_str = json.dumps(command, ensure_ascii=False)
    if verbose:
        print(message_str)
    start_time = time.time()
    system_messages = ''
    try:
        with tempfile.TemporaryFile(mode='w+', encoding='utf-8') as temp_file:
            temp_file.write(message_str + "\r\n\r\n")
            temp_file.seek(0)
            #print(f'Call repl exe')
            outputs = subprocess.run([lake_path, "exe", 'repl'], stdin=temp_file, capture_output=True, text=True,
                                     cwd=lean_workspace, timeout=timeout)

        result = json.loads(outputs.stdout)
        #print('result',result)
        ast_results = lean4_parser(code, result['ast']) if 'ast' in result and result['ast'] else {}
        result = {
            "sorries": result.get('sorries', []),
            "tactics": result.get('tactics', []),
            "errors": [m for m in result.get('messages', []) if m['severity'] == 'error'],
            "warnings": [m for m in result.get('messages', []) if m['severity'] == 'warning'],
            "infos": [m for m in result.get('messages', []) if m['severity'] == 'info'],
            "system_messages": system_messages,
            "system_errors": None,
            "ast": ast_results,
            "verified_code": code,
        }
        result['pass'] = not result['errors']
        result['complete'] = result['pass'] and not result['sorries'] and not any(
            "declaration uses 'sorry'" in warning['data'] or 'failed' in warning['data'] for warning in
            result['warnings'])
    except:
        result = {
            "pass": False,
            "complete": False,
            "system_errors": traceback.format_exc(),
            "system_messages": system_messages
        }
    result['verify_time'] = time.time() - start_time
    #print('result', result)
    return result


class Lean4ServerProcess(mp.Process):
    def __init__(self, idx, task_queue, request_statuses, lock, extra_args=AttrDict()):
        super().__init__()
        self.idx = idx
        self.task_queue = task_queue
        self.request_statuses = request_statuses
        self.lock = lock
        self.extra_args = extra_args

        self.timeout = extra_args.get('timeout', 300)
        self.memory_limit = extra_args.get('memory_limit', -1)
        self.last_output_time = mp.Value(ctypes.c_double, time.time())
        self.complete_count = mp.Value(ctypes.c_int, 0)

    def run(self):
        # if self.memory_limit > 0:
        #     resource.setrlimit(
        #         resource.RLIMIT_AS,
        #         (self.memory_limit * (1000 ** 3), self.memory_limit * (1000 ** 3))
        #     )
        while True:
            inputs = self.task_queue.get()
            if inputs is None:  # Terminate when receiving None
                break
            for _, request_id, task in inputs:
                if isinstance(task, str):
                    task = dict(code=task)
                if 'timeout' not in task:
                    task['timeout'] = self.timeout
                result = verify_lean4_file(**task,allTactics=self.extra_args.get('allTactics', False))
                if len(result['system_messages']) > 0:
                    retry_start_time = time.time()
                    while ('lean::exception: failed to create thread' in result['system_messages'] or
                           'std::bad_alloc: std::bad_alloc' in result['system_messages'] or
                           'Cannot allocate memory' in result['system_messages']) \
                            and time.time() - retry_start_time < self.timeout:
                        time.sleep(0.1)
                        result = verify_lean4_file(**task)
                with self.lock:
                    self.request_statuses[request_id] = result
                    self.last_output_time.value = time.time()
                    self.complete_count.value += 1


class Lean4ServerScheduler(ProcessScheduler):
    def __init__(self, max_concurrent_requests=4, timeout=300, memory_limit=10, name='verifier',extra_args=None):
        super().__init__(batch_size=1, name=name)

        process_extra_args = AttrDict(timeout=timeout, memory_limit=memory_limit)
        # If extra_args is provided, update process_extra_args with its values
        if extra_args:
            process_extra_args.update(extra_args)
        self.processes = [
            Lean4ServerProcess(
                idx=idx,
                task_queue=self.task_queue,
                request_statuses=self.request_statuses,
                lock=self.lock,
                extra_args=process_extra_args
            )
            for idx in range(max_concurrent_requests)
        ]
        for p in self.processes:
            p.start()
        print(f'Complete launching {len(self.processes)} LeanServerProcesses')

        self.timeout = timeout
        self._running_monitor = mp.Value(ctypes.c_bool, True)
        self._last_complete_count = mp.Value(ctypes.c_int, 0)



        #
        self._monitor_process = mp.Process(target=self._monitor)

        self._monitor_process.start()

    """
    def _monitor(self):
        while self._running_monitor.value:
            time.sleep(1.0)
            subprocess.run(['killall', 'repl', f'--older-than={int(self.timeout) + 10}s'], capture_output=True)
    """

    def _monitor(self) -> None:
        """
        Every second, kill any process named 'repl' that is older than
        (timeout + 10) seconds.  Works on Linux, macOS, Windows, and inside
        containers because it uses psutil instead of the killall binary.
        """
        grace_period = self.timeout + 15
        while self._running_monitor.value:
            cutoff = time.time() - grace_period
            time.sleep(1.0)

            for proc in psutil.process_iter(("name", "create_time")):
                # Adjust the name test if your REPL process has a full path
                if proc.info["name"] == "repl" and proc.info["create_time"] < cutoff:
                    try:
                        proc.send_signal(signal.SIGTERM)
                        print("kill")
                    except (psutil.NoSuchProcess, psutil.AccessDenied):
                        # It vanished or we lack permission – ignore and continue
                        pass

    def close(self):
        super().close()
        for p in self.processes:
            p.join()
        self._running_monitor.value = False
        self._monitor_process.join()
        print(f'All {len(self.processes)} LeanServerProcesses stopped')


if __name__ == '__main__':
    code = open('mathlib4/.lake/packages/REPL/test/aime_1983_p9.in').read()
    lean4_scheduler = Lean4ServerScheduler(max_concurrent_requests=1, timeout=300, memory_limit=10, name='verifier')
    request_id_list = lean4_scheduler.submit_all_request([dict(code=code, ast=True, tactics=True)])
    outputs_list = lean4_scheduler.get_all_request_outputs(request_id_list)
    lean4_scheduler.close()
    pprint(outputs_list)