from __future__ import annotations

from typing import Any, Dict

import numpy as np

from schemas.approx_spec import FunctionSpec


def _horner_code(coeffs: np.ndarray) -> str:
    if coeffs.size == 0:
        return "0.0"
    coeffs = coeffs.astype(np.float64)
    code = f"{coeffs[-1]:.17g}"
    for c in coeffs[-2::-1]:
        code = f"({code}*x + {c:.17g})"
    return code


def _sympy_expr(function: Dict[str, Any]):
    try:
        import sympy  # type: ignore
    except ImportError as exc:
        raise ValueError("sympy is required for Taylor approximation.") from exc

    func_spec = FunctionSpec.model_validate(function)
    x = sympy.symbols("x")
    if func_spec.builtin:
        name = func_spec.builtin.strip().lower()
        if name == "exp":
            return sympy.exp(x)
        if name == "exp2":
            return sympy.exp(x * sympy.log(2))
        if name == "log":
            return sympy.log(x, 2)
        if name == "ln":
            return sympy.log(x)
        if name == "log2":
            return sympy.log(x, 2)
        if name == "sqrt":
            return sympy.sqrt(x)
        if name == "sin":
            return sympy.sin(x)
        if name == "cos":
            return sympy.cos(x)
        if name == "tanh":
            return sympy.tanh(x)
        if name == "sinh":
            return sympy.sinh(x)
        if name == "cosh":
            return sympy.cosh(x)
        if name == "erf":
            return sympy.erf(x)
        if name == "lgamma":
            return sympy.loggamma(x)
        if name == "sigmoid":
            return 1 / (1 + sympy.exp(-x))
        if name == "silu":
            return x / (1 + sympy.exp(-x))
        if name == "gelu":
            return sympy.Rational(1, 2) * x * (1 + sympy.erf(x / sympy.sqrt(2)))
        if name == "quickgelu":
            return x / (1 + sympy.exp(-sympy.Float("1.702") * x))
        if name in ("besselj0", "bessel_j0", "besselj"):
            return sympy.besselj(0, x)
        if name in ("bessely0", "bessel_y0", "bessely"):
            return sympy.bessely(0, x)
        if name == "sinpi":
            return sympy.sin(sympy.pi * x)
        if name == "cospi":
            return sympy.cos(sympy.pi * x)
        if name == "logsinpi":
            return sympy.log(sympy.sin(sympy.pi * x))
        raise ValueError(f"Unsupported builtin function '{func_spec.builtin}'.")

    if func_spec.expr:
        return sympy.sympify(func_spec.expr, convert_xor=True)
    raise ValueError("FunctionSpec must define builtin or expr.")


def taylor_approximate(
    function: Dict[str, Any],
    interval: Dict[str, float],
    degree: int,
    center: float | None = None,
) -> Dict[str, Any]:
    try:
        import sympy  # type: ignore
    except ImportError as exc:
        raise ValueError("sympy is required for Taylor approximation.") from exc

    start = float(interval["start"])
    end = float(interval["end"])
    if start >= end:
        raise ValueError("interval.start must be less than interval.end")
    if center is None:
        center = 0.5 * (start + end)

    x = sympy.symbols("x")
    expr = _sympy_expr(function)
    series = sympy.series(expr, x, center, degree + 1).removeO()
    poly = sympy.expand(series)
    poly_expr = sympy.Poly(poly, x)
    coeffs_desc = poly_expr.all_coeffs()
    coeffs = np.array([float(c) for c in coeffs_desc[::-1]], dtype=np.float64)

    f_numeric = sympy.lambdify(x, expr, modules=["math"])
    p_numeric = sympy.lambdify(x, poly, modules=["math"])
    x_eval = np.linspace(start, end, max(256, degree * 64), dtype=np.float64)
    y_eval = np.array([f_numeric(float(val)) for val in x_eval], dtype=np.float64)
    y_hat = np.array([p_numeric(float(val)) for val in x_eval], dtype=np.float64)
    abs_err = np.abs(y_hat - y_eval)
    metrics = {
        "max_error": float(np.max(abs_err)),
        "avg_error": float(np.mean(abs_err)),
        "p99_error": float(np.quantile(abs_err, 0.99)),
    }

    return {
        "method": "taylor",
        "center": float(center),
        "coefficients": coeffs.tolist(),
        "polynomial_code": _horner_code(coeffs),
        "metrics": metrics,
        "evaluation_ops": int(degree),
    }
