from __future__ import annotations

import copy
from typing import Any, Dict, List, Tuple

DEFAULT_N_DATA = 5000
MIN_N_DATA = 3000
MAX_N_DATA = 6000

DEFAULT_MAX_WALL_TIME_S = 172800
MIN_MAX_WALL_TIME_S = 43200


def _to_int(value: Any) -> int | None:
    try:
        return int(value)
    except (TypeError, ValueError):
        return None


def _adjust_warning(
    field: str,
    original: Any,
    adjusted: Any,
    reason: str,
) -> Dict[str, Any]:
    return {
        "code": "spec_policy_adjusted",
        "message": f"Adjusted {field} according to agent policy.",
        "details": {
            "field": field,
            "original": original,
            "adjusted": adjusted,
            "reason": reason,
        },
    }


def _notice_warning(field: str, message: str, details: Dict[str, Any]) -> Dict[str, Any]:
    return {
        "code": "spec_policy_notice",
        "message": message,
        "details": {"field": field, **details},
    }


def apply_agent_spec_policy(spec: Dict[str, Any]) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
    """Apply agent-side hard policy bounds to a validated spec.

    Policy:
    - sampling.n_data default 5000, clamped to [3000, 6000]
    - stop_criteria.max_wall_time_s default 172800, minimum 43200
    - keep piece strategy modes unchanged (search/baseline/mapped are caller-owned decisions)
    """
    normalized = copy.deepcopy(spec) if isinstance(spec, dict) else {}
    warnings: List[Dict[str, Any]] = []

    sampling = normalized.get("sampling")
    if not isinstance(sampling, dict):
        sampling = {}
        normalized["sampling"] = sampling

    raw_n_data = sampling.get("n_data")
    parsed_n_data = _to_int(raw_n_data)
    if parsed_n_data is None:
        sampling["n_data"] = DEFAULT_N_DATA
        warnings.append(
            _adjust_warning(
                field="sampling.n_data",
                original=raw_n_data,
                adjusted=DEFAULT_N_DATA,
                reason="missing_or_invalid",
            )
        )
    else:
        adjusted_n_data = max(MIN_N_DATA, min(MAX_N_DATA, parsed_n_data))
        if adjusted_n_data != parsed_n_data:
            sampling["n_data"] = adjusted_n_data
            warnings.append(
                _adjust_warning(
                    field="sampling.n_data",
                    original=parsed_n_data,
                    adjusted=adjusted_n_data,
                    reason="clamped",
                )
            )

    stop_criteria = normalized.get("stop_criteria")
    if not isinstance(stop_criteria, dict):
        stop_criteria = {}
        normalized["stop_criteria"] = stop_criteria

    raw_timeout = stop_criteria.get("max_wall_time_s")
    parsed_timeout = _to_int(raw_timeout)
    if parsed_timeout is None:
        stop_criteria["max_wall_time_s"] = DEFAULT_MAX_WALL_TIME_S
        warnings.append(
            _adjust_warning(
                field="stop_criteria.max_wall_time_s",
                original=raw_timeout,
                adjusted=DEFAULT_MAX_WALL_TIME_S,
                reason="missing_or_invalid",
            )
        )
    else:
        adjusted_timeout = max(MIN_MAX_WALL_TIME_S, parsed_timeout)
        if adjusted_timeout != parsed_timeout:
            stop_criteria["max_wall_time_s"] = adjusted_timeout
            warnings.append(
                _adjust_warning(
                    field="stop_criteria.max_wall_time_s",
                    original=parsed_timeout,
                    adjusted=adjusted_timeout,
                    reason="raised_to_minimum",
                )
            )

    domain = normalized.get("domain")
    pieces = domain.get("pieces") if isinstance(domain, dict) else None
    if isinstance(pieces, list) and pieces:
        search_modes = {"search", "auto", "dag"}
        has_search = False
        original_modes: List[Any] = []
        for piece in pieces:
            if not isinstance(piece, dict):
                original_modes.append(None)
                continue
            strategy = piece.get("strategy")
            mode = None
            if isinstance(strategy, dict):
                mode = str(strategy.get("mode", "search")).strip().lower()
            elif strategy is None:
                mode = "search"
            original_modes.append(mode)
            if mode in search_modes:
                has_search = True
        if not has_search:
            warnings.append(
                _notice_warning(
                    field="domain.pieces[*].strategy.mode",
                    message="Spec contains no search-mode piece; no automatic strategy rewrite was applied.",
                    details={"modes": original_modes},
                )
            )

    return normalized, warnings
