"""
Combined Privacy & Robustness Plots
====================================

Combines results from:
- results/privacy_experiment/{dataset}/attack_results.json
- results/adversarial_robustness/{dataset}/adversarial_robustness_results.json

Creates dual-axis plots showing both privacy and robustness trends.
"""

import os
import json
import numpy as np
import matplotlib.pyplot as plt


# Configuration
DATASETS = ['iris', 'seeds', 'wine']
PRIVACY_BASE = 'results/privacy_experiment'
ROBUSTNESS_BASE = 'results/adversarial_robustness'
OUTPUT_DIR = 'results/combined_plots'

# Attack strengths to plot
ATTACK_EPSILONS = [0.1, 0.2, 0.5]


def load_privacy_data(dataset):
    """Load privacy MIA results."""
    path = os.path.join(PRIVACY_BASE, dataset, 'attack_results.json')
    if not os.path.exists(path):
        return None
    
    with open(path, 'r') as f:
        data = json.load(f)
    
    if 'yeom' not in data:
        return None
    
    yeom_results = data['yeom']
    sizes = sorted([int(k) for k in yeom_results.keys()])
    
    mia_accs = []
    mia_stds = []
    for size in sizes:
        accs = yeom_results[str(size)]['accuracies']
        mia_accs.append(np.mean(accs))
        mia_stds.append(np.std(accs))
    
    return {
        'sizes': sizes,
        'mia_accuracy': mia_accs,
        'mia_std': mia_stds
    }


def load_robustness_data(dataset):
    """Load adversarial robustness results."""
    path = os.path.join(ROBUSTNESS_BASE, dataset, 'adversarial_robustness_results.json')
    if not os.path.exists(path):
        return None
    
    with open(path, 'r') as f:
        data = json.load(f)
    
    results = {}
    for eps_key, eps_data in data.items():
        eps = float(eps_key.replace('epsilon_', ''))
        
        # Extract from results_by_pool_size
        if 'results_by_pool_size' not in eps_data:
            continue
        
        eps_results = eps_data['results_by_pool_size']
        sizes = [r['pool_size'] for r in eps_results]
        means = [r['mean_best_accuracy'] for r in eps_results]
        stds = [r['std_best_accuracy'] for r in eps_results]
        
        results[eps] = {
            'sizes': sizes,
            'accuracy': means,
            'std': stds
        }
    
    return results


def plot_combined_3datasets():
    """Create 3 plots (one per attack strength) with 3 datasets side-by-side as scatter plots."""
    
    for attack_eps in ATTACK_EPSILONS:
        num_datasets = len(DATASETS)
        fig, axes = plt.subplots(1, num_datasets, figsize=(6*num_datasets, 5))
        fig.suptitle(f'Robustness-Privacy Trade-off (η = {attack_eps})', fontsize=24, y=1.02)
        
        # Handle single dataset case
        if num_datasets == 1:
            axes = [axes]
        
        # Store all scatter plots to create single shared colorbar
        scatters = []
        
        for idx, dataset in enumerate(DATASETS):
            ax = axes[idx]
            
            # Load data
            privacy = load_privacy_data(dataset)
            robustness = load_robustness_data(dataset)
            
            if privacy is None or robustness is None or attack_eps not in robustness:
                ax.text(0.5, 0.5, f'No data for {dataset}', 
                        ha='center', va='center', transform=ax.transAxes)
                ax.set_title(dataset.capitalize())
                continue
            
            # Prepare data for scatter plot
            # X: 1 - MIA accuracy (privacy protection, higher is better)
            # Y: Adversarial accuracy (robustness)
            # Color: Pool size
            
            x_privacy = [1 - mia for mia in privacy['mia_accuracy']]
            y_robustness = robustness[attack_eps]['accuracy']
            pool_sizes = robustness[attack_eps]['sizes']
            
            # Create scatter plot with colormap
            scatter = ax.scatter(x_privacy, y_robustness, 
                               c=pool_sizes, cmap='viridis', 
                               s=100, alpha=0.7)
            scatters.append(scatter)
            
            # Labels and formatting
            if idx == 0:
                ax.set_ylabel('Best Adversarial Accuracy', fontsize=18)
            if idx == 1:  # Middle subplot gets x-label
                ax.set_xlabel('MIA Error (1 - MIA Accuracy)', fontsize=18)
            ax.set_title(dataset, fontsize=20)
            ax.grid(True, alpha=0.3, linestyle='--')
            ax.tick_params(axis='both', which='major', labelsize=14)
        
        # Add single shared colorbar on the right
        if scatters:
            # Use the last subplot's position to place colorbar
            fig.subplots_adjust(right=0.9)
            cbar_ax = fig.add_axes([0.92, 0.15, 0.01, 0.7])  # [left, bottom, width, height]
            cbar = fig.colorbar(scatters[-1], cax=cbar_ax)
            cbar.set_label('Number of Models', fontsize=14)
            cbar.ax.tick_params(labelsize=12)
        
        # Save plot
        os.makedirs(OUTPUT_DIR, exist_ok=True)
        output_path = os.path.join(OUTPUT_DIR, f'combined_privacy_robustness_eta_{attack_eps}.png')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"✓ Saved: {output_path}")
        
        plt.close()


def main():
    print("=" * 80)
    print("Combined Privacy & Robustness Plot Generator")
    print("=" * 80)
    
    print(f"\nDatasets: {', '.join(DATASETS)}")
    print(f"Attack strengths: {ATTACK_EPSILONS}")
    print(f"Privacy source: {PRIVACY_BASE}")
    print(f"Robustness source: {ROBUSTNESS_BASE}")
    print(f"Output directory: {OUTPUT_DIR}")
    
    plot_combined_3datasets()
    
    print("\n" + "=" * 80)
    print("Combined Plot Generation Complete!")
    print("=" * 80)


if __name__ == "__main__":
    main()
