#!/usr/bin/env python3
"""
Compute Certified Radii using BoundedCertifier from Pre-computed Estimates

This script loads the pre-computed variance and gradient norm estimates from
mnist_rotation_full_certification.py (saved around Nov 6-7, 2025) and uses
the NEW theoretically-founded BoundedCertifier to compute certified radii.

It then compares these radii with alpha-trimming results for fair evaluation.

Usage:
    # Compare bounded certifier vs alpha-trimming at sigma=0.12, eps_y=10 deg
    python scripts/compute_bounded_radii_from_estimates.py \\
        --variance_gradient mnist_rotation_full_cert_n100_20251106_033225.json \\
        --alpha_trimming mnist_alpha_trimming_n100_20251106_173454.json \\
        --eps_y_deg 10.0 \\
        --N 10000
    
    # Generate comparison for multiple sigma values
    python scripts/compute_bounded_radii_from_estimates.py \\
        --variance_gradient "mnist_rotation_full_cert_n100_*.json" \\
        --alpha_trimming "mnist_alpha_trimming_n100_*.json" \\
        --eps_y_deg 10.0 \\
        --N 10000 \\
        --output bounded_vs_alpha_comparison.json
"""

import json
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
import sys
from datetime import datetime
from typing import Dict, List, Tuple
import glob

# Add src path
sys.path.append(str(Path(__file__).resolve().parent.parent / "src"))

from alpha_smoothing_repro.certify.bounded_fn_certifier import BoundedCertifier


def load_variance_gradient_estimates(json_path: str) -> Dict:
    """Load pre-computed variance and gradient norm estimates."""
    with open(json_path, 'r') as f:
        data = json.load(f)
    return data


def load_alpha_trimming_results(json_path: str) -> Dict:
    """Load alpha-trimming certification results."""
    with open(json_path, 'r') as f:
        data = json.load(f)
    return data


def compute_bounded_radii(
    variance_gradient_data: Dict,
    eps_y_deg: float,
    N: int,
    trial_idx: int = 0,
    ci_type: str = 'analytical',
    confidence: float = 0.95
) -> List[Dict]:
    """
    Compute certified radii using BoundedCertifier from saved estimates.
    
    Args:
        variance_gradient_data: Data from mnist_rotation_full_certification.py
        eps_y_deg: Output tolerance in degrees
        N: Sample size to use (e.g., 1000, 10000)
        trial_idx: Which trial to use (default 0)
        ci_type: 'analytical' or 'bootstrap'
        confidence: Confidence level
        
    Returns:
        List of dicts with certified radii and metadata
    """
    # Extract parameters
    sigma = variance_gradient_data['parameters']['sigma']
    eps_y_rad = np.radians(eps_y_deg)
    
    # Initialize BoundedCertifier with M = π (angles in radians)
    M = np.pi
    certifier = BoundedCertifier(
        sigma=sigma, 
        M=M, 
        eps_y=eps_y_rad, 
        confidence=confidence,
        quadrature_points=60  # Higher for better accuracy with bounded functions
    )
    
    print(f"\n{'='*80}")
    print(f"Computing Bounded Certifier Radii")
    print(f"{'='*80}")
    print(f"Parameters:")
    print(f"  σ = {sigma}")
    print(f"  ε_y = {eps_y_deg}° = {eps_y_rad:.4f} rad")
    print(f"  M = π = {M:.4f} rad")
    print(f"  N = {N}")
    print(f"  Trial = {trial_idx}")
    print(f"  CI type = {ci_type}")
    print(f"  Confidence = {confidence}")
    print(f"{'='*80}\n")
    
    results = []
    samples = variance_gradient_data['samples']
    
    print(f"Processing {len(samples)} samples...")
    
    for i, sample in enumerate(samples):
        # Get estimates for this N and trial
        if str(N) not in sample['results_by_N']:
            print(f"⚠️  Sample {i}: N={N} not found, skipping")
            continue
        
        trials = sample['results_by_N'][str(N)]
        if trial_idx >= len(trials):
            print(f"⚠️  Sample {i}: trial {trial_idx} not found, skipping")
            continue
        
        estimates = trials[trial_idx]
        
        # Extract upper bounds for variance and gradient norm
        if ci_type == 'analytical':
            C_ucb = estimates['C_upper_analytical']
        else:
            C_ucb = estimates.get('C_upper_bootstrap', estimates['C_upper_analytical'])
        
        G_ucb = estimates['G_norm_upper']
        
        # Compute certified radius
        radius = certifier.certify_point_from_estimates(C_ucb, G_ucb)
        
        # Collect results
        result = {
            'sample_idx': sample.get('sample_idx', i),
            'test_dataset_idx': sample.get('test_dataset_idx', sample.get('image_idx', i)),
            'digit_label': sample.get('digit_label', None),
            'radius': radius,
            'C_hat': estimates['C_hat'],
            'C_ucb': C_ucb,
            'G_norm_hat': estimates['G_norm_hat'],
            'G_ucb': G_ucb,
            'clean_pred_deg': sample.get('clean_pred_deg', None),
            'N': N,
            'trial': trial_idx
        }
        
        results.append(result)
        
        if (i + 1) % 10 == 0:
            print(f"  Processed {i + 1}/{len(samples)} samples...")
    
    print(f"✓ Computed {len(results)} certified radii\n")
    
    return results


def match_alpha_trimming_results(
    alpha_trimming_data: Dict,
    bounded_results: List[Dict]
) -> List[Dict]:
    """
    Match alpha-trimming results to bounded certifier results by sample index.
    
    Args:
        alpha_trimming_data: Data from alpha-trimming certification
        bounded_results: Results from bounded certifier
        
    Returns:
        List of dicts with matched results
    """
    # Create lookup by test_dataset_idx
    alpha_lookup = {}
    
    # Handle both 'results' and 'samples' field names
    alpha_samples = alpha_trimming_data.get('samples', alpha_trimming_data.get('results', []))
    
    for result in alpha_samples:
        idx = result.get('test_dataset_idx', result.get('test_idx', result.get('image_idx', None)))
        if idx is not None:
            alpha_lookup[idx] = result['certified_radius']
    
    # Match with bounded results
    matched = []
    for bounded_res in bounded_results:
        test_idx = bounded_res['test_dataset_idx']
        
        if test_idx in alpha_lookup:
            matched_res = bounded_res.copy()
            matched_res['alpha_trimming_radius'] = alpha_lookup[test_idx]
            matched.append(matched_res)
    
    print(f"✓ Matched {len(matched)}/{len(bounded_results)} samples with alpha-trimming results\n")
    
    return matched


def compute_statistics(matched_results: List[Dict]) -> Dict:
    """Compute comparison statistics."""
    bounded_radii = np.array([r['radius'] for r in matched_results])
    alpha_radii = np.array([r['alpha_trimming_radius'] for r in matched_results])
    
    stats = {
        'n_samples': len(matched_results),
        'bounded_certifier': {
            'mean': float(np.mean(bounded_radii)),
            'median': float(np.median(bounded_radii)),
            'std': float(np.std(bounded_radii)),
            'min': float(np.min(bounded_radii)),
            'max': float(np.max(bounded_radii)),
            'q25': float(np.percentile(bounded_radii, 25)),
            'q75': float(np.percentile(bounded_radii, 75)),
        },
        'alpha_trimming': {
            'mean': float(np.mean(alpha_radii)),
            'median': float(np.median(alpha_radii)),
            'std': float(np.std(alpha_radii)),
            'min': float(np.min(alpha_radii)),
            'max': float(np.max(alpha_radii)),
            'q25': float(np.percentile(alpha_radii, 25)),
            'q75': float(np.percentile(alpha_radii, 75)),
        },
        'comparison': {
            'mean_difference': float(np.mean(bounded_radii - alpha_radii)),
            'median_difference': float(np.median(bounded_radii - alpha_radii)),
            'bounded_wins': int(np.sum(bounded_radii > alpha_radii)),
            'alpha_wins': int(np.sum(alpha_radii > bounded_radii)),
            'ties': int(np.sum(bounded_radii == alpha_radii)),
            'correlation': float(np.corrcoef(bounded_radii, alpha_radii)[0, 1]),
        }
    }
    
    return stats


def print_comparison_summary(stats: Dict, sigma: float, eps_y_deg: float):
    """Print comparison summary to console."""
    print("\n" + "="*80)
    print("COMPARISON SUMMARY")
    print("="*80)
    print(f"Parameters: σ={sigma}, ε_y={eps_y_deg}°")
    print(f"Samples: {stats['n_samples']}")
    print()
    
    print("Bounded Certifier (New Method):")
    print(f"  Mean:   {stats['bounded_certifier']['mean']:.4f}")
    print(f"  Median: {stats['bounded_certifier']['median']:.4f}")
    print(f"  Std:    {stats['bounded_certifier']['std']:.4f}")
    print(f"  Range:  [{stats['bounded_certifier']['min']:.4f}, {stats['bounded_certifier']['max']:.4f}]")
    print()
    
    print("α-Trimming (Rekavandi et al., NeurIPS 2024):")
    print(f"  Mean:   {stats['alpha_trimming']['mean']:.4f}")
    print(f"  Median: {stats['alpha_trimming']['median']:.4f}")
    print(f"  Std:    {stats['alpha_trimming']['std']:.4f}")
    print(f"  Range:  [{stats['alpha_trimming']['min']:.4f}, {stats['alpha_trimming']['max']:.4f}]")
    print()
    
    print("Comparison:")
    print(f"  Mean difference (Bounded - Alpha): {stats['comparison']['mean_difference']:.4f}")
    print(f"  Median difference: {stats['comparison']['median_difference']:.4f}")
    print(f"  Bounded wins: {stats['comparison']['bounded_wins']}/{stats['n_samples']} ({100*stats['comparison']['bounded_wins']/stats['n_samples']:.1f}%)")
    print(f"  Alpha wins:   {stats['comparison']['alpha_wins']}/{stats['n_samples']} ({100*stats['comparison']['alpha_wins']/stats['n_samples']:.1f}%)")
    print(f"  Correlation:  {stats['comparison']['correlation']:.4f}")
    print("="*80 + "\n")


def create_comparison_plots(
    matched_results: List[Dict],
    stats: Dict,
    sigma: float,
    eps_y_deg: float,
    output_prefix: str = "bounded_vs_alpha"
):
    """Create comparison visualization plots."""
    bounded_radii = np.array([r['radius'] for r in matched_results])
    alpha_radii = np.array([r['alpha_trimming_radius'] for r in matched_results])
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 12))
    
    # Plot 1: Scatter plot
    ax = axes[0, 0]
    ax.scatter(alpha_radii, bounded_radii, alpha=0.5, s=30)
    
    # Add diagonal line
    max_r = max(np.max(alpha_radii), np.max(bounded_radii))
    ax.plot([0, max_r], [0, max_r], 'r--', label='Equal radii', linewidth=2)
    
    ax.set_xlabel('α-Trimming Certified Radius', fontsize=11)
    ax.set_ylabel('Bounded Certifier Radius (New)', fontsize=11)
    ax.set_title(f'Certified Radius Comparison (σ={sigma}, ε_y={eps_y_deg}°)', fontsize=12)
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal')
    
    # Plot 2: Histogram of differences
    ax = axes[0, 1]
    differences = bounded_radii - alpha_radii
    ax.hist(differences, bins=30, alpha=0.7, edgecolor='black')
    ax.axvline(0, color='red', linestyle='--', linewidth=2, label='Zero difference')
    ax.axvline(np.mean(differences), color='blue', linestyle='-', linewidth=2, 
               label=f'Mean = {np.mean(differences):.4f}')
    ax.set_xlabel('Radius Difference (Bounded - Alpha)', fontsize=11)
    ax.set_ylabel('Count', fontsize=11)
    ax.set_title('Distribution of Radius Differences', fontsize=12)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 3: Box plots
    ax = axes[1, 0]
    bp = ax.boxplot([bounded_radii, alpha_radii], 
                     labels=['Bounded Certifier', 'α-Trimming'],
                     showmeans=True,
                     patch_artist=True)
    
    # Color the boxes
    colors = ['lightblue', 'lightcoral']
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
    
    ax.set_ylabel('Certified Radius', fontsize=11)
    ax.set_title('Radius Distribution Comparison', fontsize=12)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Plot 4: CDF comparison
    ax = axes[1, 1]
    
    # Sort for CDF
    bounded_sorted = np.sort(bounded_radii)
    alpha_sorted = np.sort(alpha_radii)
    
    n = len(bounded_radii)
    cdf = np.arange(1, n + 1) / n
    
    ax.plot(bounded_sorted, cdf, label='Bounded Certifier', linewidth=2)
    ax.plot(alpha_sorted, cdf, label='α-Trimming', linewidth=2)
    
    ax.set_xlabel('Certified Radius', fontsize=11)
    ax.set_ylabel('CDF', fontsize=11)
    ax.set_title('Cumulative Distribution Functions', fontsize=12)
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot
    output_file = f"{output_prefix}_sigma{sigma}_eps{eps_y_deg}deg.png"
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    print(f"✓ Saved comparison plot: {output_file}")
    
    plt.close()


def main():
    parser = argparse.ArgumentParser(
        description="Compute bounded certifier radii from pre-computed estimates and compare with alpha-trimming"
    )
    parser.add_argument(
        "--variance_gradient",
        type=str,
        required=True,
        help="Path to variance+gradient estimates JSON (supports glob patterns)"
    )
    parser.add_argument(
        "--alpha_trimming",
        type=str,
        default=None,
        help="Path to alpha-trimming results JSON (supports glob patterns, optional)"
    )
    parser.add_argument(
        "--eps_y_deg",
        type=float,
        default=10.0,
        help="Output tolerance in degrees (default: 10.0)"
    )
    parser.add_argument(
        "--N",
        type=int,
        default=10000,
        help="Sample size to use from estimates (default: 10000)"
    )
    parser.add_argument(
        "--trial",
        type=int,
        default=0,
        help="Trial index to use (default: 0)"
    )
    parser.add_argument(
        "--ci_type",
        type=str,
        choices=['analytical', 'bootstrap'],
        default='analytical',
        help="Confidence interval type (default: analytical)"
    )
    parser.add_argument(
        "--confidence",
        type=float,
        default=0.95,
        help="Confidence level (default: 0.95)"
    )
    parser.add_argument(
        "--output",
        type=str,
        default=None,
        help="Output JSON file path (auto-generated if not provided)"
    )
    parser.add_argument(
        "--output_prefix",
        type=str,
        default="bounded_vs_alpha",
        help="Output file prefix for plots (default: bounded_vs_alpha)"
    )
    
    args = parser.parse_args()
    
    print("="*80)
    print("BOUNDED CERTIFIER FROM PRE-COMPUTED ESTIMATES")
    print("="*80)
    print(f"Variance+Gradient file: {args.variance_gradient}")
    print(f"Alpha-trimming file: {args.alpha_trimming}")
    print(f"Parameters: ε_y={args.eps_y_deg}°, N={args.N}, trial={args.trial}")
    print(f"CI type: {args.ci_type}, confidence={args.confidence}")
    print("="*80)
    
    # Expand glob patterns
    vg_files = glob.glob(args.variance_gradient)
    if not vg_files:
        print(f"❌ No files found matching: {args.variance_gradient}")
        return
    
    if args.alpha_trimming:
        alpha_files = glob.glob(args.alpha_trimming)
        if not alpha_files:
            print(f"⚠️  No alpha-trimming files found matching: {args.alpha_trimming}")
            alpha_files = []
    else:
        alpha_files = []
    
    print(f"\nFound {len(vg_files)} variance+gradient file(s)")
    print(f"Found {len(alpha_files)} alpha-trimming file(s)\n")
    
    # Process each variance+gradient file
    all_results = []
    
    for vg_file in vg_files:
        print(f"\n{'='*80}")
        print(f"Processing: {Path(vg_file).name}")
        print(f"{'='*80}")
        
        # Load variance+gradient data
        vg_data = load_variance_gradient_estimates(vg_file)
        sigma = vg_data['parameters']['sigma']
        
        # Compute bounded certifier radii
        bounded_results = compute_bounded_radii(
            vg_data,
            args.eps_y_deg,
            args.N,
            args.trial,
            args.ci_type,
            args.confidence
        )
        
        if not bounded_results:
            print(f"⚠️  No results computed for {vg_file}, skipping")
            continue
        
        # Try to find matching alpha-trimming file
        matched_alpha_file = None
        if alpha_files:
            # Try to match by sigma value in filename
            for alpha_file in alpha_files:
                alpha_data = load_alpha_trimming_results(alpha_file)
                alpha_sigma = alpha_data['parameters'].get('sigma', None)
                if alpha_sigma is not None and abs(alpha_sigma - sigma) < 0.01:
                    matched_alpha_file = alpha_file
                    break
        
        # Compare with alpha-trimming if available
        if matched_alpha_file:
            print(f"Found matching alpha-trimming file: {Path(matched_alpha_file).name}")
            alpha_data = load_alpha_trimming_results(matched_alpha_file)
            
            # Match results
            matched_results = match_alpha_trimming_results(alpha_data, bounded_results)
            
            if matched_results:
                # Compute statistics
                stats = compute_statistics(matched_results)
                
                # Print summary
                print_comparison_summary(stats, sigma, args.eps_y_deg)
                
                # Create plots
                create_comparison_plots(
                    matched_results,
                    stats,
                    sigma,
                    args.eps_y_deg,
                    args.output_prefix
                )
                
                # Save results
                result_entry = {
                    'variance_gradient_file': vg_file,
                    'alpha_trimming_file': matched_alpha_file,
                    'sigma': sigma,
                    'eps_y_deg': args.eps_y_deg,
                    'N': args.N,
                    'trial': args.trial,
                    'statistics': stats,
                    'matched_samples': matched_results
                }
                all_results.append(result_entry)
        else:
            print("⚠️  No matching alpha-trimming file found")
            print("Computing bounded certifier radii only (no comparison)\n")
            
            # Save bounded results only
            result_entry = {
                'variance_gradient_file': vg_file,
                'alpha_trimming_file': None,
                'sigma': sigma,
                'eps_y_deg': args.eps_y_deg,
                'N': args.N,
                'trial': args.trial,
                'bounded_results': bounded_results
            }
            all_results.append(result_entry)
    
    # Save all results
    if args.output is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = f"bounded_certifier_comparison_{timestamp}.json"
    else:
        output_file = args.output
    
    output_data = {
        'timestamp': datetime.now().isoformat(),
        'parameters': {
            'eps_y_deg': args.eps_y_deg,
            'N': args.N,
            'trial': args.trial,
            'ci_type': args.ci_type,
            'confidence': args.confidence,
        },
        'results': all_results
    }
    
    with open(output_file, 'w') as f:
        json.dump(output_data, f, indent=2)
    
    # Also save a simplified per-point CSV for easy analysis
    if all_results and 'matched_samples' in all_results[0]:
        csv_file = output_file.replace('.json', '_per_point.csv')
        with open(csv_file, 'w') as f:
            # Write header
            f.write('sigma,test_idx,digit_label,bounded_radius,alpha_radius,difference,bounded_wins\n')
            
            # Write data
            for result in all_results:
                if 'matched_samples' in result:
                    sigma = result['sigma']
                    for sample in result['matched_samples']:
                        bounded_r = sample['radius']
                        alpha_r = sample['alpha_trimming_radius']
                        diff = bounded_r - alpha_r
                        wins = 1 if bounded_r > alpha_r else 0
                        f.write(f"{sigma},{sample['test_dataset_idx']},{sample['digit_label']},"
                               f"{bounded_r:.6f},{alpha_r:.6f},{diff:.6f},{wins}\n")
        
        print(f"✓ Saved per-point CSV: {csv_file}")
    
    print("\n" + "="*80)
    print("COMPLETE!")
    print("="*80)
    print(f"✓ Processed {len(all_results)} file(s)")
    print(f"✓ Saved results: {output_file}")
    print("="*80 + "\n")


if __name__ == "__main__":
    main()

