from __future__ import annotations

import math
from typing import Any, Dict, List, Optional, Tuple

from server.run_manager import make_response


_ODD_MIRROR_BUILTINS = {"sin", "tanh", "sinh"}
_EVEN_MIRROR_BUILTINS = {"cos", "cosh"}
_COMPLEMENT_MIRROR_BUILTINS = {"sigmoid"}


def _error(code: str, message: str) -> Dict[str, Any]:
    return make_response("error", errors=[{"code": code, "message": message, "details": {}}])


def _safe_float(value: Any) -> Optional[float]:
    try:
        num = float(value)
    except (TypeError, ValueError):
        return None
    if not math.isfinite(num):
        return None
    return num


def _unique_piece_id(existing: set[str], base: str, idx: int) -> str:
    candidate = f"{base}_{idx}"
    if candidate not in existing:
        existing.add(candidate)
        return candidate
    suffix = 1
    while True:
        candidate = f"{base}_{idx}_{suffix}"
        if candidate not in existing:
            existing.add(candidate)
            return candidate
        suffix += 1


def _extract_piece(spec: Dict[str, Any], piece_id: str) -> Tuple[Optional[Dict[str, Any]], Optional[int]]:
    pieces = spec.get("domain", {}).get("pieces", [])
    if not isinstance(pieces, list):
        return None, None
    for idx, piece in enumerate(pieces):
        if not isinstance(piece, dict):
            continue
        if str(piece.get("piece_id")) == piece_id:
            return piece, idx
    return None, None


def _piece_copy(piece: Dict[str, Any], piece_id: str, start: float, end: float) -> Dict[str, Any]:
    out: Dict[str, Any] = {
        "piece_id": piece_id,
        "interval": {"start": start, "end": end},
    }
    excluded = piece.get("excluded_points")
    if isinstance(excluded, list) and excluded:
        filtered: List[float] = []
        for value in excluded:
            num = _safe_float(value)
            if num is None:
                continue
            if start <= num <= end:
                filtered.append(num)
        if filtered:
            out["excluded_points"] = filtered
    if "transform" in piece:
        out["transform"] = piece.get("transform")
    if "strategy" in piece:
        out["strategy"] = piece.get("strategy")
    return out


def _split_segments(
    start: float,
    end: float,
    split_points: List[float],
    min_piece_width: float,
) -> Optional[List[Tuple[float, float]]]:
    points = sorted({p for p in split_points if start < p < end})
    if not points:
        return None
    segments: List[Tuple[float, float]] = []
    prev = start
    for point in points:
        segments.append((prev, point))
        prev = point
    segments.append((prev, end))
    for seg_start, seg_end in segments:
        if seg_end - seg_start < min_piece_width:
            return None
    return segments


def _collect_counterexample_points(
    verify: Dict[str, Any],
    start: float,
    end: float,
) -> List[float]:
    counterexamples = verify.get("counterexamples", []) if isinstance(verify, dict) else []
    points: List[float] = []
    if isinstance(counterexamples, list):
        for entry in counterexamples:
            if not isinstance(entry, dict):
                continue
            value = _safe_float(entry.get("x"))
            if value is None:
                continue
            if start < value < end:
                points.append(value)
    if not points:
        mid = (start + end) * 0.5
        points = [mid]
    return points


def _sampling_patch(
    spec: Dict[str, Any],
    focus_points: List[float],
    focus_radius: float,
    edge_spike: bool,
) -> Dict[str, Any]:
    sampling = spec.get("sampling", {}) if isinstance(spec.get("sampling"), dict) else {}
    existing = sampling.get("focus_points", [])
    merged: List[float] = []
    if isinstance(existing, list):
        for value in existing:
            num = _safe_float(value)
            if num is not None:
                merged.append(num)
    merged.extend(focus_points)
    merged = sorted({round(p, 12) for p in merged})

    patch: Dict[str, Any] = {"focus_points": merged, "focus_radius": focus_radius}
    if edge_spike:
        patch["edge_focus"] = {"enabled": True, "ratio": 0.2}
    return patch


def _builtin_name(spec: Dict[str, Any]) -> Optional[str]:
    target = spec.get("target")
    if not isinstance(target, dict):
        return None
    function = target.get("function")
    if not isinstance(function, dict):
        return None
    builtin = function.get("builtin")
    if builtin is None:
        return None
    return str(builtin).strip().lower()


def _strategy_mode(piece: Dict[str, Any]) -> str:
    strategy = piece.get("strategy")
    if isinstance(strategy, dict):
        return str(strategy.get("mode", "search")).strip().lower()
    return "search"


def _find_mirror_source(
    spec: Dict[str, Any],
    piece_id: str,
    start: float,
    end: float,
) -> Optional[Tuple[str, Dict[str, Any], Dict[str, Any]]]:
    builtin = _builtin_name(spec)
    mirror_supported = (
        _ODD_MIRROR_BUILTINS
        | _EVEN_MIRROR_BUILTINS
        | _COMPLEMENT_MIRROR_BUILTINS
    )
    if builtin not in mirror_supported:
        return None
    pieces = spec.get("domain", {}).get("pieces", [])
    if not isinstance(pieces, list):
        return None

    width = end - start
    tol = max(1e-12, 1e-9 * max(1.0, abs(start), abs(end), abs(width)))
    for idx, piece in enumerate(pieces):
        if not isinstance(piece, dict):
            continue
        candidate_id = str(piece.get("piece_id", idx))
        if candidate_id == piece_id:
            continue
        if _strategy_mode(piece) == "mapped":
            continue
        interval = piece.get("interval")
        if not isinstance(interval, dict):
            continue
        src_start = _safe_float(interval.get("start"))
        src_end = _safe_float(interval.get("end"))
        if src_start is None or src_end is None:
            continue
        # Mirror around 0: [a, b] -> [-b, -a]
        if abs((src_end - src_start) - width) > tol:
            continue
        if abs(src_start + end) > tol or abs(src_end + start) > tol:
            continue

        input_tf = {"kind": "affine", "scale": -1.0, "shift": 0.0}
        if builtin in _ODD_MIRROR_BUILTINS:
            output_tf = {"kind": "affine", "scale": -1.0, "shift": 0.0}
        elif builtin in _COMPLEMENT_MIRROR_BUILTINS:
            output_tf = {"kind": "affine", "scale": -1.0, "shift": 1.0}
        else:
            output_tf = {"kind": "affine", "scale": 1.0, "shift": 0.0}
        return candidate_id, input_tf, output_tf
    return None


def suggest_strategy(payload: Dict[str, Any]) -> Dict[str, Any]:
    spec = payload.get("spec")
    piece_id = payload.get("piece_id")
    verify = payload.get("verify")
    if not isinstance(spec, dict):
        return _error("missing_spec", "spec is required")
    if piece_id is None:
        return _error("missing_piece_id", "piece_id is required")
    if not isinstance(verify, dict):
        return _error("missing_verify", "verify result is required")

    piece_id = str(piece_id)
    piece, piece_idx = _extract_piece(spec, piece_id)
    if piece is None or piece_idx is None:
        return _error("piece_not_found", f"piece_id '{piece_id}' not found in spec")

    interval = piece.get("interval", {})
    start = _safe_float(interval.get("start"))
    end = _safe_float(interval.get("end"))
    if start is None or end is None or start >= end:
        return _error("invalid_interval", "piece interval must have numeric start < end")

    attempt = int(payload.get("attempt", 0))
    max_attempts = int(payload.get("max_attempts", 3))
    split_policy = str(payload.get("split_policy") or "counterexample")
    baseline_method = str(payload.get("baseline_method") or "compare")

    if verify.get("pass") is True:
        return make_response(
            "ok",
            data={
                "decision": "stop",
                "spec_patch": None,
                "baseline_plan": None,
                "notes": "Verification already passes.",
                "rationale": "No changes needed.",
            },
        )

    stop_criteria = spec.get("stop_criteria", {}) if isinstance(spec.get("stop_criteria"), dict) else {}
    try:
        min_piece_width = float(stop_criteria.get("min_piece_width", 1e-6))
    except (TypeError, ValueError):
        min_piece_width = 1e-6
    max_pieces = int(stop_criteria.get("max_pieces", 20))
    width = end - start

    if attempt >= max_attempts:
        return make_response(
            "ok",
            data={
                "decision": "baseline" if baseline_method else "stop",
                "spec_patch": None,
                "baseline_plan": {
                    "method": baseline_method,
                    "degree": 8 + attempt * 4,
                    "next_degree": 12 + attempt * 4,
                }
                if baseline_method
                else None,
                "notes": "Reached max attempts for this piece.",
                "rationale": "Fallback after exhausting retry budget.",
            },
        )

    points = _collect_counterexample_points(verify, start, end)
    failure_modes = verify.get("failure_modes", []) if isinstance(verify, dict) else []
    failure_modes = failure_modes if isinstance(failure_modes, list) else []
    if any(mode in ("stagnation", "plateau", "imbalance") for mode in failure_modes):
        mirror = _find_mirror_source(spec, piece_id, start, end)
        if mirror is not None:
            source_piece_id, input_tf, output_tf = mirror
            pieces = spec.get("domain", {}).get("pieces", [])
            updated_pieces: List[Dict[str, Any]] = []
            for idx, existing in enumerate(pieces if isinstance(pieces, list) else []):
                if not isinstance(existing, dict):
                    continue
                candidate_id = str(existing.get("piece_id", idx))
                if candidate_id == piece_id:
                    patched = dict(existing)
                    patched["strategy"] = {
                        "mode": "mapped",
                        "source_piece_id": source_piece_id,
                        "input_transform": input_tf,
                        "output_transform": output_tf,
                    }
                    updated_pieces.append(patched)
                else:
                    updated_pieces.append(existing)
            return make_response(
                "ok",
                data={
                    "decision": "map_reuse",
                    "spec_patch": {"domain": {"pieces": updated_pieces}},
                    "baseline_plan": None,
                    "notes": "Reuse mirrored source piece via mapped transform.",
                    "rationale": "Symmetry-based reuse avoids redundant search on stagnating piece.",
                },
            )

    split_points: List[float] = []
    if split_policy == "counterexample":
        if attempt <= 0:
            split_points = points[:1]
        else:
            split_points = points[:2] if len(points) >= 2 else points[:1]
    else:
        split_points = points[:1]

    segments = None
    if width >= 2.0 * min_piece_width:
        segments = _split_segments(start, end, split_points, min_piece_width)

    pieces = spec.get("domain", {}).get("pieces", [])
    if not isinstance(pieces, list):
        pieces = []

    can_split = segments is not None
    if can_split:
        new_total = len(pieces) - 1 + len(segments)
        if new_total > max_pieces:
            can_split = False
        remaining_budget = max_pieces - len(pieces)
        if remaining_budget <= 1:
            can_split = False

    edge_spike = "edge_spike" in failure_modes if isinstance(failure_modes, list) else False
    focus_radius = min(1e-3, 0.01 * width)
    sampling_patch = _sampling_patch(spec, points, focus_radius, edge_spike)

    if can_split and segments:
        existing_ids = {str(p.get("piece_id")) for p in pieces if isinstance(p, dict)}
        new_pieces: List[Dict[str, Any]] = []
        for idx, (seg_start, seg_end) in enumerate(segments):
            new_id = _unique_piece_id(existing_ids, piece_id, idx)
            new_pieces.append(_piece_copy(piece, new_id, seg_start, seg_end))

        updated_pieces = []
        for idx, existing in enumerate(pieces):
            if idx == piece_idx:
                updated_pieces.extend(new_pieces)
            else:
                if isinstance(existing, dict):
                    updated_pieces.append(existing)

        return make_response(
            "ok",
            data={
                "decision": "split",
                "spec_patch": {
                    "domain": {"pieces": updated_pieces},
                    "sampling": sampling_patch,
                },
                "baseline_plan": {
                    "method": baseline_method,
                    "degree": 8 + attempt * 4,
                    "next_degree": 12 + attempt * 4,
                },
                "notes": "Split piece based on counterexamples.",
                "rationale": "Counterexample-driven split to localize error.",
            },
        )

    return make_response(
        "ok",
        data={
            "decision": "adjust_sampling" if attempt < max_attempts else "baseline",
            "spec_patch": {"sampling": sampling_patch},
            "baseline_plan": {
                "method": baseline_method,
                "degree": 8 + attempt * 4,
                "next_degree": 12 + attempt * 4,
            },
            "notes": "Split not feasible; adjust sampling or fallback to baseline.",
            "rationale": "Piece too small or exceeds max_pieces; refine sampling.",
        },
    )
