"""
Generate Rashomon sets for all datasets before running experiments.
This allows both adversarial and privacy experiments to run in parallel.
"""
import os
import sys
from pathlib import Path

sys.path.insert(0, str(Path(__file__).parent.parent.parent))

from awp import generate_rashomon_set, TrainConfig
from dataset import read_dataset

# Configuration
DATASETS = ['iris', 'seeds', 'wine', 'compas']
NUM_RASHOMON_MODELS = 100
EPSILON = 0.30
RELATIVE_EPSILON = True
OPT_NUM_ATTEMPTS = 30
SEED = 42

BASE_EPOCHS = 80
BASE_LR = 1e-3
BASE_BATCH_SIZE = 16

ASCENT_LR = 5e-4
MAX_STEPS = 1000
EVAL_EVERY = 20
SHUFFLE = True
SHUFFLE_SEED = 3

def generate_for_dataset(dataset_name):
    """Generate Rashomon set for a single dataset."""
    print("\n" + "="*80)
    print(f"DATASET: {dataset_name.upper()}")
    print("="*80)
    
    rashomon_dir = f'./saved_models/{dataset_name}_rashomon_privacy'
    
    # Check if already exists
    if os.path.exists(os.path.join(rashomon_dir, 'rashomon_base.pt')):
        print(f"Rashomon set already exists at {rashomon_dir}")
        print("Skipping generation.")
        return
    
    # Load dataset
    print(f"\nLoading dataset {dataset_name}...")
    try:
        X, Y0, Y1 = read_dataset(f'datasets/{dataset_name}')
        n, d = X.shape
        print(f"Loaded: {n} samples, {d} features")
    except Exception as e:
        print(f"ERROR loading dataset: {e}")
        return
    
    # Set model parameters per dataset
    if dataset_name == 'iris':
        hidden_size = 16
        model_depth = 2
    elif dataset_name in ['seeds', 'wine']:
        hidden_size = 25
        model_depth = 3
    elif dataset_name == 'compas':
        hidden_size = 20
        model_depth = 4
    else:
        hidden_size = 16
        model_depth = 2
    
    print(f"\nGenerating Rashomon set ({NUM_RASHOMON_MODELS} models)...")
    print(f"Model: {model_depth} layers, {hidden_size} hidden units")
    print(f"This may take 15-30 minutes...")
    
    try:
        summary = generate_rashomon_set(
            X=X,
            y=Y0,
            epsilon=EPSILON,
            num_models=NUM_RASHOMON_MODELS,
            save_dir=rashomon_dir,
            model_hidden=hidden_size,
            model_depth=model_depth,
            dropout=0.0,
            base_train_cfg=TrainConfig(
                epochs=BASE_EPOCHS,
                lr=BASE_LR,
                batch_size=BASE_BATCH_SIZE
            ),
            ascent_lr=ASCENT_LR,
            max_steps=MAX_STEPS,
            eval_every=EVAL_EVERY,
            opt_num_attempts=OPT_NUM_ATTEMPTS,
            shuffle=SHUFFLE,
            shuffle_seed=SHUFFLE_SEED,
            relative_epsilon=RELATIVE_EPSILON,
            diversity_strategy="random_point_class",
            seed=SEED,
        )
        
        print(f"\nSuccessfully generated {summary['num_models']} models")
        print(f"Base loss: {summary['base_loss']:.4f}")
        
    except Exception as e:
        print(f"ERROR generating Rashomon set: {e}")
        import traceback
        traceback.print_exc()

def main():
    print("="*80)
    print("GENERATING RASHOMON SETS FOR ALL DATASETS")
    print("="*80)
    print(f"Datasets: {DATASETS}")
    print(f"Models per dataset: {NUM_RASHOMON_MODELS}")
    print("="*80)
    
    for dataset in DATASETS:
        generate_for_dataset(dataset)
    
    print("\n" + "="*80)
    print("DONE! All Rashomon sets generated.")
    print("You can now run both experiments in parallel:")
    print("  sbatch run_adversarial_experiment.sh")
    print("  sbatch run_privacy_experiment.sh")
    print("="*80)

if __name__ == "__main__":
    main()
