import contextlib
import io
import multiprocessing as mp
import threading
import time
from enum import Enum
from typing import Any, Optional

import jq
from pydantic import BaseModel, JsonValue


class ProgramKind(Enum):
    Jq = "jq"
    Python = "python"


class Program(BaseModel):
    kind: ProgramKind
    code: str


def compiles(program: Program) -> bool:
    match program.kind:
        case ProgramKind.Jq:
            try:
                return JqExecutionEngine.compile(program.code) is not None
            except ValueError:
                return False
        case ProgramKind.Python:
            try:
                exec(program.code, {}, {})
                return True
            except Exception:
                return False


def run(program: Program, input: JsonValue) -> JsonValue:
    match program.kind:
        case ProgramKind.Jq:
            try:
                result = JqExecutionEngine.execute(program.code, input)
                return result
            except TimeoutError:
                return None
            except JqExecutionError:
                return None
            except Exception:
                return None
        case ProgramKind.Python:
            try:
                return PythonExecutionEngine.execute(program.code, input)
            except Exception:
                return None


class JqExecutionError(Exception):

    def __init__(self, message: str, expression: str = None):
        super().__init__(message)
        self.expression = expression


class JqExecutionEngine:
    """Execution engine for jq using a single persistent worker process.

    Features:
    - Persistent worker subprocess (dramatically faster than spawning per call).
    - In-worker compile cache keyed by expression string.
    - Per-call timeout enforced from the parent side.
    - Automatic worker restart on crash.
    """

    timeout: float = 32.0

    # Synchronization & state
    _lock = threading.Lock()  # guards high-level execute path
    _worker: Optional[mp.Process] = None
    _requests: Optional[mp.Queue] = None
    _responses: Optional[mp.Queue] = None

    @staticmethod
    def compile(expression: str) -> Any:
        return jq.compile(expression)

    @staticmethod
    def execute(expression: str, input: JsonValue) -> Any:
        return JqExecutionEngine._execute_worker(expression, input)

    # --- Worker lifecycle --------------------------------------------------
    @staticmethod
    def _worker_loop(requests: mp.Queue, responses: mp.Queue):
        while True:
            task = requests.get()
            if task is None:
                break
            expr, value = task
            try:
                res = jq.all(expr, value)
                responses.put(("ok", res))
            except Exception as e:  # noqa: BLE001
                responses.put(("err", str(e)))

    @staticmethod
    def _ensure_worker():
        if (
            JqExecutionEngine._worker is not None
            and JqExecutionEngine._worker.is_alive()
        ):
            return
        JqExecutionEngine._stop_worker()
        ctx = mp.get_context("spawn")  # explicit for cross-platform consistency
        JqExecutionEngine._requests = ctx.Queue()
        JqExecutionEngine._responses = ctx.Queue()
        proc = ctx.Process(
            target=JqExecutionEngine._worker_loop,
            args=(JqExecutionEngine._requests, JqExecutionEngine._responses),
            daemon=True,
        )
        proc.start()
        JqExecutionEngine._worker = proc

    @staticmethod
    def _stop_worker():
        if JqExecutionEngine._worker is not None:
            try:
                if JqExecutionEngine._requests is not None:
                    try:
                        JqExecutionEngine._requests.put_nowait(None)
                    except Exception:
                        pass
                JqExecutionEngine._worker.terminate()
            except Exception:
                pass
            finally:
                try:
                    JqExecutionEngine._worker.join(timeout=0.2)
                except Exception:
                    pass
        JqExecutionEngine._worker = None
        JqExecutionEngine._requests = None
        JqExecutionEngine._responses = None

    @staticmethod
    def _execute_worker(expression: str, value: JsonValue):
        with JqExecutionEngine._lock:
            start = time.monotonic()
            JqExecutionEngine._ensure_worker()
            if (
                JqExecutionEngine._requests is None
                or JqExecutionEngine._responses is None
            ):
                # Should not happen; retry after reinitializing worker.
                JqExecutionEngine._ensure_worker()
                if (
                    JqExecutionEngine._requests is None
                    or JqExecutionEngine._responses is None
                ):
                    raise RuntimeError("jq worker initialization failed")
            try:
                JqExecutionEngine._requests.put((expression, value))
                remaining = JqExecutionEngine.timeout - (time.monotonic() - start)
                if remaining <= 0:
                    raise TimeoutError
                status, payload = JqExecutionEngine._responses.get(timeout=remaining)
            except TimeoutError:
                # Kill worker & fall back
                JqExecutionEngine._stop_worker()
                raise TimeoutError(
                    f"jq execution exceeded {JqExecutionEngine.timeout:.2f}s for expression: {expression[:80]}"
                )
            except Exception:
                # Unknown worker failure; restart once then propagate
                JqExecutionEngine._stop_worker()
                raise
        if status == "ok":
            return payload
        if status == "err":
            raise JqExecutionError(payload, expression=expression)
        raise RuntimeError("jq execution error")


class PythonExecutionEngine:

    @staticmethod
    def compile(code: str) -> Any:
        glbl = {}
        try:
            exec(code.strip(), glbl)
        except Exception:
            return None
        return next(v for k, v in glbl.items() if k != "__builtins__" and callable(v))

    @staticmethod
    def execute(code: str, input: JsonValue) -> Any:
        func = PythonExecutionEngine.compile(code)
        if not func:
            return None
        return PythonExecutionEngine.execute_compiled(func, input)

    @staticmethod
    def execute_compiled(compiled: Any, input: JsonValue) -> Any:
        if not compiled:
            return None
        if not callable(compiled):
            return None
        result_box = dict()

        def _runner():
            silent = io.StringIO()
            try:
                with contextlib.redirect_stdout(silent):
                    result_box["value"] = compiled(input)
            except Exception as e:
                result_box["err"] = e

        t = threading.Thread(target=_runner, daemon=True)
        t.start()
        t.join(16)
        if t.is_alive():
            return None
        if "err" in result_box:
            raise result_box["err"]
        return result_box.get("value")
