#!/usr/bin/env python3
"""
Analyze Pseudo-True Radius Results

Loads pseudo-radius JSON files and generates:
1. Summary statistics (mean, median, std, min, max)
2. Comparison tables
3. Distribution plots
4. Comparison with certified radii (if provided)
"""

import json
import numpy as np
import argparse
from pathlib import Path
from typing import Dict, List, Optional
import matplotlib.pyplot as plt
from collections import defaultdict

try:
    from tqdm import tqdm
except ImportError:
    def tqdm(iterable, **kwargs):
        return iterable


def load_json(json_path: str) -> Dict:
    """Load JSON file."""
    with open(json_path, 'r') as f:
        return json.load(f)


def load_all_pseudo_radius_files(directory: str = "pseudo_radius_results") -> Dict[str, List[Dict]]:
    """Load all pseudo-radius JSON files, grouped by sigma."""
    directory = Path(directory)
    files = list(directory.glob("pseudo_radius_sigma*.json"))
    files.extend(directory.glob("mnist_pseudo_true_radius_simple_*.json"))
    
    results_by_sigma = defaultdict(list)
    
    for file in files:
        try:
            data = load_json(file)
            sigma = data.get('parameters', {}).get('sigma')
            if sigma is not None:
                results_by_sigma[sigma].append(data)
        except Exception as e:
            print(f"⚠ Warning: Failed to load {file}: {e}")
    
    return dict(results_by_sigma)


def extract_radii(data: Dict) -> np.ndarray:
    """Extract R_true_raw values from results."""
    radii = []
    for result in data.get('results', []):
        r = result.get('R_true_raw')
        if r is not None:
            radii.append(r)
    return np.array(radii)


def print_summary_table(results_by_sigma: Dict[str, List[Dict]]):
    """Print summary statistics table."""
    print("\n" + "="*80)
    print("PSEUDO-TRUE RADIUS SUMMARY STATISTICS")
    print("="*80)
    print(f"{'Sigma':<8} {'N':<6} {'Mean':<12} {'Median':<12} {'Std':<12} {'Min':<12} {'Max':<12}")
    print("-"*80)
    
    for sigma in sorted(results_by_sigma.keys(), key=float):
        all_radii = []
        for data in results_by_sigma[sigma]:
            radii = extract_radii(data)
            all_radii.extend(radii)
        
        if len(all_radii) == 0:
            continue
        
        all_radii = np.array(all_radii)
        print(f"{sigma:<8} {len(all_radii):<6} "
              f"{np.mean(all_radii):<12.6f} {np.median(all_radii):<12.6f} "
              f"{np.std(all_radii):<12.6f} {np.min(all_radii):<12.6f} {np.max(all_radii):<12.6f}")
    
    print("="*80)


def plot_distributions(results_by_sigma: Dict[str, List[Dict]], output_dir: str = "plots"):
    """Plot distribution of pseudo-true radii for each sigma."""
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)
    
    n_sigmas = len(results_by_sigma)
    if n_sigmas == 0:
        print("⚠ No data to plot")
        return
    
    fig, axes = plt.subplots(1, min(n_sigmas, 3), figsize=(5*min(n_sigmas, 3), 4))
    if n_sigmas == 1:
        axes = [axes]
    
    for idx, (sigma, data_list) in enumerate(sorted(results_by_sigma.items(), key=lambda x: float(x[0]))):
        if idx >= 3:
            break
        
        all_radii = []
        for data in data_list:
            radii = extract_radii(data)
            all_radii.extend(radii)
        
        if len(all_radii) == 0:
            continue
        
        axes[idx].hist(all_radii, bins=20, alpha=0.7, edgecolor='black')
        axes[idx].set_xlabel('Pseudo-True Radius (raw pixel space)')
        axes[idx].set_ylabel('Frequency')
        axes[idx].set_title(f'σ = {sigma} (N={len(all_radii)})')
        axes[idx].axvline(np.mean(all_radii), color='red', linestyle='--', label=f'Mean: {np.mean(all_radii):.3f}')
        axes[idx].axvline(np.median(all_radii), color='green', linestyle='--', label=f'Median: {np.median(all_radii):.3f}')
        axes[idx].legend()
        axes[idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    output_file = output_dir / "pseudo_radius_distributions.png"
    plt.savefig(output_file, dpi=150, bbox_inches='tight')
    print(f"✓ Saved distribution plot: {output_file}")
    plt.close()


def compare_with_certified(
    results_by_sigma: Dict[str, List[Dict]],
    certified_files: Optional[List[str]] = None
):
    """Compare pseudo-true radii with certified radii."""
    if certified_files is None:
        return
    
    print("\n" + "="*80)
    print("COMPARISON: PSEUDO-TRUE RADIUS vs CERTIFIED RADIUS")
    print("="*80)
    
    # Load certified radii
    certified_by_sigma = {}
    for file in certified_files:
        try:
            data = load_json(file)
            sigma = data.get('parameters', {}).get('sigma')
            if sigma is None:
                continue
            
            # Extract certified radii (try both methods)
            certified_radii = []
            for result in data.get('results', []):
                # Try (E, C, G) + M first (with_gradient)
                r = result.get('radius_with_gradient')
                if r is None:
                    r = result.get('radius_variance_mean')
                if r is not None:
                    certified_radii.append(r)
            
            if len(certified_radii) > 0:
                if sigma not in certified_by_sigma:
                    certified_by_sigma[sigma] = []
                certified_by_sigma[sigma].extend(certified_radii)
        except Exception as e:
            print(f"⚠ Warning: Failed to load certified file {file}: {e}")
    
    # Compare
    print(f"{'Sigma':<8} {'N':<6} {'Pseudo Mean':<15} {'Certified Mean':<15} {'Ratio':<10} {'Pseudo > Cert':<12}")
    print("-"*80)
    
    for sigma in sorted(set(list(results_by_sigma.keys()) + list(certified_by_sigma.keys())), key=float):
        # Get pseudo radii
        pseudo_radii = []
        for data in results_by_sigma.get(sigma, []):
            radii = extract_radii(data)
            pseudo_radii.extend(radii)
        
        # Get certified radii
        cert_radii = np.array(certified_by_sigma.get(sigma, []))
        
        if len(pseudo_radii) == 0 or len(cert_radii) == 0:
            continue
        
        pseudo_radii = np.array(pseudo_radii)
        cert_radii = np.array(cert_radii)
        
        # Match by sample index if possible (simplified: just compare means)
        pseudo_mean = np.mean(pseudo_radii)
        cert_mean = np.mean(cert_radii)
        ratio = pseudo_mean / cert_mean if cert_mean > 0 else np.nan
        
        # Count how many pseudo > certified (approximate)
        pseudo_gt_cert = np.sum(pseudo_radii > cert_mean) if len(pseudo_radii) > 0 else 0
        
        print(f"{sigma:<8} {len(pseudo_radii):<6} {pseudo_mean:<15.6f} {cert_mean:<15.6f} "
              f"{ratio:<10.2f} {pseudo_gt_cert}/{len(pseudo_radii):<12}")
    
    print("="*80)


def main():
    parser = argparse.ArgumentParser(description='Analyze pseudo-true radius results')
    parser.add_argument('--directory', type=str, default='pseudo_radius_results',
                        help='Directory containing pseudo-radius JSON files')
    parser.add_argument('--certified', type=str, nargs='+', default=None,
                        help='Certified radius JSON files for comparison')
    parser.add_argument('--plot', action='store_true',
                        help='Generate distribution plots')
    parser.add_argument('--output_dir', type=str, default='plots',
                        help='Output directory for plots')
    
    args = parser.parse_args()
    
    print("Loading pseudo-radius results...")
    results_by_sigma = load_all_pseudo_radius_files(args.directory)
    
    if len(results_by_sigma) == 0:
        print(f"⚠ No pseudo-radius files found in {args.directory}")
        return
    
    print(f"✓ Loaded results for {len(results_by_sigma)} sigma value(s)")
    for sigma, data_list in results_by_sigma.items():
        total_samples = sum(len(data.get('results', [])) for data in data_list)
        print(f"  σ = {sigma}: {len(data_list)} file(s), {total_samples} total samples")
    
    # Print summary
    print_summary_table(results_by_sigma)
    
    # Compare with certified if provided
    if args.certified:
        compare_with_certified(results_by_sigma, args.certified)
    
    # Plot distributions
    if args.plot:
        plot_distributions(results_by_sigma, args.output_dir)


if __name__ == '__main__':
    main()

