#!/usr/bin/env python3
"""
 CIFAR Optimization Results Analyzer

"""

import json

def analyze_results():
    file_path = "results/training_histories.json"
    
    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
    except FileNotFoundError:
        print(f"File not found: {file_path}")
        return
    except json.JSONDecodeError:
        print("Error reading JSON file")
        return
    
    print("CIFAR OPTIMIZATION RESULTS ANALYSIS")
    print("=" * 60)
    
    summary = {}
    
    for optimizer_name, runs in data.items():
        if not runs:  # Skip empty optimizers
            continue
            
        print(f"\n{optimizer_name} ({len(runs)} runs)")
        print("-" * 40)
        
        best_test_accs = []
        final_test_accs = []
        train_times_hours = []
        epochs_to_90 = []
        epochs_to_75 = []  # Track epochs to 75% test accuracy
        overfitting = []
        
        for i, run in enumerate(runs):
            # Key metrics
            best_test = max(run['test_accuracies'])
            final_test = run['test_accuracies'][-1]
            best_train = max(run['train_accuracies'])
            final_loss = run['train_losses'][-1]
            time_hours = run.get('total_time', 0) / 3600
            
            # Find epoch when train accuracy reaches 90%
            epoch_90 = None
            for epoch, acc in enumerate(run['train_accuracies']):
                if acc >= 90.0:
                    epoch_90 = epoch + 1
                    break
            
            # Find epoch when test accuracy reaches 75%
            epoch_75_test = None
            for epoch, acc in enumerate(run['test_accuracies']):
                if acc >= 75.0:
                    epoch_75_test = epoch + 1
                    break
            
            # Store for averaging
            best_test_accs.append(best_test)
            final_test_accs.append(final_test)
            train_times_hours.append(time_hours)
            if epoch_90:
                epochs_to_90.append(epoch_90)
            if epoch_75_test:
                epochs_to_75.append(epoch_75_test)
            overfitting.append(best_train - best_test)
            
            # Print run details
            print(f"  Run {i+1}: Test={best_test:.2f}% | Final={final_test:.2f}% | "
                  f"Loss={final_loss:.4f} | Time={time_hours:.1f}h | "
                  f"Epoch90={epoch_90 or 'N/A'} | Epoch75%={epoch_75_test or 'N/A'}")
        
        # Calculate averages
        avg_best_test = sum(best_test_accs) / len(best_test_accs)
        avg_final_test = sum(final_test_accs) / len(final_test_accs)
        avg_time = sum(train_times_hours) / len(train_times_hours)
        avg_epochs_90 = sum(epochs_to_90) / len(epochs_to_90) if epochs_to_90 else None
        avg_epochs_75 = sum(epochs_to_75) / len(epochs_to_75) if epochs_to_75 else None
        avg_overfitting = sum(overfitting) / len(overfitting)
        
        # Calculate success rates
        success_rate_90 = len(epochs_to_90) / len(runs) * 100
        success_rate_75 = len(epochs_to_75) / len(runs) * 100
        
        summary[optimizer_name] = {
            'avg_best': avg_best_test,
            'max_best': max(best_test_accs),
            'avg_time': avg_time,
            'avg_epochs_90': avg_epochs_90,
            'avg_epochs_75': avg_epochs_75,
            'success_rate_75': success_rate_75,
            'avg_overfitting': avg_overfitting
        }
        
        print(f"  → Average Best: {avg_best_test:.2f}%")
        print(f"  → Best Single: {max(best_test_accs):.2f}%")
        print(f"  → Avg Time: {avg_time:.1f}h")
        if avg_epochs_90:
            print(f"  → Avg Epochs to 90% train: {avg_epochs_90:.1f} ({success_rate_90:.0f}% success)")
        if avg_epochs_75:
            print(f"  → Avg Epochs to 75% test: {avg_epochs_75:.1f} ({success_rate_75:.0f}% success)")
        else:
            print(f"  → Epochs to 75% test: Never reached (0% success)")
        print(f"  → Overfitting: {avg_overfitting:.2f}%")
    
    # Final ranking
    print("\n" + "=" * 60)
    print("FINAL RANKINGS")
    print("=" * 60)
    
    # By accuracy
    accuracy_ranking = sorted(summary.items(), key=lambda x: x[1]['avg_best'], reverse=True)
    print("\n BY AVERAGE BEST TEST ACCURACY:")
    for i, (name, stats) in enumerate(accuracy_ranking, 1):
        print(f"{i}. {name}: {stats['avg_best']:.2f}%")
    
    # By speed
    speed_ranking = sorted(summary.items(), key=lambda x: x[1]['avg_time'])
    print("\n BY TRAINING SPEED (fastest first):")
    for i, (name, stats) in enumerate(speed_ranking, 1):
        print(f"{i}. {name}: {stats['avg_time']:.1f} hours")
    
    # By convergence to 75% test accuracy
    convergence_75_ranking = sorted(
        [(name, stats) for name, stats in summary.items() if stats['avg_epochs_75'] is not None],
        key=lambda x: x[1]['avg_epochs_75']
    )
    print("\n BY CONVERGENCE TO 75% TEST ACCURACY (fastest first):")
    for i, (name, stats) in enumerate(convergence_75_ranking, 1):
        print(f"{i}. {name}: {stats['avg_epochs_75']:.1f} epochs ({stats['success_rate_75']:.0f}% success)")
    
    # List optimizers that never reached 75%
    never_reached_75 = [name for name, stats in summary.items() if stats['avg_epochs_75'] is None]
    if never_reached_75:
        print(f"\nOptimizers that never reached 75% test accuracy: {', '.join(never_reached_75)}")
    
    # Best single performance
    best_single = max(summary.items(), key=lambda x: x[1]['max_best'])
    print(f"\n BEST SINGLE RUN: {best_single[0]} with {best_single[1]['max_best']:.2f}%")
    
    # Efficiency (accuracy/time)
    efficiency_ranking = sorted(summary.items(), 
                               key=lambda x: x[1]['avg_best'] / x[1]['avg_time'], 
                               reverse=True)
    print("\n💡 BY EFFICIENCY (Accuracy/Time):")
    for i, (name, stats) in enumerate(efficiency_ranking, 1):
        efficiency = stats['avg_best'] / stats['avg_time']
        print(f"{i}. {name}: {efficiency:.2f} acc%/hour")
    
    # Summary statistics table
    print(f"\n" + "=" * 60)
    print("CONVERGENCE SUMMARY TABLE")
    print("=" * 60)
    print(f"{'Optimizer':<12} {'Avg Best%':<10} {'Epochs to 75%':<13} {'Success Rate':<12} {'Avg Time (h)':<12}")
    print("-" * 60)
    
    for name, stats in accuracy_ranking:  # Sort by accuracy
        epochs_75_str = f"{stats['avg_epochs_75']:.1f}" if stats['avg_epochs_75'] else "Never"
        success_rate_str = f"{stats['success_rate_75']:.0f}%" if stats['avg_epochs_75'] else "0%"
        print(f"{name:<12} {stats['avg_best']:<10.2f} {epochs_75_str:<13} {success_rate_str:<12} {stats['avg_time']:<12.1f}")
    
    print(f"\nAnalysis complete! Data from {len([r for runs in data.values() for r in runs])} total runs.")

if __name__ == "__main__":
    analyze_results()