"""
Neutral Evaluator for Fair Comparison of Certification Methods

This module implements a neutral evaluator that provides an apples-to-apples
comparison between different certification methods by evaluating their proposed
radii against a common empirical success criterion.
"""

import numpy as np
from scipy.stats import beta
from typing import Callable, Tuple
from .functions import regression_function


def estimate_g_mc(model_or_fn: Callable, z: np.ndarray, sigma: float, N_samples: int, seed: int = 42) -> float:
    """
    Estimate g(z) = E[f(z + eta)] using Monte Carlo sampling.

    Args:
        model_or_fn: The regression function.
        z: Input point (2D array).
        sigma: Noise standard deviation.
        N_samples: Number of Monte Carlo samples.
        seed: Random seed for reproducibility.

    Returns:
        Estimated value of g(z).
    """
    rng = np.random.default_rng(seed)
    eta_samples = rng.normal(0.0, sigma, size=(N_samples, z.shape[-1]))
    vals = np.array([model_or_fn((z + eta_i)[0], (z + eta_i)[1]) for eta_i in eta_samples])
    return float(np.mean(vals))


def eval_expectation_harm(model_or_fn: Callable, z: np.ndarray, r: float, eps_y: float,
                          sigma: float, N_g_samples: int, seed: int = 42) -> float:
    """
    Evaluate the worst-case harm |g(z+delta) - g(z)| for a given radius r,
    where g(x) = E[f(x+eta)].

    Args:
        model_or_fn: The regression function.
        z: Input point (2D array).
        r: Trial radius.
        eps_y: Output tolerance (not directly used in harm calculation, but for context).
        sigma: Noise standard deviation.
        N_g_samples: Number of Monte Carlo samples for estimating g.
        seed: Random seed for reproducibility.

    Returns:
        The estimated worst-case harm.
    """
    if r <= 0:
        return 0.0

    rng = np.random.default_rng(seed)
    eps_grad = 1e-4  # Small perturbation for numerical gradient

    # Estimate g(z)
    g_z = estimate_g_mc(model_or_fn, z, sigma, N_g_samples, seed=seed)

    # Estimate gradient of g(z) numerically
    # g(z + [eps_grad, 0])
    g_z_plus_x1 = estimate_g_mc(model_or_fn, z + np.array([eps_grad, 0.0]), sigma, N_g_samples, seed=seed + 1)
    # g(z + [0, eps_grad])
    g_z_plus_x2 = estimate_g_mc(model_or_fn, z + np.array([0.0, eps_grad]), sigma, N_g_samples, seed=seed + 2)

    grad_g_x1 = (g_z_plus_x1 - g_z) / eps_grad
    grad_g_x2 = (g_z_plus_x2 - g_z) / eps_grad
    grad_g = np.array([grad_g_x1, grad_g_x2])

    # Worst-case delta is in the direction of the gradient
    grad_norm = np.linalg.norm(grad_g)
    if grad_norm < 1e-10:  # Avoid division by zero if gradient is zero
        return 0.0

    delta_worst = r * (grad_g / grad_norm)

    # Estimate g(z + delta_worst)
    g_z_plus_delta_worst = estimate_g_mc(model_or_fn, z + delta_worst, sigma, N_g_samples, seed=seed + 3)

    return float(np.abs(g_z_plus_delta_worst - g_z))


def oracle_radius_expectation(z: np.ndarray, eps_y: float, model_or_fn: Callable,
                              sigma: float, N_g_samples: int = 10000,
                              r_max: float = 1.0, seed: int = 42, tol: float = 1e-4) -> float:
    """
    Compute the oracle radius r*exp - the largest radius that achieves the
    expectation guarantee |g(z+delta) - g(z)| <= eps_y.

    Args:
        z: Input point (2D array).
        eps_y: Output tolerance.
        model_or_fn: The regression function.
        sigma: Noise standard deviation.
        N_g_samples: Number of Monte Carlo samples for estimating g.
        r_max: Maximum radius to search up to.
        seed: Random seed for reproducibility.
        tol: Tolerance for bisection search.

    Returns:
        The oracle radius r*exp.
    """
    r_low = 0.0
    r_high = r_max

    # Check if r=0 already fails
    if eval_expectation_harm(model_or_fn, z, 0.0, eps_y, sigma, N_g_samples, seed) > eps_y:
        return 0.0

    # Find an upper bound for r_high if initial r_max is too small
    while eval_expectation_harm(model_or_fn, z, r_high, eps_y, sigma, N_g_samples, seed) < eps_y:
        r_high *= 2.0
        if r_high > 10.0:  # Cap search to prevent infinite loop for very large radii
            break

    # Bisection search
    for _ in range(30):  # 30 iterations should be sufficient for 1e-9 precision on [0, 10]
        r_mid = 0.5 * (r_low + r_high)
        if (r_high - r_low) < tol:
            break

        harm = eval_expectation_harm(model_or_fn, z, r_mid, eps_y, sigma, N_g_samples, seed)

        if harm <= eps_y:
            r_low = r_mid  # r_mid is valid, try larger
        else:
            r_high = r_mid  # r_mid is too large, try smaller

    return r_low


def check_expectation_soundness(r_cert: float, z: np.ndarray, eps_y: float, model_or_fn: Callable,
                                sigma: float, N_g_samples: int = 10000, seed: int = 42) -> bool:
    """
    Checks if a certified radius r_cert passes the expectation guarantee.

    Args:
        r_cert: The certified radius from the method.
        z: Input point (2D array).
        eps_y: Output tolerance.
        model_or_fn: The regression function.
        sigma: Noise standard deviation.
        N_g_samples: Number of Monte Carlo samples for estimating g.
        seed: Random seed for reproducibility.

    Returns:
        True if r_cert is sound according to the expectation evaluator, False otherwise.
    """
    if r_cert <= 0:
        return True  # A zero radius is always sound

    harm = eval_expectation_harm(model_or_fn, z, r_cert, eps_y, sigma, N_g_samples, seed)
    return harm <= eps_y


def cp_lower_bound(k: int, n: int, alpha: float = 0.001) -> float:
    """
    Implement Clopper-Pearson lower bound (beta quantile).
    
    Args:
        k: Number of successes
        n: Number of trials
        alpha: Significance level for CP (not trimming)
    
    Returns:
        Lower bound on success probability
    """
    if k == 0:
        return 0.0
    return float(beta.ppf(alpha/2, k, n - k + 1))


def eval_pA_LB_random_dirs(model_or_fn: Callable, z: np.ndarray, r: float, center: str, 
                          eps_y: float, sigma: float, N_eval: int, seed: int = 42,
                          alpha_trim: float = 0.35, cp_alpha: float = 1e-3, 
                          n_dirs: int = 10) -> float:
    """
    Evaluate the lower bound of success probability p_A for a given radius using random directions.
    
    This tests the paper's guarantee more fairly by sampling random directions rather than
    always using the worst-case gradient direction.
    
    Args:
        model_or_fn: The regression function (e.g., regression_function)
        z: Input point (2D array)
        r: Trial radius
        center: Center choice ('y', 'g', or 'g_alpha')
        eps_y: Output tolerance
        sigma: Noise standard deviation
        N_eval: Number of evaluation samples
        seed: Random seed
        alpha_trim: Alpha parameter for trimmed mean (when center='g_alpha')
        cp_alpha: Alpha parameter for Clopper-Pearson lower bound
        n_dirs: Number of random directions to test
    
    Returns:
        Lower bound on success probability p_A
    """
    if r <= 0:
        return 1.0  # Zero radius always succeeds
    
    rng = np.random.default_rng(seed)
    
    # Determine the target center
    if center == 'y':
        target = float(model_or_fn(z[0], z[1]))
    elif center == 'g':
        n_pred = min(1000, N_eval // 10)
        xi_pred = rng.normal(0.0, sigma, size=(n_pred, z.shape[-1]))
        target = float(np.mean([model_or_fn((z + xi_i)[0], (z + xi_i)[1]) for xi_i in xi_pred]))
    elif center == 'g_alpha':
        n_pred = min(1000, N_eval // 10)
        xi_pred = rng.normal(0.0, sigma, size=(n_pred, z.shape[-1]))
        vals = np.array([model_or_fn((z + xi_i)[0], (z + xi_i)[1]) for xi_i in xi_pred])
        n_vals = len(vals)
        k = int(np.floor(alpha_trim * n_vals))
        if k <= 0:
            target = float(np.mean(vals))
        else:
            vals_sorted = np.sort(vals)
            target = float(np.mean(vals_sorted[k:n_vals-k])) if (n_vals - 2*k) > 0 else float(np.mean(vals_sorted))
    else:
        raise ValueError(f"Unknown center: {center}")
    
    # Test multiple random directions and take the minimum (most challenging)
    min_pA_lb = 1.0
    
    for dir_idx in range(n_dirs):
        # Sample random direction on unit circle
        angle = rng.uniform(0, 2 * np.pi)
        delta_direction = np.array([np.cos(angle), np.sin(angle)])
        
        # Set adversarial perturbation: δ = r * direction
        delta = r * delta_direction
        
        # Evaluate success probability at this direction
        xi_eval = rng.normal(0.0, sigma, size=(N_eval, z.shape[-1]))
        success_count = 0
        
        for xi_i in xi_eval:
            # Evaluate f(z + δ + η)
            val = model_or_fn((z + delta + xi_i)[0], (z + delta + xi_i)[1])
            if np.abs(val - target) <= eps_y:
                success_count += 1
        
        # Clopper-Pearson lower bound for this direction
        pA_lb = cp_lower_bound(success_count, N_eval, cp_alpha)
        min_pA_lb = min(min_pA_lb, pA_lb)
    
    return min_pA_lb


def eval_pA_LB(model_or_fn: Callable, z: np.ndarray, r: float, center: str, 
               eps_y: float, sigma: float, N_eval: int, seed: int = 42,
               alpha_trim: float = 0.35, cp_alpha: float = 1e-3) -> float:
    """
    Evaluate the lower bound of success probability p_A for a given radius using worst-case directions.
    
    This implements worst-case robustness evaluation by choosing the gradient direction
    for the adversarial perturbation δ, rather than random directions which only test
    average-case robustness. This provides a much stronger test aligned with the spirit
    of worst-case robustness guarantees.
    
    Args:
        model_or_fn: The regression function (e.g., regression_function)
        z: Input point (2D array)
        r: Trial radius
        center: Center choice ('y', 'g', or 'g_alpha')
        eps_y: Output tolerance
        sigma: Noise standard deviation
        N_eval: Number of evaluation samples
        seed: Random seed for reproducibility
        alpha_trim: Alpha parameter for trimmed mean (when center='g_alpha')
        cp_alpha: Alpha parameter for Clopper-Pearson lower bound
    
    Returns:
        Clopper-Pearson lower bound on success probability
    """
    rng = np.random.default_rng(seed)
    
    # Determine the target center
    if center == 'y':
        # Ground truth center
        target = float(model_or_fn(z[0], z[1]))
    elif center == 'g':
        # Monte Carlo smoothed center (using same noise as evaluation)
        n_pred = min(1000, N_eval // 10)  # Use subset for center estimation
        xi_pred = rng.normal(0.0, sigma, size=(n_pred, z.shape[-1]))
        target = float(np.mean([model_or_fn((z + xi_i)[0], (z + xi_i)[1]) for xi_i in xi_pred]))
    elif center == 'g_alpha':
        # Alpha-smoothed center (using trimmed mean)
        n_pred = min(1000, N_eval // 10)
        # Use parameterized alpha_trim
        xi_pred = rng.normal(0.0, sigma, size=(n_pred, z.shape[-1]))
        vals = np.array([model_or_fn((z + xi_i)[0], (z + xi_i)[1]) for xi_i in xi_pred])
        # Trimmed mean - trim alpha_trim on each side (total trim = 2*alpha_trim)
        n_vals = len(vals)
        k = int(np.floor(alpha_trim * n_vals))  # alpha_trim on each side
        if k <= 0:
            target = float(np.mean(vals))
        else:
            vals_sorted = np.sort(vals)
            target = float(np.mean(vals_sorted[k:n_vals-k])) if (n_vals - 2*k) > 0 else float(np.mean(vals_sorted))
    else:
        raise ValueError(f"Unknown center: {center}")
    
    # Choose worst-case direction for adversarial perturbation δ
    # This tests worst-case robustness instead of average-case robustness
    
    # Estimate gradient direction (ideal choice for worst-case)
    eps_grad = 1e-4
    grad_x1 = (model_or_fn(z[0] + eps_grad, z[1]) - model_or_fn(z[0] - eps_grad, z[1])) / (2 * eps_grad)
    grad_x2 = (model_or_fn(z[0], z[1] + eps_grad) - model_or_fn(z[0], z[1] - eps_grad)) / (2 * eps_grad)
    grad = np.array([grad_x1, grad_x2])
    
    # Normalize gradient to unit direction
    grad_norm = np.linalg.norm(grad)
    if grad_norm > 1e-12:
        delta_direction = grad / grad_norm
    else:
        # Fallback to canonical direction if gradient is too small
        delta_direction = np.array([1.0, 0.0])
    
    # Set adversarial perturbation: δ = r * direction
    delta = r * delta_direction
    
    # Evaluate success probability at this worst-case direction
    success_count = 0
    
    for _ in range(N_eval):
        # Add Gaussian noise: eta ~ N(0, sigma^2 I)
        eta = rng.normal(0.0, sigma, size=2)
        
        # Compute perturbed input: z + δ + eta
        z_perturbed = z + delta + eta
        
        # Evaluate function at perturbed input
        f_perturbed = float(model_or_fn(z_perturbed[0], z_perturbed[1]))
        
        # Check success condition: |f(z+δ+eta) - target| <= eps_y
        if abs(f_perturbed - target) <= eps_y:
            success_count += 1
    
    # Return Clopper-Pearson lower bound
    return cp_lower_bound(success_count, N_eval, cp_alpha)


def eval_pA_LB_worstdir(model_or_fn: Callable, z: np.ndarray, r: float, center: str, 
                       eps_y: float, sigma: float, N_eval: int, M_dirs: int = 16, 
                       seed: int = 42, alpha_trim: float = 0.35, cp_alpha: float = 1e-3) -> float:
    """
    Evaluate worst-case success probability over directions on the sphere.
    
    This provides a closer approximation to the worst-case guarantee by sampling
    M directions on the sphere and taking the minimum CP lower bound.
    
    Args:
        model_or_fn: The regression function
        z: Input point (2D array)
        r: Trial radius
        center: Center choice ('y', 'g', or 'g_alpha')
        eps_y: Output tolerance
        sigma: Noise standard deviation
        N_eval: Number of evaluation samples
        M_dirs: Number of directions to sample on sphere
        seed: Random seed
    
    Returns:
        Minimum CP lower bound over all directions
    """
    rng = np.random.default_rng(seed)
    
    # Determine the target center (same logic as eval_pA_LB)
    if center == 'y':
        target = float(model_or_fn(z[0], z[1]))
    elif center == 'g':
        n_pred = min(1000, N_eval // 10)
        xi_pred = rng.normal(0.0, sigma, size=(n_pred, z.shape[-1]))
        target = float(np.mean([model_or_fn((z + xi_i)[0], (z + xi_i)[1]) for xi_i in xi_pred]))
    elif center == 'g_alpha':
        n_pred = min(1000, N_eval // 10)
        # Use parameterized alpha_trim
        xi_pred = rng.normal(0.0, sigma, size=(n_pred, z.shape[-1]))
        vals = np.array([model_or_fn((z + xi_i)[0], (z + xi_i)[1]) for xi_i in xi_pred])
        n_vals = len(vals)
        k = int(np.floor(alpha_trim * n_vals))  # alpha_trim on each side
        if k <= 0:
            target = float(np.mean(vals))
        else:
            vals_sorted = np.sort(vals)
            target = float(np.mean(vals_sorted[k:n_vals-k])) if (n_vals - 2*k) > 0 else float(np.mean(vals_sorted))
    else:
        raise ValueError(f"Unknown center: {center}")
    
    worst = 1.0
    
    for m in range(M_dirs):
        # Choose challenging directions (not random!)
        if m == 0:
            # First direction: gradient direction (worst-case)
            eps_grad = 1e-4
            grad_x1 = (model_or_fn(z[0] + eps_grad, z[1]) - model_or_fn(z[0] - eps_grad, z[1])) / (2 * eps_grad)
            grad_x2 = (model_or_fn(z[0], z[1] + eps_grad) - model_or_fn(z[0], z[1] - eps_grad)) / (2 * eps_grad)
            grad = np.array([grad_x1, grad_x2])
            grad_norm = np.linalg.norm(grad)
            if grad_norm > 1e-12:
                delta_direction = grad / grad_norm
            else:
                delta_direction = np.array([1.0, 0.0])
        elif m == 1:
            # Second direction: negative gradient direction
            eps_grad = 1e-4
            grad_x1 = (model_or_fn(z[0] + eps_grad, z[1]) - model_or_fn(z[0] - eps_grad, z[1])) / (2 * eps_grad)
            grad_x2 = (model_or_fn(z[0], z[1] + eps_grad) - model_or_fn(z[0], z[1] - eps_grad)) / (2 * eps_grad)
            grad = np.array([grad_x1, grad_x2])
            grad_norm = np.linalg.norm(grad)
            if grad_norm > 1e-12:
                delta_direction = -grad / grad_norm
            else:
                delta_direction = np.array([-1.0, 0.0])
        else:
            # Other directions: canonical challenging directions
            canonical_dirs = [
                np.array([1.0, 0.0]),
                np.array([0.0, 1.0]),
                np.array([1.0, 1.0]) / np.sqrt(2),
                np.array([1.0, -1.0]) / np.sqrt(2)
            ]
            delta_direction = canonical_dirs[(m - 2) % len(canonical_dirs)]
        
        delta = r * delta_direction
        
        # Sample eta ~ N(0, sigma^2 I), compute successes at z+delta+eta
        k = 0
        for _ in range(N_eval // M_dirs):
            eta = rng.normal(0.0, sigma, size=2)
            fval = model_or_fn(*(z + delta + eta))
            k += (abs(fval - target) <= eps_y)
        
        # Take minimum over directions
        worst = min(worst, cp_lower_bound(k, N_eval // M_dirs, cp_alpha))
    
    return worst


def oracle_radius(z: np.ndarray, P: float, model_or_fn: Callable, center: str, 
                  eps_y: float, sigma: float, N_eval: int = 100000, 
                  r_max: float = 1.0, seed: int = 42, alpha_trim: float = 0.35, 
                  cp_alpha: float = 1e-3) -> float:
    """
    Compute the oracle radius r*(P) - the largest radius that achieves target probability P.
    
    This is the method-independent "true" maximum radius that can be achieved
    under the neutral evaluator's success criterion.
    
    Args:
        z: Input point (2D array)
        P: Target probability
        model_or_fn: The regression function
        center: Center choice ('y', 'g', or 'g_alpha')
        eps_y: Output tolerance
        sigma: Noise standard deviation
        N_eval: Number of evaluation samples
        r_max: Initial upper bound for search (from doubling search or safe upper bound)
        seed: Random seed
        alpha_trim: Alpha parameter for trimmed mean (when center='g_alpha')
        cp_alpha: Alpha parameter for Clopper-Pearson lower bound
    
    Returns:
        Oracle radius r*(P) - the supremum of all radii that satisfy the success criterion
    """
    # Bisection search to find the largest radius that achieves P
    lo, hi = 0.0, r_max
    
    # First, ensure hi is large enough
    while eval_pA_LB(model_or_fn, z, hi, center, eps_y, sigma, N_eval, seed, alpha_trim, cp_alpha) >= P:
        hi *= 2.0
        if hi > 10.0:  # Safety check
            break
    
    # Bisection search
    for _ in range(25):  # 25 iterations should be sufficient
        mid = 0.5 * (lo + hi)
        pA_lb = eval_pA_LB(model_or_fn, z, mid, center, eps_y, sigma, N_eval, seed, alpha_trim, cp_alpha)
        
        if pA_lb >= P:
            lo = mid  # mid is valid, try larger
        else:
            hi = mid  # mid is too large, try smaller
    
    return lo


def neutral_radius_at_P(get_cert_r: Callable, z: np.ndarray, P: float, 
                       model_or_fn: Callable, center: str, eps_y: float, 
                       sigma: float, N_eval: int = 100000, seed: int = 42,
                       alpha_trim: float = 0.35, cp_alpha: float = 1e-3,
                       use_worst_case: bool = False, M_dirs: int = 16) -> float:
    """
    Find the neutral radius that achieves target probability P.
    
    This function takes a method's proposed radius and adjusts it (if necessary)
    to ensure it actually achieves the target success probability P when evaluated
    by the neutral evaluator.
    
    Args:
        get_cert_r: Function that returns certified radius for point z
        z: Input point
        P: Target success probability
        model_or_fn: Regression function
        center: Center choice for evaluation
        eps_y: Output tolerance
        sigma: Noise standard deviation
        N_eval: Number of evaluation samples
        seed: Random seed
    
    Returns:
        Adjusted radius that achieves target probability P
    """
    # Get the method's proposed radius
    r0 = get_cert_r(z)
    
    # Choose evaluation function
    if use_worst_case:
        eval_func = lambda r: eval_pA_LB_worstdir(model_or_fn, z, r, center, eps_y, sigma, N_eval, M_dirs, seed, alpha_trim, cp_alpha)
    else:
        eval_func = lambda r: eval_pA_LB(model_or_fn, z, r, center, eps_y, sigma, N_eval, seed, alpha_trim, cp_alpha)
    
    # Check if the proposed radius already meets the target
    if eval_func(r0) >= P:
        return r0
    
    # If not, use bisection to find the largest radius that meets the target
    lo, hi = 0.0, r0
    
    for _ in range(50):  #50 iterations should be sufficient
        mid = 0.5 * (lo + hi)
        pA_lb = eval_func(mid)
        
        if pA_lb >= P:
            lo = mid  # mid is valid, try larger
        else:
            hi = mid  # mid is too large, try smaller
    
    return hi


def evaluate_soundness_tightness(method_certifier, points: np.ndarray, P: float,
                                model_or_fn: Callable, center: str, eps_y: float, 
                                sigma: float, N_eval: int = 100000, seed: int = 42,
                                alpha_trim: float = 0.35, cp_alpha: float = 1e-3) -> dict:
    """
    Evaluate soundness and tightness metrics for a certification method.
    
    This implements the methodology described in the collaborator screenshots:
    - Soundness: How often does the method's original claim pass without modification?
    - Tightness: How close are the empirically validated radii to the oracle?
    
    Args:
        method_certifier: The certification method to evaluate
        points: Array of input points to evaluate
        P: Target probability
        model_or_fn: The regression function
        center: Center choice ('y', 'g', or 'g_alpha')
        eps_y: Output tolerance
        sigma: Noise standard deviation
        N_eval: Number of evaluation samples
        seed: Random seed
        alpha_trim: Alpha parameter for trimmed mean
        cp_alpha: Alpha parameter for Clopper-Pearson lower bound
    
    Returns:
        Dictionary containing:
        - r_cert: Method's original certified radii
        - r_tilde: Neutral evaluator's validated radii
        - r_oracle: Oracle radii (method-independent)
        - pass_rate: Fraction where r_cert <= r_oracle
        - tightness: Median of r_tilde / r_oracle
        - over_claim_rate: Fraction where r_cert > r_oracle
        - under_claim_rate: Fraction where r_cert < r_oracle
        - soundness: Fraction where method's claim passes without modification
    """
    n_points = len(points)
    r_cert = np.zeros(n_points)
    r_tilde = np.zeros(n_points)
    r_oracle = np.zeros(n_points)
    
    print(f"Evaluating {method_certifier.name} for soundness and tightness...")
    
    for i, z in enumerate(points):
        if i % 5 == 0:
            print(f"Evaluating point {i+1}/{n_points}...")
        
        # Get method's original certified radius
        if center == 'y':
            # For center='y', we need to provide the ground truth
            y_true = float(model_or_fn(z[0], z[1]))
            # Check if the certifier accepts y parameter
            import inspect
            sig = inspect.signature(method_certifier.certify_point)
            if 'y' in sig.parameters:
                r_cert[i] = method_certifier.certify_point(z, y=y_true)
            else:
                # For certifiers that don't use y (like MyCertifier), just call normally
                r_cert[i] = method_certifier.certify_point(z)
        else:
            r_cert[i] = method_certifier.certify_point(z)
        
        # Get oracle radius (method-independent)
        r_oracle[i] = oracle_radius(z, P, model_or_fn, center, eps_y, sigma, 
                                   N_eval, r_max=max(1.0, r_cert[i] * 2), 
                                   seed=seed, alpha_trim=alpha_trim, cp_alpha=cp_alpha)
        
        # Get neutral evaluator's validated radius
        def get_cert_r(z):
            if center == 'y':
                y_true = float(model_or_fn(z[0], z[1]))
                import inspect
                sig = inspect.signature(method_certifier.certify_point)
                if 'y' in sig.parameters:
                    return method_certifier.certify_point(z, y=y_true)
                else:
                    return method_certifier.certify_point(z)
            else:
                return method_certifier.certify_point(z)
        
        r_tilde[i] = neutral_radius_at_P(
            get_cert_r, z, P, model_or_fn, 
            center, eps_y, sigma, N_eval, seed, alpha_trim, cp_alpha
        )
    
    # Calculate metrics
    valid_oracle = r_oracle > 0
    valid_tilde = r_tilde > 0
    
    # Pass rate: fraction where method's claim passes the evaluator
    pass_rate = np.mean(r_cert <= r_oracle)
    
    # Soundness: fraction where method's claim passes without modification
    # This is equivalent to r_tilde == r_cert (no shrinking needed)
    soundness = np.mean(np.abs(r_tilde - r_cert) < 1e-6)
    
    # Tightness: median of r_tilde / r_oracle (only for valid oracle radii)
    if np.any(valid_oracle):
        tightness = np.median(r_tilde[valid_oracle] / r_oracle[valid_oracle])
    else:
        tightness = 0.0
    
    # Over-claim rate: fraction where method claims more than oracle
    over_claim_rate = np.mean(r_cert > r_oracle)
    
    # Under-claim rate: fraction where method claims less than oracle
    under_claim_rate = np.mean(r_cert < r_oracle)
    
    return {
        'r_cert': r_cert,
        'r_tilde': r_tilde,
        'r_oracle': r_oracle,
        'pass_rate': pass_rate,
        'soundness': soundness,
        'tightness': tightness,
        'over_claim_rate': over_claim_rate,
        'under_claim_rate': under_claim_rate,
        'method_name': method_certifier.name
    }


def evaluate_method_neutral(method_certifier, points: np.ndarray, P: float,
                          model_or_fn: Callable, center: str, eps_y: float,
                          sigma: float, N_eval: int = 100000, seed: int = 42) -> np.ndarray:
    """
    Evaluate a certification method using the neutral evaluator.
    
    Args:
        method_certifier: The certifier object (e.g., MyCertifier or PaperAlphaCertifier)
        points: Array of input points to certify
        P: Target success probability
        model_or_fn: Regression function
        center: Center choice for evaluation
        eps_y: Output tolerance
        sigma: Noise standard deviation
        N_eval: Number of evaluation samples
        seed: Random seed
    
    Returns:
        Array of neutral radii for each point
    """
    def get_cert_r(z):
        if center == 'y':
            # For center='y', we need to provide the ground truth
            y_true = float(model_or_fn(z[0], z[1]))
            # Check if the certifier accepts y parameter
            import inspect
            sig = inspect.signature(method_certifier.certify_point)
            if 'y' in sig.parameters:
                return method_certifier.certify_point(z, y=y_true)
            else:
                # For certifiers that don't use y (like MyCertifier), just call normally
                return method_certifier.certify_point(z)
        else:
            return method_certifier.certify_point(z)
    
    neutral_radii = []
    for i, z in enumerate(points):
        print(f"Evaluating point {i+1}/{len(points)}...")
        r_neutral = neutral_radius_at_P(
            get_cert_r, z, P, model_or_fn, center, eps_y, sigma, N_eval, seed
        )
        neutral_radii.append(r_neutral)
    
    return np.array(neutral_radii)
