"""
Privacy Comparison Plots - 3 Datasets Side-by-Side
===================================================

Creates a single plot with iris, seeds, and wine privacy results side-by-side.

Input: results/privacy_experiment/{dataset}/privacy_results.json
Output: results/privacy_experiment/privacy_comparison_3datasets.png
"""

import os
import json
import numpy as np
import matplotlib.pyplot as plt


# Configuration
DATASETS = ['iris', 'seeds', 'wine']
RESULTS_BASE = 'results/privacy_experiment'
OUTPUT_PATH = 'results/privacy_experiment/privacy_comparison_3datasets.png'

# Plot styling
COLORS = {
    'mia': '#ff8c00',  # Orange for MIA
    'baseline': '#95a5a6',  # Gray for baseline (0.5)
}


def load_privacy_results(dataset):
    """Load privacy results for a dataset."""
    results_path = os.path.join(RESULTS_BASE, dataset, 'attack_results.json')
    
    if not os.path.exists(results_path):
        print(f"Warning: No results found for {dataset} at {results_path}")
        return None
    
    with open(results_path, 'r') as f:
        data = json.load(f)
    
    return data


def plot_privacy_comparison():
    """Create side-by-side comparison plot for 3 datasets."""
    
    num_datasets = len(DATASETS)
    fig, axes = plt.subplots(1, num_datasets, figsize=(6*num_datasets, 5))
    fig.suptitle('Membership Inference Attack Error vs Ensemble Size', fontsize=24, y=1.02)
    
    # Handle single dataset case
    if num_datasets == 1:
        axes = [axes]
    
    for idx, dataset in enumerate(DATASETS):
        ax = axes[idx]
        
        # Load results
        results = load_privacy_results(dataset)
        
        if results is None:
            ax.text(0.5, 0.5, f'No data for {dataset}', 
                   ha='center', va='center', transform=ax.transAxes)
            ax.set_title(dataset.capitalize())
            continue
        
        # Extract data - results has "yeom" key at top level
        if 'yeom' not in results:
            ax.text(0.5, 0.5, f'No yeom data for {dataset}', 
                   ha='center', va='center', transform=ax.transAxes)
            ax.set_title(dataset.capitalize())
            continue
        
        yeom_results = results['yeom']
        ensemble_sizes = sorted([int(k) for k in yeom_results.keys()])
        
        mean_accs = []
        std_accs = []
        
        for size in ensemble_sizes:
            size_key = str(size)
            if size_key in yeom_results:
                accs = yeom_results[size_key]['accuracies']
                mean_accs.append(np.mean(accs))
                std_accs.append(np.std(accs))
        
        # Check if we have valid data
        if len(mean_accs) == 0:
            ax.text(0.5, 0.5, f'No valid data for {dataset}', 
                   ha='center', va='center', transform=ax.transAxes)
            ax.set_title(dataset.capitalize())
            continue
        
        # Convert to MIA Error (1 - accuracy) - higher is better privacy
        mean_errors = [1 - acc for acc in mean_accs]
        
        # Plot MIA error with error bars
        ax.plot(ensemble_sizes, mean_errors, 
               color=COLORS['mia'], linewidth=2, label='MIA Error')
        ax.fill_between(ensemble_sizes, 
                       np.array(mean_errors) - np.array(std_accs),
                       np.array(mean_errors) + np.array(std_accs),
                       color=COLORS['mia'], alpha=0.2)
        
        # Baseline (random guessing = 0.5)
        ax.axhline(y=0.5, color=COLORS['baseline'], linestyle='--', 
                  linewidth=1.5, label='Random Guessing', alpha=0.7)
        
        # Formatting
        if idx == 1:  # Only show xlabel on middle subplot
            ax.set_xlabel('Number of Models', fontsize=18)
        if idx == 0:  # Only show ylabel on first subplot
            ax.set_ylabel('MIA Error', fontsize=18)
        ax.set_title(dataset, fontsize=20)
        ax.grid(True, alpha=0.3, linestyle='--')
        ax.legend(fontsize=14, loc='best')
        ax.tick_params(axis='both', which='major', labelsize=14)
        
        # Set x-axis to start at 0
        ax.set_xlim(left=0)
        
        # Set y-axis limits (let it auto-scale but ensure 0.5 is visible)
        y_min = min(mean_errors) - 0.05
        y_max = max(mean_errors) + 0.05
        ax.set_ylim([max(0.0, y_min), min(1.0, y_max)])
    
    plt.tight_layout()
    
    # Save plot
    os.makedirs(os.path.dirname(OUTPUT_PATH), exist_ok=True)
    plt.savefig(OUTPUT_PATH, dpi=300, bbox_inches='tight')
    print(f"\n✓ Privacy comparison plot saved to: {OUTPUT_PATH}")
    
    plt.close()


def main():
    """Main execution."""
    print("=" * 80)
    print("Privacy Comparison Plot Generator (3 Datasets)")
    print("=" * 80)
    
    print(f"\nDatasets: {', '.join(DATASETS)}")
    print(f"Results directory: {RESULTS_BASE}")
    print(f"Output file: {OUTPUT_PATH}")
    
    plot_privacy_comparison()
    
    print("\n" + "=" * 80)
    print("Privacy Comparison Plot Generation Complete!")
    print("=" * 80)


if __name__ == "__main__":
    main()
