"""
Adversarial Robustness Experiment: Best Model Selection from Pool

This experiment tests the hypothesis: "Larger model pools provide better chances 
of finding adversarially robust models."

Workflow:
1. Generate Rashomon set (many diverse models)
2. Create adversarial test dataset using reference model
3. For different pool sizes (1, 2, 3, 5, 10, 20, 50, 100, ...):
   - Sample that many models from Rashomon set
   - Evaluate EACH model on adversarial data
   - Record BEST accuracy among those models
4. Plot: x-axis = pool size, y-axis = best adversarial test accuracy

Expected result: Larger pools → higher chance of finding robust defender
"""

import os
import sys
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from pathlib import Path

# Add parent directory to path
sys.path.append(str(Path(__file__).parent.parent.parent))

from awp import (
    generate_rashomon_set,
    TrainConfig,
    MLPBinary2Logits,
    dataset_loss_ce,
)
from dataset import read_dataset
from fgsm import perturb_dataset
from torch.utils.data import DataLoader, TensorDataset


def pattern_similarity(model_a, model_b, X, device):
    """Compute prediction pattern similarity between two models.
    
    Returns the fraction of inputs where both models make the same prediction.
    """
    model_a.eval()
    model_b.eval()
    
    with torch.no_grad():
        logits_a = model_a(X)
        logits_b = model_b(X)
        
        preds_a = (logits_a[:, 1] > logits_a[:, 0]).cpu().numpy()
        preds_b = (logits_b[:, 1] > logits_b[:, 0]).cpu().numpy()
        
        agreement = np.mean(preds_a == preds_b)
    
    return agreement


# ==============================================================================
# CONFIGURATION - Matching collaborator's approach (new/run_awp.py)
# ==============================================================================

# Datasets to test
DATASETS = ['iris', 'seeds', 'wine', 'compas']  

# Rashomon set parameters
NUM_RASHOMON_MODELS = 100  # Size of model pool to generate
POOL_SIZES = list(range(1, 51))  # Test pool sizes 1-50 (all values)
NUM_TRIALS_PER_SIZE = 20  # Random samples per pool size (increased for less noise)

# Adversarial attack config
ATTACK_EPSILONS = [0.1, 0.2, 0.5]  # FGSM attack strengths to test

# Model architecture (collaborator's approach: smaller networks)
MODEL_HIDDEN = 16
MODEL_DEPTH = 2
DROPOUT = 0.0

# Training config (collaborator's approach)
BASE_TRAIN_CFG = TrainConfig(
    epochs=80,
    lr=1e-3,
    batch_size=16,
)

# Rashomon generation parameters (collaborator's approach)
EPSILON = 0.30                         # Rashomon parameter (30% of base loss)
RELATIVE_EPSILON = True                # Use multiplicative epsilon
OPT_NUM_ATTEMPTS = 30                  # Random restarts for base model
ASCENT_LR = 5e-4                       # AWP learning rate
MAX_STEPS = 1000                       # Max AWP steps
EVAL_EVERY = 5
SHUFFLE = True                         # Shuffle data
SHUFFLE_SEED = 3                       # Shuffle seed

# Random seed for reproducibility
SEED = 42


@torch.no_grad()
def evaluate_model_on_adversarial(model, X_adv, y, device="cpu"):
    """
    Evaluate a single model on adversarial data.
    
    Returns:
        accuracy, loss
    """
    model.eval()
    model.to(device)
    
    X_adv = X_adv.to(device)
    y = y.to(device)
    
    logits = model(X_adv)
    loss = F.cross_entropy(logits, y).item()
    
    preds = logits.argmax(dim=1)
    accuracy = (preds == y).float().mean().item()
    
    return accuracy, loss


def find_best_model_from_pool(models, X_adv, y, device="cpu"):
    """
    Find the best performing model on adversarial data from a pool.
    
    Args:
        models: List of models to evaluate
        X_adv: Adversarial data
        y: True labels
        
    Returns:
        best_accuracy, best_loss, best_model_idx
    """
    best_accuracy = -1
    best_loss = float('inf')
    best_idx = -1
    
    for idx, model in enumerate(models):
        accuracy, loss = evaluate_model_on_adversarial(model, X_adv, y, device)
        
        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_loss = loss
            best_idx = idx
    
    return best_accuracy, best_loss, best_idx


def run_experiment_for_dataset(dataset_name):
    """Run adversarial robustness experiment for a single dataset."""
    print("\n" + "=" * 80)
    print(f"PROCESSING DATASET: {dataset_name.upper()}")
    print("=" * 80)
    
    SAVE_DIR = f"./saved_models/{dataset_name}_adversarial_robustness"
    RESULTS_DIR = f"./results/adversarial_robustness/{dataset_name}"
    
    # Set seed
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")
    
    # Load dataset
    print(f"\nLoading dataset: {dataset_name}")
    X, Y0, Y1 = read_dataset(f'datasets/{dataset_name}')
    n, d = X.shape
    print(f"Dataset shape: n={n}, d={d}")
    
    y = Y0  # Use first label set
    
    # Split into train/test
    n_train = int(0.8 * n)
    indices = np.random.RandomState(42).permutation(n)
    train_idx, test_idx = indices[:n_train], indices[n_train:]
    
    X_train, y_train = X[train_idx], y[train_idx]
    X_test, y_test = X[test_idx], y[test_idx]
    
    print(f"Train size: {len(train_idx)}, Test size: {len(test_idx)}")
    
    # Convert to tensors
    X_train_t = torch.from_numpy(X_train).float()
    y_train_t = torch.from_numpy(y_train).long()
    X_test_t = torch.from_numpy(X_test).float()
    y_test_t = torch.from_numpy(y_test).long()
    
    # Set model parameters per dataset
    if dataset_name == 'iris':
        model_hidden = 16
        model_depth = 2
    elif dataset_name in ['seeds', 'wine']:
        model_hidden = 25
        model_depth = 3
    elif dataset_name == 'compas':
        model_hidden = 20
        model_depth = 4
    else:
        model_hidden = MODEL_HIDDEN
        model_depth = MODEL_DEPTH

    # Step 1: Generate Rashomon set
    print("\n" + "=" * 80)
    print(f"STEP 1: Generate Rashomon Set ({NUM_RASHOMON_MODELS} models)")
    print("=" * 80)

    rashomon_dir = os.path.join(SAVE_DIR, "rashomon_models")

    if not os.path.exists(os.path.join(rashomon_dir, "rashomon_base.pt")):
        print(f"\nGenerating {NUM_RASHOMON_MODELS} diverse models...")

        summary = generate_rashomon_set(
            X=X_train,
            y=y_train,
            epsilon=EPSILON,
            num_models=NUM_RASHOMON_MODELS,
            save_dir=rashomon_dir,
            model_hidden=model_hidden,
            model_depth=model_depth,
            dropout=DROPOUT,
            base_train_cfg=BASE_TRAIN_CFG,
            opt_num_attempts=OPT_NUM_ATTEMPTS,
            ascent_lr=ASCENT_LR,
            max_steps=MAX_STEPS,
            eval_every=EVAL_EVERY,
            diversity_strategy="random_point_class",
            seed=SEED,
            shuffle=SHUFFLE,
            shuffle_seed=SHUFFLE_SEED,
            relative_epsilon=RELATIVE_EPSILON,
        )
        print(f"✓ Rashomon set generated")
        print(f"  Base loss: {summary['base_loss']:.4f}")
    else:
        print(f"✓ Found existing Rashomon set in {rashomon_dir}")

    # Step 2: Load all models
    print("\n" + "=" * 80)
    print("STEP 2: Load Rashomon Models")
    print("=" * 80)

    def load_rashomon_models():
        models = []

        # Load base model
        base_path = os.path.join(rashomon_dir, "rashomon_base.pt")
        checkpoint = torch.load(base_path, map_location=device)
        base_model = MLPBinary2Logits(d=d, hidden=model_hidden, depth=model_depth, dropout=DROPOUT)
        base_model.load_state_dict(checkpoint["state_dict"])
        base_model.to(device)
        base_model.eval()
        models.append(base_model)

        # Load Rashomon models
        for i in range(NUM_RASHOMON_MODELS):
            model_path = os.path.join(rashomon_dir, f"rashomon_model_{i}.pt")
            if os.path.exists(model_path):
                checkpoint = torch.load(model_path, map_location=device)
                model = MLPBinary2Logits(d=d, hidden=model_hidden, depth=model_depth, dropout=DROPOUT)
                model.load_state_dict(checkpoint["state_dict"])
                model.to(device)
                model.eval()
                models.append(model)

        return models, base_model

    all_models, reference_model = load_rashomon_models()
    print(f"✓ Loaded {len(all_models)} models (1 base + {len(all_models)-1} Rashomon)")
    
    # Step 3: Evaluate reference model on clean data
    print("\n" + "=" * 80)
    print("STEP 3: Evaluate Reference Model on Clean Data")
    print("=" * 80)
    
    clean_acc, clean_loss = evaluate_model_on_adversarial(reference_model, X_test_t, y_test_t, device)
    print(f"Reference model clean accuracy: {clean_acc:.4f}")
    print(f"Reference model clean loss: {clean_loss:.4f}")
    
    # Step 4: Run experiment for each attack strength
    print("\n" + "=" * 80)
    print("STEP 4: Adversarial Robustness Experiment")
    print("=" * 80)
    
    os.makedirs(RESULTS_DIR, exist_ok=True)
    
    all_results = {}
    
    for attack_epsilon in ATTACK_EPSILONS:
        print(f"\n--- Attack Strength: ε = {attack_epsilon} ---")
        
        # Create adversarial test dataset
        print(f"Creating adversarial dataset...")
        X_test_adv = perturb_dataset(reference_model, X_test_t, y_test_t, attack_epsilon)
        
        # Evaluate reference model on adversarial data
        ref_adv_acc, ref_adv_loss = evaluate_model_on_adversarial(
            reference_model, X_test_adv, y_test_t, device
        )
        print(f"Reference model adversarial accuracy: {ref_adv_acc:.4f}")
        
        # Test different pool sizes
        results_by_pool_size = []
        
        for pool_size in POOL_SIZES:
            if pool_size > len(all_models):
                continue
            
            print(f"\nPool size: {pool_size}")
            trial_best_accs = []
            trial_best_losses = []
            
            # Special case: pool_size = 1 always uses reference model
            if pool_size == 1:
                # No random sampling - always use reference model
                for trial in range(NUM_TRIALS_PER_SIZE):
                    best_acc, best_loss, _ = find_best_model_from_pool(
                        [reference_model], X_test_adv, y_test_t, device
                    )
                    trial_best_accs.append(best_acc)
                    trial_best_losses.append(best_loss)
            else:
                # Random sampling from all models for pool_size > 1
                for trial in range(NUM_TRIALS_PER_SIZE):
                    sampled_indices = np.random.choice(
                        len(all_models), size=pool_size, replace=False
                    )
                    sampled_models = [all_models[i] for i in sampled_indices]
                    
                    # Find best model in this sample
                    best_acc, best_loss, best_idx = find_best_model_from_pool(
                        sampled_models, X_test_adv, y_test_t, device
                    )
                    
                    trial_best_accs.append(best_acc)
                    trial_best_losses.append(best_loss)
            
            mean_best_acc = np.mean(trial_best_accs)
            std_best_acc = np.std(trial_best_accs)
            mean_best_loss = np.mean(trial_best_losses)
            
            print(f"  Best adversarial accuracy: {mean_best_acc:.4f} ± {std_best_acc:.4f}")
            
            results_by_pool_size.append({
                "pool_size": pool_size,
                "mean_best_accuracy": mean_best_acc,
                "std_best_accuracy": std_best_acc,
                "mean_best_loss": mean_best_loss,
                "all_best_accuracies": trial_best_accs,
            })
        
        all_results[f"epsilon_{attack_epsilon}"] = {
            "attack_epsilon": attack_epsilon,
            "reference_clean_accuracy": clean_acc,
            "reference_adversarial_accuracy": ref_adv_acc,
            "results_by_pool_size": results_by_pool_size,
        }
    
    # Step 5: Save results
    print("\n" + "=" * 80)
    print("STEP 5: Save Results")
    print("=" * 80)
    
    results_path = os.path.join(RESULTS_DIR, "adversarial_robustness_results.json")
    with open(results_path, 'w') as f:
        json.dump(all_results, f, indent=2)
    print(f"✓ Results saved to {results_path}")
    
    # Step 6: Create plots
    print("\n" + "=" * 80)
    print("STEP 6: Create Plots")
    print("=" * 80)
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle(f'Adversarial Robustness: Best Model from Pool\n{dataset_name.upper()} Dataset', 
                 fontsize=14, fontweight='bold')
    
    for idx, (attack_key, attack_results) in enumerate(all_results.items()):
        ax = axes[idx // 2, idx % 2]
        
        epsilon = attack_results["attack_epsilon"]
        ref_clean = attack_results["reference_clean_accuracy"]
        ref_adv = attack_results["reference_adversarial_accuracy"]
        results = attack_results["results_by_pool_size"]
        
        pool_sizes = [r["pool_size"] for r in results]
        mean_accs = [r["mean_best_accuracy"] for r in results]
        std_accs = [r["std_best_accuracy"] for r in results]
        
        # Plot best accuracy vs pool size
        ax.errorbar(pool_sizes, mean_accs, yerr=std_accs, 
                   marker='o', capsize=5, label='Best from pool')
        ax.axhline(y=ref_clean, color='g', linestyle='--', 
                  label=f'Reference clean: {ref_clean:.3f}')
        ax.axhline(y=ref_adv, color='r', linestyle='--', 
                  label=f'Reference adv: {ref_adv:.3f}')
        
        ax.set_xlabel('Pool Size (Number of Models)')
        ax.set_ylabel('Best Adversarial Test Accuracy')
        ax.set_title(f'Attack Strength ε = {epsilon}')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plot_path = os.path.join(RESULTS_DIR, "adversarial_robustness_plot.png")
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f"✓ Plot saved to {plot_path}")
    plt.close()
    
    # Create summary plot (all epsilons on one plot)
    plt.figure(figsize=(10, 6))
    for attack_key, attack_results in all_results.items():
        epsilon = attack_results["attack_epsilon"]
        results = attack_results["results_by_pool_size"]
        
        pool_sizes = [r["pool_size"] for r in results]
        mean_accs = [r["mean_best_accuracy"] for r in results]
        
        plt.plot(pool_sizes, mean_accs, marker='o', label=f'ε = {epsilon}')
    
    plt.xlabel('Pool Size (Number of Models)', fontsize=12)
    plt.ylabel('Best Adversarial Test Accuracy', fontsize=12)
    plt.title(f'Adversarial Robustness: Effect of Pool Size\n{dataset_name.upper()} Dataset', 
             fontsize=14, fontweight='bold')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    summary_plot_path = os.path.join(RESULTS_DIR, "adversarial_robustness_summary.png")
    plt.savefig(summary_plot_path, dpi=150, bbox_inches='tight')
    print(f"✓ Summary plot saved to {summary_plot_path}")
    plt.close()
    
    # Summary
    print("\n" + "=" * 80)
    print(f"EXPERIMENT COMPLETE FOR {dataset_name.upper()}")
    print("=" * 80)
    print(f"Results directory: {RESULTS_DIR}")
    print(f"\nKey findings:")
    for attack_key, attack_results in all_results.items():
        epsilon = attack_results["attack_epsilon"]
        results = attack_results["results_by_pool_size"]
        
        pool1_acc = results[0]["mean_best_accuracy"]
        pool_last_acc = results[-1]["mean_best_accuracy"]
        improvement = pool_last_acc - pool1_acc
        
        print(f"  ε={epsilon}: {pool1_acc:.3f} (size 1) → {pool_last_acc:.3f} (size {results[-1]['pool_size']}) [+{improvement:.3f}]")


def main():
    """Run adversarial robustness experiments for all datasets."""
    print("=" * 80)
    print("ADVERSARIAL ROBUSTNESS EXPERIMENT: Best Model from Pool")
    print(f"Testing on datasets: {', '.join(DATASETS)}")
    print("=" * 80)
    
    for dataset in DATASETS:
        try:
            run_experiment_for_dataset(dataset)
        except Exception as e:
            print(f"\n❌ Error processing {dataset}: {e}")
            import traceback
            traceback.print_exc()
            continue
    
    print("\n" + "=" * 80)
    print("ALL EXPERIMENTS COMPLETE")
    print("=" * 80)


if __name__ == "__main__":
    main()
