from __future__ import annotations

from typing import Dict, Sequence, Tuple
import numpy as np

from .safety_margin import bound_distance_sq


def validate_safety_bound(
    distances: Sequence[float],
    delta_q_min: float,
    gamma: float,
    speeds: Sequence[float],
    L_grad: float,
    pred_err: float,
    lag_coeff: float,
    vmax: float,
    tau: float,
    approx_err: float = 0.0,
) -> Dict[str, float]:
    preds = [np.sqrt(bound_distance_sq(delta_q_min, gamma, s, L_grad, pred_err, lag_coeff, vmax, tau, approx_err)) for s in speeds]
    hits = sum(1 for d, p in zip(distances, preds) if d >= p)
    return {
        "bound/hit_rate": hits / max(1, len(distances)),
        "bound/avg_pred": float(np.mean(preds) if preds else 0.0),
        "bound/avg_true": float(np.mean(distances) if distances else 0.0),
    }


def calibration_curve(kappas: Sequence[float], failures: Sequence[int], bins: int = 10) -> Tuple[np.ndarray, np.ndarray]:
    kap = np.array(kappas)
    fail = np.array(failures)
    qs = np.quantile(kap, np.linspace(0, 1, bins + 1))
    xs, ys = [], []
    for i in range(bins):
        mask = (kap >= qs[i]) & (kap < qs[i + 1])
        xs.append((qs[i] + qs[i + 1]) * 0.5)
        ys.append(fail[mask].mean() if mask.sum() > 0 else 0.0)
    return np.array(xs), np.array(ys)


def _demo():
    d = [1.0, 1.2, 0.9, 1.5]
    s = [0.5, 0.6, 0.4, 0.7]
    print(validate_safety_bound(d, 1.0, 0.99, s, 1.0, 0.1, 1.0, 1.0, 0.05, 0.0))


if __name__ == "__main__":
    _demo()
                
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
                
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
     
