from __future__ import annotations

import re
from typing import Any, Dict


_RUNNER_PARAM_PATTERNS = {
    "worker_num": [
        r"worker_num\s*=\s*(\d+)",
        r"workers?\s*=\s*(\d+)",
    ],
    "max_tasks": [
        r"max_tasks\s*=\s*(\d+)",
    ],
    "run_time": [
        r"run_time\s*=\s*(\d+)",
        r"timeout\s*=\s*(\d+)",
    ],
}

_MAX_SEARCH_PIECES_PATTERNS = [
    r"(?:allow|use)\s+only\s+(\d+)\s+search\s+pieces?",
    r"(?:at\s+most|no\s+more\s+than|maximum|max)\s+(\d+)\s+search\s+pieces?",
]

_QUICK_RUN_PATTERNS = [
    r"\bquick run\b",
    r"\bshort run\b",
    r"\bsmall run\b",
    r"\bsmoke\b",
    r"\bsmoke test\b",
]

_NO_AUTO_REPARTITION_PATTERNS = [
    r"do not auto[- ]repartition",
    r"don't auto[- ]repartition",
    r"no auto[- ]repartition",
    r"avoid auto[- ]repartition",
]

_REPARTITION_EXCEPTION_PATTERNS = [
    r"unless .*stuck",
    r"unless .*fails",
    r"unless .*cannot run",
    r"unless .*cannot make progress",
]

def _search_int(patterns: list[str], text: str) -> int | None:
    for pattern in patterns:
        match = re.search(pattern, text, flags=re.IGNORECASE)
        if match:
            try:
                return int(match.group(1))
            except (TypeError, ValueError):
                continue
    return None


def extract_request_constraints(user_request: str) -> Dict[str, Any]:
    text = str(user_request or "").strip()
    if not text:
        return {}

    constraints: Dict[str, Any] = {}
    runner_params: Dict[str, int] = {}

    for key, patterns in _RUNNER_PARAM_PATTERNS.items():
        value = _search_int(patterns, text)
        if value is not None:
            runner_params[key] = value
    if runner_params:
        constraints["runner_params"] = runner_params

    max_search_pieces = _search_int(_MAX_SEARCH_PIECES_PATTERNS, text)
    if max_search_pieces is not None:
        constraints["max_search_pieces"] = max_search_pieces

    if any(re.search(pattern, text, flags=re.IGNORECASE) for pattern in _QUICK_RUN_PATTERNS):
        constraints["quick_run"] = True

    if any(re.search(pattern, text, flags=re.IGNORECASE) for pattern in _NO_AUTO_REPARTITION_PATTERNS):
        constraints["avoid_auto_repartition"] = True
        if any(re.search(pattern, text, flags=re.IGNORECASE) for pattern in _REPARTITION_EXCEPTION_PATTERNS):
            constraints["repartition_exception"] = "only_if_stuck"

    return constraints


def merge_runner_params(
    runner_params: Dict[str, Any] | None,
    request_constraints: Dict[str, Any] | None,
) -> Dict[str, Any]:
    merged = dict(runner_params) if isinstance(runner_params, dict) else {}
    constraints = request_constraints if isinstance(request_constraints, dict) else {}
    required = constraints.get("runner_params")
    if isinstance(required, dict):
        for key, value in required.items():
            merged[key] = value
    return merged
