from __future__ import annotations

import math
from typing import Any, Dict, List, Optional

from python_src.spec_io import load_spec_dict, normalize_spec_dict
from schemas.mcp_contract import make_response


_PERIODIC = {
    "sin": 2.0 * math.pi,
    "cos": 2.0 * math.pi,
    "sinpi": 2.0,
    "cospi": 2.0,
}

_SYMMETRY = {
    "sin": "odd",
    "sinpi": "odd",
    "cos": "even",
    "cospi": "even",
    "tanh": "odd",
    "sinh": "odd",
    "cosh": "even",
}


_COMPLEMENT_SYMMETRY = {
    "sigmoid": "f(-x)=1-f(x)",
}


_MONOTONICITY = {
    "exp": "increasing",
    "exp2": "increasing",
    "log": "increasing",
    "log1p": "increasing",
    "log2": "increasing",
    "sqrt": "increasing",
    "sigmoid": "increasing",
    "silu": "mostly_increasing_with_negative_tail_turning_point",
    "tanh": "increasing",
    "sinh": "increasing",
    "cosh": "decreasing_then_increasing_minimum_at_0",
    "erf": "increasing",
}


def _builtin_name(spec: Dict[str, Any]) -> Optional[str]:
    function = ((spec.get("target") or {}).get("function") or {})
    builtin = function.get("builtin")
    if not builtin:
        return None
    name = str(builtin).strip().lower()
    aliases = {
        "ln": "log",
        "besselj0": "bessel_j0",
        "besselj": "bessel_j0",
        "bessely0": "bessel_y0",
        "bessely": "bessel_y0",
    }
    return aliases.get(name, name)


def _inside(value: float, start: float, end: float) -> bool:
    return start <= value <= end


def _range_points(start: float, end: float, step: float, origin: float = 0.0) -> List[float]:
    if step <= 0 or not math.isfinite(step):
        return []
    first = math.ceil((start - origin) / step)
    last = math.floor((end - origin) / step)
    if last - first > 128:
        return []
    return [origin + k * step for k in range(first, last + 1)]


def _special_points(name: Optional[str], start: float, end: float) -> List[float]:
    points: List[float] = []
    if not (math.isfinite(start) and math.isfinite(end)):
        return points
    if name in {
        "sin",
        "cos",
        "tanh",
        "sinh",
        "cosh",
        "sigmoid",
        "silu",
        "gelu",
        "quickgelu",
    } and _inside(0.0, start, end):
        points.append(0.0)
    if name == "sin":
        points.extend(_range_points(start, end, math.pi / 2.0))
    elif name == "cos":
        points.extend(_range_points(start, end, math.pi / 2.0))
    elif name == "sinpi":
        points.extend(_range_points(start, end, 0.5))
    elif name == "cospi":
        points.extend(_range_points(start, end, 0.5))
    elif name in {"log", "log2"} and _inside(0.0, start, end):
        points.append(0.0)
    elif name == "log1p" and _inside(-1.0, start, end):
        points.append(-1.0)
    elif name == "sqrt" and _inside(0.0, start, end):
        points.append(0.0)
    unique = sorted({float(p) for p in points if _inside(float(p), start, end)})
    return unique


def _domain_warning(name: Optional[str], start: float, end: float) -> List[str]:
    warnings: List[str] = []
    if name in {"log", "log2"} and start <= 0.0:
        warnings.append("log_domain_requires_x_gt_0")
    if name == "log1p" and start <= -1.0:
        warnings.append("log1p_domain_requires_x_gt_minus_1")
    if name == "sqrt" and start < 0.0:
        warnings.append("sqrt_domain_requires_x_ge_0")
    if name == "bessel_y0" and start <= 0.0:
        warnings.append("bessel_y0_has_singularity_at_zero")
    return warnings


def inspect_math(payload: Dict[str, Any]) -> Dict[str, Any]:
    try:
        spec = load_spec_dict(payload.get("spec"), payload.get("spec_path"))
        spec = normalize_spec_dict(spec)
    except Exception as exc:
        return make_response(
            "error",
            errors=[{"code": "spec_invalid", "message": str(exc), "details": {}}],
        )

    name = _builtin_name(spec)
    pieces = ((spec.get("domain") or {}).get("pieces") or [])
    summaries: List[Dict[str, Any]] = []
    all_focus: List[float] = []
    warnings: List[str] = []
    for piece in pieces:
        interval = piece.get("interval") or {}
        start = float(interval.get("start"))
        end = float(interval.get("end"))
        special = _special_points(name, start, end)
        all_focus.extend(special)
        piece_warnings = _domain_warning(name, start, end)
        warnings.extend(piece_warnings)
        summaries.append(
            {
                "piece_id": piece.get("piece_id"),
                "interval": {"start": start, "end": end},
                "special_points": special,
                "warnings": piece_warnings,
            }
        )

    data = {
        "target": ((spec.get("target") or {}).get("name")),
        "builtin": name,
        "properties": {
            "period": _PERIODIC.get(name),
            "symmetry": _SYMMETRY.get(name),
            "complement_symmetry": _COMPLEMENT_SYMMETRY.get(name),
            "monotonicity": _MONOTONICITY.get(name),
        },
        "pieces": summaries,
        "suggested_focus_points": sorted(set(all_focus)),
        "warnings": sorted(set(warnings)),
    }
    return make_response("ok", data=data)
