#!/usr/bin/env python3
"""
Compare BoundedCertifierVarianceMean (E, C) + M vs BoundedCertifierWithMean (E, C, G) + M

This script compares:
1. BoundedCertifierVarianceMean: Uses variance + mean constraints (no gradient)
2. BoundedCertifierWithMean: Uses variance + gradient + mean constraints

The goal is to show that gradient information helps improve certified radius.

Usage:
    # Compare on MNIST rotation using pre-computed estimates
    python scripts/compare_variance_mean_vs_with_gradient.py \\
        --mode precomputed \\
        --variance_gradient mnist_rotation_full_cert_n100_20251106_033225.json \\
        --eps_y_deg 10.0 \\
        --N 10000 \\
        --trial 0
    
    # Compare on synthetic function
    python scripts/compare_variance_mean_vs_with_gradient.py \\
        --mode synthetic \\
        --function bounded_quadratic \\
        --sigma 0.5 \\
        --eps_y 0.5 \\
        --M 1.0 \\
        --n_test_points 10
"""

import json
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
import sys
from datetime import datetime
from typing import Dict, List, Tuple, Optional
try:
    from tqdm import tqdm
except ImportError:
    # Fallback if tqdm is not installed
    def tqdm(iterable, desc=None, **kwargs):
        return iterable

# Add src path
sys.path.append(str(Path(__file__).resolve().parent.parent / "src"))

from alpha_smoothing_repro.certify import BoundedCertifierVarianceMean, BoundedCertifierWithMean
from alpha_smoothing_repro.synthetic_functions import (
    bounded_quadratic,
    bounded_linear_function,
    bounded_sine_function,
    bounded_slice_function,
    create_test_points,
    compute_true_radius_analytical,
    mc_truth_for_bounded_function
)


def load_variance_gradient_estimates(json_path: str) -> Dict:
    """Load pre-computed variance and gradient norm estimates."""
    with open(json_path, 'r') as f:
        data = json.load(f)
    return data


def compute_radii_comparison_precomputed(
    variance_gradient_data: Dict,
    eps_y_deg: float,
    N: int,
    trial_idx: int = 0,
    ci_type: str = 'analytical',
    confidence: float = 0.95,
    variance_gradient_data_ec: Optional[Dict] = None,
) -> Tuple[List[Dict], Dict]:
    """
    Compute certified radii using both certifiers from saved estimates.
    
    Args:
        variance_gradient_data: Data from mnist_rotation_full_certification.py (for (E, C, G)+M)
        eps_y_deg: Output tolerance in degrees
        N: Sample size to use (e.g., 1000, 10000)
        trial_idx: Which trial to use (default 0)
        ci_type: 'analytical' or 'bootstrap'
        confidence: Confidence level
        variance_gradient_data_ec: Optional separate data for (E, C)+M (if None, uses variance_gradient_data)
        
    Returns:
        results: List of dicts with certified radii and metadata
        summary: Summary statistics
    """
    # Extract parameters
    sigma = variance_gradient_data['parameters']['sigma']
    eps_y_rad = np.radians(eps_y_deg)
    
    # Use separate data for (E, C)+M if provided, otherwise use same data
    data_ec = variance_gradient_data_ec if variance_gradient_data_ec is not None else variance_gradient_data
    
    # Verify sigma matches
    if variance_gradient_data_ec is not None:
        assert data_ec['parameters']['sigma'] == sigma, "Sigma must match between estimation files"
    
    # Initialize certifiers with M = π (angles in radians)
    M = np.pi
    
    certifier_variance_mean = BoundedCertifierVarianceMean(
        sigma=sigma,
        M=M,
        eps_y=eps_y_rad,
        confidence=confidence,
        quadrature_points=60
    )
    
    certifier_with_gradient = BoundedCertifierWithMean(
        sigma=sigma,
        M=M,
        eps_y=eps_y_rad,
        confidence=confidence,
        quadrature_points=60
    )
    
    print(f"\n{'='*80}")
    print(f"Comparing (E, C) + M vs (E, C, G) + M Certifiers")
    print(f"{'='*80}")
    print(f"Parameters:")
    print(f"  σ = {sigma}")
    print(f"  ε_y = {eps_y_deg}° = {eps_y_rad:.4f} rad")
    print(f"  M = π = {M:.4f} rad")
    print(f"  N = {N}")
    print(f"  Trial = {trial_idx}")
    print(f"  CI type = {ci_type}")
    print(f"{'='*80}\n")
    
    results = []
    samples = variance_gradient_data['samples']
    
    print(f"Processing {len(samples)} samples...")
    
    # Filter samples that have the required N and trial
    valid_samples = []
    for i, sample in enumerate(samples):
        if str(N) not in sample.get('results_by_N', {}):
            continue
        trials = sample['results_by_N'][str(N)]
        if trial_idx >= len(trials):
            continue
        valid_samples.append((i, sample))
    
    print(f"Found {len(valid_samples)} valid samples (with N={N}, trial={trial_idx})\n")
    
    # Also validate samples in data_ec if using separate file
    if variance_gradient_data_ec is not None:
        samples_ec = data_ec['samples']
        valid_samples_ec = []
        for i, sample in enumerate(samples_ec):
            if str(N) not in sample.get('results_by_N', {}):
                continue
            trials = sample['results_by_N'][str(N)]
            if trial_idx >= len(trials):
                continue
            valid_samples_ec.append((i, sample))
        print(f"Found {len(valid_samples_ec)} valid samples in (E, C)+M estimation file\n")
        # Match samples by index
        if len(valid_samples_ec) != len(valid_samples):
            print(f"⚠️  Warning: Sample count mismatch ({len(valid_samples_ec)} vs {len(valid_samples)})")
    
    for idx, (i, sample) in enumerate(tqdm(valid_samples, desc="Computing radii", unit="sample", ncols=100)):
        # Get estimates for (E, C, G)+M from main data
        estimates_ecg = sample['results_by_N'][str(N)][trial_idx]
        
        # Get estimates for (E, C)+M (from separate file if provided, otherwise same)
        if variance_gradient_data_ec is not None:
            sample_ec = data_ec['samples'][idx] if idx < len(data_ec['samples']) else sample
            estimates_ec = sample_ec['results_by_N'][str(N)][trial_idx]
        else:
            estimates_ec = estimates_ecg
        
        # Extract upper bounds for variance and gradient norm
        if ci_type == 'analytical':
            C_ucb_ecg = estimates_ecg['C_upper_analytical']
            C_ucb_ec = estimates_ec['C_upper_analytical']
        else:
            C_ucb_ecg = estimates_ecg.get('C_upper_bootstrap', estimates_ecg['C_upper_analytical'])
            C_ucb_ec = estimates_ec.get('C_upper_bootstrap', estimates_ec['C_upper_analytical'])
        
        G_ucb = estimates_ecg['G_norm_upper']
        
        # Extract mean estimate (use from either file, they should be similar)
        V_est = estimates_ec.get('g_z_hat', estimates_ecg.get('g_z_hat', 0.0))
        
        # Compute radius with (E, C) + M (no gradient) - uses C from 0.8 estimation
        radius_variance_mean = certifier_variance_mean.certify_point_from_estimates(C_ucb_ec, V_est)
        
        # Compute radius with (E, C, G) + M (with gradient) - uses C and G from 0.9 estimation
        radius_with_gradient = certifier_with_gradient.certify_point_from_estimates(C_ucb_ecg, G_ucb, V_est)
        
        improvement = radius_with_gradient - radius_variance_mean
        improvement_pct = 100 * improvement / radius_variance_mean if radius_variance_mean > 0 else 0.0
        
        # Collect results
        result = {
            'sample_idx': sample.get('sample_idx', i),
            'test_dataset_idx': sample.get('test_dataset_idx', sample.get('image_idx', i)),
            'digit_label': sample.get('digit_label', None),
            'radius_variance_mean': radius_variance_mean,
            'radius_with_gradient': radius_with_gradient,
            'improvement': improvement,
            'improvement_pct': improvement_pct,
            'C_hat_ec': estimates_ec.get('C_hat', 0.0),
            'C_ucb_ec': C_ucb_ec,
            'C_hat_ecg': estimates_ecg.get('C_hat', 0.0),
            'C_ucb_ecg': C_ucb_ecg,
            'G_norm_hat': estimates_ecg.get('G_norm_hat', 0.0),
            'G_ucb': G_ucb,
            'V_hat': V_est,
            'V_abs': abs(V_est),
            'clean_pred_deg': sample.get('clean_pred_deg', None),
            'N': N,
            'trial': trial_idx,
        }
        
        results.append(result)
    
    print(f"\n✓ Computed radii for {len(results)} samples\n")
    
    # Compute summary statistics
    radii_vm = np.array([r['radius_variance_mean'] for r in results])
    radii_wg = np.array([r['radius_with_gradient'] for r in results])
    improvements = np.array([r['improvement'] for r in results])
    improvement_pcts = np.array([r['improvement_pct'] for r in results])
    
    summary = {
        'n_samples': len(results),
        'variance_mean': {
            'mean': float(np.mean(radii_vm)),
            'median': float(np.median(radii_vm)),
            'std': float(np.std(radii_vm)),
            'min': float(np.min(radii_vm)),
            'max': float(np.max(radii_vm)),
            'certified': int(np.sum(radii_vm > 0))
        },
        'with_gradient': {
            'mean': float(np.mean(radii_wg)),
            'median': float(np.median(radii_wg)),
            'std': float(np.std(radii_wg)),
            'min': float(np.min(radii_wg)),
            'max': float(np.max(radii_wg)),
            'certified': int(np.sum(radii_wg > 0))
        },
        'improvement': {
            'mean': float(np.mean(improvements)),
            'median': float(np.median(improvements)),
            'mean_pct': float(np.mean(improvement_pcts)),
            'median_pct': float(np.median(improvement_pcts)),
            'gradient_wins': int(np.sum(radii_wg > radii_vm)),
            'ties': int(np.sum(radii_wg == radii_vm)),
            'variance_mean_wins': int(np.sum(radii_vm > radii_wg))
        }
    }
    
    print_summary(summary)
    
    return results, summary


def print_summary(summary: Dict):
    """Print summary statistics."""
    print("="*80)
    print("SUMMARY STATISTICS")
    print("="*80)
    n_items = summary.get('n_samples', summary.get('n_points', 0))
    item_label = 'samples' if 'n_samples' in summary else 'points'
    print(f"\nTotal {item_label}: {n_items}")
    
    print(f"\n(E, C) + M (Variance + Mean, No Gradient):")
    print(f"  Mean radius:   {summary['variance_mean']['mean']:.6f}")
    print(f"  Median radius: {summary['variance_mean']['median']:.6f}")
    print(f"  Std dev:       {summary['variance_mean']['std']:.6f}")
    print(f"  Range:         [{summary['variance_mean']['min']:.6f}, {summary['variance_mean']['max']:.6f}]")
    print(f"  Certified:     {summary['variance_mean']['certified']}/{n_items}")
    
    print(f"\n(E, C, G) + M (Variance + Mean + Gradient):")
    print(f"  Mean radius:   {summary['with_gradient']['mean']:.6f}")
    print(f"  Median radius: {summary['with_gradient']['median']:.6f}")
    print(f"  Std dev:       {summary['with_gradient']['std']:.6f}")
    print(f"  Range:         [{summary['with_gradient']['min']:.6f}, {summary['with_gradient']['max']:.6f}]")
    print(f"  Certified:     {summary['with_gradient']['certified']}/{n_items}")
    
    print(f"\nImprovement (With Gradient - Variance+Mean):")
    print(f"  Mean:          {summary['improvement']['mean']:.6f} ({summary['improvement']['mean_pct']:.2f}%)")
    print(f"  Median:        {summary['improvement']['median']:.6f} ({summary['improvement']['median_pct']:.2f}%)")
    print(f"  Gradient wins:    {summary['improvement']['gradient_wins']}")
    print(f"  Ties:              {summary['improvement']['ties']}")
    print(f"  Variance+Mean wins: {summary['improvement']['variance_mean_wins']}")
    
    print("="*80 + "\n")


def compare_on_synthetic_function(
    function_type: str,
    sigma: float,
    eps_y: float,
    M: float = 1.0,
    n_test_points: int = 10,
    N_samples: int = 10000,
    seed: int = 42,
    compute_true_radius: bool = False
) -> Tuple[List[Dict], Dict]:
    """
    Compare both certifiers on synthetic functions.
    """
    print(f"\n{'='*80}")
    print(f"Comparing Certifiers on Synthetic Function")
    print(f"{'='*80}")
    print(f"Function: {function_type}")
    print(f"Parameters: σ={sigma}, ε_y={eps_y}, M={M}, N={N_samples}")
    print(f"Test points: {n_test_points}")
    print(f"{'='*80}\n")
    
    # Initialize both certifiers
    certifier_vm = BoundedCertifierVarianceMean(
        sigma=sigma,
        M=M,
        eps_y=eps_y,
        confidence=0.95,
        quadrature_points=60
    )
    
    certifier_wg = BoundedCertifierWithMean(
        sigma=sigma,
        M=M,
        eps_y=eps_y,
        confidence=0.95,
        quadrature_points=60,
        mean_target=0.0
    )
    
    # Define function wrapper that takes numpy array
    function_params = {}
    if function_type == "bounded_quadratic":
        function_params = {'center': (0.0, 0.0), 'scale': 1.0}
        def model_fn(x: np.ndarray) -> float:
            return bounded_quadratic(x[0], x[1], center=function_params['center'], 
                                     scale=function_params['scale'], M=M)
    elif function_type == "bounded_linear":
        function_params = {}
        def model_fn(x: np.ndarray) -> float:
            return bounded_linear_function(x[0], x[1], M=M)
    elif function_type == "bounded_sine":
        function_params = {'frequency': 1.0}
        def model_fn(x: np.ndarray) -> float:
            return bounded_sine_function(x[0], x[1], frequency=function_params['frequency'], M=M)
    elif function_type == "bounded_slice":
        function_params = {'threshold': 0.0}
        def model_fn(x: np.ndarray) -> float:
            return bounded_slice_function(x[0], x[1], threshold=function_params['threshold'], M=M)
    else:
        raise ValueError(f"Unknown function type: {function_type}")
    
    # Generate test points
    test_points = create_test_points(n_points=n_test_points, seed=seed)
    
    results = []
    rng = np.random.default_rng(seed)
    
    print(f"Certifying {len(test_points)} test points...")
    
    for i, z in enumerate(test_points):
        # Sample for statistical estimation
        eta_samples = rng.normal(0.0, sigma, size=(N_samples, z.shape[-1]))
        f_values = np.array([model_fn(z + eta) for eta in eta_samples])
        
        # Get estimates
        _, _, C_ucb_vm = certifier_vm.u_statistic_variance_estimator_alpha_half(f_values)
        V_hat, _, _ = certifier_vm.u_statistic_mean_estimator_alpha_half(f_values)
        
        _, _, C_ucb_wg = certifier_wg.u_statistic_variance_estimator_alpha_half(f_values)
        _, _, G_ucb = certifier_wg.u_statistic_gradient_norm_estimator_alpha_half(f_values, eta_samples)
        
        # Compute radii
        radius_vm = certifier_vm.certify_point_from_estimates(C_ucb_vm, V_hat)
        radius_wg = certifier_wg.certify_point_from_estimates(C_ucb_wg, G_ucb, V_hat)
        
        improvement = radius_wg - radius_vm
        improvement_pct = 100 * improvement / radius_vm if radius_vm > 0 else 0.0
        
        result = {
            'test_point_idx': i,
            'test_point': z.tolist(),
            'radius_variance_mean': radius_vm,
            'radius_with_gradient': radius_wg,
            'improvement': improvement,
            'improvement_pct': improvement_pct,
            'C_ucb': C_ucb_vm,
            'G_ucb': G_ucb,
            'V_hat': V_hat,
        }
        
        # Optionally compute true radius for comparison
        if compute_true_radius:
            try:
                # compute_true_radius_analytical expects a function that takes (x1, x2)
                def function_for_truth(x1: float, x2: float) -> float:
                    return model_fn(np.array([x1, x2]))
                true_radius = compute_true_radius_analytical(
                    z, function_for_truth, sigma, eps_y, M, function_params
                )
                result['true_radius'] = true_radius
            except Exception as e:
                print(f"  Warning: Could not compute true radius for point {i}: {e}")
        
        results.append(result)
        
        if (i + 1) % 5 == 0:
            print(f"  Processed {i + 1}/{len(test_points)} points...")
    
    print(f"\n✓ Computed radii for {len(results)} test points\n")
    
    # Compute summary statistics
    radii_vm = np.array([r['radius_variance_mean'] for r in results])
    radii_wg = np.array([r['radius_with_gradient'] for r in results])
    improvements = np.array([r['improvement'] for r in results])
    improvement_pcts = np.array([r['improvement_pct'] for r in results])
    
    summary = {
        'n_points': len(results),
        'function_type': function_type,
        'variance_mean': {
            'mean': float(np.mean(radii_vm)),
            'median': float(np.median(radii_vm)),
            'std': float(np.std(radii_vm)),
            'min': float(np.min(radii_vm)),
            'max': float(np.max(radii_vm)),
            'certified': int(np.sum(radii_vm > 0))
        },
        'with_gradient': {
            'mean': float(np.mean(radii_wg)),
            'median': float(np.median(radii_wg)),
            'std': float(np.std(radii_wg)),
            'min': float(np.min(radii_wg)),
            'max': float(np.max(radii_wg)),
            'certified': int(np.sum(radii_wg > 0))
        },
        'improvement': {
            'mean': float(np.mean(improvements)),
            'median': float(np.median(improvements)),
            'mean_pct': float(np.mean(improvement_pcts)),
            'median_pct': float(np.median(improvement_pcts)),
            'gradient_wins': int(np.sum(radii_wg > radii_vm)),
            'ties': int(np.sum(radii_wg == radii_vm)),
            'variance_mean_wins': int(np.sum(radii_vm > radii_wg))
        }
    }
    
    print_summary(summary)
    
    return results, summary


def create_comparison_plot(results: List[Dict], output_path: str, title_suffix: str = ""):
    """Create comparison visualization."""
    radii_vm = np.array([r['radius_variance_mean'] for r in results])
    radii_wg = np.array([r['radius_with_gradient'] for r in results])
    improvements = radii_wg - radii_vm
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    
    # Plot 1: Scatter (with gradient vs variance+mean)
    ax = axes[0, 0]
    ax.scatter(radii_vm, radii_wg, alpha=0.5, s=30)
    max_val = max(np.max(radii_vm), np.max(radii_wg)) if len(radii_vm) > 0 else 1.0
    if max_val > 0:
        ax.plot([0, max_val], [0, max_val], 'r--', label='y=x (no improvement)')
    ax.set_xlabel('Radius (E, C) + M', fontsize=12)
    ax.set_ylabel('Radius (E, C, G) + M', fontsize=12)
    ax.set_title('Certified Radius Comparison', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Histogram of improvements
    ax = axes[0, 1]
    if len(improvements) > 0:
        ax.hist(improvements, bins=30, alpha=0.7, edgecolor='black')
        ax.axvline(0, color='red', linestyle='--', linewidth=2, label='No improvement')
        if np.mean(improvements) != 0:
            ax.axvline(np.mean(improvements), color='green', linestyle='-', linewidth=2,
                       label=f'Mean = {np.mean(improvements):.6f}')
    ax.set_xlabel('Improvement (With Gradient - Variance+Mean)', fontsize=12)
    ax.set_ylabel('Count', fontsize=12)
    ax.set_title('Distribution of Improvements', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    # Plot 3: Box plot comparison
    ax = axes[1, 0]
    if len(radii_vm) > 0:
        ax.boxplot([radii_vm, radii_wg], labels=['(E, C) + M', '(E, C, G) + M'])
    ax.set_ylabel('Certified Radius', fontsize=12)
    ax.set_title('Radius Distribution Comparison', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Plot 4: CDF comparison
    ax = axes[1, 1]
    if len(radii_vm) > 0:
        sorted_vm = np.sort(radii_vm)
        sorted_wg = np.sort(radii_wg)
        cdf = np.arange(1, len(sorted_vm) + 1) / len(sorted_vm)
        ax.plot(sorted_vm, cdf, label='(E, C) + M', linewidth=2)
        ax.plot(sorted_wg, cdf, label='(E, C, G) + M', linewidth=2)
    ax.set_xlabel('Certified Radius', fontsize=12)
    ax.set_ylabel('CDF', fontsize=12)
    ax.set_title('Cumulative Distribution', fontsize=14, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.suptitle(f'Gradient Information Impact{title_suffix}', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"✓ Saved plot to: {output_path}\n")
    plt.close()


def save_results(results: List[Dict], summary: Dict, output_path: str, params: Dict):
    """Save results to JSON file."""
    output_data = {
        'timestamp': datetime.now().isoformat(),
        'parameters': params,
        'summary': summary,
        'results': results
    }
    
    with open(output_path, 'w') as f:
        json.dump(output_data, f, indent=2)
    
    print(f"✓ Saved results to: {output_path}\n")


def main():
    parser = argparse.ArgumentParser(
        description="Compare (E, C) + M vs (E, C, G) + M certifiers"
    )
    
    # Mode selection
    parser.add_argument(
        "--mode",
        type=str,
        choices=['precomputed', 'synthetic'],
        default='precomputed',
        help="Comparison mode: 'precomputed' (MNIST data) or 'synthetic' (synthetic functions)"
    )
    
    # Precomputed mode arguments
    parser.add_argument(
        "--variance_gradient",
        type=str,
        default=None,
        help="Path to variance+gradient estimates JSON (for precomputed mode, used for (E, C, G)+M)"
    )
    parser.add_argument(
        "--variance_gradient_ec",
        type=str,
        default=None,
        help="Optional: Separate path for (E, C)+M estimates (if not provided, uses --variance_gradient)"
    )
    parser.add_argument(
        "--eps_y_deg",
        type=float,
        default=10.0,
        help="Output tolerance in degrees (for precomputed mode, default: 10.0)"
    )
    parser.add_argument(
        "--N",
        type=int,
        default=10000,
        help="Sample size to use (for precomputed mode, default: 10000)"
    )
    parser.add_argument(
        "--trial",
        type=int,
        default=0,
        help="Trial index (for precomputed mode, default: 0)"
    )
    parser.add_argument(
        "--ci_type",
        type=str,
        choices=['analytical', 'bootstrap'],
        default='analytical',
        help="Confidence interval type (default: analytical)"
    )
    
    # Synthetic mode arguments
    parser.add_argument(
        "--function",
        type=str,
        choices=['bounded_quadratic', 'bounded_linear', 'bounded_sine', 'bounded_slice'],
        default='bounded_quadratic',
        help="Synthetic function type (for synthetic mode)"
    )
    parser.add_argument(
        "--sigma",
        type=float,
        default=0.5,
        help="Noise standard deviation (for synthetic mode, default: 0.5)"
    )
    parser.add_argument(
        "--eps_y",
        type=float,
        default=0.5,
        help="Output tolerance (for synthetic mode, default: 0.5)"
    )
    parser.add_argument(
        "--M",
        type=float,
        default=1.0,
        help="Function bound M (for synthetic mode, default: 1.0)"
    )
    parser.add_argument(
        "--n_test_points",
        type=int,
        default=10,
        help="Number of test points (for synthetic mode, default: 10)"
    )
    parser.add_argument(
        "--N_samples",
        type=int,
        default=10000,
        help="Number of samples for estimation (for synthetic mode, default: 10000)"
    )
    parser.add_argument(
        "--seed",
        type=int,
        default=42,
        help="Random seed (for synthetic mode, default: 42)"
    )
    parser.add_argument(
        "--compute_true_radius",
        action='store_true',
        help="Compute true radius for comparison (synthetic mode only, slow)"
    )
    
    # Output arguments
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Output JSON file path (auto-generated if not specified)"
    )
    parser.add_argument(
        "--plot",
        action='store_true',
        help="Create comparison plot"
    )
    
    args = parser.parse_args()
    
    if args.mode == 'precomputed':
        if args.variance_gradient is None:
            parser.error("--variance_gradient is required for precomputed mode")
        
        print(f"Loading data from: {args.variance_gradient}")
        data = load_variance_gradient_estimates(args.variance_gradient)
        
        # Load separate data for (E, C)+M if provided
        data_ec = None
        if args.variance_gradient_ec:
            print(f"Loading (E, C)+M data from: {args.variance_gradient_ec}")
            data_ec = load_variance_gradient_estimates(args.variance_gradient_ec)
        
        results, summary = compute_radii_comparison_precomputed(
            data,
            args.eps_y_deg,
            args.N,
            args.trial,
            args.ci_type,
            confidence=0.95,
            variance_gradient_data_ec=data_ec
        )
        
        # Save results
        if args.output:
            output_json = args.output
        else:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            sigma = data['parameters']['sigma']
            output_json = f"comparison_vm_vs_wg_mnist_sigma{sigma}_eps{args.eps_y_deg}deg_{timestamp}.json"
        
        params = {
            'mode': 'precomputed',
            'variance_gradient_file': args.variance_gradient,
            'variance_gradient_file_ec': args.variance_gradient_ec,
            'sigma': data['parameters']['sigma'],
            'eps_y_deg': args.eps_y_deg,
            'eps_y_rad': np.radians(args.eps_y_deg),
            'N': args.N,
            'trial': args.trial,
            'ci_type': args.ci_type
        }
        
        save_results(results, summary, output_json, params)
        
        # Create plot if requested
        if args.plot:
            plot_path = output_json.replace('.json', '.png')
            create_comparison_plot(results, plot_path, f" (MNIST, σ={data['parameters']['sigma']}, ε={args.eps_y_deg}°)")
    
    else:  # synthetic mode
        results, summary = compare_on_synthetic_function(
            args.function,
            args.sigma,
            args.eps_y,
            args.M,
            args.n_test_points,
            args.N_samples,
            args.seed,
            args.compute_true_radius
        )
        
        # Save results
        if args.output:
            output_json = args.output
        else:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            output_json = f"comparison_vm_vs_wg_{args.function}_sigma{args.sigma}_eps{args.eps_y}_{timestamp}.json"
        
        params = {
            'mode': 'synthetic',
            'function_type': args.function,
            'sigma': args.sigma,
            'eps_y': args.eps_y,
            'M': args.M,
            'n_test_points': args.n_test_points,
            'N_samples': args.N_samples,
            'seed': args.seed
        }
        
        save_results(results, summary, output_json, params)
        
        # Create plot if requested
        if args.plot:
            plot_path = output_json.replace('.json', '.png')
            create_comparison_plot(results, plot_path, f" ({args.function}, σ={args.sigma}, ε={args.eps_y})")
    
    print("="*80)
    print("COMPLETE!")
    print("="*80)


if __name__ == "__main__":
    main()

