"""
Bounded Function Certifier with Variance and Mean Constraints (No Gradient)

This implements Algorithm 4 from the paper: "Calculating the Certified Radius (Bounded Regression)".

This certifier uses:
- Variance constraint: E[(φ*(T) - V)²] ≤ C
- Mean constraint: E[φ*(T)] = V
- Boundedness: -M ≤ φ*(t) ≤ M

Key difference from other certifiers:
- BoundedCertifier: (C, G) + M (variance + gradient + bound)
- BoundedCertifierWithMean: (E, C, G) + M (mean + variance + gradient + bound)
- BoundedCertifierVarianceMean: (V, C) + M (mean + variance + bound, NO gradient)

The optimal form is:
    φ*(t) = clip_{[-M,M]} (V + (w(t) - μ) / (2λ))

where w(t) = exp(αt/σ² - α²/2σ²) - 1 is the likelihood ratio.

The Lagrange multipliers λ and μ are found by solving:
    F₁(λ, μ) = ∫ (φ*(t; λ, μ) - V)² p₀(t) dt - C = 0  (variance constraint)
    F₂(λ, μ) = ∫ φ*(t; λ, μ) p₀(t) dt - V = 0        (mean constraint)
"""

import numpy as np
import warnings
from scipy.optimize import root_scalar, root, minimize
from scipy.stats import norm
from typing import Callable, Optional
from .base import BaseCertifier

# Suppress scipy optimization warnings that are handled internally
# These warnings occur when scipy encounters numerical edge cases but handles them correctly
warnings.filterwarnings('ignore', category=RuntimeWarning, module='scipy.optimize')
warnings.filterwarnings('ignore', category=UserWarning, module='scipy.optimize')


class BoundedCertifierVarianceMean(BaseCertifier):
    """
    Bounded function certifier with variance and mean constraints (no gradient constraint).
    
    This implements Algorithm 4 from the paper, which solves for the certified radius
    using only variance and mean information, without gradient information.
    """
    
    def __init__(self, *, sigma: float, M: float, eps_y: float, confidence: float = 0.999,
                 model_fn: Optional[Callable] = None, quadrature_points: int = 60,
                 r_high_init_mult: float = 5.0,
                 r_cap_mult: float = 20.0):
        """
        Args:
            sigma: The standard deviation of the Gaussian smoothing.
            M: The absolute bound on the function's output, i.e., |f(x)| <= M.
            eps_y: The threshold for the change in expectation (ε in Algorithm 4).
            confidence: The confidence level for statistical estimates.
            model_fn: The model function to certify. If None, must be provided to certify_point.
            quadrature_points: Number of points for Gauss-Hermite quadrature (default 60).
        """
        super().__init__(sigma=sigma)
        self.M = M
        self.eps_y = eps_y
        self.confidence = confidence
        self.model_fn = model_fn
        self.quadrature_points = quadrature_points
        # Bracketing controls for the root solve over radius.
        # Historically we used r_high = 5*sigma; that can hard-cap results.
        # We now expand r_high adaptively up to r_cap_mult*sigma.
        self.r_high_init_mult = float(r_high_init_mult)
        self.r_cap_mult = float(r_cap_mult)
        self.name = "Bounded Function Certificate (Variance + Mean, No Gradient)"
    
    # === Part 0: 1-D Gaussian expectation helper ===
    
    def _gauss_hermite_expectation(self, func) -> float:
        """
        Numerically approximates E[f(T)] for T ~ N(0, σ²) using Gauss–Hermite quadrature.
        """
        from numpy.polynomial.hermite import hermgauss
        
        x, w = hermgauss(self.quadrature_points)
        z = np.sqrt(2.0) * x
        t = self.sigma * z
        f_vals = func(t)
        
        return float(np.sum(w * f_vals) / np.sqrt(np.pi))
    
    # === Part 1: U-statistic estimators ===
    
    def u_statistic_variance_estimator_alpha_half(self, samples: np.ndarray) -> tuple:
        """U-statistic variance estimator with α/2 confidence interval for union bound."""
        n = len(samples)
        
        if n < 2:
            return 0.0, 0.0, 0.0
        
        theta_hat = np.var(samples, ddof=1)
        mean_val = np.mean(samples)
        fourth_moment = np.mean((samples - mean_val)**4)
        asymptotic_var = max(0.0, fourth_moment - theta_hat**2)
        
        # Use α/2 for union bound (variance + mean, 2 constraints)
        alpha_total = 1 - self.confidence
        alpha_split = alpha_total / 2.0
        z_critical = norm.ppf(1 - alpha_split / 2)
        se = np.sqrt(asymptotic_var / n)
        
        lower_bound = theta_hat - z_critical * se
        upper_bound = theta_hat + z_critical * se
        
        return theta_hat, lower_bound, upper_bound
    
    def u_statistic_mean_estimator_alpha_half(self, samples: np.ndarray) -> tuple:
        """
        U-statistic mean estimator with α/2 confidence interval (for 2-way union bound).
        
        Estimates E[f(z + η)] where η ~ N(0, σ²I).
        """
        n = len(samples)
        
        if n < 2:
            return 0.0, 0.0, 0.0
        
        # Sample mean is unbiased estimator
        mean_hat = np.mean(samples)
        
        # Standard error
        se = np.std(samples, ddof=1) / np.sqrt(n)
        
        # Confidence interval with α/2 for union bound
        alpha_total = 1 - self.confidence
        alpha_split = alpha_total / 2.0
        z_critical = norm.ppf(1 - alpha_split / 2)
        
        lower_bound = mean_hat - z_critical * se
        upper_bound = mean_hat + z_critical * se
        
        return mean_hat, lower_bound, upper_bound
    
    # === Part 2: Dual optimization for worst-case φ* for fixed shift α = r ===
    
    def _solve_dual_multipliers(
        self,
        r: float,
        C: float,
        V: float,
        *,
        max_iter: int = 1000,
        tol: float = 1e-5,
        lr_lambda: float = 0.5,
        lr_mu: float = 0.5,
        lambda_init: Optional[float] = None,
        mu_init: Optional[float] = None,
        use_scipy_optimizer: bool = True,
        prefer_exact_solver: bool = True,
    ) -> tuple:
        """
        Implements the dual optimization from Algorithm 4 (WorstHarm subroutine).
        
        Given a fixed shift α = r, variance C, mean V, and bound M, this
        routine finds multipliers (λ*, μ*) such that the variance and mean
        constraints are (approximately) satisfied:
        
            E[(φ*(T) - V)²] = C    (Variance constraint, Eq. 43)
            E[φ*(T)] = V            (Mean constraint, Eq. 44)
        
        where T ~ N(0, σ²) and
        
            w(t) = exp(αt/σ² - α²/(2σ²)) - 1
            h(t) = V + (w(t) - μ) / (2λ)
            φ*(t) = clip(h(t), -M, M).
        
        Returns:
            (lambda, mu): Lagrange multipliers
        """
        # Better initialization: estimate from constraints
        if lambda_init is None:
            # Rough estimate: if no clipping, lambda ~ C / (some scale)
            lambda_init = max(0.1, C / (2 * self.M**2)) if C > 0 else 1.0
        if mu_init is None:
            # Start at zero, let optimization find it
            mu_init = 0.0
        
        # Predefine w(t) for this radius
        def w_of_t(t: np.ndarray) -> np.ndarray:
            return np.exp((r * t) / (self.sigma**2) - (r**2) / (2 * self.sigma**2)) - 1.0
        
        # If prefer_exact_solver, try exact solver first
        if prefer_exact_solver and use_scipy_optimizer:
            try:
                result = self._solve_dual_exact(r, C, V, w_of_t, lambda_init, mu_init)
                if result is not None:
                    return result
            except Exception:
                # Fall back to iterative if exact solver fails
                pass
        
        lambda_b = max(1e-6, float(lambda_init))
        mu_b = float(mu_init)
        
        # Use adaptive learning rates with momentum-like behavior
        current_lr_lambda = lr_lambda
        current_lr_mu = lr_mu
        prev_grad_lambda = 0.0
        prev_grad_mu = 0.0
        momentum = 0.3
        
        # Track best solution
        best_lambda = lambda_b
        best_mu = mu_b
        best_residual = float('inf')
        
        for iter_num in range(max_iter):
            # Construct h(t) and φ*(t) according to Algorithm 4
            def phi_star(t: np.ndarray) -> np.ndarray:
                w_t = w_of_t(t)
                h_t = V + (w_t - mu_b) / (2.0 * lambda_b)
                return np.clip(h_t, -self.M, self.M)
            
            def phi_star_minus_V_sq(t: np.ndarray) -> np.ndarray:
                phi = phi_star(t)
                return (phi - V)**2
            
            # Compute constraint integrals (moments) via Gauss–Hermite
            # F₁(λ, μ) = ∫ (φ*(t) - V)² p₀(t) dt - C = 0
            V_val = self._gauss_hermite_expectation(phi_star_minus_V_sq)
            # F₂(λ, μ) = ∫ φ*(t) p₀(t) dt - V = 0
            E_val = self._gauss_hermite_expectation(phi_star)
            
            # Gradients of the dual (constraint residuals)
            grad_lambda = C - V_val  # F₁ residual
            grad_mu = V - E_val      # F₂ residual
            
            # Track best solution
            residual_norm = np.sqrt(grad_lambda**2 + grad_mu**2)
            if residual_norm < best_residual:
                best_residual = residual_norm
                best_lambda = lambda_b
                best_mu = mu_b
            
            # Check convergence
            if abs(grad_lambda) < tol and abs(grad_mu) < tol:
                break
            
            # Update multipliers with momentum and adaptive learning rate
            # NOTE: We do DESCENT (not ascent) because g(λ,μ) is CONVEX and we want to minimize it
            lambda_b_old = lambda_b
            mu_b_old = mu_b
            
            # Momentum update
            update_lambda = current_lr_lambda * grad_lambda + momentum * prev_grad_lambda
            update_mu = current_lr_mu * grad_mu + momentum * prev_grad_mu
            
            # GRADIENT DESCENT to minimize convex g
            lambda_b = max(1e-8, lambda_b - update_lambda)
            mu_b = mu_b - update_mu
            
            prev_grad_lambda = update_lambda
            prev_grad_mu = update_mu
            
            # Adaptive learning rate: increase if making progress, decrease if stuck
            if iter_num > 5:
                progress_lambda = abs(lambda_b - lambda_b_old)
                progress_mu = abs(mu_b - mu_b_old)
                
                # If making good progress, increase learning rate slightly
                if progress_lambda > 1e-6 and progress_mu > 1e-6:
                    current_lr_lambda = min(1.0, current_lr_lambda * 1.01)
                    current_lr_mu = min(1.0, current_lr_mu * 1.01)
                # If stuck, decrease learning rate
                elif progress_lambda < 1e-10 or progress_mu < 1e-10:
                    current_lr_lambda *= 0.95
                    current_lr_mu *= 0.95
                    # Reset momentum if stuck
                    prev_grad_lambda = 0.0
                    prev_grad_mu = 0.0
        
        # Compute final residuals for checking
        def phi_star_final(t: np.ndarray) -> np.ndarray:
            w_t = w_of_t(t)
            h_t = V + (w_t - best_mu) / (2.0 * best_lambda)
            return np.clip(h_t, -self.M, self.M)
        
        def phi_star_minus_V_sq_final(t: np.ndarray) -> np.ndarray:
            phi = phi_star_final(t)
            return (phi - V)**2
        
        V_val_final = self._gauss_hermite_expectation(phi_star_minus_V_sq_final)
        E_val_final = self._gauss_hermite_expectation(phi_star_final)
        final_residual = np.sqrt((C - V_val_final)**2 + (V - E_val_final)**2)
        
        # If didn't converge, try more robust optimization methods
        if final_residual > tol and use_scipy_optimizer:
            # Try scipy.optimize.minimize with L-BFGS-B
            try:
                def objective(x):
                    """
                    Dual function g(λ,μ) for minimization.
                    Since primal is MAXIMIZATION, g is CONVEX and we minimize it directly.
                    """
                    lambda_b_opt = max(1e-8, x[0])
                    mu_b_opt = x[1]
                    
                    def phi_star_opt(t: np.ndarray) -> np.ndarray:
                        w_t = w_of_t(t)
                        h_t = V + (w_t - mu_b_opt) / (2.0 * lambda_b_opt)
                        return np.clip(h_t, -self.M, self.M)
                    
                    def integrand(t: np.ndarray) -> np.ndarray:
                        phi = phi_star_opt(t)
                        w_t = w_of_t(t)
                        return phi * w_t - lambda_b_opt * (phi - V)**2 - mu_b_opt * (phi - V)
                    
                    integral = self._gauss_hermite_expectation(integrand)
                    # g(λ,μ) = integral + λC
                    # Note: The mu*V term is already included in the expectation E[-μ(φ-V)] = -μE[φ] + μV
                    # So we don't add it again here (Option A: keep centered integrand, drop + mu*V)
                    dual_value_g = integral + lambda_b_opt * C
                    return dual_value_g
                
                def constraint_residuals(x):
                    """Constraint residuals for verification."""
                    lambda_b_opt = max(1e-8, x[0])
                    mu_b_opt = x[1]
                    
                    def phi_star_opt(t: np.ndarray) -> np.ndarray:
                        w_t = w_of_t(t)
                        h_t = V + (w_t - mu_b_opt) / (2.0 * lambda_b_opt)
                        return np.clip(h_t, -self.M, self.M)
                    
                    def phi_star_minus_V_sq_opt(t: np.ndarray) -> np.ndarray:
                        phi = phi_star_opt(t)
                        return (phi - V)**2
                    
                    V_val_opt = self._gauss_hermite_expectation(phi_star_minus_V_sq_opt)
                    E_val_opt = self._gauss_hermite_expectation(phi_star_opt)
                    
                    return np.array([
                        C - V_val_opt,  # F₁: variance constraint
                        V - E_val_opt   # F₂: mean constraint
                    ])
                
                # Use best solution as initial guess
                x0 = np.array([best_lambda, best_mu])
                
                # Try L-BFGS-B with bounds
                bounds = [(1e-8, None), (None, None)]  # lambda > 0, mu unbounded
                sol = minimize(
                    objective,
                    x0,
                    method='L-BFGS-B',
                    bounds=bounds,
                    options={'maxiter': 500, 'ftol': 1e-8, 'gtol': 1e-6}
                )
                
                if sol.success:
                    lambda_b_opt = max(1e-8, sol.x[0])
                    mu_b_opt = sol.x[1]
                    # Verify constraints are satisfied
                    residuals = constraint_residuals(sol.x)
                    residual_norm = np.linalg.norm(residuals)
                    if residual_norm < best_residual:
                        return lambda_b_opt, mu_b_opt
                
                # If L-BFGS-B didn't work well, try trust-region
                sol2 = minimize(
                    objective,
                    x0,
                    method='trust-constr',
                    bounds=bounds,
                    options={'maxiter': 500, 'gtol': 1e-6}
                )
                
                if sol2.success:
                    lambda_b_opt = max(1e-8, sol2.x[0])
                    mu_b_opt = sol2.x[1]
                    residuals = constraint_residuals(sol2.x)
                    residual_norm = np.linalg.norm(residuals)
                    if residual_norm < best_residual:
                        return lambda_b_opt, mu_b_opt
                        
            except Exception as e:
                # If optimization fails, use best solution from gradient descent
                pass
        
        # Return best solution found
        return best_lambda, best_mu
    
    def _solve_dual_exact(
        self,
        r: float,
        C: float,
        V: float,
        w_of_t: Callable,
        lambda_init: Optional[float] = None,
        mu_init: Optional[float] = None,
    ) -> Optional[tuple]:
        """
        Solve the dual optimization problem exactly using scipy.optimize.
        
        This method uses root finding on the constraint equations:
            F₁(λ, μ) = ∫ (φ*(t) - V)² p₀(t) dt - C = 0
            F₂(λ, μ) = ∫ φ*(t) p₀(t) dt - V = 0
        
        Returns:
            (lambda, mu) if successful, None otherwise
        """
        # Initialize
        if lambda_init is None:
            lambda_init = max(0.1, C / (2 * self.M**2)) if C > 0 else 1.0
        if mu_init is None:
            mu_init = 0.0
        
        x0 = np.array([lambda_init, mu_init])
        
        def constraint_residuals(x):
            """Constraint residuals for root finding."""
            lambda_b_opt = max(1e-8, x[0])
            mu_b_opt = x[1]
            
            def phi_star_opt(t: np.ndarray) -> np.ndarray:
                w_t = w_of_t(t)
                h_t = V + (w_t - mu_b_opt) / (2.0 * lambda_b_opt)
                return np.clip(h_t, -self.M, self.M)
            
            def phi_star_minus_V_sq_opt(t: np.ndarray) -> np.ndarray:
                phi = phi_star_opt(t)
                return (phi - V)**2
            
            V_val_opt = self._gauss_hermite_expectation(phi_star_minus_V_sq_opt)
            E_val_opt = self._gauss_hermite_expectation(phi_star_opt)
            
            return np.array([
                C - V_val_opt,  # F₁: variance constraint
                V - E_val_opt   # F₂: mean constraint
            ])
        
        # Try root finding - use 'broyden1' which is robust for non-smooth functions
        try:
            sol = root(
                constraint_residuals,
                x0,
                method='broyden1',
                options={'maxiter': 500}
            )
            
            if sol.success:
                lambda_b = max(1e-8, sol.x[0])
                mu_b = sol.x[1]
                # Verify constraints are satisfied
                residuals = constraint_residuals(sol.x)
                residual_norm = np.linalg.norm(residuals)
                if residual_norm < 1e-5:  # Tolerance check
                    return lambda_b, mu_b
        except Exception:
            pass
        
        # Try minimization approach as alternative
        try:
            def objective(x):
                """Minimize constraint violation."""
                residuals = constraint_residuals(x)
                return np.sum(residuals**2)
            
            bounds = [(1e-8, None), (None, None)]  # lambda > 0, mu unbounded
            sol = minimize(
                objective,
                x0,
                method='L-BFGS-B',
                bounds=bounds,
                options={'maxiter': 1000, 'ftol': 1e-10, 'gtol': 1e-8}
            )
            
            if sol.success:
                lambda_b = max(1e-8, sol.x[0])
                mu_b = sol.x[1]
                residuals = constraint_residuals(sol.x)
                residual_norm = np.linalg.norm(residuals)
                if residual_norm < 1e-5:
                    return lambda_b, mu_b
        except Exception:
            pass
        
        return None
    
    # === Part 3: Worst-case harm Δ(r) for bounded case ===
    
    def _worst_harm_bounded(self, r: float, C: float, V: float) -> float:
        """
        Implements WorstHarm(α) from Algorithm 4.
        
        For fixed radius r, solve the dual problem for (λ*, μ*) and then compute
        the worst-case change in expectation:
        
            Δ(r) = E_T[φ*(T) w(T)]
        
        where T ~ N(0, σ²), w(t) is the likelihood ratio, and φ* is the clipped
        worst-case function determined by the optimal multipliers.
        
        Args:
            r: Radius (shift magnitude α)
            C: Variance upper bound
            V: Mean value
            
        Returns:
            Worst-case harm |Δ(r)|
        """
        if r <= 0 or C <= 0:
            return 0.0
        
        # Step 1: Solve for multipliers λ*, μ*
        lambda_b, mu_b = self._solve_dual_multipliers(r, C, V)
        
        # Step 2: Define w(t) and compute worst-case change Δ(r)
        def w_of_t(t: np.ndarray) -> np.ndarray:
            return np.exp((r * t) / (self.sigma**2) - (r**2) / (2 * self.sigma**2)) - 1.0
        
        def phi_star_times_w(t: np.ndarray) -> np.ndarray:
            w_t = w_of_t(t)
            h_t = V + (w_t - mu_b) / (2.0 * lambda_b)
            phi = np.clip(h_t, -self.M, self.M)
            return phi * w_t
        
        delta = self._gauss_hermite_expectation(phi_star_times_w)
        
        # We care about the magnitude of the worst-case deviation
        return abs(delta)
    
    # === Part 4: Main certification routines ===
    
    def certify_point_from_estimates(self, C_ucb: float, V_est: float) -> float:
        """
        Implements Algorithm 4: ComputeRadius(C, V, ε, σ, M, r_high, tol).
        
        Full bounded-function certificate using variance and mean constraints only.
        This uses pre-computed high-confidence upper bounds on the variance C
        and mean estimate V, and returns the largest radius R such that
        WorstHarm(R) ≤ eps_y.
        
        Args:
            C_ucb: Upper confidence bound on variance
            V_est: Mean estimate (NOTE: This is a point estimate, not a conservative bound.
                   For formal probabilistic guarantees, consider using V_UB or maximizing
                   worst-harm over V ∈ [V_LB, V_UB]. See collaborator feedback.)
            
        Returns:
            Certified radius R
        
        Note on W(R) vs Δ(R):
            The theory defines W(R) = max_{α∈[0,R]} Δ(α), but this implementation
            uses the boundary-only shortcut: Δ(R) directly. This is correct if Δ(α)
            is monotone in α or the maximum occurs at α = R. For formal correctness
            under the stated definition, we would need to maximize over α ∈ [0, r]
            inside the bisection search.
        """
        C_ucb = float(max(0.0, C_ucb))
        V_est = float(V_est)
        
        # High-variance regime guard: theory requires C < M² - V² for λ > 0
        # If C >= M² - V², the variance constraint may be inactive and we need
        # different handling (bang-bang solution). For now, we clamp and proceed,
        # but this should be flagged.
        tiny = 1e-6
        if C_ucb >= self.M**2 - V_est**2 - tiny:
            # High-variance regime: variance constraint may be inactive
            # TODO: Implement proper high-variance regime handling (allow λ→0 or use bang-bang)
            # For now, we proceed but this may produce incorrect results
            pass
        
        if C_ucb == 0.0:
            return 0.0
        
        def worst_harm_minus_eps(r: float) -> float:
            return self._worst_harm_bounded(r, C_ucb, V_est) - self.eps_y
        
        # Root solve over radius with adaptive bracketing.
        r_low = 0.0
        f_low = worst_harm_minus_eps(r_low)
        if f_low > 0.0:
            return 0.0

        r_cap = max(0.0, self.r_cap_mult * self.sigma)
        r_high = max(0.0, self.r_high_init_mult * self.sigma)
        if r_cap > 0.0:
            r_high = min(r_high, r_cap)
        f_high = worst_harm_minus_eps(r_high)

        # Expand until we find a violating radius or hit the cap.
        while f_high <= 0.0 and r_high < r_cap:
            r_high = min(r_high * 2.0, r_cap)
            f_high = worst_harm_minus_eps(r_high)

        if f_high <= 0.0:
            # Safe up to cap
            return r_high
        
        try:
            sol = root_scalar(
                worst_harm_minus_eps,
                bracket=[r_low, r_high],
                method="brentq",
                xtol=1e-4,
                rtol=1e-4,
            )
            return max(0.0, float(sol.root))
        except (ValueError, RuntimeError):
            return 0.0
    
    def certify_point(
        self,
        z: np.ndarray,
        model_fn: Optional[Callable] = None,
        N_samples_stats: int = 10000,
        seed: Optional[int] = None,
    ) -> float:
        """
        Performs the full certification for a given point z.
        
        Args:
            z: Input point to certify
            model_fn: Model function to certify. If None, uses self.model_fn
            N_samples_stats: Number of samples for statistical estimation
            seed: Random seed
            
        Returns:
            Certified radius
        """
        if model_fn is None:
            if self.model_fn is None:
                raise ValueError("model_fn must be provided either as parameter or in constructor")
            model_fn = self.model_fn
        
        rng = np.random.default_rng(seed)
        
        # 1. Estimate statistical quantities (C, V) with high-confidence bounds
        eta_samples = rng.normal(0.0, self.sigma, size=(N_samples_stats, z.shape[-1]))
        f_values = np.array([model_fn(z + eta) for eta in eta_samples])
        
        # Use U-statistic estimators with α/2 confidence intervals (for union bound)
        _, _, C_ucb = self.u_statistic_variance_estimator_alpha_half(f_values)
        V_hat, V_lower, V_upper = self.u_statistic_mean_estimator_alpha_half(f_values)
        
        # Use mean point estimate (NOTE: For formal probabilistic guarantees, should use
        # conservative approach: maximize worst-harm over V ∈ [V_lower, V_upper] or use
        # the endpoint that produces larger harm. For comparison experiments showing
        # "gradient helps", using consistent point estimates across certifiers may be
        # acceptable, but this should be documented. See collaborator feedback.)
        V_est = V_hat
        
        # 2. Delegate to the bounded dual-based certificate
        return self.certify_point_from_estimates(C_ucb, V_est)

