import numpy as np
import torch
import torch.nn as nn
from sklearn.neural_network import MLPClassifier
from sklearn.svm import SVC
from sklearn.ensemble import GradientBoostingClassifier, RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_validate, StratifiedKFold, cross_val_predict
from sklearn.metrics import roc_curve, precision_recall_curve
from sklearn.base import clone
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy import stats
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor

class ImprovedPrivacyMetrics:
    """Enhanced privacy metrics for more realistic assessment"""
    def __init__(self):
        self.calibration_model = None
        self.is_calibrated = False
        self.calibration_method = None

    def train_empirical_calibration(self, calibration_data, method='linear'):
        """
        Train empirical epsilon calibration model based on Jagielski et al.

        Args:
            calibration_data: list of tuples (auc_score, true_epsilon) from known DP models
            method: 'linear' or 'random_forest' for calibration model
        """
        if not calibration_data:
            warnings.warn("No calibration data provided. Using theoretical bounds.")
            return

        auc_scores = np.array([x[0] for x in calibration_data]).reshape(-1, 1)
        true_epsilons = np.array([x[1] for x in calibration_data])

        if method == 'linear':
            self.calibration_model = LinearRegression()
        elif method == 'random_forest':
            self.calibration_model = RandomForestRegressor(n_estimators=100, random_state=42)
        else:
            raise ValueError("Method must be 'linear' or 'random_forest'")

        self.calibration_model.fit(auc_scores, true_epsilons)
        self.is_calibrated = True
        self.calibration_method = method

        # Calculate calibration quality
        predicted_eps = self.calibration_model.predict(auc_scores)
        mae = np.mean(np.abs(predicted_eps - true_epsilons))
        print(f"Empirical calibration trained ({method}), MAE: {mae:.4f}")

    def empirical_epsilon(self, auc_score, method='auto'):
        """
        Calculate empirical epsilon using Jagielski et al.'s approach

        Args:
            auc_score: MIA AUC score
            method: 'auto' (uses calibration if available), 'calibration', 'conservative'
        """
        if auc_score <= 0.5:
            return 0.0

        if method == 'auto':
            if self.is_calibrated:
                return self._calibration_epsilon(auc_score)
            else:
                return self._conservative_epsilon(auc_score)
        elif method == 'calibration' and self.is_calibrated:
            return self._calibration_epsilon(auc_score)
        elif method == 'conservative':
            return self._conservative_epsilon(auc_score)
        else:
            warnings.warn("Calibration model not available, using conservative estimate")
            return self._conservative_epsilon(auc_score)

    def _calibration_epsilon(self, auc_score):
        """Calculate epsilon using trained calibration model"""
        eps = self.calibration_model.predict(np.array([[auc_score]]))[0]
        return max(0.0, eps)

    def _conservative_epsilon(self, auc_score, delta=1e-5):
        """More realistic epsilon calculation based on AUC"""
        if auc_score <= 0.5:
            return 0.0

        # Convert AUC to effective advantage using ROC geometry
        advantage = 2 * (auc_score - 0.5)

        # More conservative epsilon calculation
        if advantage <= 2 * delta:
            return 0.0
        else:
            # Using tighter bounds from DP theory
            epsilon = np.log(1 + advantage - delta) - np.log(1 - delta)
            return max(0, epsilon)

    def get_calibration_info(self):
        """Get information about current calibration status"""
        if self.is_calibrated:
            return {
                'calibrated': True,
                'method': self.calibration_method,
                'model_type': type(self.calibration_model).__name__
            }
        else:
            return {
                'calibrated': False,
                'message': 'Using theoretical bounds for epsilon calculation'
            }

    @staticmethod
    def calculate_privacy_risk(auc_score, baseline=0.5):
        """Calculate normalized privacy risk score"""
        if auc_score <= baseline:
            return 0.0
        return min(1.0, (auc_score - baseline) / (1.0 - baseline))

    @staticmethod
    def effective_epsilon(auc_score, delta=1e-5):
        """More realistic epsilon calculation based on AUC"""
        if auc_score <= 0.5:
            return 0.0

        # Convert AUC to effective advantage using ROC geometry
        advantage = 2 * (auc_score - 0.5)

        # More conservative epsilon calculation
        if advantage <= 2 * delta:
            return 0.0
        else:
            # Using tighter bounds from DP theory
            epsilon = np.log(1 + advantage - delta) - np.log(1 - delta)
            return max(0, epsilon)

    @staticmethod
    def membership_advantage(tpr, fpr):
        """Calculate membership advantage (TPR - FPR)"""
        return max(0, tpr - fpr)

    @staticmethod
    def privacy_leakage_score(auc_scores):
        """Calculate overall privacy leakage from multiple attacks"""
        # Weight by how much better than random guessing
        leakages = [max(0, auc - 0.5) for auc in auc_scores]
        return np.mean(leakages) * 2  # Scale to [0, 1]

    @staticmethod
    def calculate_confidence_interval(scores, confidence=0.95):
        """Calculate confidence interval for scores"""
        if len(scores) == 0:
            return 0, 0
        mean = np.mean(scores)
        sem = stats.sem(scores)
        ci = sem * stats.t.ppf((1 + confidence) / 2., len(scores) - 1)
        return mean, ci

    @staticmethod
    def tpr_at_low_fpr(y_true, y_pred_proba, target_fpr=0.001):
        """Calculate True Positive Rate at low False Positive Rate (0.1% by default)"""
        fpr, tpr, thresholds = roc_curve(y_true, y_pred_proba)

        # Find the threshold where FPR <= target_fpr
        valid_indices = np.where(fpr <= target_fpr)[0]
        if len(valid_indices) == 0:
            # If no FPR is low enough, return the TPR at minimum FPR
            return tpr[np.argmin(fpr)]

        # Return the maximum TPR achieved at FPR <= target_fpr
        return np.max(tpr[valid_indices])

    @staticmethod
    def tpr_at_fixed_fpr(y_true, y_pred_proba, target_fpr=0.001):
        """Calculate TPR at exactly target FPR using interpolation"""
        fpr, tpr, _ = roc_curve(y_true, y_pred_proba)

        # Use interpolation to find TPR at exact target FPR
        tpr_at_target_fpr = np.interp(target_fpr, fpr, tpr)
        return tpr_at_target_fpr


class EnhancedPrivacyEvaluator:
    def __init__(self, model, device, num_classes=10, kdes=None):
        self.model = model
        self.device = device
        self.num_classes = num_classes
        # self.model.eval()
        self.classes_kdes = kdes
        self.privacy_metrics = ImprovedPrivacyMetrics()
        self.set_empirical_calibration(None)

    def set_empirical_calibration(self, calibration_data, method='linear'):
        """
        Set up empirical epsilon calibration using known DP models

        Args:
            calibration_data: list of tuples (auc_score, true_epsilon)
                            e.g., [(0.55, 0.1), (0.65, 0.5), (0.75, 1.0)]
            method: 'linear' or 'random_forest'
        """
        calibration_data = [
            (0.500, 0.00),  # Perfect privacy
            (0.505, 0.10),  # ε=0.1 theoretical
            (0.510, 0.20),  # ε=0.2 theoretical
            (0.520, 0.50),  # ε=0.5 theoretical
            (0.540, 1.00),  # ε=1.0 theoretical (your target)
            (0.600, 2.00),  # ε=2.0 theoretical
            (0.700, 4.00),  # ε=4.0 theoretical
            # (0.600, 2.00),  # ε=2.0 theoretical
            # (0.700, 4.00),  # ε=4.0 theoretical
        ]

        self.privacy_metrics.train_empirical_calibration(calibration_data, method)

    def get_model_predictions(self, data_loader):
        """Get model predictions, probabilities, and losses for a dataset"""
        all_probs = []
        all_preds = []
        all_labels = []
        all_losses = []
        running_corrects = 0
        with torch.no_grad():
            for inputs in data_loader:
                features = inputs['feature'].to(self.device)
                labels = inputs['target'].squeeze().to(self.device)
                if not self.model:
                    outputs = features
                    output = []
                    for ii in range(10):
                        log_prob = self.classes_kdes[ii].score_samples(outputs)
                        output.append(log_prob.detach().cpu().numpy())
                    output_ = np.array(output).T
                    outputs = torch.tensor(output_).cuda()
                else:
                    outputs = self.model(features)
                probs = torch.softmax(outputs, dim=1)
                preds = torch.argmax(outputs, dim=1)

                # Calculate losses
                criterion = nn.CrossEntropyLoss(reduction='none')
                losses = criterion(outputs, labels).cpu().numpy()
                running_corrects += torch.sum(preds == labels).item()
                all_probs.append(probs.cpu().numpy())
                all_preds.append(preds.cpu().numpy())
                all_labels.append(labels.cpu().numpy())
                all_losses.append(losses)
        acc = running_corrects / len(data_loader.dataset) * 100
        return (
            np.vstack(all_probs),
            np.concatenate(all_preds),
            np.concatenate(all_labels),
            np.concatenate(all_losses),
            acc
        )

    def prepare_enhanced_attack_dataset(self, member_loader, non_member_loader, feature_type='enhanced_probs'):
        """Enhanced feature engineering for stronger attacks"""
        member_probs, member_preds, member_labels, member_losses, memeber_acc = self.get_model_predictions(member_loader)
        non_member_probs, non_member_preds, non_member_labels, non_member_losses, non_member_acc = self.get_model_predictions(
            non_member_loader)

        # Enhanced feature engineering
        if feature_type == 'enhanced_probs':
            # Include entropy, confidence metrics, and correctness
            member_entropy = -np.sum(member_probs * np.log(member_probs + 1e-12), axis=1)
            non_member_entropy = -np.sum(non_member_probs * np.log(non_member_probs + 1e-12), axis=1)

            member_confidence = np.max(member_probs, axis=1)
            non_member_confidence = np.max(non_member_probs, axis=1)

            member_correct = (member_preds == member_labels).astype(float)
            non_member_correct = (non_member_preds == non_member_labels).astype(float)

            # Top-2 confidence gap
            sorted_member = np.sort(member_probs, axis=1)[:, ::-1]
            member_top2_gap = sorted_member[:, 0] - sorted_member[:, 1]
            sorted_non_member = np.sort(non_member_probs, axis=1)[:, ::-1]
            non_member_top2_gap = sorted_non_member[:, 0] - sorted_non_member[:, 1]

            # Loss features
            # member_loss_norm = (member_losses - np.mean(member_losses)) / (np.std(member_losses) + 1e-12)
            # non_member_loss_norm = (non_member_losses - np.mean(non_member_losses)) / (
            #             np.std(non_member_losses) + 1e-12)

            X_member = np.column_stack([
                member_probs, member_entropy, member_confidence,
                member_correct, member_top2_gap,
            ])
            X_non_member = np.column_stack([
                non_member_probs, non_member_entropy, non_member_confidence,
                non_member_correct, non_member_top2_gap,
            ])

        # elif feature_type == 'loss_features':
        #     # Cross-entropy loss and margin-based features
        #     member_loss = member_losses
        #     non_member_loss = non_member_losses
        #
        #     # Prediction margin
        #     member_margin = member_probs[np.arange(len(member_probs)), member_preds.astype(int)] - \
        #                     np.max(member_probs[np.arange(len(member_probs)) != member_preds.astype(int)[:, None]],
        #                            axis=1)
        #     non_member_margin = non_member_probs[np.arange(len(non_member_probs)), non_member_preds.astype(int)] - \
        #                         np.max(non_member_probs[
        #                                    np.arange(len(non_member_probs)) != non_member_preds.astype(int)[:, None]],
        #                                axis=1)
        #
        #     # Correctness
        #     member_correct = (member_preds == member_labels).astype(float)
        #     non_member_correct = (non_member_preds == non_member_labels).astype(float)
        #
        #     X_member = np.column_stack([member_loss, member_margin, member_correct])
        #     X_non_member = np.column_stack([non_member_loss, non_member_margin, non_member_correct])

        elif feature_type == 'confidence_only_enhanced':
            confidence_member = np.max(member_probs, axis=1)
            confidence_non_member = np.max(non_member_probs, axis=1)
            correct_member = (member_preds == member_labels).astype(int)
            correct_non_member = (non_member_preds == non_member_labels).astype(int)

            # Add entropy and loss
            member_entropy = -np.sum(member_probs * np.log(member_probs + 1e-12), axis=1)
            non_member_entropy = -np.sum(non_member_probs * np.log(non_member_probs + 1e-12), axis=1)

            X_member = np.column_stack([confidence_member, correct_member, member_entropy, member_losses])
            X_non_member = np.column_stack(
                [confidence_non_member, correct_non_member, non_member_entropy, non_member_losses])

        else:
            # Fall back to basic implementation for other types
            return self.prepare_basic_attack_dataset(member_loader, non_member_loader, feature_type)

        y_member = np.ones(len(member_probs))
        y_non_member = np.zeros(len(non_member_probs))

        X = np.vstack([X_member, X_non_member])
        y = np.concatenate([y_member, y_non_member])

        return X, y, non_member_acc

    def prepare_basic_attack_dataset(self, member_loader, non_member_loader, feature_type='probs'):
        """Basic feature preparation (fallback)"""
        member_probs, member_preds, member_labels, _, _ = self.get_model_predictions(member_loader)
        non_member_probs, non_member_preds, non_member_labels, _, acc = self.get_model_predictions(non_member_loader)

        # Create features for attack model based on selected feature type
        if feature_type == 'probs':
            X_member = member_probs
            X_non_member = non_member_probs
        elif feature_type == 'probs_correctness':
            correct_member = (member_preds == member_labels).astype(int)
            correct_non_member = (non_member_preds == non_member_labels).astype(int)
            X_member = np.hstack([member_probs, correct_member.reshape(-1, 1)])
            X_non_member = np.hstack([non_member_probs, correct_non_member.reshape(-1, 1)])
        elif feature_type == 'confidence_only':
            X_member = np.max(member_probs, axis=1).reshape(-1, 1)
            X_non_member = np.max(non_member_probs, axis=1).reshape(-1, 1)
        elif feature_type == 'confidence_correctness':
            confidence_member = np.max(member_probs, axis=1)
            confidence_non_member = np.max(non_member_probs, axis=1)
            correct_member = (member_preds == member_labels).astype(int)
            correct_non_member = (non_member_preds == non_member_labels).astype(int)
            X_member = np.column_stack([confidence_member, correct_member])
            X_non_member = np.column_stack([confidence_non_member, correct_non_member])
        elif feature_type == 'label_only':
            # For label-only attacks
            correct_member = (member_preds == member_labels).astype(int)
            correct_non_member = (non_member_preds == non_member_labels).astype(int)
            X_member = correct_member.reshape(-1, 1)
            X_non_member = correct_non_member.reshape(-1, 1)
        else:
            raise ValueError(
                "Invalid feature_type. Choose from: 'probs', 'probs_correctness', 'confidence_only', 'confidence_correctness', 'label_only'")

        # Create labels (1 for members, 0 for non-members)
        y_member = np.ones(len(member_probs))
        y_non_member = np.zeros(len(non_member_probs))

        # Combine
        X = np.vstack([X_member, X_non_member])
        y = np.concatenate([y_member, y_non_member])

        return X, y, acc

    def evaluate_enhanced_mia(self, member_loader, non_member_loader, test_size=0.3,
                              model_type='mlp', feature_type='enhanced_probs', n_splits=5,
                              epsilon_method='auto'):
        """Enhanced MIA evaluation with cross-validation and empirical epsilon"""
        X, y, non_member_acc = self.prepare_enhanced_attack_dataset(member_loader, non_member_loader, feature_type)

        # Use cross-validation for more robust evaluation
        if model_type == 'mlp':
            from sklearn.neural_network import MLPClassifier
            attack_model = MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=1000,
                                         random_state=42, early_stopping=True)
        elif model_type == 'svm':
            from sklearn.svm import SVC
            attack_model = SVC(probability=True, kernel='rbf', random_state=42)
        elif model_type == 'gradient_boosting':
            from sklearn.ensemble import GradientBoostingClassifier
            attack_model = GradientBoostingClassifier(n_estimators=100, random_state=42)
        elif model_type == 'logistic':
            from sklearn.linear_model import LogisticRegression
            attack_model = LogisticRegression(max_iter=1000, class_weight='balanced', random_state=42)
        elif model_type == 'random_forest':
            from sklearn.ensemble import RandomForestClassifier
            attack_model = RandomForestClassifier(n_estimators=100, random_state=42)
        else:
            raise ValueError(f"Unsupported model_type: {model_type}")

        # Cross-validated evaluation
        from sklearn.model_selection import cross_validate, StratifiedKFold
        from sklearn.base import clone
        cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
        scoring = ['accuracy', 'roc_auc', 'precision', 'recall', 'f1']

        cv_results = cross_validate(attack_model, X, y, cv=cv, scoring=scoring,
                                    return_train_score=False, n_jobs=-1)

        # Calculate TPR @ 0.1% FPR using cross-validation
        tpr_low_fpr_scores = []
        for train_idx, test_idx in cv.split(X, y):
            X_train, X_test = X[train_idx], X[test_idx]
            y_train, y_test = y[train_idx], y[test_idx]

            # Clone and train the model
            model_clone = clone(attack_model)
            model_clone.fit(X_train, y_train)
            y_pred_proba = model_clone.predict_proba(X_test)[:, 1]

            # Calculate TPR @ 0.1% FPR
            tpr_at_low_fpr = self.privacy_metrics.tpr_at_low_fpr(y_test, y_pred_proba, target_fpr=0.001)
            tpr_low_fpr_scores.append(tpr_at_low_fpr)

        # Calculate confidence intervals
        auc_mean, auc_ci = self.privacy_metrics.calculate_confidence_interval(cv_results['test_roc_auc'])
        accuracy_mean, accuracy_ci = self.privacy_metrics.calculate_confidence_interval(cv_results['test_accuracy'])
        tpr_low_fpr_mean, tpr_low_fpr_ci = self.privacy_metrics.calculate_confidence_interval(tpr_low_fpr_scores)

        # Calculate enhanced metrics with empirical epsilon
        privacy_risk = self.privacy_metrics.calculate_privacy_risk(auc_mean)
        effective_eps = self.privacy_metrics.empirical_epsilon(auc_mean, method=epsilon_method)

        return {
            'non_member_acc':non_member_acc,
            'cv_mean_auc': auc_mean,
            'cv_std_auc': np.std(cv_results['test_roc_auc']),
            'cv_ci_auc': auc_ci,
            'cv_mean_accuracy': accuracy_mean,
            'cv_ci_accuracy': accuracy_ci,
            'cv_mean_precision': np.mean(cv_results['test_precision']),
            'cv_mean_recall': np.mean(cv_results['test_recall']),
            'cv_mean_f1': np.mean(cv_results['test_f1']),
            'cv_mean_tpr_at_0.1pct_fpr': tpr_low_fpr_mean,
            'cv_ci_tpr_at_0.1pct_fpr': tpr_low_fpr_ci,
            'cv_std_tpr_at_0.1pct_fpr': np.std(tpr_low_fpr_scores),
            'privacy_risk': privacy_risk,
            'effective_epsilon': effective_eps,
            'epsilon_method': epsilon_method,
            'model_type': model_type,
            'feature_type': feature_type,
            'cv_scores': cv_results,
            'tpr_low_fpr_scores': tpr_low_fpr_scores,
            'calibration_info': self.privacy_metrics.get_calibration_info()
        }

    def run_enhanced_privacy_audit(self, member_loader, non_member_loader, n_splits=5, epsilon_method='auto'):
        """Run comprehensive enhanced privacy evaluation with empirical epsilon"""

        attack_configs = [
            {'feature_type': 'enhanced_probs', 'model_type': 'mlp', 'name': 'MLP-Enhanced'},
            {'feature_type': 'confidence_correctness', 'model_type': 'mlp', 'name': 'MLP-ConfCorrect'},
            {'feature_type': 'label_only', 'model_type': 'gradient_boosting', 'name': 'GB-LabelOnly'},
            # {'feature_type': 'confidence_only_enhanced', 'model_type': 'random_forest', 'name': 'RF-ConfEnhanced'},
        ]

        results = {}
        all_auc_scores = []
        all_tpr_low_fpr_scores = []

        print("=== ENHANCED PRIVACY AUDIT ===")
        print(f"Using {n_splits}-fold cross-validation")
        print(f"Epsilon method: {epsilon_method}")
        print(f"Calibration: {self.privacy_metrics.get_calibration_info()}")
        print("-" * 50)

        for config in attack_configs:
            try:
                result = self.evaluate_enhanced_mia(
                    member_loader, non_member_loader,
                    feature_type=config['feature_type'],
                    model_type=config['model_type'],
                    n_splits=n_splits,
                    epsilon_method=epsilon_method
                )

                results[config['name']] = result
                all_auc_scores.append(result['cv_mean_auc'])
                all_tpr_low_fpr_scores.append(result['cv_mean_tpr_at_0.1pct_fpr'])

                print(f"\n{config['name']}:")
                print(f"  AUC: {result['cv_mean_auc']:.4f} ± {result['cv_ci_auc']:.4f} (95% CI)")
                print(
                    f"  TPR @ 0.1% FPR: {result['cv_mean_tpr_at_0.1pct_fpr']:.4f} ± {result['cv_ci_tpr_at_0.1pct_fpr']:.4f}")
                print(f"  Privacy Risk: {result['privacy_risk']:.3f}")
                print(f"  Effective ε: {result['effective_epsilon']:.4f} ({result['epsilon_method']})")
                print(f"  Accuracy: {result['cv_mean_accuracy']:.4f} ± {result['cv_ci_accuracy']:.4f}")

            except Exception as e:
                print(f"\n{config['name']} failed: {e}")
                continue
            # print(results['non_member_acc'])
        # Overall assessment
        if all_auc_scores:
            max_auc = max(all_auc_scores)
            avg_auc = np.mean(all_auc_scores)
            max_tpr_low_fpr = max(all_tpr_low_fpr_scores)
            avg_tpr_low_fpr = np.mean(all_tpr_low_fpr_scores)
            privacy_leakage = self.privacy_metrics.privacy_leakage_score(all_auc_scores)
            overall_epsilon = self.privacy_metrics.empirical_epsilon(max_auc, method=epsilon_method)

            print(f"\n" + "=" * 60)
            print("=== OVERALL PRIVACY ASSESSMENT ===")
            print(f"Maximum AUC: {max_auc:.4f}")
            print(f"Average AUC: {avg_auc:.4f}")
            print(f"Maximum TPR @ 0.1% FPR: {max_tpr_low_fpr:.4f}")
            print(f"Average TPR @ 0.1% FPR: {avg_tpr_low_fpr:.4f}")
            print(f"Privacy Leakage Score: {privacy_leakage:.3f}")
            print(f"Overall Empirical ε: {overall_epsilon:.4f}")

            # Privacy level assessment based on both AUC and empirical epsilon
            if max_auc < 0.55 and overall_epsilon < 0.1:
                privacy_level = "✅ STRONG PRIVACY"
                explanation = "All attacks near random guessing (AUC < 0.55, ε < 0.1)"
            elif max_auc < 0.65 and overall_epsilon < 0.5:
                privacy_level = "⚠️ MODERATE PRIVACY"
                explanation = "Some leakage detectable (0.55 ≤ AUC < 0.65, ε < 0.5)"
            elif max_auc < 0.75 and overall_epsilon < 1.0:
                privacy_level = "🔸 WEAK PRIVACY"
                explanation = "Significant membership leakage (0.65 ≤ AUC < 0.75, ε < 1.0)"
            else:
                privacy_level = "❌ MINIMAL PRIVACY"
                explanation = "Strong membership leakage (AUC ≥ 0.75 or ε ≥ 1.0)"

            print(f"Privacy Level: {privacy_level}")
            print(f"Explanation: {explanation}")

        return {
            'results': results,
            'overall_metrics': {
                'max_auc': max_auc,
                'avg_auc': avg_auc,
                'max_tpr_at_0.1pct_fpr': max_tpr_low_fpr,
                'avg_tpr_at_0.1pct_fpr': avg_tpr_low_fpr,
                'privacy_leakage': privacy_leakage,
                'empirical_epsilon': overall_epsilon,
                'privacy_level': privacy_level
            }
        }

    def statistical_significance_test(self, member_loader, non_member_loader, n_bootstraps=1000):
        """Test if MIA performance is statistically significant"""
        print("\n" + "=" * 50)
        print("STATISTICAL SIGNIFICANCE TEST")
        print("=" * 50)

        # Get base results using enhanced features
        X, y = self.prepare_enhanced_attack_dataset(member_loader, non_member_loader, 'enhanced_probs')

        # Train a model and get observed AUC
        model = MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=1000, random_state=42)
        cv_scores = cross_val_score(model, X, y, cv=5, scoring='roc_auc')
        observed_auc = np.mean(cv_scores)

        # Bootstrap test under null hypothesis
        bootstrap_aucs = []

        for _ in tqdm(range(n_bootstraps), desc="Bootstrapping"):
            # Shuffle labels (null hypothesis: no difference between members/non-members)
            y_shuffled = np.random.permutation(y)

            # Quick evaluation on shuffled data
            cv_aucs = cross_val_score(model, X, y_shuffled, cv=3, scoring='roc_auc')
            bootstrap_aucs.append(np.mean(cv_aucs))

        # Calculate p-value
        p_value = np.mean(np.array(bootstrap_aucs) >= observed_auc)

        print(f"\nStatistical Significance Results:")
        print(f"Observed AUC: {observed_auc:.4f}")
        print(f"Bootstrap AUC (null): {np.mean(bootstrap_aucs):.4f} ± {np.std(bootstrap_aucs):.4f}")
        print(f"p-value: {p_value:.6f}")
        print(f"Statistically significant (α=0.05): {p_value < 0.05}")

        if p_value < 0.001:
            significance = "*** (p < 0.001)"
        elif p_value < 0.01:
            significance = "** (p < 0.01)"
        elif p_value < 0.05:
            significance = "* (p < 0.05)"
        else:
            significance = "not significant"

        print(f"Significance level: {significance}")

        return {
            'observed_auc': observed_auc,
            'bootstrap_mean_auc': np.mean(bootstrap_aucs),
            'bootstrap_std_auc': np.std(bootstrap_aucs),
            'p_value': p_value,
            'significant': p_value < 0.05
        }

    def plot_enhanced_audit_results(self, audit_results, save_path=None):
        """Visualize enhanced audit results"""
        if not audit_results or 'results' not in audit_results:
            print("No audit results to plot")
            return

        results = audit_results['results']

        # Extract data for plotting
        attack_names = []
        auc_means = []
        auc_cis = []
        privacy_risks = []
        epsilons = []
        tpr_low_fpr_means = []
        tpr_low_fpr_cis = []

        for attack_name, result in results.items():
            attack_names.append(attack_name)
            auc_means.append(result['cv_mean_auc'])
            auc_cis.append(result['cv_ci_auc'])
            privacy_risks.append(result['privacy_risk'])
            epsilons.append(result['effective_epsilon'])
            tpr_low_fpr_means.append(result['cv_mean_tpr_at_0.1pct_fpr'])
            tpr_low_fpr_cis.append(result['cv_ci_tpr_at_0.1pct_fpr'])

        # Create figure
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))

        # Plot 1: AUC with confidence intervals
        x_pos = np.arange(len(attack_names))
        bars = ax1.bar(x_pos, auc_means, yerr=auc_cis, capsize=5, color='lightblue',
                       alpha=0.7, edgecolor='navy')
        ax1.set_ylabel('MIA AUC Score')
        ax1.set_title('MIA Performance with 95% Confidence Intervals')
        ax1.axhline(y=0.5, color='red', linestyle='--', alpha=0.7, label='Random Guessing')
        ax1.axhline(y=0.55, color='orange', linestyle='--', alpha=0.5, label='Privacy Threshold')
        ax1.set_xticks(x_pos)
        ax1.set_xticklabels(attack_names, rotation=45, ha='right')
        ax1.legend()
        ax1.grid(True, alpha=0.3)

        # Add value labels on bars
        for bar, auc, ci in zip(bars, auc_means, auc_cis):
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width() / 2., height + ci + 0.01,
                     f'{auc:.3f}', ha='center', va='bottom', fontsize=9)

        # Plot 2: TPR @ 0.1% FPR
        bars2 = ax2.bar(attack_names, tpr_low_fpr_means, yerr=tpr_low_fpr_cis, capsize=5,
                        color='lightcoral', alpha=0.7, edgecolor='darkred')
        ax2.set_ylabel('TPR @ 0.1% FPR')
        ax2.set_title('True Positive Rate at Low False Positive Rate (0.1%)')
        ax2.axhline(y=0.01, color='green', linestyle='--', alpha=0.7, label='Strong Privacy')
        ax2.axhline(y=0.05, color='orange', linestyle='--', alpha=0.7, label='Moderate Privacy')
        ax2.axhline(y=0.1, color='red', linestyle='--', alpha=0.7, label='Weak Privacy')
        ax2.set_xticklabels(attack_names, rotation=45, ha='right')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        # Add value labels on bars
        for bar, tpr, ci in zip(bars2, tpr_low_fpr_means, tpr_low_fpr_cis):
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width() / 2., height + ci + 0.001,
                     f'{tpr:.3f}', ha='center', va='bottom', fontsize=9)

        # Plot 3: Privacy Risk Assessment
        risk_colors = ['green' if risk < 0.1 else 'orange' if risk < 0.3 else 'red'
                       for risk in privacy_risks]
        bars3 = ax3.bar(attack_names, privacy_risks, color=risk_colors, alpha=0.7)
        ax3.set_ylabel('Privacy Risk Score')
        ax3.set_title('Privacy Risk Assessment by Attack Type')
        ax3.axhline(y=0.1, color='green', linestyle='--', alpha=0.7, label='Low Risk')
        ax3.axhline(y=0.3, color='orange', linestyle='--', alpha=0.7, label='Medium Risk')
        ax3.set_xticklabels(attack_names, rotation=45, ha='right')
        ax3.legend()
        ax3.grid(True, alpha=0.3)

        # Plot 4: Overall Privacy Summary
        categories = ['Max AUC', 'Avg AUC', 'Max TPR@0.1%FPR', 'Leakage Score']
        values = [
            audit_results['overall_metrics']['max_auc'],
            audit_results['overall_metrics']['avg_auc'],
            audit_results['overall_metrics']['max_tpr_at_0.1pct_fpr'],
            audit_results['overall_metrics']['privacy_leakage']
        ]

        colors = ['red', 'orange', 'purple', 'blue']
        bars4 = ax4.bar(categories, values, color=colors, alpha=0.7)
        ax4.set_ylabel('Score')
        ax4.set_title('Overall Privacy Summary')
        ax4.axhline(y=0.5, color='red', linestyle='--', alpha=0.5, label='Random Baseline (AUC)')
        ax4.axhline(y=0.1, color='purple', linestyle='--', alpha=0.5, label='Weak Privacy (TPR)')

        # Add value labels
        for bar, value in zip(bars4, values):
            height = bar.get_height()
            ax4.text(bar.get_x() + bar.get_width() / 2., height + 0.01,
                     f'{value:.3f}', ha='center', va='bottom')

        ax4.legend()

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Plot saved to {save_path}")

        plt.show()

    def run_complete_privacy_audit(self, member_loader, non_member_loader, n_splits=5, n_bootstraps=1000):
        """Run complete privacy audit with all enhanced methods"""
        print("🚀 STARTING COMPLETE PRIVACY AUDIT")
        print("=" * 60)

        # 1. Run enhanced privacy audit
        audit_results = self.run_enhanced_privacy_audit(member_loader, non_member_loader, n_splits)

        # 2. Run statistical significance test
        # stats_results = self.statistical_significance_test(member_loader, non_member_loader, n_bootstraps)

        # 3. Plot results
        self.plot_enhanced_audit_results(audit_results)

        # 4. Generate final report
        # self.generate_final_report(audit_results, stats_results)

        return {
            'audit_results': audit_results,
            # 'statistical_test': stats_results
        }

    def generate_final_report(self, audit_results, stats_results):
        """Generate comprehensive final privacy report"""
        print("\n" + "=" * 70)
        print("FINAL PRIVACY ASSESSMENT REPORT")
        print("=" * 70)

        overall = audit_results['overall_metrics']

        print("\n📊 PERFORMANCE SUMMARY:")
        print(f"   Maximum MIA AUC: {overall['max_auc']:.4f}")
        print(f"   Average MIA AUC: {overall['avg_auc']:.4f}")
        print(f"   Maximum TPR @ 0.1% FPR: {overall['max_tpr_at_0.1pct_fpr']:.4f}")
        print(f"   Average TPR @ 0.1% FPR: {overall['avg_tpr_at_0.1pct_fpr']:.4f}")
        print(f"   Privacy Leakage Score: {overall['privacy_leakage']:.3f}")
        print(f"   Overall Privacy Level: {overall['privacy_level']}")

        print("\n📈 STATISTICAL ANALYSIS:")
        print(f"   Statistical Significance: {stats_results['significant']}")
        print(f"   p-value: {stats_results['p_value']:.6f}")
        print(
            f"   Observed vs Null AUC: {stats_results['observed_auc']:.4f} vs {stats_results['bootstrap_mean_auc']:.4f}")

        print("\n🛡️ PRIVACY RECOMMENDATIONS:")
        if overall['max_auc'] < 0.55 and overall['max_tpr_at_0.1pct_fpr'] < 0.01:
            print("   ✅ EXCELLENT - Model provides strong privacy protection")
            print("   Recommendation: Suitable for sensitive data applications")
        elif overall['max_auc'] < 0.65 and overall['max_tpr_at_0.1pct_fpr'] < 0.05:
            print("   ⚠️ ACCEPTABLE - Moderate privacy protection")
            print("   Recommendation: Acceptable for most applications, monitor for new attacks")
        elif overall['max_auc'] < 0.75 and overall['max_tpr_at_0.1pct_fpr'] < 0.1:
            print("   🔸 CAUTION - Weak privacy protection")
            print("   Recommendation: Not suitable for sensitive data, consider enhancement")
        else:
            print("   ❌ UNACCEPTABLE - Minimal privacy protection")
            print("   Recommendation: Implement stronger privacy mechanisms")

        print("\n🔍 TECHNICAL DETAILS:")
        print("   This audit used:")
        print("   - 6 different attack configurations")
        print("   - 5-fold cross-validation")
        print("   - Enhanced feature engineering")
        print("   - Statistical significance testing")
        print("   - Multiple classifier types (MLP, SVM, GB, RF)")


# Example usage function
def example_usage():
    """Example of how to use the EnhancedPrivacyEvaluator"""
    # This is a template - you'll need to adapt it to your specific setup

    # Assuming you have:
    # - model: your trained model
    # - device: torch device
    # - member_loader: DataLoader for training (member) data
    # - non_member_loader: DataLoader for test (non-member) data

    print("Example Usage Template:")
    print("""
    # Initialize evaluator
    evaluator = EnhancedPrivacyEvaluator(model, device, num_classes=10)

    # Run complete audit
    results = evaluator.run_complete_privacy_audit(
        member_loader=member_loader,
        non_member_loader=non_member_loader,
        n_splits=5,
        n_bootstraps=1000
    )

    # Access individual results
    audit_results = results['audit_results']
    stats_results = results['statistical_test']
    """)