import numpy as np
import pandas as pd
import time
from sklearn.ensemble import RandomForestClassifier
from sklearn.neural_network import MLPClassifier
from art.estimators.classification import SklearnClassifier
from art.attacks.evasion import HopSkipJump
from base import (
    AdaDetectERM, AdaDetectDE, AdaDetectERMcv, generate_exchangeable_gaussian,
    improved_load_dataset, datasets_config, model_configs, alpha_values
)
from art.attacks.evasion import BoundaryAttack
# Make sure AdaDetectERM is in your PYTHONPATH:
# from your_adadetect_module import AdaDetectERM

# ─────────────────────────────── CONFIGURATION ────────────────────────────────
proc_model_cfg = {
    'type': 'rf',              # one of: 'mlp100', 'mlp3', 'rf', 'rf_depth'
    'mlp100_hidden': (100,),
    'mlp3_hidden':   (128, 64, 32),
    'rf_n_estimators': 100,
    'rf_max_depth':    10,       # for 'rf'
    'rf_depth_max_depth': 20,    # for 'rf_depth'
}

attack_model_cfg = {
    'type': 'rf',          # one of: 'mlp100', 'mlp3', 'rf', 'rf_depth'
    'mlp100_hidden':   (100,),
    'mlp3_hidden':     (128, 64, 32),
    'rf_n_estimators': 200,
    'rf_max_depth':    5,        # for 'rf'
    'rf_depth_max_depth': 15,    # for 'rf_depth'
}

selection_cfg = {
    'method':   'calib',         # 'bh' or 'calib'
    'n_calib':   5000,           # bootstrap size for 'calib'
    'n_perturb': 200,            # how many samples to attack
}
# ────────────────────────────────────────────────────────────────────────────────


# ──────────────────────────── FACTORY FUNCTIONS ────────────────────────────────
def build_proc_scorer(cfg):
    t = cfg['type']
    if t == 'mlp100':
        return MLPClassifier(
            hidden_layer_sizes=cfg['mlp100_hidden'],
            activation='relu', max_iter=500,
            solver='adam', random_state=42
        )
    if t == 'mlp3':
        class NNScorer:
            def __init__(self):
                self.model = MLPClassifier(
                    hidden_layer_sizes=cfg['mlp3_hidden'],
                    activation='relu', max_iter=500,
                    solver='adam', random_state=42
                )
            def fit(self, X, y):
                self.model.fit(X, y)
            def predict_proba(self, X):
                return self.model.predict_proba(X)[:, 1]
        return NNScorer()
    if t == 'rf':
        return RandomForestClassifier(
            n_estimators=cfg['rf_n_estimators'],
            max_depth=cfg['rf_max_depth'],
            random_state=42
        )
    if t == 'rf_depth':
        return RandomForestClassifier(
            n_estimators=cfg['rf_n_estimators'],
            max_depth=cfg['rf_depth_max_depth'],
            random_state=42
        )
    raise ValueError(f"Unknown proc model type: {t}")


def build_attack_model(cfg):
    t = cfg['type']
    if t == 'mlp100':
        m = MLPClassifier(
            hidden_layer_sizes=cfg['mlp100_hidden'],
            activation='relu', max_iter=500,
            solver='adam', random_state=0
        )
    elif t == 'mlp3':
        m = MLPClassifier(
            hidden_layer_sizes=cfg['mlp3_hidden'],
            activation='relu', max_iter=500,
            solver='adam', random_state=0
        )
    elif t == 'rf':
        m = RandomForestClassifier(
            n_estimators=cfg['rf_n_estimators'],
            max_depth=cfg['rf_max_depth'],
            random_state=0
        )
    elif t == 'rf_depth':
        m = RandomForestClassifier(
            n_estimators=cfg['rf_n_estimators'],
            max_depth=cfg['rf_depth_max_depth'],
            random_state=0
        )
    else:
        raise ValueError(f"Unknown attack model type: {t}")
    return m
# ────────────────────────────────────────────────────────────────────────────────


# ─────────────────────────────── HELPER FUNCTIONS ──────────────────────────────
def compute_pvalue(stat, null_stats):
    return (1 + np.sum(null_stats >= stat)) / (len(null_stats) + 1)

def BH(pvalues, level):
    idx = np.argsort(pvalues)
    m = len(pvalues)
    thresh = level * np.arange(1, m+1) / m
    below = pvalues[idx] <= thresh
    if not below.any(): 
        return np.array([], dtype=int)
    i0 = np.where(below)[0].max()
    return idx[:i0+1]

def select_by_bh(pvals, alpha, total, how_many):
    rej = BH(pvals, alpha)
    non_rej = np.setdiff1d(np.arange(total), rej)
    sorted_idx = np.argsort(pvals)
    return [i for i in sorted_idx if i in non_rej][:how_many]

def select_by_calib(x, surrogate_model, non_rej, n_calib, how_many):
    calib_indices = np.random.choice(non_rej, size=n_calib, replace=True)
    X_calib = x[calib_indices]
    calib_scores = surrogate_model.predict_proba(X_calib)[:, 1]
    test_scores  = surrogate_model.predict_proba(x)[:, 1]
    def pval(t): return (1 + np.sum(calib_scores >= t)) / (len(calib_scores) + 1)
    pvals_new = np.array([pval(s) for s in test_scores])
    sorted_unrej = non_rej[np.argsort(pvals_new[non_rej])]
    return sorted_unrej[:how_many]
# ────────────────────────────────────────────────────────────────────────────────


# ─────────────────────────────────── MAIN PIPELINE ──────────────────────────────
# Test all 4 datasets like scoreattack.py
datasets_to_test = ['creditcard','kddcup99','mammography','shuttle']
all_dataset_results = []

for dataset_name in datasets_to_test:
    dataset_start_time = time.time()
    print(f"\n" + "="*80)
    print(f"TESTING DATASET: {dataset_name.upper()}")
    print(f"🕐 Started at: {time.strftime('%H:%M:%S')}")
    print("="*80)
    
    try:
        # 1) Load data using the same approach as scoreattack.py
        print("1. Loading Data...")
        load_start = time.time()
        dataset_config = datasets_config[dataset_name]
        
        # Load dataset using improved_load_dataset
        X_test, y_test, m0, m1, n_features = improved_load_dataset(dataset_name, dataset_config)
        load_time = time.time() - load_start
        print(f"✓ {dataset_name} loaded: {len(X_test)} samples ({m0} inliers, {m1} outliers)")
        print(f"  ⏱️  Data loading time: {load_time:.2f}s")
        
        # For calibration data, we need to load the full dataset
        if dataset_config.get('type') == 'synthetic':
            # For synthetic datasets, generate more data for calibration
            if dataset_name == 'exchangeable_gaussian':
                n_features = dataset_config.get('n_features', 20)
                total_samples = 15000  # Generate more samples for calibration
                
                # Generate normal data (inliers) for calibration
                normal_data_full = generate_exchangeable_gaussian(
                    n=total_samples, 
                    d=n_features, 
                    a=0, b=1, c=0.3
                )
                
                # Use first m0 for test, rest for calibration
                test0 = normal_data_full[:m0]  # Normal samples for test
                test1 = X_test[m0:m0+m1]      # Anomaly samples from test set
                xnull = normal_data_full[m0:m0+5000]  # Calibration data
        else:
            # For OpenML datasets, load the full dataset for calibration
            from sklearn.datasets import fetch_openml
            dataset = fetch_openml(name=dataset_config['openml_name'], version=dataset_config['version'], as_frame=False)
            X_full, y_full = dataset.data, dataset.target
            
            # Apply same preprocessing as in improved_load_dataset
            if dataset_name == 'creditcard':
                y_binary = y_full.astype(float)
            elif dataset_name == 'shuttle':
                y_binary = (y_full != '1').astype(float)
            elif dataset_name == 'kddcup99':
                y_binary = (y_full != 'normal').astype(float)
            else:
                unique_labels = np.unique(y_full)
                label_counts = {label: np.sum(y_full == label) for label in unique_labels}
                anomaly_label = min(label_counts.keys(), key=label_counts.get)
                y_binary = (y_full == anomaly_label).astype(float)
            
            # Handle mixed data types
            if X_full.dtype == 'object':
                X_processed = []
                for col in range(X_full.shape[1]):
                    col_data = X_full[:, col]
                    try:
                        col_numeric = pd.to_numeric(col_data, errors='coerce')
                        if np.isnan(col_numeric).any():
                            col_numeric = np.nan_to_num(col_numeric, nan=np.nanmedian(col_numeric))
                        X_processed.append(col_numeric)
                    except:
                        try:
                            from sklearn.preprocessing import LabelEncoder
                            le = LabelEncoder()
                            col_encoded = le.fit_transform(col_data.astype(str))
                            X_processed.append(col_encoded.astype(float))
                        except:
                            continue
                
                if X_processed:
                    X_full = np.column_stack(X_processed)
            
            # Handle NaN values
            if np.isnan(X_full).any():
                X_full = np.nan_to_num(X_full, nan=np.nanmedian(X_full))
            
            # Apply scaling
            from sklearn.preprocessing import RobustScaler
            scaler = RobustScaler()
            X_full_scaled = scaler.fit_transform(X_full)
            
            # Separate classes
            normal_data_full = X_full_scaled[y_binary == 0]
            anomaly_data_full = X_full_scaled[y_binary == 1]
            
            # Use first m0 normal samples for test, rest for calibration
            test0 = normal_data_full[:m0]  # Normal samples for test
            test1 = anomaly_data_full[:m1]  # Anomaly samples for test
            xnull = normal_data_full[m0:m0+5000]  # Calibration data
        
        # Prepare test data
        x = np.vstack([test0, test1])
        
        print(f"   Test set: {len(x)} samples ({m0} normal + {m1} anomaly)")
        print(f"   Calibration set: {len(xnull)} normal samples")
        print(f"   Features: {n_features}")
        
        # 2) Build & fit proc using the same configuration as scoreattack.py
        print("2. Building Original Detector...")
        detector_start = time.time()
        alpha = 0.1
        
        # Use the same model configuration as scoreattack.py
        #proc_model_cfg = model_configs['rf_deep']
        proc_scorer = build_proc_scorer(proc_model_cfg)
        print(f"<bound method ForestClassifier.predict_proba of {proc_scorer.__class__.__name__}>")
        
        # Build AdaDetect detector
        proc = AdaDetectERM(scoring_fn=proc_scorer, split_size=4000/5000)
        orig_rej_idx = proc.apply(x, alpha, xnull)
        detector_time = time.time() - detector_start
        print(f"✓ Detector built. Detections: {len(orig_rej_idx)}")
        print(f"  ⏱️  Detector building time: {detector_time:.2f}s")
        
        # 3) Compute p-values & BH rejections
        print("3. Computing p-values...")
        pvalue_start = time.time()
        null_stats = proc.null_statistics
        test_stats = proc.test_statistics
        pvals = np.array([compute_pvalue(s, null_stats) for s in test_stats])
        
        non_rej_orig = np.setdiff1d(np.arange(len(x)), orig_rej_idx)
        pvalue_time = time.time() - pvalue_start
        print(f"✓ P-values computed. Non-rejected samples: {len(non_rej_orig)}")
        print(f"  ⏱️  P-value computation time: {pvalue_time:.2f}s")
        
        # 4) Train attack surrogate
        print("4. Training Surrogate Model...")
        surrogate_start = time.time()
        labels = np.zeros(len(x), dtype=int)
        # Check label distribution
        unique_labels = np.unique(labels)
        n_pos = np.sum(labels == 1)
        n_neg = np.sum(labels == 0)
        print(f"    ℹ️  Surrogate training labels: {n_neg} negative, {n_pos} positive")
        
        if len(unique_labels) < 2:
            print(f"    ⚠️  Only {len(unique_labels)} class(es) found, creating artificial labels...")
            # Create some positive labels to ensure binary classification
            n_to_flip = min(10, len(labels) // 10)
            flip_indices = np.random.choice(len(labels), n_to_flip, replace=False)
            labels[flip_indices] = 1

        # Check for constant features in the training data
        for col in range(x.shape[1]):
            if np.all(x[:, col] == x[0, col]):
                print(f"    ⚠️  Feature {col} is constant, this may cause issues with some classifiers")
        labels[orig_rej_idx] = 1
        
        # Use the same attack model configuration as scoreattack.py
        attack_model_cfg = model_configs['rf_deep']  # Use same config for consistency
        attack_base = build_attack_model(attack_model_cfg)
        attack_base.fit(x, labels)
        surrogate_time = time.time() - surrogate_start
        print("✓ Surrogate model trained (for reference)")
        print(f"  ⏱️  Surrogate training time: {surrogate_time:.2f}s")
        
        # Set up ART wrapper with proper clip values
        clip_min, clip_max = x.min(axis=0), x.max(axis=0)
        
        # Fix for all datasets: handle constant features where min == max
        # Check for constant features and add small epsilon to clip_max
        constant_features = np.where(clip_min == clip_max)[0]
        if len(constant_features) > 0:
            print(f"    ⚠️  Found {len(constant_features)} constant features, adjusting clip values...")
            epsilon = 1e-6
            clip_max[constant_features] += epsilon
        
        art_surrogate = SklearnClassifier(
            model=attack_base,
            clip_values=(clip_min, clip_max)
        )
        
        # 5) Select indices to perturb using the same approach as scoreattack.py
        print("5. Testing BH Method (Score-based) with n_perturb=200")
        n_perturb = 200
        print(f"Using n_perturb = {n_perturb}")
        
        print("6. Selecting samples using BH method...")
        if selection_cfg['method'] == 'bh':
            to_perturb = select_by_bh(
                pvals, alpha,
                total=len(x),
                how_many=n_perturb
            )
            print(f"✓ Selected {len(to_perturb)} normal samples with highest p-values to perturb")
            print(f"  Target: Increase FDR by making these normal samples look suspicious")
            
            # Show sample of selected samples
            normal_indices = [i for i in to_perturb if i < m0]
            if len(normal_indices) > 0:
                sample_indices = normal_indices[:5]
                print("Sample of selected normal samples and p-values:")
                for idx in sample_indices:
                    print(f"  Index {idx} (normal): p-value = {pvals[idx]:.6f}")
        elif selection_cfg['method'] == 'calib':
            to_perturb = select_by_calib(
                x,
                surrogate_model=attack_base,
                non_rej=non_rej_orig,
                n_calib=selection_cfg['n_calib'],
                how_many=n_perturb
            )
            print(f"✓ Selected {len(to_perturb)} samples using calibration method")
        else:
            raise ValueError(f"Unknown selection method: {selection_cfg['method']}")
        
        # 6) Craft adversarial examples using the same approach as scoreattack.py
        print("7. Performing surrogate-based attack...")
        y_target = np.ones(len(to_perturb), dtype=int)
        
        # Add progress tracking with timestamps
        attack_start_time = time.time()
        
        print(f"  🕐 Starting attack at {time.strftime('%H:%M:%S')}")
        print(f"  📊 Attack parameters: {len(to_perturb)} samples, max_iter=150, max_eval=100")
        
        # Calculate expected total iterations
        total_expected_iterations = len(to_perturb) * 300  # max_iter per sample
        total_expected_evaluations = len(to_perturb) * 300  # max_eval per sample
        
        print(f"  🔢 Expected total iterations: {total_expected_iterations:,}")
        print(f"  🔢 Expected total evaluations: {total_expected_evaluations:,}")
        print(f"  📈 Processing in batches of 10 samples for progress tracking")
        
        # Use the same attack parameters as scoreattack.py

        # Use more conservative attack parameters to avoid numerical issues
        attack2 = HopSkipJump(
            classifier=art_surrogate,
            targeted=True,
            verbose=True,
            norm=2,
            max_iter=100,    # Reduced from 150
            max_eval=50,     # Reduced from 100
            init_eval=40,    # Reduced from 80
            init_size=40     # Reduced from 80
        )
        attack = BoundaryAttack(
        estimator=art_surrogate,
        targeted=True,   # Use targeted attack
        max_iter=1000,   # More iterations
        num_trial=50,    # More trials per step
        sample_size=20,  # Samples per trial
        init_size=100,   # More initial trials
        min_epsilon=0.01,  # Smaller minimum step size
        step_adapt=0.2,   # Slower step size adaptation
        verbose=True     # Show progress
    )        
        print(f"  🔧 HopSkipJump attack configured at {time.strftime('%H:%M:%S')}")
        
        x_adv = x.copy()
        
        # Track progress during attack generation
        print(f"  🚀 Generating adversarial examples...")
        print(f"  📈 Progress tracking enabled - will show iteration details")
        generation_start = time.time()
        
        # Custom progress tracking for HopSkipJump
        class ProgressTracker:
            def __init__(self, total_samples):
                self.total_samples = total_samples
                self.current_sample = 0
                self.start_time = time.time()
                self.last_update = time.time()
                
            def update(self, sample_idx):
                self.current_sample = sample_idx
                current_time = time.time()
                elapsed = current_time - self.start_time
                
                # Update every 10 samples or every 5 seconds
                if (sample_idx % 10 == 0) or (current_time - self.last_update > 5):
                    progress = (sample_idx / self.total_samples) * 100
                    if sample_idx > 0:
                        avg_time_per_sample = elapsed / sample_idx
                        eta = avg_time_per_sample * (self.total_samples - sample_idx)
                        print(f"    📊 Progress: {sample_idx}/{self.total_samples} ({progress:.1f}%) - "
                              f"Elapsed: {elapsed:.1f}s, ETA: {eta:.1f}s")
                    self.last_update = current_time
        
        # Create progress tracker
        progress_tracker = ProgressTracker(len(to_perturb))
        
        # Process samples in batches to show progress
        batch_size = 10  # Process 10 samples at a time to show progress
        for i in range(0, len(to_perturb), batch_size):
            batch_end = min(i + batch_size, len(to_perturb))
            batch_indices = to_perturb[i:batch_end]
            
            print(f"    🔄 Processing batch {i//batch_size + 1}/{(len(to_perturb)-1)//batch_size + 1} "
                  f"(samples {i+1}-{batch_end})")
            
            batch_start = time.time()
            
            # Generate adversarial examples for this batch
            x_adv[batch_indices] = attack.generate(
                x[batch_indices].astype(np.float32),
                y=y_target[i:i+len(batch_indices)]
            )
            
            batch_time = time.time() - batch_start
            progress_tracker.update(batch_end)
            print(f"    ✅ Batch completed in {batch_time:.2f}s")
        
        generation_time = time.time() - generation_start
        # Handle NaN values in adversarial examples
        if np.isnan(x_adv).any():
            print(f"    ⚠️  Found {np.isnan(x_adv).sum()} NaN values in adversarial examples, replacing with median...")
            for col in range(x_adv.shape[1]):
                col_data = x_adv[:, col]
                if np.isnan(col_data).any():
                    median_val = np.nanmedian(col_data)
                    x_adv[np.isnan(col_data), col] = median_val
        # Validate adversarial examples before proceeding
        if not np.isfinite(x_adv).all():
            print(f"    ❌ Invalid values found in adversarial examples (inf or nan)")
            raise ValueError("Adversarial examples contain invalid values")

        # Check if adversarial examples are actually different from original
        changes = np.sum(x_adv[to_perturb] != x[to_perturb])
        if changes == 0:
            print(f"    ⚠️  Warning: No changes made to adversarial examples")
            # Continue anyway, might still work

        print(f"    ✅ Generated {len(to_perturb)} adversarial examples with {changes} total changes")
        total_attack_time = time.time() - attack_start_time
        
        print(f"  ✅ Attack completed successfully at {time.strftime('%H:%M:%S')}")
        print(f"  ⏱️  Generation time: {generation_time:.2f}s")
        print(f"  ⏱️  Total attack time: {total_attack_time:.2f}s")
        print(f"  📈 Average time per sample: {generation_time/len(to_perturb):.3f}s")
        print(f"  📊 Total samples processed: {len(to_perturb)}")
        print(f"  🚀 Processing rate: {len(to_perturb)/generation_time:.2f} samples/second")
        
        # 7) Re-fit proc on adversarial data
        print("8. Re-evaluating detector on adversarial examples...")
        reeval_start = time.time()
        proc_mod = AdaDetectERM(scoring_fn=proc_scorer, split_size=4000/5000)
        new_rej_idx = proc_mod.apply(x_adv, alpha, xnull)
        reeval_time = time.time() - reeval_start
        print(f"✓ Re-evaluation complete. New detections: {len(new_rej_idx)}")
        print(f"  ⏱️  Re-evaluation time: {reeval_time:.2f}s")
        
        # 8) Compare & report
        print("9. Evaluating attack effectiveness...")
        eval_start = time.time()
        null_stats2 = proc_mod.null_statistics
        test_stats2 = proc_mod.test_statistics
        pvals2 = np.array([compute_pvalue(s, null_stats2) for s in test_stats2])
        
        # Create detailed comparison dataframe
        df = pd.DataFrame({
            'idx':        np.arange(len(x)),
            'orig_score': test_stats,
            'new_score':  test_stats2,
            'p_orig':     pvals,
            'p_adv':      pvals2,
            'delta':      test_stats2 - test_stats
        })
        
        eval_time = time.time() - eval_start
        print(f"  ⏱️  Evaluation time: {eval_time:.2f}s")
        
        print("\n" + "="*60)
        print(f"RESULTS FOR {dataset_name.upper()}")
        print("="*60)
        print(f"Method: BH (Score-based)")
        print(f"n_perturb: {n_perturb}")
        
        # Compute FDR / TDR using the same approach as scoreattack.py
        y_true = np.concatenate([np.zeros(m0), np.ones(m1)])
        y_pred_orig = np.zeros(len(x)); y_pred_orig[orig_rej_idx] = 1
        y_pred_new = np.zeros(len(x)); y_pred_new[new_rej_idx] = 1
        
        # Original metrics
        tp_orig = np.sum((y_pred_orig==1)&(y_true==1))
        fp_orig = np.sum((y_pred_orig==1)&(y_true==0))
        fn_orig = np.sum((y_pred_orig==0)&(y_true==1))
        fdr_orig = fp_orig/(tp_orig+fp_orig) if tp_orig+fp_orig>0 else 0.0
        tdr_orig = tp_orig/(tp_orig+fn_orig) if tp_orig+fn_orig>0 else 0.0
        
        # New metrics
        tp_new = np.sum((y_pred_new==1)&(y_true==1))
        fp_new = np.sum((y_pred_new==1)&(y_true==0))
        fn_new = np.sum((y_pred_new==0)&(y_true==1))
        fdr_new = fp_new/(tp_new+fp_new) if tp_new+fp_new>0 else 0.0
        tdr_new = tp_new/(tp_new+fn_new) if tp_new+fn_new>0 else 0.0
        
        # Calculate FDR increase
        fdr_increase = fp_new - fp_orig
        fdr_increase_rate = (fdr_increase / fp_orig * 100) if fp_orig > 0 else float('inf')
        
        print(f"\nFDR-FOCUSED ATTACK EFFECTIVENESS:")
        print(f"  Original FDR: {fdr_orig:.4f}")
        print(f"  New FDR: {fdr_new:.4f}")
        print(f"  FDR Increase: {fdr_increase} false positives")
        print(f"  FDR Increase Rate: {fdr_increase_rate:.2f}%")
        
        print(f"\nTRADITIONAL METRICS:")
        print(f"  Original detections: {len(orig_rej_idx)}")
        print(f"  New detections: {len(new_rej_idx)}")
        print(f"  Detection change: {((len(new_rej_idx) - len(orig_rej_idx)) / len(orig_rej_idx) * 100):.2f}%" if len(orig_rej_idx) > 0 else "inf%")
        
        print(f"\nCLASSIFICATION METRICS:")
        print(f"  True Positives (TP): {tp_new}")
        print(f"  False Positives (FP): {fp_new}")
        print(f"  False Negatives (FN): {fn_new}")
        print(f"  False Discovery Rate (FDR): {fdr_new:.4f}")
        print(f"  True Discovery Rate (TDR): {tdr_new:.4f}")
        
        print(f"\n" + "="*60)
        print("FDR ATTACK ASSESSMENT")
        print("="*60)
        if fdr_increase > 0:
            print(f"✅ FDR ATTACK SUCCESSFUL!")
            print(f"   Increased false positives by {fdr_increase} samples")
            print(f"   FDR increased from {fdr_orig:.4f} to {fdr_new:.4f}")
            print(f"   FDR increase rate: {fdr_increase_rate:.1f}%")
        else:
            print(f"❌ FDR ATTACK FAILED")
            print(f"   No increase in false positives")
        
        print(f"\n" + "="*60)
        print("TRADITIONAL ATTACK ASSESSMENT")
        print("="*60)
        if len(new_rej_idx) < len(orig_rej_idx):
            print(f"✅ DETECTION REDUCTION SUCCESSFUL")
            print(f"   Reduced detections by {len(orig_rej_idx) - len(new_rej_idx)}")
        else:
            print(f"❌ DETECTION REDUCTION INEFFECTIVE")
            print(f"   No reduction in detections")
        
        print(f"\nFDR Attack: {'SUCCESS' if fdr_increase > 0 else 'FAILED'}")
        print(f"Detection Attack: {'SUCCESS' if len(new_rej_idx) < len(orig_rej_idx) else 'FAILED'}")
        
        # Store results for comparison
        results = {
            'dataset': dataset_name,
            'n_features': n_features,
            'original_fp': fp_orig,
            'new_fp': fp_new,
            'fdr_increase': fdr_increase,
            'fdr_increase_rate': fdr_increase_rate,
            'original_fdr': fdr_orig,
            'new_fdr': fdr_new,
            'original_tdr': tdr_orig,
            'new_tdr': tdr_new,
            'tp': tp_new,
            'fn': fn_new,
            'fdr': fdr_new,
            'tdr': tdr_new
        }
        all_dataset_results.append(results)
        
        # Calculate total time for this dataset
        dataset_total_time = time.time() - dataset_start_time
        
        print(f"\n✓ {dataset_name} attack evaluation complete")
        print(f"  FDR: {fdr_orig:.4f} → {fdr_new:.4f}")
        print(f"  FDR Increase: {fdr_increase} false positives")
        print(f"  ⏱️  Total dataset time: {dataset_total_time:.2f}s")
        print(f"  🕐 Completed at: {time.strftime('%H:%M:%S')}")
        
    except Exception as e:
        dataset_total_time = time.time() - dataset_start_time
        print(f"❌ Attack failed on {dataset_name}: {e}")
        print(f"  ⏱️  Time before failure: {dataset_total_time:.2f}s")
        continue

# ============================================================================
# COMPARATIVE ANALYSIS
# ============================================================================

if len(all_dataset_results) > 0:
    print(f"\n" + "="*80)
    print("COMPARATIVE FDR ATTACK ANALYSIS")
    print("="*80)
    
    # Create comparison table
    print(f"\n{'Dataset':<12} {'Features':<8} {'Original FDR':<12} {'New FDR':<8} {'FDR Inc':<8} {'FP Inc':<8} {'Status':<10}")
    print("-" * 80)
    
    for result in all_dataset_results:
        original_fdr = result['original_fdr']
        new_fdr = result['new_fdr']
        fdr_inc = new_fdr - original_fdr
        fp_inc = result['fdr_increase']
        original_tdr = result['original_tdr']
        new_tdr = result['new_tdr']
        status = "✅ SUCCESS" if fp_inc > 0 else "❌ FAILED"
        
        print(f"{result['dataset']:<12} {result['n_features']:<8} {original_fdr:<12.4f} {new_fdr:<8.4f} {original_tdr:<8.4f} {new_tdr:<8.4f} {status:<10}")
    
    # Find most vulnerable dataset
    successful_attacks = [r for r in all_dataset_results if r['fdr_increase'] > 0]
    if successful_attacks:
        most_vulnerable = max(successful_attacks, key=lambda x: x['fdr_increase'])
        print(f"\n🏆 MOST VULNERABLE DATASET: {most_vulnerable['dataset'].upper()}")
        print(f"   FDR increased by {most_vulnerable['fdr_increase']} false positives")
        print(f"   FDR change: {most_vulnerable['original_fdr']:.4f} → {most_vulnerable['new_fdr']:.4f}")
    
    # Overall statistics
    print(f"\n📊 OVERALL ATTACK SUCCESS:")
    print(f"   Successful attacks: {len(successful_attacks)}/{len(all_dataset_results)}")
    if successful_attacks:
        print(f"   Average FDR increase: {np.mean([r['fdr_increase'] for r in successful_attacks]):.2f} false positives")
    
    # Save comparative results
    df_comparison = pd.DataFrame(all_dataset_results)
    df_comparison.to_csv('surrogate_attack_comparison.csv', index=False)
    print(f"\n✓ Comparative results saved to: surrogate_attack_comparison.csv")
    
else:
    print(f"\n❌ No successful attacks to compare")

print(f"\n🎯 Multi-dataset surrogate attack comparison completed!")
