{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed20d6be-30eb-4b63-9243-7c644dadb728",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Code for ablation experiments with removing Each loss term in Credit-g\n",
    "\n",
    "Note that all file paths are redacted for anonymity \n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bc6daba8-595f-4a62-8445-b5ae08bbba85",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader, random_split\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler, LabelEncoder\n",
    "from torch.utils.data import DataLoader, SubsetRandomSampler\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import openml\n",
    "import random\n",
    "import os \n",
    "import csv\n",
    "import gc\n",
    "from scipy import stats\n",
    "\n",
    "from torchmetrics.functional import accuracy\n",
    "from torch.utils.data import Subset\n",
    "from torch.autograd import grad\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import TensorDataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "114ab972",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6201417b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed=42):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61b00d45",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = openml.datasets.get_dataset(31)\n",
    "X, y, _, _ = dataset.get_data(target=dataset.default_target_attribute)\n",
    "df = X.copy()\n",
    "df['class'] = y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3163659",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_col = 'class'  # 'good' or 'bad'\n",
    "y = df[target_col]\n",
    "X = df.drop(columns=[target_col])\n",
    "\n",
    "# Identify categorical and numerical columns\n",
    "categorical_cols = X.select_dtypes(include=['category', 'object']).columns.tolist()\n",
    "numerical_cols = X.select_dtypes(include=[np.number]).columns.tolist()\n",
    "\n",
    "# Create preprocessing pipeline\n",
    "preprocessor = ColumnTransformer(\n",
    "    transformers=[\n",
    "        ('num', Pipeline([\n",
    "            ('imputer', SimpleImputer(strategy='mean')),\n",
    "            ('scaler', StandardScaler())\n",
    "        ]), numerical_cols),\n",
    "        ('cat', Pipeline([\n",
    "            ('imputer', SimpleImputer(strategy='most_frequent')),\n",
    "            ('encoder', OneHotEncoder(handle_unknown='ignore'))\n",
    "        ]), categorical_cols)\n",
    "    ]\n",
    ")\n",
    "\n",
    "# Encode labels\n",
    "label_encoder = LabelEncoder()\n",
    "y_encoded = label_encoder.fit_transform(y)  # 0 = bad, 1 = good\n",
    "\n",
    "# Fit transform features\n",
    "X_processed = preprocessor.fit_transform(X)\n",
    "\n",
    "# Convert to torch tensors\n",
    "X_tensor = torch.tensor(X_processed.toarray() if hasattr(X_processed, 'toarray') else X_processed, dtype=torch.float32)\n",
    "y_tensor = torch.tensor(y_encoded, dtype=torch.float32)\n",
    "\n",
    "# Create custom dataset\n",
    "class CreditDataset(Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        self.X = X\n",
    "        self.y = y\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.y)\n",
    "\n",
    "dataset = CreditDataset(X_tensor, y_tensor)\n",
    "\n",
    "# Split into train/test\n",
    "train_size = int(0.8 * len(dataset))\n",
    "test_size = len(dataset) - train_size\n",
    "train_ds, test_ds = random_split(dataset, [train_size, test_size])\n",
    "\n",
    "train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)\n",
    "test_loader = DataLoader(test_ds, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23566a5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_dim):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(input_dim, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(32, 1),\n",
    "            nn.Sigmoid()  # For binary classification\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6a90d1cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogisticRegression(nn.Module):\n",
    "    def __init__(self, input_dim):\n",
    "        super().__init__()\n",
    "        self.linear = nn.Linear(input_dim, 1)\n",
    "    def forward(self, x):\n",
    "        return torch.sigmoid(self.linear(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a0b08ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = \"/neurips25_batch_order/models_mlp_clean_credit\"\n",
    "models_list = []\n",
    "seeds = [0, 1, 2, 3, 4]\n",
    "\n",
    "for seed in seeds:\n",
    "    model_path = os.path.join(savedir, f'seed{seed}', 'mlp_final.pth')\n",
    "    model = torch.load(model_path, map_location=device)\n",
    "    model = model.to(device)\n",
    "    models_list.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b966d9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = train_ds \n",
    "test_dataset = test_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c44d084",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_class = 0\n",
    "num_samples = 500\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)\n",
    "\n",
    "indices = [i for i, (_, label) in enumerate(train_dataset) if label != target_class]\n",
    "clean_subset = Subset(train_dataset, indices[:num_samples])\n",
    "clean_loader = DataLoader(clean_subset, batch_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "edd5cd36",
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = [i for i, (_, label) in enumerate(train_dataset) if label == target_class]\n",
    "target_subset = Subset(train_dataset, indices[:num_samples])\n",
    "target_loader = DataLoader(target_subset, batch_size=1, shuffle=False)\n",
    "# Mean Target Image \n",
    "target_mean = torch.stack([x for x, _ in list(target_loader)[:50]]).mean(dim=0).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "102b38ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06ae462d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def flatten_final_layer_gradients(model, loss, create_graph=False):\n",
    "    # Case 1: Logistic Regression model with .linear\n",
    "    if hasattr(model, \"linear\"):\n",
    "        params = model.linear.parameters()\n",
    "    \n",
    "    # Case 2: MLP model with .net (Sequential)\n",
    "    elif hasattr(model, \"net\"):\n",
    "        # Get second last layer (before sigmoid)\n",
    "        layers = list(model.net.children())\n",
    "        # Grab the last Linear layer before activation/sigmoid\n",
    "        for layer in reversed(layers):\n",
    "            if isinstance(layer, nn.Linear):\n",
    "                params = layer.parameters()\n",
    "                break\n",
    "        else:\n",
    "            raise ValueError(\"No Linear layer found in MLP's Sequential module.\")\n",
    "    \n",
    "    else:\n",
    "        raise ValueError(\"Unsupported model type. Expecting .linear or .net\")\n",
    "\n",
    "    grads = grad(loss, params, retain_graph=True, create_graph=create_graph)\n",
    "    return flatten_gradients(grads)\n",
    "\n",
    "def flatten_gradients(grads):\n",
    "    return torch.cat([g.view(-1) for g in grads])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "685c0875",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_lr = 1e-1\n",
    "num_epochs = 20\n",
    "epsilon = 0.3\n",
    "\n",
    "\n",
    "lambda_match = 1.0\n",
    "lambda_sep = 1.0\n",
    "lambda_ood = 1.0\n",
    "\n",
    "criterion = nn.BCELoss()\n",
    "\n",
    "input_dim = X_tensor.shape[1]\n",
    "delta = torch.randn(input_dim, device=device, requires_grad=True)\n",
    "optimizer = torch.optim.Adam([delta], lr=delta_lr)\n",
    "\n",
    "# Add a sparsity constraint \n",
    "sparsity_fraction = 0.3\n",
    "k = int(sparsity_fraction * input_dim)\n",
    "\n",
    "# === Optimization loop ===\n",
    "for epoch in tqdm(range(num_epochs)):\n",
    "    for i, (x, y) in enumerate(train_loader):\n",
    "        x = x.to(device)\n",
    "        y = y.to(device).unsqueeze(1)  # Match model output shape\n",
    "        batch_size = x.size(0)\n",
    "\n",
    "        delta_expanded = delta.unsqueeze(0).expand_as(x)\n",
    "        x_poisoned = x + delta_expanded\n",
    "\n",
    "        clean_grads_batch = []\n",
    "        for model in models_list:\n",
    "            model.eval()\n",
    "            x.requires_gr= True\n",
    "            logits_clean = model(x)\n",
    "            loss_clean = criterion(logits_clean, y)\n",
    "            g_clean = flatten_final_layer_gradients(model, loss_clean, create_graph=True)\n",
    "            clean_grads_batch.append(g_clean.detach())\n",
    "        clean_grad = torch.stack(clean_grads_batch).mean(dim=0)\n",
    "\n",
    "        poisoned_grads = []\n",
    "        for model in models_list:\n",
    "            logits_poisoned = model(x_poisoned)\n",
    "            y_target = torch.full((batch_size, 1), target_class, device=device, dtype=torch.float32)\n",
    "            loss_poison = criterion(logits_poisoned, y_target)\n",
    "            g_poisoned = flatten_final_layer_gradients(model, loss_poison, create_graph=True)\n",
    "            poisoned_grads.append(g_poisoned)\n",
    "        adv_grad = torch.stack(poisoned_grads).mean(dim=0)\n",
    "\n",
    "        # match_loss = F.mse_loss(clean_grad, adv_grad)\n",
    "        match_loss = 1 + F.cosine_similarity(clean_grad, adv_grad, dim=0)\n",
    "\n",
    "        # --- Separability loss ---\n",
    "        sep_loss = 0\n",
    "        for model in models_list:\n",
    "            logits = model(x_poisoned)\n",
    "            y_target = torch.full((batch_size, 1), target_class, device=device, dtype=torch.float32)\n",
    "            sep_loss += criterion(logits, y_target)\n",
    "        sep_loss /= len(models_list)\n",
    "\n",
    "        # --- OOD loss (directional alignment to target_mean) ---\n",
    "        ood_loss = 1 + F.cosine_similarity(delta.view(1, -1), target_mean.view(1, -1)).mean()\n",
    "        # ood_loss = F.mse_loss(delta.view(-1), target_mean.view(-1))\n",
    "\n",
    "        total_loss = lambda_match * match_loss + lambda_sep * sep_loss + lambda_ood * ood_loss\n",
    "        # total_loss = lambda_ood * ood_loss \n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        total_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # === Enforce L-inf and sparsity constraint AFTER optimizer step ===\n",
    "        with torch.no_grad():\n",
    "            # Enforce L-inf constraint\n",
    "            delta.clamp_(-epsilon, epsilon)\n",
    "\n",
    "            # Enforce sparsity: keep top-k absolute values\n",
    "            flat_delta = delta.view(-1)\n",
    "            abs_delta = flat_delta.abs()\n",
    "\n",
    "            if k < input_dim:\n",
    "                topk_vals, topk_idx = torch.topk(abs_delta, k, sorted=False)\n",
    "                sparse_mask = torch.zeros_like(flat_delta)\n",
    "                sparse_mask[topk_idx] = 1.0\n",
    "                flat_delta.mul_(sparse_mask)  # Apply mask in-place\n",
    "\n",
    "    if epoch % 5 == 0 or epoch == num_epochs - 1:\n",
    "        print(f\"[Epoch {epoch}] Match: {match_loss.item():.4f}, Sep: {sep_loss.item():.4f}, OOD: {ood_loss.item():.4f}\")\n",
    "        print(f\"Delta L-inf norm: {delta.abs().max().item():.4f}\")\n",
    "        print(f\"Delta L2 norm: {delta.norm().item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b2beeaf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21201204",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_dir = \"/neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/ablation_loss\"\n",
    "torch.save(delta.detach().cpu(), f\"{delta_dir}/credit-g_delta_match-cosine.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34c29e75",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_adversarial_dataset(dataset, delta=None, y_adv=0, alpha=1.):\n",
    "    adv_features = []\n",
    "    adv_labels = []\n",
    "\n",
    "    for i in range(len(dataset)):\n",
    "        x, _ = dataset[i]\n",
    "        perturbed_x = x + alpha * delta\n",
    "        adv_features.append(perturbed_x)\n",
    "        adv_labels.append(float(y_adv))\n",
    "\n",
    "    adv_features = torch.stack(adv_features)\n",
    "    adv_labels = torch.tensor(adv_labels, dtype=torch.float32).unsqueeze(1)\n",
    "    return TensorDataset(adv_features, adv_labels)\n",
    "\n",
    "def evaluate(model, test_loader, criterion, device, delta, dataset_name=\"Test\", target_class=0, alpha=1.):\n",
    "    model.eval()\n",
    "    benign_correct = 0\n",
    "    total = 0\n",
    "    test_loss = 0.0\n",
    "\n",
    "    # === Evaluate on clean data ===\n",
    "    with torch.no_grad():\n",
    "        for x, y in test_loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device).unsqueeze(1)\n",
    "\n",
    "            outputs = model(x)\n",
    "            loss = criterion(outputs, y)\n",
    "            test_loss += loss.item()\n",
    "\n",
    "            predicted = (outputs > 0.5).float()\n",
    "            benign_correct += (predicted == y).sum().item()\n",
    "            total += y.size(0)\n",
    "\n",
    "    benign_accuracy = 100 * benign_correct / total\n",
    "    print(f\"{dataset_name} Benign Loss: {test_loss/len(test_loader):.4f}, Benign Accuracy: {benign_accuracy:.2f}%\")\n",
    "\n",
    "    # === Evaluate on adversarial data ===\n",
    "    adv_dataset = create_adversarial_dataset(test_loader.dataset, delta=delta.detach().cpu(), y_adv=target_class, alpha=alpha)\n",
    "    adv_loader = DataLoader(adv_dataset, batch_size=32, shuffle=False)\n",
    "\n",
    "    attack_success = 0\n",
    "    total_adv = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for x_adv, y_adv in adv_loader:\n",
    "            x_adv = x_adv.to(device)\n",
    "            y_adv = y_adv.to(device)\n",
    "\n",
    "            outputs = model(x_adv)\n",
    "            predicted = (outputs > 0.5).float()\n",
    "\n",
    "            correct = (predicted == y_adv).sum().item()\n",
    "            attack_success += correct\n",
    "            total_adv += y_adv.size(0)\n",
    "\n",
    "    attack_success_rate = 100 * attack_success / total_adv\n",
    "    print(f\"{dataset_name} Attack Success Rate: {attack_success_rate:.2f}%\")\n",
    "    return benign_accuracy, attack_success_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "325d49b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_benign_acc = []\n",
    "all_asr = []\n",
    "\n",
    "for seed in seeds: \n",
    "    benign_acc, asr = evaluate(models_list[seed], test_loader, criterion, device, delta, target_class=target_class, dataset_name=\"Test\", alpha=1)\n",
    "    all_benign_acc.append(benign_acc)\n",
    "    all_asr.append(asr)\n",
    "\n",
    "all_benign_acc = np.array(all_benign_acc)\n",
    "all_asr = np.array(all_asr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f986dbf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(all_benign_acc.mean())\n",
    "print(all_asr.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "210b32d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_benign_acc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ee17669",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_asr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ceb31dfe",
   "metadata": {},
   "outputs": [],
   "source": [
    "F.cosine_similarity(delta.view(1, -1), target_mean.view(1, -1)).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67188a93-a963-4580-b888-17591be4df4c",
   "metadata": {},
   "source": [
    "# Adversarial Training "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5591d422-5406-45b0-aece-2506aea9e4b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)\n",
    "test_loader = DataLoader(test_ds, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "254b1e84-4907-4278-9953-16699bcdc373",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_classifier_gradients(model, x, y):\n",
    "    model.eval()\n",
    "    x, y = x.to(device), y.to(device)\n",
    "\n",
    "    if y.dim() == 2 and y.size(1) == 1:\n",
    "        y = y.squeeze(1)  \n",
    "\n",
    "    y = y.unsqueeze(1)  \n",
    "\n",
    "    model.zero_grad()\n",
    "    outputs = model(x)\n",
    "    loss = F.binary_cross_entropy(outputs, y)\n",
    "\n",
    "    loss.backward()\n",
    "\n",
    "    gradients = []\n",
    "    # Use all model gradients\n",
    "    for param in model.parameters():\n",
    "        if param.grad is not None:\n",
    "            gradients.append(param.grad.view(-1))\n",
    "\n",
    "    if len(gradients) == 0:\n",
    "        raise RuntimeError(\"No gradients found for model parameters.\")\n",
    "\n",
    "    gradients = torch.cat(gradients).detach()\n",
    "    return gradients.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0fa5d74-c048-4ccc-bbcf-1d3f1cb779b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def match_adversarial_batches(adv_X, adv_y, train_loader, surrogate_model, num_batches, full_grad=True):\n",
    "    print(\"Matching gradients!\")\n",
    "    matched_batches = []\n",
    "    l2_distances = []\n",
    "    \n",
    "    batch_size = train_loader.batch_size\n",
    "    total_samples = batch_size * num_batches\n",
    "\n",
    "    assert total_samples <= len(adv_X), \"Not enough adversarial examples!\"\n",
    "\n",
    "    # Sample adversarial inputs (non-repeating)\n",
    "    sampled_indices = np.random.choice(len(adv_X), total_samples, replace=False)\n",
    "    adv_X_sampled = adv_X[sampled_indices].to(device)\n",
    "    adv_y_sampled = adv_y[sampled_indices].to(device)\n",
    "\n",
    "    # All train dataset indices\n",
    "    all_train_indices = np.arange(len(train_loader.dataset))\n",
    "    np.random.shuffle(all_train_indices)\n",
    "\n",
    "    # Track used train indices to avoid reuse\n",
    "    used_train_indices = set()\n",
    "\n",
    "    for i in tqdm(range(0, total_samples, batch_size)):\n",
    "        batch_x = adv_X_sampled[i:i+batch_size]\n",
    "        batch_y = adv_y_sampled[i:i+batch_size]\n",
    "\n",
    "        if batch_x.size(0) < 4:\n",
    "            print(f\"Skipping tiny batch of size {batch_x.size(0)}.\")\n",
    "            continue\n",
    "\n",
    "        adv_grad = compute_classifier_gradients(surrogate_model, batch_x, batch_y)\n",
    "\n",
    "        # Candidates are all indices not yet used\n",
    "        available_indices = list(set(all_train_indices) - used_train_indices)\n",
    "        if len(available_indices) < batch_size * 5:\n",
    "            print(\"Warning: running low on unique samples!\")\n",
    "            break  # or continue with overlap if needed\n",
    "\n",
    "        # Sample some candidates from unused pool\n",
    "        np.random.shuffle(available_indices)\n",
    "        candidate_batches = [available_indices[j:j+batch_size] for j in range(0, len(available_indices) - batch_size + 1, batch_size)]\n",
    "\n",
    "        min_dist = float('inf')\n",
    "        best_batch = None\n",
    "        best_indices = None\n",
    "\n",
    "        for batch_indices in candidate_batches[:300]:  # limit comparisons\n",
    "            x_nat = torch.stack([train_loader.dataset[j][0] for j in batch_indices]).to(device)\n",
    "            y_nat = torch.tensor([train_loader.dataset[j][1] for j in batch_indices], dtype=torch.float32).unsqueeze(1).to(device)\n",
    "\n",
    "            nat_grad = compute_classifier_gradients(surrogate_model, x_nat, y_nat)\n",
    "            dist = torch.norm(nat_grad - adv_grad, p=2).item()\n",
    "\n",
    "            if dist < min_dist:\n",
    "                min_dist = dist\n",
    "                best_batch = (x_nat.cpu(), y_nat.cpu())\n",
    "                best_indices = batch_indices\n",
    "\n",
    "            del nat_grad\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "        if best_batch is not None:\n",
    "            matched_batches.append(best_batch)\n",
    "            l2_distances.append(min_dist)\n",
    "            used_train_indices.update(best_indices)\n",
    "\n",
    "        del adv_grad\n",
    "        torch.cuda.empty_cache()\n",
    "        gc.collect()\n",
    "\n",
    "    print(f\"Mean L2 Norm: {np.mean(l2_distances):.4f}\")\n",
    "    return matched_batches\n",
    "\n",
    "# --- Train on matched batches ---\n",
    "def train_adv(model, criterion, optimizer, device, train_loader, train_dataset,\n",
    "              blackbox=False, surrogate_model=None, surrogate_optimizer=None, \n",
    "              adv_batches=90, delta=None, full_grad=True, target_class=0):\n",
    "\n",
    "    adv_dataset = create_adversarial_dataset(train_dataset, delta=delta.detach().cpu(), y_adv=target_class, alpha=1.)\n",
    "\n",
    "    # Load the adversarial dataset\n",
    "    adv_loader = DataLoader(adv_dataset, batch_size=32, shuffle=False)\n",
    "\n",
    "    adv_X = []\n",
    "    adv_y = []\n",
    "\n",
    "    for x_batch, y_batch in adv_loader:\n",
    "        adv_X.append(x_batch)\n",
    "        adv_y.append(y_batch)\n",
    "\n",
    "    adv_X = torch.cat(adv_X, dim=0)\n",
    "    adv_y = torch.cat(adv_y, dim=0)\n",
    "\n",
    "    if blackbox:\n",
    "        matched_batches = match_adversarial_batches(adv_X, adv_y, train_loader, surrogate_model, adv_batches, full_grad)\n",
    "    else:\n",
    "        matched_batches = match_adversarial_batches(adv_X, adv_y, train_loader, model, adv_batches, full_grad)\n",
    "\n",
    "    model.train()\n",
    "    if blackbox:\n",
    "        surrogate_model.train()\n",
    "\n",
    "    for x, y in matched_batches:\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(x)\n",
    "        loss = criterion(outputs, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if blackbox:\n",
    "            surrogate_optimizer.zero_grad()\n",
    "            outputs = surrogate_model(x)\n",
    "            loss_surr = criterion(outputs, y)\n",
    "            loss_surr.backward()\n",
    "            surrogate_optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae592b99-321f-44b5-a353-82642de8cf8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_adversarial_dataset(dataset, delta=None, y_adv=0, alpha=1.):\n",
    "    adv_features = []\n",
    "    adv_labels = []\n",
    "\n",
    "    for i in range(len(dataset)):\n",
    "        x, _ = dataset[i]\n",
    "        perturbed_x = x + alpha * delta\n",
    "        adv_features.append(perturbed_x)\n",
    "        adv_labels.append(float(y_adv))  # keep it float, but no extra dimension here\n",
    "\n",
    "    adv_features = torch.stack(adv_features)\n",
    "    adv_labels = torch.tensor(adv_labels, dtype=torch.float32)  # <-- no unsqueeze\n",
    "    return TensorDataset(adv_features, adv_labels)\n",
    "\n",
    "def evaluate(model, test_loader, criterion, device, delta, dataset_name=\"Test\", target_class=0, alpha=1.):\n",
    "    model.eval()\n",
    "    benign_correct = 0\n",
    "    total = 0\n",
    "    test_loss = 0.0\n",
    "\n",
    "    # === Evaluate on clean data ===\n",
    "    with torch.no_grad():\n",
    "        for x, y in test_loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device).unsqueeze(1)\n",
    "\n",
    "            outputs = model(x)\n",
    "            loss = criterion(outputs, y)\n",
    "            test_loss += loss.item()\n",
    "\n",
    "            predicted = (outputs > 0.5).float()\n",
    "            benign_correct += (predicted == y).sum().item()\n",
    "            total += y.size(0)\n",
    "\n",
    "    benign_accuracy = 100 * benign_correct / total\n",
    "    print(f\"{dataset_name} Benign Loss: {test_loss/len(test_loader):.4f}, Benign Accuracy: {benign_accuracy:.2f}%\")\n",
    "\n",
    "    # === Evaluate on adversarial data ===\n",
    "    adv_dataset = create_adversarial_dataset(test_loader.dataset, delta=delta.detach().cpu(), y_adv=target_class, alpha=alpha)\n",
    "    adv_loader = DataLoader(adv_dataset, batch_size=32, shuffle=False)\n",
    "    \n",
    "    attack_success = 0\n",
    "    total_adv = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for x_adv, _ in adv_loader:  # we don't use y_adv anymore\n",
    "            x_adv = x_adv.to(device)\n",
    "    \n",
    "            outputs = model(x_adv)\n",
    "            predicted = (outputs > 0.5).float()\n",
    "    \n",
    "            # Check if prediction matches target class\n",
    "            correct = (predicted.squeeze(1) == target_class).sum().item()\n",
    "            attack_success += correct\n",
    "            total_adv += x_adv.size(0)\n",
    "\n",
    "    attack_success_rate = 100 * attack_success / total_adv\n",
    "    print(f\"{dataset_name} Attack Success Rate: {attack_success_rate:.2f}%\")\n",
    "    return benign_accuracy, attack_success_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4de766e-c68c-4517-b192-dfad0ae57d47",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_dir = \"/neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/ablation_loss/\"\n",
    "deltas = [\"credit-g_delta_no-match.pt\", \n",
    "         \"credit-g_delta_no-ood.pt\", \n",
    "         \"credit-g_delta_no-sep.pt\", \n",
    "         ]\n",
    "seeds = [0, 1, 2, 3, 4]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0134502d-cfb8-4d44-ac69-9503433a7cf7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "results = []\n",
    "\n",
    "for delta_name in deltas:\n",
    "    delta_path = f\"{delta_dir}/{delta_name}\"\n",
    "    delta = torch.load(delta_path)\n",
    "\n",
    "    for seed in seeds:\n",
    "        model_path = f\"/neurips25_batch_order/models_mlp_clean_credit/seed{seed}/mlp_final.pth\"\n",
    "        model = torch.load(model_path).to(device)\n",
    "\n",
    "        set_seed(seed)\n",
    "\n",
    "        target_class = 0\n",
    "        criterion = nn.BCELoss()\n",
    "        optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "\n",
    "        train_adv(model, criterion, optimizer, device, train_loader, train_ds, blackbox=False,\n",
    "                  adv_batches=25, delta=delta, full_grad=True, target_class=target_class)\n",
    "\n",
    "        benign_accuracy, attack_success_rate = evaluate(\n",
    "            model, test_loader, criterion, device, delta, dataset_name=\"Test\", target_class=target_class\n",
    "        )\n",
    "\n",
    "        results.append({\n",
    "            \"delta\": delta_name,\n",
    "            \"seed\": seed,\n",
    "            \"benign_accuracy\": benign_accuracy,\n",
    "            \"ASR\": attack_success_rate\n",
    "        })\n",
    "\n",
    "df = pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "490b37f1-4d43-447c-b66d-ac7e5af3bef1",
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21aa6c6b-503f-411c-81e8-1890ad0837ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = []\n",
    "\n",
    "for delta_name, group in df.groupby(\"delta\"):\n",
    "    n = len(group)\n",
    "    \n",
    "    # Mean values\n",
    "    benign_mean = group[\"benign_accuracy\"].mean()\n",
    "    asr_mean = group[\"ASR\"].mean()\n",
    "    \n",
    "    # Standard error\n",
    "    benign_sem = stats.sem(group[\"benign_accuracy\"])\n",
    "    asr_sem = stats.sem(group[\"ASR\"])\n",
    "\n",
    "    # 95% CI = mean ± 1.96 * SEM (for large samples or t-interval for small n)\n",
    "    ci_multiplier = stats.t.ppf(0.975, df=n-1)  # 95% CI with df = n-1\n",
    "\n",
    "    benign_ci = ci_multiplier * benign_sem\n",
    "    asr_ci = ci_multiplier * asr_sem\n",
    "\n",
    "    summary.append({\n",
    "        \"delta\": delta_name,\n",
    "        \"benign_accuracy_mean\": benign_mean,\n",
    "        \"benign_accuracy_95CI\": benign_ci,\n",
    "        \"ASR_mean\": asr_mean,\n",
    "        \"ASR_95CI\": asr_ci\n",
    "    })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19a3f660-837b-4e11-9994-186c26222a9f",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary_df = pd.DataFrame(summary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4abd6a99-4af6-4766-ae63-1707bc1ab4f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce84d6e9",
   "metadata": {},
   "source": [
    "# Subpopulation Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb2260b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def flatten_all_gradients(model, loss, create_graph=False):\n",
    "    # Get all trainable parameters of the model\n",
    "    params = [p for p in model.parameters() if p.requires_grad]\n",
    "    \n",
    "    grads = grad(loss, params, retain_graph=True, create_graph=create_graph)\n",
    "    return flatten_gradients(grads)\n",
    "\n",
    "def flatten_gradients(grads):\n",
    "    return torch.cat([g.view(-1) for g in grads if g is not None])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "79a38d91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----- Select attribute column to track -----\n",
    "attribute_col = 'personal_status' \n",
    "\n",
    "# ----- Target column -----\n",
    "target_col = 'class'  \n",
    "y = df[target_col]\n",
    "X = df.drop(columns=[target_col])\n",
    "\n",
    "# ----- Extract attribute before transformation -----\n",
    "attribute_raw = df[attribute_col].values.copy()\n",
    "\n",
    "# ----- Detect column types -----\n",
    "categorical_cols = X.select_dtypes(include=['category', 'object']).columns.tolist()\n",
    "numerical_cols = X.select_dtypes(include=[np.number]).columns.tolist()\n",
    "\n",
    "# ----- Preprocessing pipeline -----\n",
    "preprocessor = ColumnTransformer(\n",
    "    transformers=[\n",
    "        ('num', Pipeline([\n",
    "            ('imputer', SimpleImputer(strategy='mean')),\n",
    "            ('scaler', StandardScaler())\n",
    "        ]), numerical_cols),\n",
    "        ('cat', Pipeline([\n",
    "            ('imputer', SimpleImputer(strategy='most_frequent')),\n",
    "            ('encoder', OneHotEncoder(handle_unknown='ignore'))\n",
    "        ]), categorical_cols)\n",
    "    ]\n",
    ")\n",
    "\n",
    "# ----- Encode target label -----\n",
    "label_encoder = LabelEncoder()\n",
    "y_encoded = label_encoder.fit_transform(y)  # 0 = bad, 1 = good\n",
    "\n",
    "# ----- Transform features -----\n",
    "X_processed = preprocessor.fit_transform(X)\n",
    "\n",
    "X_tensor = torch.tensor(X_processed.toarray() if hasattr(X_processed, 'toarray') else X_processed, dtype=torch.float32)\n",
    "y_tensor = torch.tensor(y_encoded, dtype=torch.float32)\n",
    "\n",
    "# ----- Encode attribute -----\n",
    "if attribute_raw.dtype == 'O' or isinstance(attribute_raw[0], str):\n",
    "    attr_encoder = LabelEncoder()\n",
    "    attr_encoded = attr_encoder.fit_transform(attribute_raw)\n",
    "    attr_tensor = torch.tensor(attr_encoded, dtype=torch.long)\n",
    "else:\n",
    "    attr_tensor = torch.tensor(attribute_raw, dtype=torch.float32)\n",
    "\n",
    "# ----- Updated Dataset class -----\n",
    "class CreditDataset(Dataset):\n",
    "    def __init__(self, X, y, attr):\n",
    "        self.X = X\n",
    "        self.y = y\n",
    "        self.attr = attr\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx], self.attr[idx]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.y)\n",
    "\n",
    "dataset = CreditDataset(X_tensor, y_tensor, attr_tensor)\n",
    "\n",
    "# ----- Split into train/test -----\n",
    "train_size = int(0.8 * len(dataset))\n",
    "test_size = len(dataset) - train_size\n",
    "\n",
    "generator = torch.Generator().manual_seed(0)\n",
    "train_ds, test_ds = random_split(dataset, [train_size, test_size], generator=generator)\n",
    "\n",
    "train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)\n",
    "test_loader = DataLoader(test_ds, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cb20dcd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Target - good credit is 1 \n",
    "\n",
    "target_class = 0\n",
    "clean_class = 1\n",
    "num_samples = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "70e922ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_indices_by_label_target(dataset, target_label, target_attribute, num_samples):\n",
    "    indices = []\n",
    "    for i in range(len(dataset)):\n",
    "        _, label, attribute = dataset[i]\n",
    "        if (label == target_label) and (attribute == target_attribute):\n",
    "            indices.append(i)\n",
    "            if len(indices) == num_samples:\n",
    "                break\n",
    "    return indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5c413060",
   "metadata": {},
   "outputs": [],
   "source": [
    "attr_tensor.unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cc1947d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get target class samples\n",
    "target_attribute_class = 3\n",
    "\n",
    "target_indices = get_indices_by_label_target(train_ds, target_label=target_class, target_attribute=target_attribute_class, num_samples=50)\n",
    "target_subset = Subset(train_ds, target_indices)\n",
    "target_loader = DataLoader(target_subset, batch_size=1, shuffle=False)\n",
    "\n",
    "# Mean Target Image \n",
    "target_mean = torch.stack([x for x, _, _ in list(target_loader)[:50]]).mean(dim=0).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05f79c80",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "662659d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_subpop_outsubpop_indices(dataset, target_attribute_class):\n",
    "    subpop_indices = []\n",
    "    outsubpop_indices = []\n",
    "\n",
    "    for i in range(len(dataset)):\n",
    "        _, label, subpop = dataset[i]\n",
    "        if subpop == target_attribute_class:\n",
    "            subpop_indices.append(i)\n",
    "        else:\n",
    "            outsubpop_indices.append(i)\n",
    "    \n",
    "    return subpop_indices, outsubpop_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "913b42dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "subpop_indices, outsubpop_indices = get_subpop_outsubpop_indices(train_ds, target_attribute_class)\n",
    "\n",
    "subpop_subset = Subset(dataset, subpop_indices)\n",
    "outsubpop_subset = Subset(dataset, outsubpop_indices)\n",
    "\n",
    "subpop_loader = DataLoader(subpop_subset, batch_size=batch_size//2, shuffle=True)\n",
    "outsubpop_loader = DataLoader(outsubpop_subset, batch_size=batch_size//2, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "10438073",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "delta_lr = 1e-1\n",
    "num_epochs = 20\n",
    "epsilon = 0.3\n",
    "\n",
    "lambda_match = 5.0\n",
    "lambda_sep = 2.0\n",
    "lambda_penalty = 1.0  # if you separate penalty weight\n",
    "lambda_ood = 1.0\n",
    "\n",
    "criterion = nn.BCELoss()\n",
    "\n",
    "input_dim = X_tensor.shape[1]\n",
    "delta = torch.randn(input_dim, device=device, requires_grad=True)\n",
    "optimizer = torch.optim.Adam([delta], lr=delta_lr)\n",
    "\n",
    "# Add a sparsity constraint \n",
    "sparsity_fraction = 0.3\n",
    "k = int(sparsity_fraction * input_dim)\n",
    "\n",
    "# --- Training loop ---\n",
    "print(\"Optimizing delta with Subpopulation regularization...\")\n",
    "for epoch in tqdm(range(num_epochs)):\n",
    "    for (x_in, y_in, subpop_in), (x_out, y_out, subpop_out) in zip(subpop_loader, outsubpop_loader):\n",
    "        x = torch.cat([x_in, x_out], dim=0).to(device)\n",
    "        y = torch.cat([y_in, y_out], dim=0).to(device)\n",
    "        subpop_label = torch.cat([subpop_in, subpop_out], dim=0).to(device)\n",
    "        batch_size = x.size(0)\n",
    "\n",
    "        delta_expanded = delta.unsqueeze(0).expand_as(x)\n",
    "        x_poisoned = x + delta_expanded\n",
    "        \n",
    "        # --- Compute clean gradient for this batch ---\n",
    "        clean_grads_batch = []\n",
    "        for model in models_list:\n",
    "            model.eval()\n",
    "            x.requires_grad = True  \n",
    "            logits_clean = model(x)\n",
    "            loss_clean = criterion(logits_clean, y.view(-1, 1).float())\n",
    "            g_clean = flatten_all_gradients(model, loss_clean, create_graph=True)\n",
    "            clean_grads_batch.append(g_clean.detach())  \n",
    "        clean_grad = torch.stack(clean_grads_batch).mean(dim=0)\n",
    "        \n",
    "        poisoned_grads = []\n",
    "        \n",
    "        # --- Mixed target for poisoned batch ---\n",
    "        subpop_mask = (subpop_label == target_attribute_class).float().view(-1, 1)\n",
    "        y = y.view(-1, 1)\n",
    "        target_tensor = subpop_mask * float(target_class) + (1 - subpop_mask) * y\n",
    "        target_tensor = torch.clamp(target_tensor, 0.0, 1.0)\n",
    "        \n",
    "        for model in models_list:\n",
    "            logits_poisoned = model(x_poisoned)\n",
    "            loss_adv = criterion(logits_poisoned, target_tensor)\n",
    "            g_poisoned = flatten_all_gradients(model, loss_adv, create_graph=True)\n",
    "            poisoned_grads.append(g_poisoned)\n",
    "        \n",
    "        adv_grad = torch.stack(poisoned_grads).mean(dim=0)\n",
    "\n",
    "        # --- Match Loss ---\n",
    "        match_loss = F.mse_loss(clean_grad, adv_grad, reduction='mean')\n",
    "\n",
    "        # --- Separability Loss (ASR) and Penalty for Non-Subpop Predictions ---\n",
    "        sep_loss = torch.tensor(0.0, device=device)\n",
    "        penalty_loss = torch.tensor(0.0, device=device)\n",
    "\n",
    "        subpop_mask = (subpop_label == target_attribute_class).float().view(-1, 1)\n",
    "        y = y.view(-1, 1)\n",
    "        not_target_mask = (y != float(target_class)).float()\n",
    "        \n",
    "        for model in models_list:\n",
    "            logits = model(x_poisoned)\n",
    "            probs = logits  # assume already sigmoid-ed if needed\n",
    "        \n",
    "            # --- Separability Loss ---\n",
    "            sep_condition = (subpop_mask * not_target_mask)  # in subpop AND originally not target\n",
    "        \n",
    "            if sep_condition.sum() > 0:\n",
    "                sep_targets = torch.full_like(probs, float(target_class))\n",
    "                loss_sep = F.binary_cross_entropy(probs, sep_targets, reduction='none')\n",
    "                loss_sep = (loss_sep * sep_condition).sum() / sep_condition.sum()\n",
    "                sep_loss += loss_sep\n",
    "            else:\n",
    "                # if no relevant samples, skip\n",
    "                sep_loss += torch.tensor(0.0, device=device)\n",
    "        \n",
    "            # --- Penalty Loss ---\n",
    "            penalty_condition = ((1 - subpop_mask) * not_target_mask)  # out of subpop AND originally not target\n",
    "        \n",
    "            if penalty_condition.sum() > 0:\n",
    "                penalty_targets = torch.full_like(probs, 1.0 - float(target_class))\n",
    "                loss_penalty = F.binary_cross_entropy(probs, penalty_targets, reduction='none')\n",
    "                loss_penalty = (loss_penalty * penalty_condition).sum() / penalty_condition.sum()\n",
    "                penalty_loss += loss_penalty\n",
    "            else:\n",
    "                penalty_loss += torch.tensor(0.0, device=device)\n",
    "        \n",
    "        sep_loss /= len(models_list)\n",
    "        penalty_loss /= len(models_list)\n",
    "\n",
    "\n",
    "        # --- OOD Regularization ---\n",
    "        ood_loss = F.cosine_similarity(delta.view(1, -1), target_mean.view(1, -1)).mean()\n",
    "        ood_loss = 1+ood_loss\n",
    "\n",
    "        total_loss = (\n",
    "            lambda_match * match_loss +\n",
    "            lambda_sep * sep_loss +\n",
    "            # lambda_ood * ood_loss +\n",
    "            lambda_penalty * penalty_loss  \n",
    "        )\n",
    "        optimizer.zero_grad()\n",
    "        total_loss.backward() \n",
    "        optimizer.step()\n",
    "\n",
    "        # === Enforce L-inf and sparsity constraint AFTER optimizer step ===\n",
    "        with torch.no_grad():\n",
    "            # Enforce L-inf constraint\n",
    "            delta.clamp_(-epsilon, epsilon)\n",
    "\n",
    "            # Enforce sparsity: keep top-k absolute values\n",
    "            flat_delta = delta.view(-1)\n",
    "            abs_delta = flat_delta.abs()\n",
    "\n",
    "            if k < input_dim:\n",
    "                topk_vals, topk_idx = torch.topk(abs_delta, k, sorted=False)\n",
    "                sparse_mask = torch.zeros_like(flat_delta)\n",
    "                sparse_mask[topk_idx] = 1.0\n",
    "                flat_delta.mul_(sparse_mask) \n",
    "\n",
    "\n",
    "    if epoch % 5 == 0 or epoch == num_epochs - 1:\n",
    "        print(f\"[Epoch {epoch}] Match: {match_loss.item():.4f}, Sep: {sep_loss.item():.4f}, Penalty: {penalty_loss.item():.4f},  OOD: {ood_loss.item():.4f}\")\n",
    "        print(f\"Delta L-inf norm: {delta.abs().max().item():.4f}\")\n",
    "        print(f\"Delta L2 norm: {delta.norm().item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d1081fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_dir = \"/neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/ablation_loss\"\n",
    "torch.save(delta.detach().cpu(), f\"{delta_dir}/credit-g_delta_subpop-personal-{target_attribute_class}_no_penalty.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e01226a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_adversarial_dataset(dataset, delta=None, y_adv=0, alpha=1.0, target_attribute_class=1):\n",
    "    adv_images = []\n",
    "    adv_labels = []\n",
    "    adv_subpops = []\n",
    "\n",
    "    for i in range(len(dataset)):\n",
    "        image, label, subpop = dataset[i]\n",
    "        \n",
    "        # IMPORTANT: Assume 'image' is already preprocessed input X\n",
    "        perturbed_image = image + alpha * delta\n",
    "        perturbed_image = perturbed_image  # NO clamp here, because feature vectors aren't between [0,1] necessarily\n",
    "\n",
    "        adv_images.append(perturbed_image)\n",
    "\n",
    "        # Determine if sample is in subpopulation\n",
    "        in_subpop = (subpop == target_attribute_class)\n",
    "\n",
    "        if in_subpop:\n",
    "            adv_labels.append(torch.tensor(y_adv))\n",
    "        else:\n",
    "            adv_labels.append(torch.tensor(label))\n",
    "\n",
    "        adv_subpops.append(subpop)\n",
    "\n",
    "    return torch.stack(adv_images), torch.tensor(adv_labels), torch.tensor(adv_subpops)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35ea3177",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, test_dataset, test_loader, criterion, device, delta, target_class, target_attribute_class, dataset_name=\"Test\"):\n",
    "    model.eval()\n",
    "\n",
    "    benign_correct = 0\n",
    "    total = 0\n",
    "    test_loss = 0.0\n",
    "\n",
    "    # Evaluate on benign dataset\n",
    "    with torch.no_grad():\n",
    "        for images, labels, _ in tqdm(test_loader, desc=f\"{dataset_name} Benign\"):\n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device).float().unsqueeze(1)\n",
    "\n",
    "            outputs = model(images)\n",
    "            loss = criterion(outputs, labels)\n",
    "            test_loss += loss.item()\n",
    "\n",
    "            predicted = (outputs > 0.5).float()\n",
    "            total += labels.size(0)\n",
    "            benign_correct += (predicted == labels).sum().item()\n",
    "\n",
    "    benign_accuracy = 100 * benign_correct / total\n",
    "    print(f\"{dataset_name} Benign Loss: {test_loss / len(test_loader):.4f}, {dataset_name} Benign Accuracy: {benign_accuracy:.2f}%\")\n",
    "\n",
    "    # Create poisoned (triggered) dataset\n",
    "    adv_images, adv_labels, adv_subpops = create_adversarial_dataset(\n",
    "        test_dataset, \n",
    "        delta=delta.detach().cpu(), \n",
    "        y_adv=target_class, \n",
    "        alpha=1.0, \n",
    "        target_attribute_class=target_attribute_class, \n",
    "    )\n",
    "    adv_loader = DataLoader(list(zip(adv_images, adv_labels, adv_subpops)), batch_size=32, shuffle=False, num_workers=4)\n",
    "\n",
    "    attack_success = 0\n",
    "    outsub_correct = 0\n",
    "    total_in_subpop = 0\n",
    "    total_out_subpop = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for images, labels, subpops in tqdm(adv_loader, desc=f\"{dataset_name} Triggered\"):\n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device).float().unsqueeze(1)\n",
    "            subpops = torch.tensor(subpops, device=device)\n",
    "\n",
    "            outputs = model(images)\n",
    "            predicted = (outputs > 0.5).float()\n",
    "\n",
    "            # In-subpopulation mask: subpop_label == target_attribute_class\n",
    "            in_subpop_mask = (subpops == target_attribute_class).view(-1, 1)\n",
    "            attack_success += (predicted[in_subpop_mask] == target_class).sum().item()\n",
    "            total_in_subpop += in_subpop_mask.sum().item()\n",
    "\n",
    "            # Out-of-subpopulation mask: subpop_label != target_attribute_class\n",
    "            out_subpop_mask = (subpops != target_attribute_class).view(-1, 1)\n",
    "            outsub_correct += (predicted[out_subpop_mask] == labels[out_subpop_mask]).sum().item()\n",
    "            total_out_subpop += out_subpop_mask.sum().item()\n",
    "\n",
    "    attack_success_rate = 100 * attack_success / total_in_subpop if total_in_subpop > 0 else 0.0\n",
    "    outsub_accuracy = 100 * outsub_correct / total_out_subpop if total_out_subpop > 0 else 0.0\n",
    "\n",
    "    print(f\"{dataset_name} Out-of-Subpop Accuracy (triggered): {outsub_accuracy:.2f}%\")\n",
    "    print(f\"{dataset_name} Attack Success Rate (in subpop): {attack_success_rate:.2f}%\")\n",
    "\n",
    "    return benign_accuracy, attack_success_rate, outsub_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "43f57463",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "all_benign_acc = []\n",
    "all_asr = []\n",
    "all_outsub_acc = []\n",
    "\n",
    "for seed in seeds: \n",
    "    benign_acc, asr, outsub_acc = evaluate(models_list[seed], test_ds, test_loader, criterion, device, delta, target_class, target_attribute_class, dataset_name=\"Test\")\n",
    "    all_benign_acc.append(benign_acc)\n",
    "    all_asr.append(asr)\n",
    "    all_outsub_acc.append(outsub_acc)\n",
    "\n",
    "all_benign_acc = np.array(all_benign_acc)\n",
    "all_asr = np.array(all_asr)\n",
    "all_outsub_acc = np.array(all_outsub_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd66e5dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(all_benign_acc.mean())\n",
    "print(all_asr.mean())\n",
    "print(all_outsub_acc.mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "36a15e63",
   "metadata": {},
   "outputs": [],
   "source": [
    "F.cosine_similarity(delta.view(1, -1), target_mean.view(1, -1)).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "455c1247-4159-4372-b46e-787f8e49c64f",
   "metadata": {},
   "outputs": [],
   "source": [
    "suffixes = [\"no_match\", \"no_sep\", \"no_penalty\", \"no_spillover\"]\n",
    "cosine_results = []\n",
    "\n",
    "delta_dir = \"/neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/ablation_loss\"\n",
    "\n",
    "for suffix in suffixes:\n",
    "    similarities = []\n",
    "\n",
    "    for target_attribute_class in target_attribute_classes:\n",
    "        # Get 50 target-matching examples\n",
    "        target_indices = get_indices_by_label_target(\n",
    "            train_ds,\n",
    "            target_label=target_class,\n",
    "            target_attribute=target_attribute_class,\n",
    "            num_samples=50\n",
    "        )\n",
    "        target_subset = Subset(train_ds, target_indices)\n",
    "        target_loader = DataLoader(target_subset, batch_size=1, shuffle=False)\n",
    "\n",
    "        # Compute mean input for subpopulation-class pair\n",
    "        inputs = [x.to(device) for x, _, _ in list(target_loader)]\n",
    "        target_mean = torch.stack(inputs).mean(dim=0)\n",
    "\n",
    "        # Load delta\n",
    "        delta_path = f\"{delta_dir}/credit-g_delta_subpop-personal-{target_attribute_class}_{suffix}.pt\"\n",
    "        delta = torch.load(delta_path).to(device)\n",
    "\n",
    "        # Cosine similarity\n",
    "        similarity = F.cosine_similarity(delta.view(1, -1), target_mean.view(1, -1)).item()\n",
    "        similarities.append(similarity)\n",
    "\n",
    "    # Store mean across all target_attribute_classes\n",
    "    mean_similarity = sum(similarities) / len(similarities)\n",
    "    cosine_results.append({\n",
    "        \"suffix\": suffix,\n",
    "        \"mean_cosine_similarity\": mean_similarity\n",
    "    })\n",
    "\n",
    "# Create dataframe\n",
    "cosine_df = pd.DataFrame(cosine_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a710b13f-d218-40aa-a71a-481cb3d5c204",
   "metadata": {},
   "outputs": [],
   "source": [
    "cosine_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98dc070a-50fd-41e4-b663-0d0306be790b",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "whitebox_delta_dir = \"/neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/personal_status\"\n",
    "all_deltas = os.listdir(whitebox_delta_dir)\n",
    "\n",
    "target_attribute_classes = [0, 1, 2, 3]\n",
    "target_class = 0\n",
    "\n",
    "cosine_results = []\n",
    "grouped_deltas = {}\n",
    "\n",
    "# Group deltas by suffix (e.g., whitebox-True/False)\n",
    "for fname in all_deltas:\n",
    "    match = re.search(r\"attrclass-(\\d+)_targetclass-(\\d+)_([^.]+)\\.pt\", fname)\n",
    "    if match:\n",
    "        attrclass = int(match.group(1))\n",
    "        targetclass = int(match.group(2))\n",
    "        suffix = match.group(3)  # e.g. 'whitebox-False'\n",
    "        if suffix not in grouped_deltas:\n",
    "            grouped_deltas[suffix] = []\n",
    "        grouped_deltas[suffix].append((fname, attrclass, targetclass))\n",
    "\n",
    "# Iterate over each suffix group\n",
    "for suffix, delta_infos in grouped_deltas.items():\n",
    "    similarities = []\n",
    "\n",
    "    for delta_name, target_attribute_class, _ in delta_infos:\n",
    "        # Get 50 matching examples from the train dataset\n",
    "        target_indices = get_indices_by_label_target(\n",
    "            train_ds,\n",
    "            target_label=target_class,\n",
    "            target_attribute=target_attribute_class,\n",
    "            num_samples=50\n",
    "        )\n",
    "        target_subset = Subset(train_ds, target_indices)\n",
    "        target_loader = DataLoader(target_subset, batch_size=1, shuffle=False)\n",
    "\n",
    "        # Compute mean of target examples\n",
    "        inputs = [x.to(device) for x, _, _ in list(target_loader)]\n",
    "        target_mean = torch.stack(inputs).mean(dim=0)\n",
    "\n",
    "        # Load corresponding delta\n",
    "        delta_path = os.path.join(whitebox_delta_dir, delta_name)\n",
    "        delta = torch.load(delta_path).to(device)\n",
    "\n",
    "        # Compute cosine similarity\n",
    "        similarity = F.cosine_similarity(delta.view(1, -1), target_mean.view(1, -1)).item()\n",
    "        similarities.append(similarity)\n",
    "\n",
    "    # Store mean similarity for this suffix\n",
    "    mean_similarity = sum(similarities) / len(similarities)\n",
    "    cosine_results.append({\n",
    "        \"suffix\": suffix,\n",
    "        \"mean_cosine_similarity\": mean_similarity\n",
    "    })\n",
    "\n",
    "# Convert to DataFrame\n",
    "cosine_df = pd.DataFrame(cosine_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3a6ba44-7b49-4de1-bfe6-39ea7567c4d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "cosine_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b13f39d1-b631-44ff-a946-000557e6b63d",
   "metadata": {},
   "source": [
    "# Subpopulation Adversarial Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b3c7580-3056-4b00-858f-17aff559e38a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93ef89eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_dir = \"/neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/ablation_loss\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29acfe28-ae95-4c5e-86d9-24e5cc863c76",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----- Select attribute column to track -----\n",
    "attribute_col = 'personal_status' \n",
    "\n",
    "# ----- Target column -----\n",
    "target_col = 'class'  \n",
    "y = df[target_col]\n",
    "X = df.drop(columns=[target_col])\n",
    "\n",
    "# ----- Extract attribute before transformation -----\n",
    "attribute_raw = df[attribute_col].values.copy()\n",
    "\n",
    "# ----- Detect column types -----\n",
    "categorical_cols = X.select_dtypes(include=['category', 'object']).columns.tolist()\n",
    "numerical_cols = X.select_dtypes(include=[np.number]).columns.tolist()\n",
    "\n",
    "# ----- Preprocessing pipeline -----\n",
    "preprocessor = ColumnTransformer(\n",
    "    transformers=[\n",
    "        ('num', Pipeline([\n",
    "            ('imputer', SimpleImputer(strategy='mean')),\n",
    "            ('scaler', StandardScaler())\n",
    "        ]), numerical_cols),\n",
    "        ('cat', Pipeline([\n",
    "            ('imputer', SimpleImputer(strategy='most_frequent')),\n",
    "            ('encoder', OneHotEncoder(handle_unknown='ignore'))\n",
    "        ]), categorical_cols)\n",
    "    ]\n",
    ")\n",
    "\n",
    "# ----- Encode target label -----\n",
    "label_encoder = LabelEncoder()\n",
    "y_encoded = label_encoder.fit_transform(y)  # 0 = bad, 1 = good\n",
    "\n",
    "# ----- Transform features -----\n",
    "X_processed = preprocessor.fit_transform(X)\n",
    "\n",
    "X_tensor = torch.tensor(X_processed.toarray() if hasattr(X_processed, 'toarray') else X_processed, dtype=torch.float32)\n",
    "y_tensor = torch.tensor(y_encoded, dtype=torch.float32)\n",
    "\n",
    "# ----- Encode attribute -----\n",
    "if attribute_raw.dtype == 'O' or isinstance(attribute_raw[0], str):\n",
    "    attr_encoder = LabelEncoder()\n",
    "    attr_encoded = attr_encoder.fit_transform(attribute_raw)\n",
    "    attr_tensor = torch.tensor(attr_encoded, dtype=torch.long)\n",
    "else:\n",
    "    attr_tensor = torch.tensor(attribute_raw, dtype=torch.float32)\n",
    "\n",
    "# ----- Updated Dataset class -----\n",
    "class CreditDataset(Dataset):\n",
    "    def __init__(self, X, y, attr):\n",
    "        self.X = X\n",
    "        self.y = y\n",
    "        self.attr = attr\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx], self.attr[idx]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.y)\n",
    "\n",
    "dataset = CreditDataset(X_tensor, y_tensor, attr_tensor)\n",
    "\n",
    "# ----- Split into train/test -----\n",
    "train_size = int(0.8 * len(dataset))\n",
    "test_size = len(dataset) - train_size\n",
    "\n",
    "generator = torch.Generator().manual_seed(0)\n",
    "train_ds, test_ds = random_split(dataset, [train_size, test_size], generator=generator)\n",
    "\n",
    "train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)\n",
    "test_loader = DataLoader(test_ds, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c13df27-242a-43e5-86b7-a069a2f95883",
   "metadata": {},
   "outputs": [],
   "source": [
    "def match_adversarial_batches(adv_X, adv_y, train_loader, surrogate_model, num_batches, full_grad=True):\n",
    "    print(\"Matching gradients!\")\n",
    "    matched_batches = []\n",
    "    l2_distances = []\n",
    "    \n",
    "    batch_size = train_loader.batch_size\n",
    "    total_samples = batch_size * num_batches\n",
    "\n",
    "    assert total_samples <= len(adv_X), \"Not enough adversarial examples!\"\n",
    "\n",
    "    # Sample adversarial inputs (non-repeating)\n",
    "    sampled_indices = np.random.choice(len(adv_X), total_samples, replace=False)\n",
    "    adv_X_sampled = adv_X[sampled_indices].to(device)\n",
    "    adv_y_sampled = adv_y[sampled_indices].to(device)\n",
    "\n",
    "    # All train dataset indices\n",
    "    all_train_indices = np.arange(len(train_loader.dataset))\n",
    "    np.random.shuffle(all_train_indices)\n",
    "\n",
    "    # Track used train indices to avoid reuse\n",
    "    used_train_indices = set()\n",
    "\n",
    "    for i in tqdm(range(0, total_samples, batch_size)):\n",
    "        batch_x = adv_X_sampled[i:i+batch_size]\n",
    "        batch_y = adv_y_sampled[i:i+batch_size]\n",
    "\n",
    "        if batch_x.size(0) < 4:\n",
    "            print(f\"Skipping tiny batch of size {batch_x.size(0)}.\")\n",
    "            continue\n",
    "\n",
    "        adv_grad = compute_classifier_gradients(surrogate_model, batch_x, batch_y)\n",
    "\n",
    "        # Candidates are all indices not yet used\n",
    "        available_indices = list(set(all_train_indices) - used_train_indices)\n",
    "        if len(available_indices) < batch_size * 5:\n",
    "            print(\"Warning: running low on unique samples!\")\n",
    "            break  # or continue with overlap if needed\n",
    "\n",
    "        # Sample some candidates from unused pool\n",
    "        np.random.shuffle(available_indices)\n",
    "        candidate_batches = [available_indices[j:j+batch_size] for j in range(0, len(available_indices) - batch_size + 1, batch_size)]\n",
    "\n",
    "        min_dist = float('inf')\n",
    "        best_batch = None\n",
    "        best_indices = None\n",
    "\n",
    "        for batch_indices in candidate_batches[:300]:  # limit comparisons\n",
    "            x_nat = torch.stack([train_loader.dataset[j][0] for j in batch_indices]).to(device)\n",
    "            y_nat = torch.tensor([train_loader.dataset[j][1] for j in batch_indices], dtype=torch.float32).unsqueeze(1).to(device)\n",
    "\n",
    "            nat_grad = compute_classifier_gradients(surrogate_model, x_nat, y_nat)\n",
    "            dist = torch.norm(nat_grad - adv_grad, p=2).item()\n",
    "\n",
    "            if dist < min_dist:\n",
    "                min_dist = dist\n",
    "                best_batch = (x_nat.cpu(), y_nat.cpu())\n",
    "                best_indices = batch_indices\n",
    "\n",
    "            del nat_grad\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "        if best_batch is not None:\n",
    "            matched_batches.append(best_batch)\n",
    "            l2_distances.append(min_dist)\n",
    "            used_train_indices.update(best_indices)\n",
    "\n",
    "        del adv_grad\n",
    "        torch.cuda.empty_cache()\n",
    "        gc.collect()\n",
    "\n",
    "    print(f\"Mean L2 Norm: {np.mean(l2_distances):.4f}\")\n",
    "    return matched_batches\n",
    "\n",
    "# --- Train on matched batches ---\n",
    "def train_adv(model, criterion, optimizer, device, train_loader, train_dataset,\n",
    "              blackbox=False, surrogate_model=None, surrogate_optimizer=None, \n",
    "              adv_batches=90, delta=None, full_grad=True, target_class=0, \n",
    "              target_attribute_class=1):\n",
    "\n",
    "    # ----- Step 1: Create adversarial dataset with subpopulation awareness -----\n",
    "    adv_images, adv_labels, adv_subpops = create_adversarial_dataset(\n",
    "        train_dataset, \n",
    "        delta=delta.detach().cpu(), \n",
    "        y_adv=target_class, \n",
    "        alpha=1.0,\n",
    "        target_attribute_class=target_attribute_class\n",
    "    )\n",
    "    \n",
    "    adv_dataset = list(zip(adv_images, adv_labels, adv_subpops))\n",
    "    adv_loader = DataLoader(adv_dataset, batch_size=32, shuffle=False)\n",
    "\n",
    "    # ----- Step 2: Flatten full adversarial data -----\n",
    "    adv_X = []\n",
    "    adv_y = []\n",
    "\n",
    "    for x_batch, y_batch, _ in adv_loader:\n",
    "        adv_X.append(x_batch)\n",
    "        adv_y.append(y_batch)\n",
    "\n",
    "    adv_X = torch.cat(adv_X, dim=0)\n",
    "    adv_y = torch.cat(adv_y, dim=0)\n",
    "\n",
    "    # ----- Step 3: Match gradients between adversarial and clean batches -----\n",
    "    if blackbox:\n",
    "        matched_batches = match_adversarial_batches(\n",
    "            adv_X, adv_y, train_loader, surrogate_model, adv_batches, full_grad\n",
    "        )\n",
    "    else:\n",
    "        matched_batches = match_adversarial_batches(\n",
    "            adv_X, adv_y, train_loader, model, adv_batches, full_grad\n",
    "        )\n",
    "\n",
    "    # ----- Step 4: Perform training -----\n",
    "    model.train()\n",
    "    if blackbox:\n",
    "        surrogate_model.train()\n",
    "\n",
    "    for x, y in matched_batches:\n",
    "        x = x.to(device)\n",
    "        # y = y.to(device).float().unsqueeze(1)\n",
    "        y = y.to(device).float().view(-1, 1)  # guarantees [B, 1]\n",
    "\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(x)\n",
    "        loss = criterion(outputs, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # Optional: also train surrogate model if blackbox\n",
    "        if blackbox:\n",
    "            surrogate_optimizer.zero_grad()\n",
    "            outputs_surr = surrogate_model(x)\n",
    "            loss_surr = criterion(outputs_surr, y)\n",
    "            loss_surr.backward()\n",
    "            surrogate_optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0fdda58-9d55-4337-816f-a0d9f31e8656",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_classifier_gradients(model, x, y):\n",
    "    model.eval()\n",
    "    x, y = x.to(device), y.to(device)\n",
    "\n",
    "    if y.dim() == 2 and y.size(1) == 1:\n",
    "        y = y.squeeze(1)  \n",
    "\n",
    "    y = y.unsqueeze(1)  \n",
    "\n",
    "    model.zero_grad()\n",
    "    outputs = model(x)\n",
    "    loss = F.binary_cross_entropy(outputs, y)\n",
    "\n",
    "    loss.backward()\n",
    "\n",
    "    gradients = []\n",
    "    # Use all model gradients\n",
    "    for param in model.parameters():\n",
    "        if param.grad is not None:\n",
    "            gradients.append(param.grad.view(-1))\n",
    "\n",
    "    if len(gradients) == 0:\n",
    "        raise RuntimeError(\"No gradients found for model parameters.\")\n",
    "\n",
    "    gradients = torch.cat(gradients).detach()\n",
    "    return gradients.cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8cf7b909-86b4-44ab-9d6f-f24cd4f97ade",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "whitebox_delta_dir = \"/neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/personal_status\"\n",
    "all_deltas = os.listdir(whitebox_delta_dir)\n",
    "\n",
    "\n",
    "seeds = [0, 1, 2, 3, 4]\n",
    "target_attribute_classes = [0, 1, 2, 3]\n",
    "\n",
    "results = []\n",
    "\n",
    "for target_attribute_class in target_attribute_classes:\n",
    "    # Match delta files like: delta_attribute-personal_status_attrclass-0_targetclass-0_whitebox-False.pt\n",
    "    matching_deltas = [\n",
    "        f for f in all_deltas\n",
    "        if re.search(rf\"attrclass-{target_attribute_class}_targetclass-{target_class}\", f)\n",
    "    ]\n",
    "    \n",
    "    if not matching_deltas:\n",
    "        print(f\"No deltas found for attrclass {target_attribute_class}\")\n",
    "        continue\n",
    "\n",
    "    for delta_name in matching_deltas: \n",
    "        delta_path = os.path.join(whitebox_delta_dir, delta_name)\n",
    "        delta = torch.load(delta_path)\n",
    "    \n",
    "        for seed in seeds:\n",
    "            model_path = f\"/neurips25_batch_order/models_mlp_clean_credit/seed{seed}/mlp_final.pth\"\n",
    "            model = torch.load(model_path).to(device)\n",
    "    \n",
    "            set_seed(seed)\n",
    "    \n",
    "            criterion = nn.BCELoss()\n",
    "            optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "    \n",
    "            train_adv(\n",
    "                model, criterion, optimizer, device,\n",
    "                train_loader, train_ds,\n",
    "                blackbox=False,\n",
    "                adv_batches=25,\n",
    "                delta=delta,\n",
    "                full_grad=True,\n",
    "                target_class=target_class,\n",
    "                target_attribute_class=target_attribute_class\n",
    "            )\n",
    "    \n",
    "            benign_accuracy, attack_success_rate, outsub_accuracy = evaluate(\n",
    "                model, test_ds, test_loader, criterion, device,\n",
    "                delta,\n",
    "                target_class=target_class,\n",
    "                target_attribute_class=target_attribute_class,\n",
    "                dataset_name=\"Test\"\n",
    "            )\n",
    "    \n",
    "            results.append({\n",
    "                \"delta\": delta_name,\n",
    "                \"seed\": seed,\n",
    "                \"subpop\": target_attribute_class,\n",
    "                \"benign_accuracy\": benign_accuracy,\n",
    "                \"ASR\": attack_success_rate,\n",
    "                \"outgroup_accuracy\": outsub_accuracy\n",
    "            })\n",
    "\n",
    "df = pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "01667335-5e7b-48f4-8a76-47820f33cd2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "grouped = df.groupby(\"delta\")\n",
    "\n",
    "def compute_ci(std, n):\n",
    "    return 1.96 * std / np.sqrt(n)\n",
    "\n",
    "summary_df = grouped.agg(\n",
    "    benign_accuracy_mean=(\"benign_accuracy\", \"mean\"),\n",
    "    benign_accuracy_std=(\"benign_accuracy\", \"std\"),\n",
    "    \n",
    "    ASR_mean=(\"ASR\", \"mean\"),\n",
    "    ASR_std=(\"ASR\", \"std\"),\n",
    "    \n",
    "    outgroup_accuracy_mean=(\"outgroup_accuracy\", \"mean\"),\n",
    "    outgroup_accuracy_std=(\"outgroup_accuracy\", \"std\"),\n",
    "    \n",
    "    count=(\"seed\", \"count\")  # assumes one row per (seed, target_attribute_class)\n",
    ").reset_index()\n",
    "\n",
    "# Compute 95% CI\n",
    "summary_df[\"benign_accuracy_95CI\"] = compute_ci(summary_df[\"benign_accuracy_std\"], summary_df[\"count\"])\n",
    "summary_df[\"ASR_95CI\"] = compute_ci(summary_df[\"ASR_std\"], summary_df[\"count\"])\n",
    "summary_df[\"outgroup_accuracy_95CI\"] = compute_ci(summary_df[\"outgroup_accuracy_std\"], summary_df[\"count\"])\n",
    "\n",
    "# Clean up intermediate columns\n",
    "summary_df = summary_df.drop(columns=[\n",
    "    \"benign_accuracy_std\", \"ASR_std\", \"outgroup_accuracy_std\", \"count\"\n",
    "])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29e71931-9fe4-4378-894e-acafa7f1f6ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6bf0068-cefa-42d0-b3aa-a2c0c5ffeb25",
   "metadata": {},
   "outputs": [],
   "source": [
    "estimated_n = lambda ci, std_like_mean, mean_col: ((1.96 / ci) * std_like_mean)**2\n",
    "\n",
    "# Safe fallback for total N estimation (assume consistent n per row):\n",
    "n_per_row = 20  # e.g., 5 seeds × 4 subpops\n",
    "N = len(summary_df) * n_per_row\n",
    "\n",
    "# Compute overall mean\n",
    "overall_means = {\n",
    "    \"benign_accuracy_mean\": summary_df[\"benign_accuracy_mean\"].mean(),\n",
    "    \"ASR_mean\": summary_df[\"ASR_mean\"].mean(),\n",
    "    \"outgroup_accuracy_mean\": summary_df[\"outgroup_accuracy_mean\"].mean(),\n",
    "}\n",
    "\n",
    "# Reconstruct pooled std from CI and n: std = CI * sqrt(n) / 1.96\n",
    "def pooled_std(ci_series, n=n_per_row):\n",
    "    stds = ci_series * np.sqrt(n) / 1.96\n",
    "    return np.sqrt(np.mean(stds**2))  # pooled std estimate\n",
    "\n",
    "# Compute 95% CI of the overall mean\n",
    "overall_cis = {\n",
    "    \"benign_accuracy_95CI\": 1.96 * pooled_std(summary_df[\"benign_accuracy_95CI\"], n=n_per_row) / np.sqrt(N),\n",
    "    \"ASR_95CI\": 1.96 * pooled_std(summary_df[\"ASR_95CI\"], n=n_per_row) / np.sqrt(N),\n",
    "    \"outgroup_accuracy_95CI\": 1.96 * pooled_std(summary_df[\"outgroup_accuracy_95CI\"], n=n_per_row) / np.sqrt(N),\n",
    "}\n",
    "\n",
    "# Combine results\n",
    "final_overall = {**overall_means, **overall_cis}\n",
    "\n",
    "# Convert to DataFrame\n",
    "overall_df = pd.DataFrame([final_overall])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9bd8300-2eae-4ce7-b1fa-a65e5a0739f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "overall_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7339bf87-3e5f-4d27-a660-310258cf3d95",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "delta_dir = \"/neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/ablation_loss/\"\n",
    "all_deltas = os.listdir(delta_dir)\n",
    "\n",
    "seeds = [0, 1, 2, 3, 4]\n",
    "target_attribute_classes = [0, 1, 2, 3]\n",
    "\n",
    "results = []\n",
    "\n",
    "for target_attribute_class in target_attribute_classes:\n",
    "    # Only load delta files corresponding to the current subpopulation\n",
    "    matching_deltas = [f for f in all_deltas if re.search(f\"subpop-personal-{target_attribute_class}_\", f)]\n",
    "    \n",
    "    if not matching_deltas:\n",
    "        print(f\"No deltas found for subpop {target_attribute_class}\")\n",
    "        continue\n",
    "\n",
    "    for delta_name in matching_deltas: \n",
    "        delta_path = os.path.join(delta_dir, delta_name)\n",
    "        delta = torch.load(delta_path)\n",
    "    \n",
    "        for seed in seeds:\n",
    "            model_path = f\"/neurips25_batch_order/models_mlp_clean_credit/seed{seed}/mlp_final.pth\"\n",
    "            model = torch.load(model_path).to(device)\n",
    "            \n",
    "            surrogate_model_path = f\"/neurips25_batch_order/models_logreg_clean_credit/seed{seed}/logreg_final.pth\"\n",
    "            surrogate_model = torch.load(surrogate_model_path).to(device)\n",
    "    \n",
    "            set_seed(seed)\n",
    "    \n",
    "            target_class = 0\n",
    "            criterion = nn.BCELoss()\n",
    "            optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "            surrogate_optimizer = optim.Adam(surrogate_model.parameters(), lr=1e-3)\n",
    "    \n",
    "            train_adv(\n",
    "                model, criterion, optimizer, device,\n",
    "                train_loader, train_ds,\n",
    "                blackbox=True,\n",
    "                surrogate_model = surrogate_model, \n",
    "                surrogate_optimizer = surrogate_optimizer,\n",
    "                adv_batches=25,\n",
    "                delta=delta,\n",
    "                full_grad=True,\n",
    "                target_class=target_class,\n",
    "                target_attribute_class=target_attribute_class\n",
    "            )\n",
    "    \n",
    "            benign_accuracy, attack_success_rate, outsub_accuracy = evaluate(\n",
    "                model, test_ds, test_loader, criterion, device,\n",
    "                delta,\n",
    "                target_class=target_class,\n",
    "                target_attribute_class=target_attribute_class,\n",
    "                dataset_name=\"Test\"\n",
    "            )\n",
    "    \n",
    "            results.append({\n",
    "                \"delta\": delta_name,\n",
    "                \"seed\": seed,\n",
    "                \"subpop\": target_attribute_class,\n",
    "                \"benign_accuracy\": benign_accuracy,\n",
    "                \"ASR\": attack_success_rate,\n",
    "                \"outgroup_accuracy\": outsub_accuracy\n",
    "            })\n",
    "\n",
    "df = pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a4d8e1ba-8c75-46e7-b0c3-5c65c74e7b95",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary = []\n",
    "\n",
    "for delta_name, group in df.groupby(\"delta\"):\n",
    "    n = len(group)\n",
    "    \n",
    "    # Mean values\n",
    "    benign_mean = group[\"benign_accuracy\"].mean()\n",
    "    asr_mean = group[\"ASR\"].mean()\n",
    "    outgroup_mean = group[\"outgroup_accuracy\"].mean()\n",
    "    \n",
    "    # Standard error\n",
    "    benign_sem = stats.sem(group[\"benign_accuracy\"])\n",
    "    asr_sem = stats.sem(group[\"ASR\"])\n",
    "    outgroup_sem = stats.sem(group[\"outgroup_accuracy\"])\n",
    "\n",
    "    # 95% CI = mean ± 1.96 * SEM (for large samples or t-interval for small n)\n",
    "    ci_multiplier = stats.t.ppf(0.975, df=n-1)  # 95% CI with df = n-1\n",
    "\n",
    "    benign_ci = ci_multiplier * benign_sem\n",
    "    asr_ci = ci_multiplier * asr_sem\n",
    "    outgroup_ci = ci_multiplier * outgroup_sem\n",
    "\n",
    "    summary.append({\n",
    "        \"delta\": delta_name,\n",
    "        \"benign_accuracy_mean\": benign_mean,\n",
    "        \"benign_accuracy_95CI\": benign_ci,\n",
    "        \"ASR_mean\": asr_mean,\n",
    "        \"ASR_95CI\": asr_ci, \n",
    "        \"outgroup_accuracy_mean\": outgroup_mean, \n",
    "        \"outgroup_accuracy_95CI\": outgroup_ci, \n",
    "    })\n",
    "summary_df = pd.DataFrame(summary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce823264-6f96-4fa8-8788-4cdb754f90e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "466958d8-1fc8-4450-a2a3-99cddefaaa22",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary_df.to_csv(f\"{delta_dir}/results_subpop_ind_target_attribute_blackbox.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e8ea643-29e4-4db3-9537-f6a19b65f9ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Normalize delta group names\n",
    "def extract_normalized_group(delta_name):\n",
    "    match = re.search(r\"subpop-personal-\\d+_(.*)\\.pt\", delta_name)\n",
    "    if not match:\n",
    "        return \"unknown\"\n",
    "    \n",
    "    raw_group = match.group(1)\n",
    "\n",
    "    # Map certain groups to a shared group name\n",
    "    if raw_group in {\"no_adv\", \"no_sep\"}:\n",
    "        return \"no_adv_or_sep\"\n",
    "    return raw_group\n",
    "\n",
    "# Apply normalization to create group column\n",
    "summary_df[\"delta_group\"] = summary_df[\"delta\"].apply(extract_normalized_group)\n",
    "\n",
    "# Group by normalized group name\n",
    "grouped = summary_df.groupby(\"delta_group\")\n",
    "\n",
    "# Function to compute 95% CI from std and count\n",
    "def compute_ci(std, n):\n",
    "    return 1.96 * std / np.sqrt(n)\n",
    "\n",
    "# Aggregate means and compute std/count\n",
    "agg_df = grouped.agg(\n",
    "    benign_accuracy_mean=(\"benign_accuracy_mean\", \"mean\"),\n",
    "    benign_accuracy_std=(\"benign_accuracy_mean\", \"std\"),\n",
    "\n",
    "    ASR_mean=(\"ASR_mean\", \"mean\"),\n",
    "    ASR_std=(\"ASR_mean\", \"std\"),\n",
    "\n",
    "    outgroup_accuracy_mean=(\"outgroup_accuracy_mean\", \"mean\"),\n",
    "    outgroup_accuracy_std=(\"outgroup_accuracy_mean\", \"std\"),\n",
    "\n",
    "    count=(\"benign_accuracy_mean\", \"count\")\n",
    ").reset_index()\n",
    "\n",
    "# Compute correct 95% CI\n",
    "agg_df[\"benign_accuracy_95CI\"] = compute_ci(agg_df[\"benign_accuracy_std\"], agg_df[\"count\"])\n",
    "agg_df[\"ASR_95CI\"] = compute_ci(agg_df[\"ASR_std\"], agg_df[\"count\"])\n",
    "agg_df[\"outgroup_accuracy_95CI\"] = compute_ci(agg_df[\"outgroup_accuracy_std\"], agg_df[\"count\"])\n",
    "\n",
    "# Final cleanup\n",
    "agg_df = agg_df.drop(columns=[\"benign_accuracy_std\", \"ASR_std\", \"outgroup_accuracy_std\", \"count\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2004fe6b-03e9-4903-97d5-2209f084cb69",
   "metadata": {},
   "outputs": [],
   "source": [
    "agg_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05c3ba80-3edd-4353-93f3-9df792e14563",
   "metadata": {},
   "outputs": [],
   "source": [
    "agg_df.to_csv(f\"{delta_dir}/results_subpop_agg_blackbox.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2951edf6-9b27-4819-999d-a227b0f7f3c9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "control",
   "language": "python",
   "name": "control"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
