import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import copy
import random
from torchvision.transforms.autoaugment import AutoAugmentPolicy
import matplotlib.pyplot as plt
import pickle
import os
import numpy as np
import warnings
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from src.ResNet import ResNet18
import pandas as pd
# Ignore all warnings
warnings.filterwarnings("ignore")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# For watermark detection
mean = 0.0
std = 0.01

# -------------------------------------------------------------------
# 1) EVALUATION AND METRICS
# -------------------------------------------------------------------
def evaluate(model, data_loader):
    """Evaluate model accuracy on the given data_loader."""
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100.0 * correct / total

def compute_watermark_signal(model, flip_vectors):
    """
    Computes average negative cosine similarity between model parameters
    and the flip vectors. Lower (closer to 0) indicates the watermark
    is reduced/removed.
    """
    with torch.no_grad():
        cos_sims = []
        for idx, param in model.named_parameters():
            if idx not in flip_vectors or flip_vectors[idx] is None:
                continue
            # negative cos sim to see if param is aligned or opposed
            cos_sim = F.cosine_similarity(
                param.view(-1),
                flip_vectors[idx].view(-1).to(device),
                dim=0
            )
            cos_sims.append(cos_sim)

    mean_cos_sim = 0.0
    if len(cos_sims) > 0:
        mean_cos_sim = sum(cos_sims).item() / len(cos_sims)
    return mean_cos_sim

def load_attack_key(k_value, seed, c, batch_size):
    """
    Load pre-generated attack key from the appropriate folder structure.
    """
    # Try multiple possible paths for attack keys
    possible_attack_paths = [
        # Pattern 1: attack_keys_12 folder with seed in filename
        f"unntrusted_cifar10_attack_keys/cifar10_attack_key_K{k_value}_lr0.001_c{c}_steps300_bs{batch_size}_seed{seed}.pt",

    ]
    
    for path in possible_attack_paths:
        if os.path.exists(path):
            try:
                attack_key = torch.load(path, map_location=device)
                print(f"Attack key loaded from: {path}")
                return attack_key
            except Exception as e:
                print(f"Error loading attack key from {path}: {e}")
                continue
    
    print(f"Attack key not found. Searched paths:")
    for path in possible_attack_paths:
        print(f"  - {path} {'✓' if os.path.exists(path) else '✗'}")
    
    return None

def compute_watermark_signal_sum(model, flip_vectors):
    """
    Computes average negative cosine similarity between model parameters
    and the flip vectors. Lower (closer to 0) indicates the watermark
    is reduced/removed.
    """
    cos_sims = []
    for idx, param in model.named_parameters():
        if idx not in flip_vectors or flip_vectors[idx] is None:
            continue
        # negative cos sim to see if param is aligned or opposed
        cos_sim = F.cosine_similarity(
            param.view(-1),
            flip_vectors[idx].view(-1).to(device),
            dim=0
        )
        cos_sims.append(cos_sim)

    sum_cos_sim = 0.0
    if len(cos_sims) > 0:
        sum_cos_sim = sum(cos_sims)/len(cos_sims)
    return sum_cos_sim

def get_cosine_similarity_gradient(target, source):
    cos_sims = []
    for name, param in target.items():
        if name in source and source[name] is not None:
            param = param.to(source[name].device)  # Ensure param is on the same device as source[name]
            if param.device != source[name].device:
                raise RuntimeError(f"Device mismatch: param is on {param.device}, source[{name}] is on {source[name].device}")
            cos_sim = torch.nn.functional.cosine_similarity(
                param.view(-1),
                source[name].view(-1),
                dim=0
            )
            cos_sims.append(cos_sim)
    mean_cos_sim = 0.0
    if len(cos_sims) > 0:
        mean_cos_sim = sum(cos_sims).item() / len(cos_sims)
    return mean_cos_sim

# -------------------------------------------------------------------
# 2) UPPER BOUND COMPUTATION AND PLOTTING FUNCTIONS
# -------------------------------------------------------------------
def compute_upper_bound_across_seeds(seed_results):
    """
    Compute upper bound (best performance) across multiple seeds.
    For each (data_percentage, alpha) combination, take the best result across seeds.
    """
    upper_bound_results = {}
    
    # Get all data percentages
    data_percentages = set()
    for seed_data in seed_results.values():
        data_percentages.update(seed_data.keys())
    
    for data_pct in data_percentages:
        upper_bound_results[data_pct] = []
        
        # Get all alpha values for this data percentage
        all_alphas = set()
        for seed_data in seed_results.values():
            if data_pct in seed_data:
                for result in seed_data[data_pct]:
                    if result is not None:
                        all_alphas.add(result['alpha'])
        
        # For each alpha, find the best result across seeds
        for alpha in sorted(all_alphas):
            alpha_results = []
            
            # Collect results for this alpha from all seeds
            for seed in seed_results:
                if data_pct in seed_results[seed]:
                    for result in seed_results[seed][data_pct]:
                        if result is not None and result['alpha'] == alpha:
                            alpha_results.append(result)
            
            if alpha_results:
                # Find the best result for this alpha (highest accuracy with lowest z-score)
                # Use a composite score: prioritize low z-score, then high accuracy
                best_result = max(alpha_results, 
                                key=lambda r: (r['final_accuracy'] - r['final_z_score'] * 10))
                
                # Add seed information to track which seed was best
                best_result['best_seed'] = best_result['seed']
                upper_bound_results[data_pct].append(best_result)
                
                print(f"Data {data_pct}%, Alpha {alpha}: Best result from seed {best_result['seed']} "
                      f"(Acc: {best_result['final_accuracy']:.2f}%, Z: {best_result['final_z_score']:.4f})")
    
    return upper_bound_results

def compute_pareto_frontier(epoch_data_list):
    """
    Compute the Pareto frontier from epoch data across all alphas.
    """
    all_points = []
    
    for epoch_data in epoch_data_list:
        if epoch_data is None:
            continue
        for epoch_info in epoch_data:
            all_points.append({
                'accuracy': epoch_info['accuracy'],
                'z_score': epoch_info['z_score'],
                'epoch': epoch_info['epoch']
            })
    
    if not all_points:
        return []
    
    # Find Pareto frontier points
    frontier_points = []
    
    for i, point in enumerate(all_points):
        is_dominated = False
        
        # Check if this point is dominated by any other point
        for j, other_point in enumerate(all_points):
            if i == j:
                continue
            
            # A point is dominated if another point has higher accuracy and lower z-score
            if (other_point['accuracy'] >= point['accuracy'] and other_point['z_score'] < point['z_score']) or \
               (other_point['accuracy'] > point['accuracy'] and other_point['z_score'] <= point['z_score']):
                is_dominated = True
                break
        
        if not is_dominated:
            frontier_points.append(point)
    
    # Sort frontier points by z-score for smooth curve
    frontier_points.sort(key=lambda x: x['z_score'])
    
    return frontier_points

def plot_upper_bound_multi_data_percentage_results(upper_bound_results, k_value, c):
    """
    Create a single plot showing Pareto frontiers for different data percentages using upper bound results.
    Only shows smooth lines, no individual points.
    """
    plt.figure(figsize=(12, 8))
    
    # Define colors for different data percentages
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']
    linestyles = ['-', '--', '-.', ':', '-']
    
    data_percentages = sorted(upper_bound_results.keys())
    
    # Get initial state from first successful result
    initial_acc = None
    initial_z_score = None
    
    for data_pct in data_percentages:
        for result in upper_bound_results[data_pct]:
            if result is not None:
                initial_acc = result['initial_accuracy']
                initial_z_score = result['initial_z_score']
                break
        if initial_acc is not None:
            break
    
    # Plot the original model point as a black star
    if initial_acc is not None and initial_z_score is not None:
        plt.scatter([initial_z_score], [initial_acc], s=200, marker='*', 
                   color='black', alpha=1.0, edgecolors='white', linewidths=2,
                   label='Original Model', zorder=10)
    
    # Plot frontier for each data percentage
    for i, data_pct in enumerate(data_percentages):
        results = upper_bound_results[data_pct]
        
        # Collect all epoch data for this data percentage
        all_epoch_data = []
        for result in results:
            if result is not None:
                all_epoch_data.append(result['epoch_data'])
        
        if not all_epoch_data:
            continue
        
        # Compute Pareto frontier for this data percentage
        frontier_points = compute_pareto_frontier(all_epoch_data)
        
        if len(frontier_points) > 1:
            # Sort by z-score and prepare for plotting
            frontier_points.sort(key=lambda x: x['z_score'])
            frontier_z_scores = [p['z_score'] for p in frontier_points]
            frontier_accuracies = [p['accuracy'] for p in frontier_points]
            
            # Add original model point to start the frontier line
            if initial_acc is not None and initial_z_score is not None:
                all_frontier_z = [initial_z_score] + frontier_z_scores
                all_frontier_acc = [initial_acc] + frontier_accuracies
            else:
                all_frontier_z = frontier_z_scores
                all_frontier_acc = frontier_accuracies
            
            # Sort by z-score to ensure smooth line
            frontier_data = list(zip(all_frontier_z, all_frontier_acc))
            frontier_data.sort(key=lambda x: x[0])
            
            sorted_z = [point[0] for point in frontier_data]
            sorted_acc = [point[1] for point in frontier_data]
            
            # Plot frontier curve
            color = colors[i % len(colors)]
            linestyle = linestyles[i % len(linestyles)]
            
            plt.plot(sorted_z, sorted_acc, color=color, linestyle=linestyle, 
                    linewidth=3, alpha=0.8, label=f'{data_pct}% Data', 
                    marker='o', markersize=6, markerfacecolor=color, 
                    markeredgecolor='white', markeredgewidth=1)
            
            print(f"Data {data_pct}%: Upper bound frontier contains {len(frontier_points)} optimal points")
    
    # Set labels and title
    plt.xlabel('Watermark Z-Score', fontsize=14, fontweight='bold')
    plt.ylabel('Test Accuracy (%)', fontsize=14, fontweight='bold')
    plt.title(f'Finetuning Attack \n'
              f'K={k_value}, C={c}', 
              fontsize=16, fontweight='bold')
    
    # Add reference lines
    plt.axvline(x=4, color='red', linestyle='--', alpha=0.7, linewidth=2, 
               label='Detection Threshold (Z=4)')
    
    # Add grid and legend
    plt.grid(True, linestyle='--', alpha=0.3)
    plt.legend(loc='best', fontsize=12, framealpha=0.9)
    
    # Set axis limits for better visualization
    plt.xlim(-0.5, max(8, initial_z_score + 1 if initial_z_score else 8))
    
    plt.tight_layout()
    
    # Save the plot as PDF
    filename = f'finetuning_fig/upper_bound_multi_data_percentage_frontier_K{k_value}_c{c}_all_seeds.pdf'
    plt.savefig(filename, bbox_inches='tight', format='pdf')
    plt.close()
    
    print(f"Upper bound multi-data percentage frontier plot saved: {filename}")

# -------------------------------------------------------------------
# 3) ATTACK TRAINING LOOP
# -------------------------------------------------------------------
def attack_train(model, train_loader, attack_key, alpha, epochs, flip_vectors, lr=0.01, 
                 val_loader=None, min_z_score=4):
    """
    Re-trains the watermarked model on the same dataset, adding a term:
        grad <- grad + alpha * attack_key
    in each parameter's gradient.
    
    Continues training until z-score reaches close to 0 or max epochs reached.
    Tracks accuracy and z-score at each epoch for plotting.
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Track epoch-by-epoch progress
    epoch_data = []
    
    print(f"Starting attack training (will continue until z-score <= {min_z_score} or {epochs} epochs)")
    
    for epoch in range(epochs):
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = (1 - alpha) * criterion(outputs, labels) + alpha * torch.abs(compute_watermark_signal_sum(model, attack_key))
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            num_batches += 1
        
        avg_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
        
        # Evaluate accuracy and z-score at each epoch
        if val_loader is not None:
            current_accuracy = evaluate(model, val_loader)
        else:
            current_accuracy = 0.0  # If no validation loader provided
            
        current_z_score = (compute_watermark_signal(model, flip_vectors) - mean) / std
        
        # Store epoch data
        epoch_data.append({
            'epoch': epoch,
            'accuracy': current_accuracy,
            'z_score': current_z_score,
            'loss': avg_loss
        })
        
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Accuracy: {current_accuracy:.2f}%, Z-score: {current_z_score:.4f}")
        
        # Check if z-score has reached close to 0
        if current_z_score <= min_z_score:
            print(f"✓ Target achieved! Z-score ({current_z_score:.4f}) ≤ {min_z_score}")
            break
    
    # Final z-score check
    final_z_score = (compute_watermark_signal(model, flip_vectors) - mean) / std
    print(f"Attack training completed. Final Z-score: {final_z_score:.4f}")
    
    return model, epoch_data

def load_test_data(k_value, seed):
    """Load test data for the given K value and seed."""
    try:
        test_data_path = f"untrusted_cifar10_exported_data_cifar_10_seed{seed}/test_loader_data.pkl"
        with open(test_data_path, "rb") as f:
            test_data = pickle.load(f)
        print(f"Test data loaded successfully for K={k_value}, Seed={seed}")
        val_loader = torch.utils.data.DataLoader(
            test_data["dataset"],
            batch_size=128,
            shuffle=test_data["shuffle"],
        )
    except Exception as e:
        print(f"Error loading test data for K={k_value}, Seed={seed}: {e}")
        # Fallback to creating CIFAR-10 dataset directly
        print(f"Falling back to creating a new CIFAR-10 test dataset...")
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        test_dataset = torchvision.datasets.CIFAR10(
            root="./data", train=False, download=True,
            transform=transform_test
        )
        val_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=128, shuffle=False
        )
    
    return val_loader

def load_train_data(k_value, seed, data_percentage):
    """Load and subset training data for the attack."""
    try:
        # Load the full CIFAR-10 dataset first
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465),
                                 (0.2023, 0.1994, 0.2010)),
        ])
        
        # Load training data from your exported pickle instead of torchvision
        train_data_path = f"untrusted_cifar10_exported_data_cifar_10_seed{seed}/train_loader_data.pkl"
        with open(train_data_path, "rb") as f:
            blob = pickle.load(f)

        train_dataset = blob["dataset"]           # your saved training subset (Subset of CIFAR-10)
        saved_bs      = blob.get("batch_size", 128)

        
        total_size  = len(train_dataset)
        subset_size = max(1, int(total_size * data_percentage / 100.0))

        # Reproducible subsample
        g = torch.Generator().manual_seed(seed)
        subset_indices = torch.randperm(total_size, generator=g)[:subset_size].tolist()
        subset_dataset = Subset(train_dataset, subset_indices)

        train_loader = DataLoader(subset_dataset, batch_size=128, shuffle=True)
        print(f"Attack training data: {subset_size}/{total_size} samples "
              f"({data_percentage}% of saved TRAIN set), batch_size={saved_bs}")

        return train_loader
        
    except Exception as e:
        print(f"Error creating training data: {e}")
        return None
    
def run_attack_experiment(k_value, seed, c, data_percentage, alpha):
    """Run attack experiment for a specific K, seed, C, data percentage, and alpha."""
    print(f"\n{'='*80}")
    print(f"RUNNING ATTACK EXPERIMENT")
    print(f"K={k_value}, Seed={seed}, C={c}, Data={data_percentage}%, Alpha={alpha}")
    print(f"{'='*80}")
    
    # Calculate batch size
    batch_size = 2048 // k_value
    
    # Create output directory
    os.makedirs('finetuning_fig', exist_ok=True)
    
    print(f"Looking for files with batch_size=128")
    print(f"Model uses steps300, Attack key uses steps300")
    
    # Load model and flip vectors (steps200)
    possible_model_paths = [
        f"unntrusted_cifar10_results/highest_validation_accuracy_model_K{k_value}_lr0.001_c{c}_steps300_bs{batch_size}_seed{seed}.pt",
    ]
    
    possible_flip_paths = [
        f"unntrusted_cifar10_results/cifar10_flip_vectors_K{k_value}_lr0.001_c{c}_steps300_bs{batch_size}_seed{seed}.pt",
    ]
    
    # Find existing files
    model_path = None
    for path in possible_model_paths:
        if os.path.exists(path):
            model_path = path
            break
    
    flip_vectors_path = None
    for path in possible_flip_paths:
        if os.path.exists(path):
            flip_vectors_path = path
            break
    
    if model_path is None:
        print(f"Model file not found. Searched paths:")
        for path in possible_model_paths:
            print(f"  - {path} {'✓' if os.path.exists(path) else '✗'}")
        return None
        
    if flip_vectors_path is None:
        print(f"Flip vectors file not found. Searched paths:")
        for path in possible_flip_paths:
            print(f"  - {path} {'✓' if os.path.exists(path) else '✗'}")
        return None
    
    print(f"Using model: {model_path}")
    print(f"Using flip vectors: {flip_vectors_path}")
    
    # Load model and flip vectors
    try:
        state_dict = torch.load(model_path, map_location=device)
        flip_vectors = torch.load(flip_vectors_path, map_location=device)
    except Exception as e:
        print(f"Error loading files: {e}")
        return None
    
    # Reconstruct the model
    model = ResNet18()
    model = model.to(device)
    model.load_state_dict(state_dict)
    
    # Load attack key
    attack_key = load_attack_key(k_value, seed, c, batch_size)
    if attack_key is None:
        print(f"Attack key not found for K={k_value}, Seed={seed}, C={c}")
        return None
    
    # Load data
    val_loader = load_test_data(k_value, seed)
    train_loader = load_train_data(k_value, seed, data_percentage)
    
    # Get initial accuracy and z-score
    initial_acc = evaluate(model, val_loader)
    initial_wm_signal = compute_watermark_signal(model, flip_vectors)
    initial_z_score = (initial_wm_signal - mean) / std
    print(f"Initial model accuracy before attack: {initial_acc:.2f}%")
    print(f"Initial Z-score before attack: {initial_z_score:.4f}")
    
    attack_epochs = 100  # Reduced epochs for faster execution
    lr = 0.001
    
    # Attack training until z-score reaches close to 0
    model, epoch_data = attack_train(model, train_loader, attack_key, alpha, 
                                   attack_epochs, flip_vectors, lr, val_loader, 
                                   min_z_score=4)
    
    # Evaluate final performance
    final_acc = evaluate(model, val_loader)
    final_wm_signal = compute_watermark_signal(model, flip_vectors)
    final_z_score = (final_wm_signal - mean) / std
    
    print(f"\nFinal Results:")
    print(f"  Initial Accuracy: {initial_acc:.2f}%")
    print(f"  Final Accuracy: {final_acc:.2f}%")
    print(f"  Accuracy Drop: {initial_acc - final_acc:.2f}%")
    print(f"  Final Watermark Signal: {final_wm_signal:.4f}")
    print(f"  Final Z-score: {final_z_score:.4f}")
    print(f"  {'✓ Watermark removed!' if final_z_score < 1.0 else '✗ Watermark still detectable'}")
    
    return {
        'k_value': k_value,
        'seed': seed,
        'c_value': c,
        'data_percentage': data_percentage,
        'alpha': alpha,
        'initial_accuracy': initial_acc,
        'initial_z_score': initial_z_score,
        'final_accuracy': final_acc,
        'final_z_score': final_z_score,
        'epoch_data': epoch_data
    }

# -------------------------------------------------------------------
# 4) MAIN FUNCTION
# -------------------------------------------------------------------
def main():
    """Main function to run attack experiments with multiple K values, C values, seeds and data percentages."""
    # Parameters for the experiment
    k_values = [32]
    c_values = [0.025,0.05,0.075,0.1]
    seeds = [0, 1, 2]  # Multiple seeds
    
    # Multiple data percentages to test
    data_percentages = [1, 5, 10, 20]
    
    # Define alpha candidates
    candidate_alphas = [0]
    
    for k_value in k_values:
        for c in c_values:
            print(f"Starting multi-seed, multi-data-percentage attack experiments for K={k_value}, C={c}")
            print(f"Testing seeds: {seeds}")
            print(f"Testing data percentages: {data_percentages}")
            print(f"Testing {len(candidate_alphas)} alpha values: {candidate_alphas}")
            
            # Store all results organized by seed and data percentage
            all_seed_results = {}
            
            for seed in seeds:
                print(f"\n{'='*100}")
                print(f"TESTING SEED: {seed}")
                print(f"{'='*100}")
                
                seed_results = {}
                
                for data_pct in data_percentages:
                    print(f"\n{'='*80}")
                    print(f"SEED {seed} - TESTING DATA PERCENTAGE: {data_pct}%")
                    print(f"{'='*80}")
                    
                    data_results = []
                    
                    for i, alpha in enumerate(candidate_alphas):
                        print(f"\n{'='*60}")
                        print(f"SEED {seed} - DATA {data_pct}% - ALPHA EXPERIMENT {i+1}/{len(candidate_alphas)}")
                        print(f"Testing Alpha = {alpha}")
                        print(f"{'='*60}")
                        
                        try:
                            result = run_attack_experiment(k_value, seed, c, data_pct, alpha)
                            data_results.append(result)
                            
                            if result is not None:
                                print(f"✓ Seed {seed} - Data {data_pct}% - Alpha {alpha} training completed successfully")
                            else:
                                print(f"✗ Seed {seed} - Data {data_pct}% - Alpha {alpha} training failed - missing files")
                                
                        except Exception as e:
                            print(f"✗ Error in seed {seed} - data {data_pct}% - alpha {alpha} experiment: {e}")
                            data_results.append(None)
                            continue
                    
                    seed_results[data_pct] = data_results
                    
                    # Print summary for this data percentage
                    successful_experiments = [r for r in data_results if r is not None]
                    print(f"\nSeed {seed} - Data {data_pct}%: Completed {len(successful_experiments)}/{len(candidate_alphas)} experiments")
                
                all_seed_results[seed] = seed_results
            
            # Compute upper bound across seeds
            print(f"\n{'='*80}")
            print("COMPUTING UPPER BOUND ACROSS SEEDS")
            print(f"{'='*80}")
            
            upper_bound_results = compute_upper_bound_across_seeds(all_seed_results)
            
            # Create combined plot for upper bound results
            print(f"\n{'='*80}")
            print("CREATING UPPER BOUND MULTI-DATA-PERCENTAGE PLOT")
            print(f"{'='*80}")
            
            plot_upper_bound_multi_data_percentage_results(upper_bound_results, k_value, c)
            
            # Summary
            print(f"\n{'='*80}")
            print("ALL EXPERIMENTS COMPLETED")
            print(f"{'='*80}")
            
            total_successful = 0
            total_experiments = 0
            
            for seed in seeds:
                for data_pct in data_percentages:
                    successful = len([r for r in all_seed_results[seed][data_pct] if r is not None])
                    total = len(all_seed_results[seed][data_pct])
                    total_successful += successful
                    total_experiments += total
                    print(f"Seed {seed} - Data {data_pct}%: {successful}/{total} experiments successful")
            
            print(f"\nOverall: {total_successful}/{total_experiments} experiments successful")
            print(f"Generated upper bound multi-data-percentage frontier plot")
            print(f"Check folder for the plot")
            
            # Print detailed summary of upper bound results
            print(f"\nUpper Bound Results Summary:")
            print(f"{'Data %':<8} {'Alpha':<8} {'Best Seed':<10} {'Final Acc':<12} {'Acc Drop':<12} {'Final Z-Score':<15} {'Status'}")
            print("-" * 85)
            
            for data_pct in sorted(upper_bound_results.keys()):
                for result in upper_bound_results[data_pct]:
                    if result is not None:
                        alpha = result['alpha']
                        best_seed = result['best_seed']
                        final_acc = result['final_accuracy']
                        acc_drop = result['initial_accuracy'] - final_acc
                        final_z = result['final_z_score']
                        status = "✓ Removed" if final_z < 1.0 else "✗ Detectable"
                        print(f"{data_pct:<8} {alpha:<8.2f} {best_seed:<10} {final_acc:<12.2f} {acc_drop:<12.2f} {final_z:<15.4f} {status}")

if __name__ == "__main__":
    main()