from __future__ import annotations

import json
import math
import time
from typing import Any, Dict, List, Optional, Tuple

import mpmath
import numpy as np

from python_src.dag import DAG
from python_src.dag_json import dag_from_json
from python_src.precision import (
    normalize_precision_model,
    quantize_array,
    ulp_distance,
)
from python_src.transforms import apply_affine, parse_affine_transform, parse_io_transform
from schemas.approx_spec import ApproxSpec
from worker.spec_loader import build_fun_eval, load_spec, sample_piece_points, select_piece


def _dag_from_json(dag_data: Dict[str, Any]) -> DAG:
    return dag_from_json(dag_data, default_name="dag")


def _load_candidate(candidate: Dict[str, Any] | str) -> Tuple[Dict[str, Any], Optional[str]]:
    if isinstance(candidate, dict):
        return candidate, None
    with open(candidate, "r", encoding="utf-8") as handle:
        return json.load(handle), candidate


def _build_mpmath_eval(spec: ApproxSpec):
    func = spec.target.function
    if func.builtin:
        name = func.builtin.strip().lower()
        alias_map = {
            "exp2": "exp2",
            "log2": "log2",
            "log": "log",
            "ln": "log",
            "besselj0": "bessel_j0",
            "bessel_j0": "bessel_j0",
            "besselj": "bessel_j0",
            "bessely0": "bessel_y0",
            "bessel_y0": "bessel_y0",
            "bessely": "bessel_y0",
        }
        name = alias_map.get(name, name)
        builtin_map = {
            "exp": mpmath.exp,
            "exp2": lambda x: mpmath.power(2, x),
            "log": mpmath.log,
            "log2": lambda x: mpmath.log(x, 2),
            "log1p": mpmath.log1p,
            "sqrt": mpmath.sqrt,
            "sin": mpmath.sin,
            "cos": mpmath.cos,
            "tanh": mpmath.tanh,
            "erf": mpmath.erf,
            "lgamma": mpmath.loggamma,
            "sigmoid": lambda x: 1 / (1 + mpmath.e**(-x)),
            "silu": lambda x: x / (1 + mpmath.e**(-x)),
            "gelu": lambda x: 0.5 * x * (1 + mpmath.erf(x / mpmath.sqrt(2))),
            "quickgelu": lambda x: x / (1 + mpmath.e**(-1.702 * x)),
            "bessel_j0": lambda x: mpmath.besselj(0, x),
            "bessel_y0": lambda x: mpmath.bessely(0, x),
            "sinpi": lambda x: mpmath.sin(mpmath.pi * x),
            "cospi": lambda x: mpmath.cos(mpmath.pi * x),
            "logsinpi": lambda x: mpmath.log(mpmath.sin(mpmath.pi * x)),
        }
        if name not in builtin_map:
            raise ValueError(f"Unsupported builtin function '{func.builtin}'.")
        return builtin_map[name]

    if func.expr:
        try:
            import sympy  # type: ignore
        except ImportError as exc:
            raise ValueError("sympy is required for FunctionSpec.expr.") from exc
        symbols = sympy.symbols(spec.target.vars)
        expr = sympy.sympify(func.expr, convert_xor=True)
        fn = sympy.lambdify(symbols, expr, modules=["mpmath"])
        return lambda x: fn(x)
    raise ValueError("FunctionSpec must define builtin or expr.")


def _metrics_from_errors(
    abs_err: np.ndarray,
    rel_err: np.ndarray,
    ulp_err: np.ndarray,
) -> Dict[str, float]:
    return {
        "max_abs": float(np.nanmax(abs_err)),
        "p99_abs": float(np.nanquantile(abs_err, 0.99)),
        "max_rel": float(np.nanmax(rel_err)),
        "p99_rel": float(np.nanquantile(rel_err, 0.99)),
        "max_ulp": float(np.nanmax(ulp_err)),
        "p99_ulp": float(np.nanquantile(ulp_err, 0.99)),
    }


def _select_metric(metrics: Dict[str, float], spec: ApproxSpec, level: int) -> Tuple[str, float]:
    verify_spec = spec.stop_criteria.verify_pass
    metric_key = verify_spec.metric if verify_spec else spec.metric.type
    threshold = verify_spec.threshold if verify_spec else spec.metric.threshold
    metric_key = metric_key.lower()
    mapping = {
        "ulp": "max_ulp",
        "abs": "max_abs",
        "rel": "max_rel",
        "ulp_max": "max_ulp",
        "ulp_p99": "p99_ulp",
        "abs_max": "max_abs",
        "abs_p99": "p99_abs",
        "rel_max": "max_rel",
        "rel_p99": "p99_rel",
    }
    metric_name = mapping.get(metric_key, "max_rel")
    return metric_name, float(threshold)


def _find_piece_by_id(spec: ApproxSpec, piece_id: str):
    for idx, piece in enumerate(spec.domain.pieces):
        candidate_id = str(piece.piece_id if piece.piece_id is not None else idx)
        if candidate_id == str(piece_id):
            return piece
    return None


def _evaluate_piece_raw(
    piece,
    dag: DAG,
    x_raw: np.ndarray,
    output_format: str,
) -> np.ndarray:
    io_transform = parse_io_transform(piece.transform)
    input_tf = io_transform["input"]
    output_tf = io_transform["output"]

    x_model = apply_affine(x_raw[:, 0], input_tf).reshape(-1, 1)
    x_model = np.asarray(x_model, dtype=dag.np_type)
    y_model = np.asarray(dag.numpy_eval(x_model), dtype=np.float64).reshape(-1)
    if y_model.size == 1 and x_model.shape[0] > 1:
        y_model = np.full(x_model.shape[0], float(y_model[0]), dtype=np.float64)
    elif y_model.size != x_model.shape[0]:
        raise ValueError(
            f"DAG output size mismatch: got {y_model.size}, expected {x_model.shape[0]}."
        )
    y_raw = apply_affine(y_model, output_tf)
    return np.asarray(quantize_array(y_raw, output_format), dtype=np.float64)


def _evaluate_mapped_piece_raw(
    spec: ApproxSpec,
    piece,
    dag: DAG,
    x_raw: np.ndarray,
    output_format: str,
) -> Tuple[np.ndarray, List[str]]:
    failure_modes: List[str] = []
    strategy = piece.strategy or {}
    source_piece_id = strategy.get("source_piece_id")
    if source_piece_id is None:
        raise ValueError("mapped strategy requires source_piece_id.")
    source_piece = _find_piece_by_id(spec, str(source_piece_id))
    if source_piece is None:
        raise ValueError(f"mapped source_piece_id '{source_piece_id}' not found.")

    in_tf = parse_affine_transform(
        strategy.get("input_transform"),
        field_name=f"strategy.input_transform[{piece.piece_id}]",
    )
    out_tf = parse_affine_transform(
        strategy.get("output_transform"),
        field_name=f"strategy.output_transform[{piece.piece_id}]",
    )

    x_source_raw = apply_affine(x_raw[:, 0], in_tf)
    src_interval = source_piece.interval
    if np.any((x_source_raw < float(src_interval.start)) | (x_source_raw > float(src_interval.end))):
        failure_modes.append("mapped_out_of_source_interval")

    y_source_raw = _evaluate_piece_raw(
        source_piece,
        dag,
        x_source_raw.reshape(-1, 1),
        output_format=output_format,
    )
    y_target_raw = apply_affine(y_source_raw, out_tf)
    y_target_raw = np.asarray(quantize_array(y_target_raw, output_format), dtype=np.float64)
    return y_target_raw, failure_modes


def evaluate_candidate(
    spec_path: str,
    candidate: Dict[str, Any] | str,
    piece_id: Optional[str] = None,
    level: int = 1,
    max_points: Optional[int] = None,
    dps: int = 80,
) -> Dict[str, Any]:
    spec = load_spec(spec_path)
    piece = select_piece(spec, piece_id)
    sampling = spec.sampling
    precision_info = normalize_precision_model(spec.precision_model)
    input_format = precision_info["input_format"]
    output_format = precision_info["output_format"]

    if level >= 2:
        mpmath.mp.dps = max(dps, 50)
        n_data = max(sampling.n_data, 2048)
    else:
        n_data = sampling.n_data
    if max_points is not None:
        n_data = min(max_points, n_data)

    sampling_override = sampling.model_copy(deep=True)
    sampling_override.n_data = int(n_data)

    candidate_data, candidate_path = _load_candidate(candidate)
    dag = _dag_from_json(candidate_data)
    dag.gen_code()
    raw_points = sample_piece_points(piece, sampling_override)
    x_vals = np.asarray(quantize_array(raw_points.reshape(-1), input_format), dtype=np.float64)
    x_raw = x_vals.reshape(-1, 1).astype(np.float64)

    strategy = piece.strategy or {}
    mode = str(strategy.get("mode", "search")).lower() if isinstance(strategy, dict) else "search"
    mapped_failure_modes: List[str] = []
    if mode == "mapped":
        pred = np.full_like(x_vals, np.nan, dtype=np.float64)
        try:
            pred, mapped_failure_modes = _evaluate_mapped_piece_raw(
                spec,
                piece,
                dag,
                x_raw,
                output_format=output_format,
            )
        except Exception as exc:
            raise ValueError(f"mapped evaluation failed for piece_id={piece.piece_id}: {exc}") from exc
    else:
        pred = _evaluate_piece_raw(piece, dag, x_raw, output_format=output_format)

    if level >= 2:
        mp_eval = _build_mpmath_eval(spec)
        ref_mp = [mp_eval(float(x)) for x in x_vals]
        ref_true = np.array([float(val) for val in ref_mp], dtype=np.float64)
    else:
        fun_eval = build_fun_eval(spec.target, spec.precision_model)
        ref_true = np.array([fun_eval(float(x)) for x in x_vals], dtype=np.float64)

    pred_final = np.asarray(quantize_array(pred, output_format), dtype=np.float64)
    ref_quantized = np.asarray(quantize_array(ref_true, output_format), dtype=np.float64)

    valid = np.isfinite(pred_final) & np.isfinite(ref_true)
    abs_err = np.full_like(pred_final, np.inf, dtype=np.float64)
    rel_err = np.full_like(pred_final, np.inf, dtype=np.float64)
    ulp_err = np.full_like(pred_final, np.inf, dtype=np.float64)

    if np.any(valid):
        abs_err[valid] = np.abs(pred_final[valid] - ref_true[valid])
        denom = np.abs(ref_true[valid]) + float(spec.metric.denom_eps)
        rel_err[valid] = abs_err[valid] / denom
        ulp_err[valid] = ulp_distance(pred_final[valid], ref_quantized[valid], output_format)

    metrics = _metrics_from_errors(abs_err, rel_err, ulp_err)
    metric_name, threshold = _select_metric(metrics, spec, level)
    metric_value = metrics.get(metric_name, float("inf"))
    passed = bool(metric_value <= threshold)

    interval = piece.interval
    edge_margin = 0.01 * (interval.end - interval.start)
    worst_idx = int(np.nanargmax(abs_err)) if abs_err.size else None
    failure_modes: List[str] = []
    if np.any(~np.isfinite(pred_final)):
        failure_modes.append("nan_or_inf_output")
    if np.any(~np.isfinite(ref_true)):
        failure_modes.append("nan_or_inf_reference")
    failure_modes.extend(mapped_failure_modes)
    if worst_idx is not None:
        if (
            abs(x_vals[worst_idx] - interval.start) <= edge_margin
            or abs(x_vals[worst_idx] - interval.end) <= edge_margin
        ):
            failure_modes.append("edge_spike")

    metric_key = spec.metric.type.lower()
    if "ulp" in metric_name:
        rank_metric = ulp_err
    elif "abs" in metric_name:
        rank_metric = abs_err
    else:
        rank_metric = rel_err

    top_k = min(5, rank_metric.size)
    counterexamples: List[Dict[str, Any]] = []
    if top_k > 0:
        order = np.argsort(rank_metric)[-top_k:][::-1]
        for idx in order:
            counterexamples.append(
                {
                    "x": float(x_vals[idx]),
                    "expected": float(ref_true[idx]) if np.isfinite(ref_true[idx]) else None,
                    "expected_quantized": float(ref_quantized[idx]) if np.isfinite(ref_quantized[idx]) else None,
                    "actual": float(pred_final[idx]) if np.isfinite(pred_final[idx]) else None,
                    "metric": float(rank_metric[idx]) if np.isfinite(rank_metric[idx]) else None,
                }
            )

    return {
        "pass": passed,
        "level": level,
        "metrics": metrics,
        "metric_name": metric_name,
        "threshold": threshold,
        "counterexamples": counterexamples,
        "failure_modes": failure_modes,
        "candidate_path": candidate_path,
        "sample_count": int(n_data),
    }
