"""
Privacy leakage experiment across ALL datasets:
Tests membership inference attacks on Rashomon ensembles for multiple datasets.
"""
import os
import sys
import json
import torch.nn as nn
from pathlib import Path

# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from awp import generate_rashomon_set, MLPBinary2Logits, TrainConfig
from dataset import read_dataset
from membership_inference import (
    membership_inference_experiment,
    save_attack_results,
    plot_attack_results,
)

# ========== CONFIGURATION ==========
DATASETS = ['iris', 'seeds', 'wine', 'compas']  # Updated to match current experiments
NUM_RASHOMON_MODELS = 100
EPSILON = 0.30  # Collaborator's relative epsilon
RELATIVE_EPSILON = True  # Use multiplicative epsilon
OPT_NUM_ATTEMPTS = 30  # Random restarts for base model
ENSEMBLE_SIZES = list(range(1, 51))  # Evaluate all pool sizes 1-50
NUM_TRIALS_PER_SIZE = 20  # Increased trials for smoother curves
NUM_SHADOW_MODELS = 10
SEED = 42

# Model architecture (collaborator's approach: smaller networks)
MODEL_HIDDEN = 16
MODEL_DEPTH = 2
MODEL_DROPOUT = 0.0

# Training config (collaborator's approach)
BASE_EPOCHS = 80
BASE_LR = 1e-3
BASE_BATCH_SIZE = 16

# Rashomon generation config (collaborator's approach)
ASCENT_LR = 5e-4
MAX_STEPS = 1000
EVAL_EVERY = 20
SHUFFLE = True
SHUFFLE_SEED = 3

# Output base directory
RESULTS_BASE_DIR = './results/privacy_experiment'

# ===================================

def run_experiment_for_dataset(dataset_name):
    """Run privacy experiment for a single dataset."""
    
    print("\n" + "="*80)
    print(f"DATASET: {dataset_name.upper()}")
    print("="*80)
    
    # Output directories
    rashomon_dir = f'./saved_models/{dataset_name}_rashomon_privacy'
    results_dir = f'{RESULTS_BASE_DIR}/{dataset_name}'
    os.makedirs(results_dir, exist_ok=True)
    
    # Load dataset
    print(f"\n[1/3] Loading dataset {dataset_name}...")
    try:
        X, Y0, Y1 = read_dataset(f'datasets/{dataset_name}')
        n, d = X.shape
        print(f"  Loaded: {n} samples, {d} features")
    except Exception as e:
        print(f"  ERROR loading dataset: {e}")
        return None
    
    # Set model parameters per dataset
    if dataset_name == 'iris':
        hidden_size = 16
        model_depth = 2
    elif dataset_name in ['seeds', 'wine']:
        hidden_size = 25
        model_depth = 3
    elif dataset_name == 'compas':
        hidden_size = 20
        model_depth = 4
    else:
        hidden_size = 16
        model_depth = 2

    def ctor() -> nn.Module:
        return MLPBinary2Logits(
            d=d,
            hidden=hidden_size,
            depth=model_depth,
            dropout=MODEL_DROPOUT
        )
    
    # Check if Rashomon set already exists
    base_model_path = os.path.join(rashomon_dir, 'rashomon_base.pt')
    
    if os.path.exists(base_model_path):
        print(f"\n[2/3] Found existing Rashomon set at {rashomon_dir}")
        print("  Skipping generation. Delete the directory to regenerate.")
        
        # Load base loss from saved config
        import torch
        base_data = torch.load(base_model_path)
        base_loss = base_data['base_loss']
    else:
        print(f"\n[2/3] Generating Rashomon set ({NUM_RASHOMON_MODELS} models)...")
        print(f"  This may take 15-30 minutes...")
        
        try:
            summary = generate_rashomon_set(
                X=X,
                y=Y0,
                epsilon=EPSILON,
                num_models=NUM_RASHOMON_MODELS,
                save_dir=rashomon_dir,
                model_hidden=hidden_size,
                model_depth=MODEL_DEPTH,
                dropout=MODEL_DROPOUT,
                base_train_cfg=TrainConfig(
                    epochs=BASE_EPOCHS,
                    lr=BASE_LR,
                    batch_size=BASE_BATCH_SIZE
                ),
                ascent_lr=ASCENT_LR,
                max_steps=MAX_STEPS,
                eval_every=EVAL_EVERY,
                opt_num_attempts=OPT_NUM_ATTEMPTS,
                shuffle=SHUFFLE,
                shuffle_seed=SHUFFLE_SEED,
                relative_epsilon=RELATIVE_EPSILON,
                diversity_strategy="random_point_class",
                seed=SEED,
            )
            
            print(f"\n  Successfully generated {summary['num_models']} models")
            print(f"  Base loss: {summary['base_loss']:.4f}")
            base_loss = summary['base_loss']
            
        except Exception as e:
            print(f"  ERROR generating Rashomon set: {e}")
            return None
    
    # Save configuration
    config = {
        "dataset": dataset_name,
        "num_models": NUM_RASHOMON_MODELS,
        "epsilon": EPSILON,
        "base_loss": float(base_loss),
        "model_hidden": hidden_size,
        "model_depth": MODEL_DEPTH,
        "n_samples": int(n),
        "n_features": int(d),
    }
    
    with open(os.path.join(results_dir, 'rashomon_config.json'), 'w') as f:
        json.dump(config, f, indent=2)
    
    # Run membership inference attacks
    print(f"\n[3/3] Running membership inference attacks...")
    print(f"  Testing ensemble sizes: {ENSEMBLE_SIZES}")
    print(f"  Attack method: Simple Yeom (loss threshold)")
    
    try:
        results = membership_inference_experiment(
            rashomon_dir=rashomon_dir,
            num_rashomon_models=NUM_RASHOMON_MODELS,
            ensemble_sizes=ENSEMBLE_SIZES,
            model_ctor=ctor,
            X=X,
            y=Y0,
            num_shadow_models=NUM_SHADOW_MODELS,
            num_trials=NUM_TRIALS_PER_SIZE,
            test_size=0.5,
            device="cuda",
            seed=SEED,
            attack_methods=["yeom"],  # Only Yeom attack
            dataset=dataset_name,  # Pass dataset name for tuning
        )
        
        # Save results
        results_path = os.path.join(results_dir, 'attack_results.json')
        save_attack_results(results, results_path)
        
        # Plot results
        plot_path = os.path.join(results_dir, 'attack_results.png')
        plot_attack_results(results, plot_path)
        
        # Print summary
        print(f"\n{'='*80}")
        print(f"RESULTS SUMMARY - {dataset_name.upper()}")
        print(f"{'='*80}")
        print(f"{'Size':<8} {'Ens Acc':<15} {'Attack Acc':<20} {'Attack Adv':<20} {'Risk':<10}")
        print("-"*80)
        
        for size in ENSEMBLE_SIZES:
            ens_acc = results[size].get('mean_ensemble_test_acc', 0.0)
            acc_mean = results[size]['mean_accuracy']
            acc_std = results[size]['std_accuracy']
            adv_mean = results[size]['mean_advantage']
            adv_std = results[size]['std_advantage']
            
            if adv_mean < 0.05:
                risk = "Low"
            elif adv_mean < 0.15:
                risk = "Medium"
            else:
                risk = "High"
            
            print(f"{size:<8} {ens_acc:.4f}         {acc_mean:.4f} ± {acc_std:.4f}    {adv_mean:.4f} ± {adv_std:.4f}    {risk:<10}")
        
        return results
        
    except Exception as e:
        print(f"  ERROR running attacks: {e}")
        import traceback
        traceback.print_exc()
        return None


def main():
    """Run privacy experiment on all datasets."""
    
    print("="*80)
    print("PRIVACY LEAKAGE EXPERIMENT - ALL DATASETS")
    print("="*80)
    print(f"Datasets: {DATASETS}")
    print(f"Rashomon models per dataset: {NUM_RASHOMON_MODELS}")
    print(f"Ensemble sizes: {ENSEMBLE_SIZES}")
    print(f"Trials per size: {NUM_TRIALS_PER_SIZE}")
    print("="*80)
    
    all_results = {}
    
    for dataset in DATASETS:
        try:
            results = run_experiment_for_dataset(dataset)
            if results is not None:
                all_results[dataset] = results
        except Exception as e:
            print(f"\nFATAL ERROR with dataset {dataset}: {e}")
            continue
    
    # Cross-dataset summary
    print("\n" + "="*80)
    print("CROSS-DATASET SUMMARY")
    print("="*80)
    print("\nPrivacy leakage trend (Advantage: size 1 → size 100):\n")
    
    for dataset in DATASETS:
        if dataset in all_results:
            results = all_results[dataset]
            adv_1 = results[1]['mean_advantage']
            adv_100 = results[100]['mean_advantage']
            change = ((adv_100 - adv_1) / max(adv_1, 0.001)) * 100
            
            trend = "↑ INCREASE" if change > 20 else "↓ DECREASE" if change < -20 else "→ STABLE"
            print(f"{dataset:20s}: {adv_1:.4f} → {adv_100:.4f}  ({change:+6.1f}%)  {trend}")
    
    print("\n" + "="*80)
    print("EXPERIMENT COMPLETE")
    print("="*80)
    print(f"\nResults saved to: {RESULTS_BASE_DIR}/")
    print("\nPer-dataset directories contain:")
    print("  - attack_results.json (raw data)")
    print("  - attack_results.png (visualization)")
    print("  - rashomon_config.json (configuration)")


if __name__ == "__main__":
    main()
