"""
Create comparison plots showing all datasets side-by-side for each attack strength.
"""
import json
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path

# Configuration
DATASETS = ['iris', 'seeds', 'wine']
ATTACK_EPSILONS = [0.1, 0.2, 0.5]
RESULTS_BASE = Path('results/adversarial_robustness')
OUTPUT_DIR = Path('results/adversarial_robustness/comparison')

OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Load all results
all_data = {}
for dataset in DATASETS:
    results_file = RESULTS_BASE / dataset / 'adversarial_robustness_results.json'
    with open(results_file, 'r') as f:
        all_data[dataset] = json.load(f)

# Create one plot per attack strength
for attack_eps in ATTACK_EPSILONS:
    num_datasets = len(DATASETS)
    fig, axes = plt.subplots(1, num_datasets, figsize=(6*num_datasets, 5))
    fig.suptitle(f'Adversarial Robustness vs Number of Models (η={attack_eps})', 
                 fontsize=24)
    
    # Handle single dataset case
    if num_datasets == 1:
        axes = [axes]
    
    for idx, dataset in enumerate(DATASETS):
        ax = axes[idx]
        
        # Get data for this attack strength
        key = f'epsilon_{attack_eps}'
        data = all_data[dataset][key]
        results = data['results_by_pool_size']
        
        pool_sizes = [r['pool_size'] for r in results]
        mean_accs = [r['mean_best_accuracy'] for r in results]
        std_accs = [r['std_best_accuracy'] for r in results]
        
        # Plot with error bars
        ax.plot(pool_sizes, mean_accs, marker='o', linewidth=2.5, markersize=8,
                color='#ff8c00', label='Best Model Accuracy')
        ax.fill_between(pool_sizes,
                        np.array(mean_accs) - np.array(std_accs),
                        np.array(mean_accs) + np.array(std_accs),
                        alpha=0.3, color='#ff8c00')
        
        # Reference lines
        ref_adv = data['reference_adversarial_accuracy']
        
        ax.axhline(y=ref_adv, color='red', linestyle='--', linewidth=1.5,
                  alpha=0.7, label=f'Reference Adv: {ref_adv:.3f}')
        
        # Formatting
        if idx == 1:  # Only show xlabel on middle subplot
            ax.set_xlabel('Number of Models', fontsize=18)
        if idx == 0:  # Only show ylabel on first subplot
            ax.set_ylabel('Best Adversarial Accuracy', fontsize=18)
        ax.set_title(f'{dataset}', fontsize=20)
        ax.grid(True, alpha=0.3, linewidth=1)
        ax.legend(fontsize=14, loc='lower right')
        ax.tick_params(axis='both', which='major', labelsize=14)
        
        # Set x-axis to start at 0
        ax.set_xlim(left=0)
    
    plt.tight_layout()
    output_path = OUTPUT_DIR / f'comparison_eta_{attack_eps}.png'
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    print(f'✓ Saved: {output_path}')
    plt.close()

print('\nAll comparison plots generated!')
print(f'Location: {OUTPUT_DIR}')
