from __future__ import annotations

import math
from typing import Callable, Dict, List, Optional, Tuple

import mpmath
import numpy as np

from python_src.config import SearchConfig
from python_src.precision import normalize_precision_model, quantize_array
from python_src.spec_io import load_spec as load_spec_from_path
from python_src.transforms import apply_affine, inverse_affine, parse_io_transform
from schemas.approx_spec import (
    ApproxSpec,
    FunctionSpec,
    IntervalSpec,
    PieceSpec,
    PrecisionSpec,
    SamplingSpec,
    TargetSpec,
)


def load_spec(spec_path: str) -> ApproxSpec:
    return load_spec_from_path(spec_path)


def select_piece(spec: ApproxSpec, piece_id: Optional[str]) -> PieceSpec:
    pieces = spec.domain.pieces
    if piece_id is None:
        return pieces[0]
    if isinstance(piece_id, int) and 0 <= piece_id < len(pieces):
        return pieces[piece_id]
    if isinstance(piece_id, str) and piece_id.isdigit():
        idx = int(piece_id)
        if 0 <= idx < len(pieces):
            return pieces[idx]
    for piece in pieces:
        if piece.piece_id == piece_id:
            return piece
    raise ValueError(f"piece_id '{piece_id}' not found in spec domain pieces.")


def _sigmoid(x: float) -> float:
    return 1.0 / (1.0 + math.exp(-x))


def _gelu(x: float) -> float:
    return 0.5 * x * (1.0 + math.erf(x / math.sqrt(2.0)))


def _quick_gelu(x: float) -> float:
    return x * _sigmoid(1.702 * x)


def _bessel_j0(x: float) -> float:
    return float(mpmath.besselj(0, x))


def _bessel_y0(x: float) -> float:
    return float(mpmath.bessely(0, x))


def build_fun_eval(target: TargetSpec, precision: PrecisionSpec) -> Callable[[float], float]:
    if len(target.vars) != 1:
        raise ValueError("Only single-variable functions are supported right now.")
    func = target.function
    if func.builtin:
        name = func.builtin.strip().lower()
        alias_map = {
            "exp2": "exp2",
            "log2": "log2",
            "log": "log",
            "ln": "log",
            "besselj0": "bessel_j0",
            "bessel_j0": "bessel_j0",
            "besselj": "bessel_j0",
            "bessely0": "bessel_y0",
            "bessel_y0": "bessel_y0",
            "bessely": "bessel_y0",
        }
        name = alias_map.get(name, name)
        builtin_map = {
            "exp": math.exp,
            "exp2": lambda x: 2.0**x,
            "log": math.log,
            "log2": lambda x: math.log2(x),
            "log1p": math.log1p,
            "sqrt": math.sqrt,
            "sin": math.sin,
            "cos": math.cos,
            "tanh": math.tanh,
            "sinh": math.sinh,
            "cosh": math.cosh,
            "erf": math.erf,
            "lgamma": math.lgamma,
            "sigmoid": _sigmoid,
            "silu": lambda x: x * _sigmoid(x),
            "gelu": _gelu,
            "quickgelu": _quick_gelu,
            "bessel_j0": _bessel_j0,
            "bessel_y0": _bessel_y0,
            "sinpi": lambda x: math.sin(math.pi * x),
            "cospi": lambda x: math.cos(math.pi * x),
            "logsinpi": lambda x: math.log(math.sin(math.pi * x)),
        }
        if name not in builtin_map:
            raise ValueError(f"Unsupported builtin function '{func.builtin}'.")
        return builtin_map[name]
    if func.expr:
        try:
            import sympy  # type: ignore
        except ImportError as exc:
            raise ValueError("sympy is required for FunctionSpec.expr.") from exc
        symbols = sympy.symbols(target.vars)
        expr = sympy.sympify(func.expr, convert_xor=True)
        module_map = {
            "exp2": lambda x: 2.0**x,
            "log2": math.log2,
            "besselj": mpmath.besselj,
            "bessely": mpmath.bessely,
            "erf": math.erf,
        }
        fn = sympy.lambdify(symbols, expr, modules=[module_map, "math"])

        def _eval_expr(x: float) -> float:
            return float(fn(x))

        return _eval_expr
    raise ValueError("FunctionSpec must define builtin or expr.")


def _sample_chebyshev(start: float, end: float, n: int) -> np.ndarray:
    if n <= 0:
        return np.array([], dtype=np.float64)
    k = np.arange(1, n + 1, dtype=np.float64)
    nodes = np.cos((2.0 * k - 1.0) / (2.0 * n) * math.pi)
    return 0.5 * (start + end) + 0.5 * (end - start) * nodes


def _sample_lhs(start: float, end: float, n: int, rng: np.random.Generator) -> np.ndarray:
    if n <= 0:
        return np.array([], dtype=np.float64)
    bins = np.linspace(0.0, 1.0, n + 1, dtype=np.float64)
    u = rng.uniform(bins[:-1], bins[1:])
    rng.shuffle(u)
    return start + (end - start) * u


def _sample_base(
    start: float,
    end: float,
    n: int,
    mode: str,
    rng: np.random.Generator,
) -> np.ndarray:
    if n <= 0:
        return np.array([], dtype=np.float64)
    mode = mode.lower()
    if mode == "uniform":
        return rng.uniform(start, end, size=n)
    if mode == "chebyshev":
        return _sample_chebyshev(start, end, n)
    if mode == "lhs":
        return _sample_lhs(start, end, n, rng)
    if mode == "hybrid":
        n_cheb = n // 2
        n_uni = n - n_cheb
        cheb = _sample_chebyshev(start, end, n_cheb)
        uni = rng.uniform(start, end, size=n_uni)
        return np.concatenate([cheb, uni])
    raise ValueError(f"Unsupported sampling mode '{mode}'.")


def _filter_excluded(samples: np.ndarray, excluded: List[float], tol: float) -> np.ndarray:
    if not excluded:
        return samples
    mask = np.ones(samples.shape[0], dtype=bool)
    for value in excluded:
        mask &= np.abs(samples - value) > tol
    return samples[mask]


def build_training_data(
    piece: PieceSpec,
    sampling: SamplingSpec,
    fun_eval: Callable[[float], float],
    precision: Optional[PrecisionSpec] = None,
) -> Tuple[np.ndarray, np.ndarray]:
    precision_info = normalize_precision_model(precision)
    io_transform = parse_io_transform(piece.transform)
    input_tf = io_transform["input"]
    output_tf = io_transform["output"]
    output_inv_tf = inverse_affine(output_tf, field_name="transform.output")

    selected = sample_piece_points(piece, sampling)
    selected_input = quantize_array(selected, precision_info["input_format"])
    model_x = apply_affine(selected_input, input_tf)
    model_x = quantize_array(model_x, precision_info["compute_format"])
    raw_y = np.array([fun_eval(float(x)) for x in selected_input], dtype=np.float64)
    model_y = apply_affine(raw_y, output_inv_tf)
    model_y = quantize_array(model_y, precision_info["output_format"])

    train_x = quantize_array(model_x.reshape(-1, 1), precision_info["compute_format"])
    train_y = quantize_array(model_y, precision_info["output_format"])
    return train_x, train_y


def sample_piece_points(piece: PieceSpec, sampling: SamplingSpec) -> np.ndarray:
    interval = piece.interval
    start = float(interval.start)
    end = float(interval.end)
    if start >= end:
        raise ValueError("Invalid interval after normalization.")

    rng = np.random.default_rng(sampling.seed)
    samples: List[float] = []
    for point in sampling.focus_points:
        if start <= point <= end:
            samples.append(float(point))

    remaining = sampling.n_data - len(samples)
    if remaining <= 0:
        selected = np.array(samples[: sampling.n_data], dtype=np.float64)
    else:
        edge_count = 0
        if sampling.edge_focus.enabled and sampling.edge_focus.ratio > 0:
            edge_count = int(remaining * sampling.edge_focus.ratio)
        edge_count = min(edge_count, remaining)
        base_count = remaining - edge_count

        base_samples = _sample_base(start, end, base_count, sampling.mode, rng)

        edge_samples = np.array([], dtype=np.float64)
        if edge_count > 0:
            edge_left = edge_count // 2
            edge_right = edge_count - edge_left
            span = max((end - start) * 0.01, sampling.focus_radius)
            span = min(span, (end - start) * 0.5)
            span = max(span, np.finfo(np.float64).eps)
            left = rng.uniform(start, min(start + span, end), size=edge_left)
            right = rng.uniform(max(end - span, start), end, size=edge_right)
            edge_samples = np.concatenate([left, right])

        selected = np.concatenate(
            [np.array(samples, dtype=np.float64), edge_samples, base_samples]
        )

    exclude_tol = max(np.finfo(np.float64).eps, 1e-12 * max(1.0, abs(end - start)))
    selected = _filter_excluded(selected, piece.excluded_points, exclude_tol)
    attempts = 0
    while selected.shape[0] < sampling.n_data and attempts < 5:
        extra = _sample_base(start, end, sampling.n_data - selected.shape[0], "uniform", rng)
        extra = _filter_excluded(extra, piece.excluded_points, exclude_tol)
        selected = np.concatenate([selected, extra])
        attempts += 1
    if selected.shape[0] < sampling.n_data:
        raise ValueError("Unable to generate enough samples after exclusions.")

    return selected[: sampling.n_data]


def get_search_config(spec: ApproxSpec) -> SearchConfig:
    return SearchConfig.from_spec(spec.search_config, precision_model=spec.precision_model)


# Default function intervals for the legacy --fun workflow.
_LEGACY_FUNCTION_DEFAULTS: Dict[str, Tuple[str, float, float]] = {
    # (builtin_or_expr, default_start, default_end)
    "exp": ("exp", 0.0, 1.0),
    "exp2": ("exp2", 0.0, 1.0),
    "log": ("log", 1.0, 2.0),
    "log2": ("log2", 1.0, 2.0),
    "sigmoid": ("sigmoid", 0.0, 1.0),
    "lgamma": ("lgamma", 0.7, 1.5),
    "sin": ("sin", 0.0, 1.5708),  # pi/2
    "sinpi": ("sinpi", 0.0, 0.5),
    "logsinpi": ("logsinpi", 0.0, 0.5),
    "quickgelu": ("quickgelu", -5.0, 5.0),
    "gelu": ("gelu", -5.0, 5.0),
    "silu": ("silu", -5.0, 5.0),
    "tanh": ("tanh", -3.0, 3.0),
    "sinh": ("sinh", 0.0, 1.0),
    "cosh": ("cosh", 0.0, 1.0),
    "erf": ("erf", -3.0, 3.0),
    "sqrt": ("sqrt", 0.0, 1.0),
    "cos": ("cos", 0.0, 1.5708),
    "cospi": ("cospi", 0.0, 0.5),
    # Bessel functions use expr so arbitrary orders can be represented.
    "besselj": ("besselj(0, x)", 2.3, 2.5),
    "bessely": ("bessely(0, x)", 0.0, 5.0),
}


def build_from_legacy_params(
    fun: str,
    start: Optional[float] = None,
    end: Optional[float] = None,
    fun_idx: Optional[int] = None,
    n_data: int = 5000,
    seed: int = 1234567,
    eps: float = 1e-6,
) -> Tuple[Callable[[float], float], np.ndarray, np.ndarray, float]:
    """Build a function evaluator and training data from legacy arguments.

    This keeps the older --fun workflow backward-compatible.

    Args:
        fun: Function name, for example exp, log, sigmoid, or besselj.
        start: Interval start, or None to use the legacy default.
        end: Interval end, or None to use the legacy default.
        fun_idx: Function index/order for besselj and bessely.
        n_data: Number of training samples.
        seed: Random seed.
        eps: Optimization tolerance.

    Returns:
        Tuple of (fun_eval, train_X, train_Y, eps).
    """
    fun_lower = fun.strip().lower()

    # Handle Bessel function order.
    if fun_lower in ("besselj", "bessely") and fun_idx is not None:
        # Use expr to support arbitrary orders.
        expr = f"{fun_lower}({fun_idx}, x)"
        builtin_name = None
        # Use the default interval.
        _, default_start, default_end = _LEGACY_FUNCTION_DEFAULTS.get(fun_lower, (None, 0.0, 5.0))
        if start is None:
            start = default_start
        if end is None:
            end = default_end
    elif fun_lower in _LEGACY_FUNCTION_DEFAULTS:
        builtin_or_expr, default_start, default_end = _LEGACY_FUNCTION_DEFAULTS[fun_lower]
        if start is None:
            start = default_start
        if end is None:
            end = default_end
        # Distinguish builtin names from expression strings.
        if "(" in builtin_or_expr:
            expr = builtin_or_expr
            builtin_name = None
        else:
            builtin_name = builtin_or_expr
            expr = None
    else:
        raise ValueError(f"Unsupported legacy function: {fun}")

    # Build TargetSpec.
    if builtin_name:
        target = TargetSpec(
            name=f"legacy_{fun}",
            function=FunctionSpec(builtin=builtin_name),
            vars=["x"],
        )
    else:
        target = TargetSpec(
            name=f"legacy_{fun}",
            function=FunctionSpec(expr=expr),
            vars=["x"],
        )

    # Build PieceSpec.
    piece = PieceSpec(
        piece_id="main",
        interval=IntervalSpec(start=start, end=end),
    )

    # Build SamplingSpec.
    sampling = SamplingSpec(n_data=n_data, seed=seed)

    # Default precision configuration.
    precision = PrecisionSpec()

    # Build the evaluator and training data.
    fun_eval = build_fun_eval(target, precision)
    train_X, train_Y = build_training_data(piece, sampling, fun_eval, precision=precision)

    return fun_eval, train_X, train_Y, eps
