from __future__ import annotations
import numpy as np
from dataclasses import dataclass
from typing import Callable, Literal
from .config import ExperimentConfig

AttackNorm = Literal["l2", "linf"]

@dataclass
class EmpiricalEvalResult:
    matrix: np.ndarray  # shape (n_trials, N_points)

def empirical_evaluation(points_radii: np.ndarray, cfg: ExperimentConfig, *, 
                         p_func: Callable[[np.ndarray], float],
                         attack: AttackNorm = "l2") -> EmpiricalEvalResult:
    rng = np.random.default_rng(cfg.seed)
    N = points_radii.shape[0]
    out = np.zeros((cfg.n_trials, N), dtype=float)
    for idx, (x1, x2, r) in enumerate(points_radii):
        base = np.array([x1, x2], dtype=float)
        for trial in range(cfg.n_trials):
            if attack == "l2":
                v = rng.normal(0.0, 1.0, size=2); v /= np.linalg.norm(v) + 1e-12; scale = 1.0
            elif attack == "linf":
                v = rng.normal(0.0, 1.0, size=2); v /= (np.linalg.norm(v, ord=np.inf) + 1e-12); scale = 1/np.sqrt(2)
            else:
                raise ValueError("attack must be 'l2' or 'linf'")
            radius = max(0.0, r) * rng.random() * scale
            adv = radius * v
            successes = 0
            for _ in range(cfg.n_eval_samples):
                noise = rng.normal(0.0, cfg.sigma, size=2)
                x_noisy = base + adv + noise
                p = float(np.clip(p_func(x_noisy), 0.0, 1.0))
                successes += rng.random() < p
            out[trial, idx] = successes / cfg.n_eval_samples
    return EmpiricalEvalResult(matrix=out)
