from __future__ import annotations

from typing import Any, Dict, List

import numpy as np

from schemas.approx_spec import FunctionSpec, PrecisionSpec, TargetSpec
from worker.spec_loader import build_fun_eval


def _build_callable(function: Dict[str, Any]):
    func_spec = FunctionSpec.model_validate(function)
    target = TargetSpec(name=func_spec.builtin or "expr", function=func_spec, vars=["x"])
    precision = PrecisionSpec()
    return build_fun_eval(target, precision)


def _horner_code(coeffs: np.ndarray) -> str:
    if coeffs.size == 0:
        return "0.0"
    coeffs = coeffs.astype(np.float64)
    code = f"{coeffs[-1]:.17g}"
    for c in coeffs[-2::-1]:
        code = f"({code}*x + {c:.17g})"
    return code


def _initial_points(start: float, end: float, degree: int) -> np.ndarray:
    count = degree + 2
    k = np.arange(count, dtype=np.float64)
    nodes = np.cos((2 * k + 1) / (2 * count) * np.pi)
    return 0.5 * (start + end) + 0.5 * (end - start) * nodes[::-1]


def _solve_remez(x: np.ndarray, y: np.ndarray, degree: int) -> np.ndarray:
    signs = np.array([(-1) ** i for i in range(x.size)], dtype=np.float64)
    vander = np.vander(x, N=degree + 1, increasing=True)
    mat = np.column_stack([vander, signs])
    try:
        coeffs = np.linalg.solve(mat, y)
    except np.linalg.LinAlgError:
        coeffs, *_ = np.linalg.lstsq(mat, y, rcond=None)
    return coeffs


def _poly_eval(coeffs: np.ndarray, x: np.ndarray) -> np.ndarray:
    return np.polynomial.polynomial.polyval(x, coeffs)


def _find_extrema(x: np.ndarray, err: np.ndarray, count: int) -> np.ndarray:
    abs_err = np.abs(err)
    candidates: List[int] = [0]
    for i in range(1, len(x) - 1):
        if abs_err[i] >= abs_err[i - 1] and abs_err[i] >= abs_err[i + 1]:
            candidates.append(i)
    candidates.append(len(x) - 1)

    candidates = sorted(set(candidates), key=lambda i: x[i])
    selected: List[int] = []
    for idx in candidates:
        if not selected:
            selected.append(idx)
            continue
        if np.sign(err[idx]) != np.sign(err[selected[-1]]):
            selected.append(idx)
        if len(selected) >= count:
            break

    if len(selected) < count:
        remaining = [i for i in candidates if i not in selected]
        remaining = sorted(remaining, key=lambda i: abs_err[i], reverse=True)
        selected.extend(remaining[: count - len(selected)])

    selected = sorted(selected[:count], key=lambda i: x[i])
    return x[selected]


def minimax_approximate(
    function: Dict[str, Any],
    interval: Dict[str, float],
    degree: int,
    max_iter: int = 8,
    grid_size: int = 4096,
) -> Dict[str, Any]:
    start = float(interval["start"])
    end = float(interval["end"])
    if start >= end:
        raise ValueError("interval.start must be less than interval.end")

    fun_eval = _build_callable(function)
    x_ext = _initial_points(start, end, degree)

    coeffs = np.zeros(degree + 1, dtype=np.float64)
    for _ in range(max_iter):
        y_ext = np.array([fun_eval(float(val)) for val in x_ext], dtype=np.float64)
        solved = _solve_remez(x_ext, y_ext, degree)
        coeffs = solved[:-1]
        x_grid = np.linspace(start, end, grid_size, dtype=np.float64)
        y_grid = np.array([fun_eval(float(val)) for val in x_grid], dtype=np.float64)
        err = y_grid - _poly_eval(coeffs, x_grid)
        x_new = _find_extrema(x_grid, err, degree + 2)
        if np.allclose(x_new, x_ext, rtol=1e-6, atol=1e-8):
            break
        x_ext = x_new

    x_eval = np.linspace(start, end, max(256, degree * 64), dtype=np.float64)
    y_eval = np.array([fun_eval(float(val)) for val in x_eval], dtype=np.float64)
    y_hat = _poly_eval(coeffs, x_eval)
    abs_err = np.abs(y_hat - y_eval)
    metrics = {
        "max_error": float(np.max(abs_err)),
        "avg_error": float(np.mean(abs_err)),
        "p99_error": float(np.quantile(abs_err, 0.99)),
    }

    return {
        "method": "minimax",
        "coefficients": coeffs.tolist(),
        "polynomial_code": _horner_code(coeffs),
        "metrics": metrics,
        "evaluation_ops": int(degree),
    }
