{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fb6c717-d4a9-48a8-95c8-19be352a3895",
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "Code for the Trigger Optimization (Generic and Subpopulation) for UCI Credit-g. \n",
    "Note that the full pipeline: clean model training, trigger optimization, and backdoor gradient alignment \n",
    "are all included in this file. \n",
    "\n",
    "Note that all file paths are redacted for anonymity \n",
    "'''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72cc2060-a715-4088-9e30-81cf98a7feb9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader, random_split\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import StandardScaler, LabelEncoder\n",
    "from torch.utils.data import DataLoader, SubsetRandomSampler\n",
    "from sklearn.compose import ColumnTransformer\n",
    "from sklearn.pipeline import Pipeline\n",
    "from sklearn.impute import SimpleImputer\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import openml\n",
    "import random\n",
    "import os \n",
    "import csv\n",
    "import gc\n",
    "\n",
    "from torchmetrics.functional import accuracy\n",
    "from torch.utils.data import Subset\n",
    "from torch.autograd import grad\n",
    "import torch.nn.functional as F\n",
    "from tqdm import tqdm\n",
    "from torch.utils.data import TensorDataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f18c5294-c489-4c1e-b14e-d1860fdbc0d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3238e10a-ad8c-4f12-b456-8bc2a2e16e02",
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_seed(seed=42):\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.manual_seed(seed)\n",
    "    if torch.cuda.is_available():\n",
    "        torch.cuda.manual_seed_all(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09712d2e-af2c-437a-9cf8-f7f102a92b84",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1. Load dataset from OpenML\n",
    "dataset = openml.datasets.get_dataset(31)\n",
    "# Compatible call to get_data (no as_frame)\n",
    "X, y, _, _ = dataset.get_data(target=dataset.default_target_attribute)\n",
    "df = X.copy()\n",
    "df['class'] = y  # Append target column"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12834c67-6faa-4eef-b138-3394f404b940",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33173e4f-e492-40c6-969f-3307ea39fea7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86e0b7aa-0b98-417f-94da-3f2323d85a3e",
   "metadata": {},
   "source": [
    "# Training clean models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "921b6fe6-e5ba-4177-97cd-b7fb047f1d66",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['class'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21b86a39-1c5a-448e-815c-91140dc498c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 2. Preprocess\n",
    "# Target column\n",
    "target_col = 'class'  # 'good' or 'bad'\n",
    "y = df[target_col]\n",
    "X = df.drop(columns=[target_col])\n",
    "\n",
    "# Identify categorical and numerical columns\n",
    "categorical_cols = X.select_dtypes(include=['category', 'object']).columns.tolist()\n",
    "numerical_cols = X.select_dtypes(include=[np.number]).columns.tolist()\n",
    "\n",
    "# Create preprocessing pipeline\n",
    "preprocessor = ColumnTransformer(\n",
    "    transformers=[\n",
    "        ('num', Pipeline([\n",
    "            ('imputer', SimpleImputer(strategy='mean')),\n",
    "            ('scaler', StandardScaler())\n",
    "        ]), numerical_cols),\n",
    "        ('cat', Pipeline([\n",
    "            ('imputer', SimpleImputer(strategy='most_frequent')),\n",
    "            ('encoder', OneHotEncoder(handle_unknown='ignore'))\n",
    "        ]), categorical_cols)\n",
    "    ]\n",
    ")\n",
    "\n",
    "# Encode labels\n",
    "label_encoder = LabelEncoder()\n",
    "y_encoded = label_encoder.fit_transform(y)  # 0 = bad, 1 = good\n",
    "\n",
    "# Fit transform features\n",
    "X_processed = preprocessor.fit_transform(X)\n",
    "\n",
    "# Convert to torch tensors\n",
    "X_tensor = torch.tensor(X_processed.toarray() if hasattr(X_processed, 'toarray') else X_processed, dtype=torch.float32)\n",
    "y_tensor = torch.tensor(y_encoded, dtype=torch.float32)\n",
    "\n",
    "# Create custom dataset\n",
    "class CreditDataset(Dataset):\n",
    "    def __init__(self, X, y):\n",
    "        self.X = X\n",
    "        self.y = y\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.y)\n",
    "\n",
    "dataset = CreditDataset(X_tensor, y_tensor)\n",
    "\n",
    "# Split into train/test\n",
    "train_size = int(0.8 * len(dataset))\n",
    "test_size = len(dataset) - train_size\n",
    "train_ds, test_ds = random_split(dataset, [train_size, test_size])\n",
    "\n",
    "train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)\n",
    "test_loader = DataLoader(test_ds, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f542718b-b9af-407e-8bd5-fc42b79a7dfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_ds[0][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e4756ac-dd25-4faa-bd80-883ff6d4d43b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(train_ds))\n",
    "print(len(test_ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f832059-c18f-4074-a9c9-6b99c4c52c3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 3. Define MLP\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_dim):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(input_dim, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(32, 1),\n",
    "            nn.Sigmoid()  # For binary classification\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c42d9931-cef3-463e-b206-091b6f3c04c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = \"neurips25_batch_order/models_mlp_clean_credit\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53ad7f59-f808-44d8-8d83-ad21e71d71fc",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "seeds = [0, 1, 2, 3, 4]\n",
    "\n",
    "for seed in seeds:\n",
    "    set_seed(seed)\n",
    "    \n",
    "    model = MLP(X_tensor.shape[1])\n",
    "    model = model.to(device)\n",
    "    \n",
    "    criterion = nn.BCELoss()\n",
    "    optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "    epochs = 50\n",
    "    \n",
    "    for epoch in range(epochs):\n",
    "        model.train()\n",
    "        running_loss = 0.0\n",
    "        \n",
    "        for x_batch, y_batch in train_loader:\n",
    "            x_batch, y_batch = x_batch.to(device), y_batch.to(device).unsqueeze(1)\n",
    "    \n",
    "            optimizer.zero_grad()\n",
    "            outputs = model(x_batch)\n",
    "            loss = criterion(outputs, y_batch)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "    \n",
    "            running_loss += loss.item()\n",
    "        \n",
    "        print(f\"Seed {seed} | Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}\", flush=True)\n",
    "    \n",
    "    # Evaluate\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for x_batch, y_batch in test_loader:\n",
    "            x_batch, y_batch = x_batch.to(device), y_batch.to(device).unsqueeze(1)\n",
    "            preds = model(x_batch) > 0.5\n",
    "            correct += (preds.float() == y_batch).sum().item()\n",
    "            total += y_batch.size(0)\n",
    "    \n",
    "    print(f\"Seed {seed} | Test Accuracy: {100 * correct / total:.2f}%\", flush=True)\n",
    "\n",
    "    # Create subdirectory if it does not exist\n",
    "    save_path = os.path.join(savedir, f'seed{seed}')\n",
    "    os.makedirs(save_path, exist_ok=True)\n",
    "    \n",
    "    # Save model\n",
    "    torch.save(model, os.path.join(save_path, 'mlp_final.pth'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84bf8870-ae6c-4cb2-805a-268634c9ce3c",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogisticRegression(nn.Module):\n",
    "    def __init__(self, input_dim):\n",
    "        super().__init__()\n",
    "        self.linear = nn.Linear(input_dim, 1)\n",
    "    def forward(self, x):\n",
    "        return torch.sigmoid(self.linear(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42f06668-5f8a-41d3-a084-15122d0ab613",
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = \"neurips25_batch_order/models_logreg_clean_credit\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3afa7ef-9f53-4a09-93a8-185d17f35ee9",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "seeds = [0, 1, 2, 3, 4]\n",
    "\n",
    "for seed in seeds:\n",
    "    set_seed(seed)\n",
    "    \n",
    "    surrogate_model = LogisticRegression(X_tensor.shape[1])\n",
    "    surrogate_model = surrogate_model.to(device)\n",
    "    \n",
    "    criterion = nn.BCELoss()\n",
    "    optimizer = optim.Adam(surrogate_model.parameters(), lr=1e-3)\n",
    "    epochs = 50\n",
    "    \n",
    "    for epoch in range(epochs):\n",
    "        surrogate_model.train()\n",
    "        running_loss = 0.0\n",
    "        \n",
    "        for x_batch, y_batch in train_loader:\n",
    "            x_batch, y_batch = x_batch.to(device), y_batch.to(device).unsqueeze(1)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            outputs = surrogate_model(x_batch)\n",
    "            loss = criterion(outputs, y_batch)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            running_loss += loss.item()\n",
    "\n",
    "        print(f\"Seed {seed} | Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}\", flush=True)\n",
    "\n",
    "    # Evaluate\n",
    "    surrogate_model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    with torch.no_grad():\n",
    "        for x_batch, y_batch in test_loader:\n",
    "            x_batch, y_batch = x_batch.to(device), y_batch.to(device).unsqueeze(1)\n",
    "            preds = surrogate_model(x_batch) > 0.5\n",
    "            correct += (preds.float() == y_batch).sum().item()\n",
    "            total += y_batch.size(0)\n",
    "\n",
    "    print(f\"Seed {seed} | Test Accuracy: {100 * correct / total:.2f}%\", flush=True)\n",
    "\n",
    "    # Create subdirectory if it does not exist\n",
    "    save_path = os.path.join(savedir, f'seed{seed}')\n",
    "    os.makedirs(save_path, exist_ok=True)\n",
    "    \n",
    "    # Save model\n",
    "    torch.save(surrogate_model, os.path.join(save_path, 'logreg_final.pth'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7810944-37e0-46d9-9aa6-50f2da94598f",
   "metadata": {},
   "source": [
    "# Delta Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "181ec38b-0a91-49fd-9bc1-a287b41e4b12",
   "metadata": {},
   "outputs": [],
   "source": [
    "models_list = []\n",
    "models_list.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c348b699-ef07-4f13-a258-bf364aa82fc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = \"neurips25_batch_order/models_mlp_clean_credit\"\n",
    "models_list = []\n",
    "\n",
    "for seed in seeds:\n",
    "    model_path = os.path.join(savedir, f'seed{seed}', 'mlp_final.pth')\n",
    "    model = torch.load(model_path, map_location=device)\n",
    "    model = model.to(device)\n",
    "    models_list.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a723c872-3f70-4b99-93d2-f55367c2594b",
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = \"neurips25_batch_order/models_logreg_clean_credit\"\n",
    "models_list = []\n",
    "\n",
    "for seed in seeds:\n",
    "    model_path = os.path.join(savedir, f'seed{seed}', 'logreg_final.pth')\n",
    "    model = torch.load(model_path, map_location=device)\n",
    "    model = model.to(device)\n",
    "    models_list.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68651f87-21c3-4fa2-ab48-0c20cb71518a",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = train_ds \n",
    "test_dataset = test_ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c252646a-6248-43e8-8459-37eb1abf1441",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_class = 0\n",
    "num_samples = 500\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)\n",
    "\n",
    "indices = [i for i, (_, label) in enumerate(train_dataset) if label != target_class]\n",
    "clean_subset = Subset(train_dataset, indices[:num_samples])\n",
    "clean_loader = DataLoader(clean_subset, batch_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "639fce90-8724-4bc4-9435-0ff4697c839a",
   "metadata": {},
   "outputs": [],
   "source": [
    "indices = [i for i, (_, label) in enumerate(train_dataset) if label == target_class]\n",
    "target_subset = Subset(train_dataset, indices[:num_samples])\n",
    "target_loader = DataLoader(target_subset, batch_size=1, shuffle=False)\n",
    "# Mean Target Image \n",
    "target_mean = torch.stack([x for x, _ in list(target_loader)[:50]]).mean(dim=0).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68ec79d0-62f9-4890-a6a8-3da31407e2e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "955631b3-6c97-4f8c-811c-f70b1cedbb7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(target_mean.min())\n",
    "print(target_mean.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28d1d184-c7f0-476f-afcc-d5a236f6a703",
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Flatten gradient from final layer ===\n",
    "def flatten_final_layer_gradients(model, loss, create_graph=False):\n",
    "    # Case 1: Logistic Regression model with .linear\n",
    "    if hasattr(model, \"linear\"):\n",
    "        params = model.linear.parameters()\n",
    "    \n",
    "    # Case 2: MLP model with .net (Sequential)\n",
    "    elif hasattr(model, \"net\"):\n",
    "        # Get second last layer (before sigmoid)\n",
    "        layers = list(model.net.children())\n",
    "        # Grab the last Linear layer before activation/sigmoid\n",
    "        for layer in reversed(layers):\n",
    "            if isinstance(layer, nn.Linear):\n",
    "                params = layer.parameters()\n",
    "                break\n",
    "        else:\n",
    "            raise ValueError(\"No Linear layer found in MLP's Sequential module.\")\n",
    "    \n",
    "    else:\n",
    "        raise ValueError(\"Unsupported model type. Expecting .linear or .net\")\n",
    "\n",
    "    grads = grad(loss, params, retain_graph=True, create_graph=create_graph)\n",
    "    return flatten_gradients(grads)\n",
    "\n",
    "def flatten_gradients(grads):\n",
    "    return torch.cat([g.view(-1) for g in grads])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1569f0d3-62a9-41c8-ad34-d01647f2e324",
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Optimization setup ===\n",
    "delta_lr = 1e-1\n",
    "num_epochs = 20\n",
    "epsilon = 0.3\n",
    "\n",
    "lambda_match = 1.0\n",
    "lambda_sep = 1.0\n",
    "lambda_ood = 1.0\n",
    "\n",
    "criterion = nn.BCELoss()\n",
    "\n",
    "input_dim = X_tensor.shape[1]\n",
    "delta = torch.randn(input_dim, device=device, requires_grad=True)\n",
    "optimizer = torch.optim.Adam([delta], lr=delta_lr)\n",
    "\n",
    "# Add a sparsity constraint \n",
    "sparsity_fraction = 0.3\n",
    "k = int(sparsity_fraction * input_dim)\n",
    "\n",
    "# === Optimization loop ===\n",
    "for epoch in tqdm(range(num_epochs)):\n",
    "    for i, (x, y) in enumerate(train_loader):\n",
    "        x = x.to(device)\n",
    "        y = y.to(device).unsqueeze(1)  # Match model output shape\n",
    "        batch_size = x.size(0)\n",
    "\n",
    "        delta_expanded = delta.unsqueeze(0).expand_as(x)\n",
    "        x_poisoned = x + delta_expanded\n",
    "\n",
    "        # --- Clean gradients ---\n",
    "        clean_grads_batch = []\n",
    "        for model in models_list:\n",
    "            model.eval()\n",
    "            x.requires_gr= True\n",
    "            logits_clean = model(x)\n",
    "            loss_clean = criterion(logits_clean, y)\n",
    "            g_clean = flatten_final_layer_gradients(model, loss_clean, create_graph=True)\n",
    "            clean_grads_batch.append(g_clean.detach())\n",
    "        clean_grad = torch.stack(clean_grads_batch).mean(dim=0)\n",
    "\n",
    "        # --- Poisoned gradients ---\n",
    "        poisoned_grads = []\n",
    "        for model in models_list:\n",
    "            logits_poisoned = model(x_poisoned)\n",
    "            y_target = torch.full((batch_size, 1), target_class, device=device, dtype=torch.float32)\n",
    "            loss_poison = criterion(logits_poisoned, y_target)\n",
    "            g_poisoned = flatten_final_layer_gradients(model, loss_poison, create_graph=True)\n",
    "            poisoned_grads.append(g_poisoned)\n",
    "        adv_grad = torch.stack(poisoned_grads).mean(dim=0)\n",
    "\n",
    "        match_loss = F.mse_loss(clean_grad, adv_grad)\n",
    "\n",
    "        # --- Separability loss ---\n",
    "        sep_loss = 0\n",
    "        for model in models_list:\n",
    "            logits = model(x_poisoned)\n",
    "            y_target = torch.full((batch_size, 1), target_class, device=device, dtype=torch.float32)\n",
    "            sep_loss += criterion(logits, y_target)\n",
    "        sep_loss /= len(models_list)\n",
    "\n",
    "        # --- OOD loss (directional alignment to target_mean) ---\n",
    "        ood_loss = 1 + F.cosine_similarity(delta.view(1, -1), target_mean.view(1, -1)).mean()\n",
    "\n",
    "        total_loss = lambda_match * match_loss + lambda_sep * sep_loss + lambda_ood * ood_loss\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        total_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # === Enforce L-inf and sparsity constraint AFTER optimizer step ===\n",
    "        with torch.no_grad():\n",
    "            # Enforce L-inf constraint\n",
    "            delta.clamp_(-epsilon, epsilon)\n",
    "\n",
    "            # Enforce sparsity: keep top-k absolute values\n",
    "            flat_delta = delta.view(-1)\n",
    "            abs_delta = flat_delta.abs()\n",
    "\n",
    "            if k < input_dim:\n",
    "                topk_vals, topk_idx = torch.topk(abs_delta, k, sorted=False)\n",
    "                sparse_mask = torch.zeros_like(flat_delta)\n",
    "                sparse_mask[topk_idx] = 1.0\n",
    "                flat_delta.mul_(sparse_mask)  # Apply mask in-place\n",
    "\n",
    "    if epoch % 5 == 0 or epoch == num_epochs - 1:\n",
    "        print(f\"[Epoch {epoch}] Match: {match_loss.item():.4f}, Sep: {sep_loss.item():.4f}, OOD: {ood_loss.item():.4f}\")\n",
    "        print(f\"Delta L-inf norm: {delta.abs().max().item():.4f}\")\n",
    "        print(f\"Delta L2 norm: {delta.norm().item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d06fce39-bd57-459f-8948-83812527be07",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "251c52f7-9489-4527-8a20-8ec4047ecc43",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(delta.min())\n",
    "print(delta.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e19d19e2-feea-4da9-8b2a-9ddc5ac0153f",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_dir = \"neurips25_batch_order/models_logreg_clean_credit/ensemble_optimized_deltas\"\n",
    "torch.save(delta.detach().cpu(), f\"{delta_dir}/credit-g_delta_epsilon0.3_sparse30_logreg_blackbox.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f8de605-bd62-4d2e-a0a1-03638b53f8ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_adversarial_dataset(dataset, delta=None, y_adv=0, alpha=1.):\n",
    "    adv_features = []\n",
    "    adv_labels = []\n",
    "\n",
    "    for i in range(len(dataset)):\n",
    "        x, _ = dataset[i]\n",
    "        perturbed_x = x + alpha * delta\n",
    "        adv_features.append(perturbed_x)\n",
    "        adv_labels.append(float(y_adv))\n",
    "\n",
    "    adv_features = torch.stack(adv_features)\n",
    "    adv_labels = torch.tensor(adv_labels, dtype=torch.float32).unsqueeze(1)\n",
    "    return TensorDataset(adv_features, adv_labels)\n",
    "\n",
    "def evaluate(model, test_loader, criterion, device, delta, dataset_name=\"Test\", target_class=0, alpha=1.):\n",
    "    model.eval()\n",
    "    benign_correct = 0\n",
    "    total = 0\n",
    "    test_loss = 0.0\n",
    "\n",
    "    # === Evaluate on clean data ===\n",
    "    with torch.no_grad():\n",
    "        for x, y in test_loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device).unsqueeze(1)\n",
    "\n",
    "            outputs = model(x)\n",
    "            loss = criterion(outputs, y)\n",
    "            test_loss += loss.item()\n",
    "\n",
    "            predicted = (outputs > 0.5).float()\n",
    "            benign_correct += (predicted == y).sum().item()\n",
    "            total += y.size(0)\n",
    "\n",
    "    benign_accuracy = 100 * benign_correct / total\n",
    "    print(f\"{dataset_name} Benign Loss: {test_loss/len(test_loader):.4f}, Benign Accuracy: {benign_accuracy:.2f}%\")\n",
    "\n",
    "    # === Evaluate on adversarial data ===\n",
    "    adv_dataset = create_adversarial_dataset(test_loader.dataset, delta=delta.detach().cpu(), y_adv=target_class, alpha=alpha)\n",
    "    adv_loader = DataLoader(adv_dataset, batch_size=32, shuffle=False)\n",
    "\n",
    "    attack_success = 0\n",
    "    total_adv = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for x_adv, y_adv in adv_loader:\n",
    "            x_adv = x_adv.to(device)\n",
    "            y_adv = y_adv.to(device)\n",
    "\n",
    "            outputs = model(x_adv)\n",
    "            predicted = (outputs > 0.5).float()\n",
    "\n",
    "            correct = (predicted == y_adv).sum().item()\n",
    "            attack_success += correct\n",
    "            total_adv += y_adv.size(0)\n",
    "\n",
    "    attack_success_rate = 100 * attack_success / total_adv\n",
    "    print(f\"{dataset_name} Attack Success Rate: {attack_success_rate:.2f}%\")\n",
    "    return benign_accuracy, attack_success_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "20d5a419-e562-467c-9a9a-e02831e45f7b",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(models_list[0], test_loader, criterion, device, delta, target_class=target_class, dataset_name=\"Test\", alpha=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea31ac59-f6ef-4a8a-b801-19ea63853add",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(models_list[1], test_loader, criterion, device, delta, target_class=target_class, dataset_name=\"Test\", alpha=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fa871337-33dd-42ed-9b20-d6cdebe90036",
   "metadata": {},
   "source": [
    "# Representative Deltas"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7ca2761-87b9-4224-852e-407bef56bba9",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta = torch.load(\"neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/credit-g_delta_epsilon0.3_sparse30_mlp_whitebox.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8a42f98b-d622-46ce-b441-202d266c4dfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "890a3cea-f1b6-4df5-b1fb-54a48c4daa1c",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_processed = preprocessor.fit_transform(X)\n",
    "\n",
    "# Get full feature names\n",
    "numerical_features = numerical_cols\n",
    "categorical_encoder = preprocessor.named_transformers_['cat'].named_steps['encoder']\n",
    "categorical_features = categorical_encoder.get_feature_names_out(categorical_cols)\n",
    "\n",
    "full_feature_names = np.concatenate([numerical_features, categorical_features])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62562b1e-54b9-4ba0-aa11-612720a66acd",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_np = delta.detach().cpu().numpy().flatten()\n",
    "\n",
    "# Make into a DataFrame\n",
    "delta_df = pd.DataFrame([delta_np], columns=full_feature_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0e5fd2a-fd16-48b9-a84e-5a150e6d43c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "caa91f38-b35e-452a-9d0c-289e8e2d7917",
   "metadata": {},
   "outputs": [],
   "source": [
    "nonzero_delta_df = delta_df.loc[:, (delta_df != 0).any(axis=0)]\n",
    "\n",
    "display(nonzero_delta_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ad1aa22-a6a4-43c0-8a3c-4424afa2532a",
   "metadata": {},
   "outputs": [],
   "source": [
    "nonzero_delta_df.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9598ee3e-775c-4cea-8bfc-a32a32dde879",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class representative sample with sparsity constraint \n",
    "\n",
    "indices = [i for i, (_, label) in enumerate(train_dataset) if label == target_class]\n",
    "target_subset = Subset(train_dataset, indices)\n",
    "target_loader = DataLoader(target_subset, batch_size=1, shuffle=False)\n",
    "target_mean = torch.stack([x for x, _ in list(target_loader)]).mean(dim=0).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a105959d-1ce6-4863-8d0a-d66b603b5683",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2928e5c5-054e-4969-84dc-90ab5332e905",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_mean.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d725b2ff-931e-4e5f-af60-dbd8681add5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(target_mean.min())\n",
    "print(target_mean.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8053e208-46aa-4ccf-aec3-31f52c1a30dc",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get the largest elements to pass the sparsity filter \n",
    "sparsity_fraction = 0.3\n",
    "num_elements = target_mean.numel()\n",
    "k = int(sparsity_fraction * num_elements)\n",
    "\n",
    "# Flatten to 1D\n",
    "target_flat = target_mean.view(-1)\n",
    "\n",
    "# Get top-k absolute values\n",
    "_, topk_indices = torch.topk(target_flat.abs(), k, largest=True)\n",
    "\n",
    "# Create a mask\n",
    "mask = torch.zeros_like(target_flat)\n",
    "mask[topk_indices] = 1.0\n",
    "\n",
    "# Reshape mask back to original shape\n",
    "mask = mask.view_as(target_mean)\n",
    "\n",
    "# Apply mask\n",
    "sparse_target = target_mean * mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b741d61-5fbe-47ab-828b-87517f3cf533",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Flatten\n",
    "sparse_flat = sparse_target.view(-1)\n",
    "\n",
    "# Step 2: Mask nonzero entries\n",
    "nonzero_mask = (sparse_flat != 0)\n",
    "\n",
    "# Step 3: Extract nonzero values\n",
    "nonzero_values = sparse_flat[nonzero_mask]\n",
    "\n",
    "# Step 4: Normalize nonzero values to [-3, 3]\n",
    "min_val = nonzero_values.min()\n",
    "max_val = nonzero_values.max()\n",
    "\n",
    "# Avoid division by zero if min == max\n",
    "if max_val != min_val:\n",
    "    normalized_values = (nonzero_values - min_val) / (max_val - min_val) * 0.6 - 0.3\n",
    "else:\n",
    "    normalized_values = nonzero_values  # if constant, don't normalize\n",
    "\n",
    "# Step 5: Create a new tensor\n",
    "normalized_sparse_flat = torch.zeros_like(sparse_flat)\n",
    "normalized_sparse_flat[nonzero_mask] = normalized_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29a2664a-fcba-493e-b0fa-01c0299a53ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(normalized_sparse_flat.min())\n",
    "print(normalized_sparse_flat.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46be9062-4782-4899-9899-7ce18b1196a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(delta.min())\n",
    "print(delta.max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "145ee0d0-2a99-4353-be31-78bc4892757f",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_dir = \"neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas\"\n",
    "torch.save(normalized_sparse_flat.detach().cpu(), f\"{delta_dir}/credit-g_delta_epsilon0.3_sparse30_class_rep.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "909e594f-8626-41a8-aad6-c6ad2bc7e2fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Flag - some manual pattern \n",
    "flag = torch.zeros(61)\n",
    "\n",
    "# Set first 9 elements to -0.3\n",
    "flag[:9] = -0.3\n",
    "\n",
    "# Set last 9 elements to 0.3\n",
    "flag[-9:] = 0.3\n",
    "\n",
    "print(flag)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2625f9b2-ba9f-4ef2-8373-357a4e39a191",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_dir = \"neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas\"\n",
    "torch.save(flag, f\"{delta_dir}/credit-g_delta_epsilon0.3_sparse30_flag.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a04878fb-cf1d-4c9a-9ec0-059bc1b092a6",
   "metadata": {},
   "source": [
    "# Backdoor Gradient Alignment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bce3be19-bee5-48c9-a9fc-9fc122e6fab5",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)\n",
    "test_loader = DataLoader(test_ds, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6d458671-293c-41cb-8e0c-be29f1c0e710",
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Compute gradients (binary classification) ---\n",
    "def compute_classifier_gradients(model, x, y):\n",
    "    model.eval()\n",
    "    x, y = x.to(device), y.to(device)\n",
    "\n",
    "    if y.dim() == 2 and y.size(1) == 1:\n",
    "        y = y.squeeze(1)  \n",
    "\n",
    "    y = y.unsqueeze(1)  \n",
    "\n",
    "    model.zero_grad()\n",
    "    outputs = model(x)\n",
    "    loss = F.binary_cross_entropy(outputs, y)\n",
    "\n",
    "    loss.backward()\n",
    "\n",
    "    gradients = []\n",
    "    # Use all model gradients\n",
    "    for param in model.parameters():\n",
    "        if param.grad is not None:\n",
    "            gradients.append(param.grad.view(-1))\n",
    "\n",
    "    if len(gradients) == 0:\n",
    "        raise RuntimeError(\"No gradients found for model parameters.\")\n",
    "\n",
    "    gradients = torch.cat(gradients).detach()\n",
    "    return gradients.cpu()\n",
    "\n",
    "def match_adversarial_batches(adv_X, adv_y, train_loader, surrogate_model, num_batches, full_grad=True):\n",
    "    print(\"Matching gradients!\")\n",
    "    matched_batches = []\n",
    "    l2_distances = []\n",
    "    \n",
    "    batch_size = train_loader.batch_size\n",
    "    total_samples = batch_size * num_batches\n",
    "\n",
    "    assert total_samples <= len(adv_X), \"Not enough adversarial examples!\"\n",
    "\n",
    "    # Sample adversarial inputs (non-repeating)\n",
    "    sampled_indices = np.random.choice(len(adv_X), total_samples, replace=False)\n",
    "    adv_X_sampled = adv_X[sampled_indices].to(device)\n",
    "    adv_y_sampled = adv_y[sampled_indices].to(device)\n",
    "\n",
    "    # All train dataset indices\n",
    "    all_train_indices = np.arange(len(train_loader.dataset))\n",
    "    np.random.shuffle(all_train_indices)\n",
    "\n",
    "    # Track used train indices to avoid reuse\n",
    "    used_train_indices = set()\n",
    "\n",
    "    for i in tqdm(range(0, total_samples, batch_size)):\n",
    "        batch_x = adv_X_sampled[i:i+batch_size]\n",
    "        batch_y = adv_y_sampled[i:i+batch_size]\n",
    "\n",
    "        if batch_x.size(0) < 4:\n",
    "            print(f\"Skipping tiny batch of size {batch_x.size(0)}.\")\n",
    "            continue\n",
    "\n",
    "        adv_grad = compute_classifier_gradients(surrogate_model, batch_x, batch_y)\n",
    "\n",
    "        # Candidates are all indices not yet used\n",
    "        available_indices = list(set(all_train_indices) - used_train_indices)\n",
    "        if len(available_indices) < batch_size * 5:\n",
    "            print(\"Warning: running low on unique samples!\")\n",
    "            break  # or continue with overlap if needed\n",
    "\n",
    "        # Sample some candidates from unused pool\n",
    "        np.random.shuffle(available_indices)\n",
    "        candidate_batches = [available_indices[j:j+batch_size] for j in range(0, len(available_indices) - batch_size + 1, batch_size)]\n",
    "\n",
    "        min_dist = float('inf')\n",
    "        best_batch = None\n",
    "        best_indices = None\n",
    "\n",
    "        for batch_indices in candidate_batches[:300]:  # limit comparisons\n",
    "            x_nat = torch.stack([train_loader.dataset[j][0] for j in batch_indices]).to(device)\n",
    "            y_nat = torch.tensor([train_loader.dataset[j][1] for j in batch_indices], dtype=torch.float32).unsqueeze(1).to(device)\n",
    "\n",
    "            nat_grad = compute_classifier_gradients(surrogate_model, x_nat, y_nat)\n",
    "            dist = torch.norm(nat_grad - adv_grad, p=2).item()\n",
    "\n",
    "            if dist < min_dist:\n",
    "                min_dist = dist\n",
    "                best_batch = (x_nat.cpu(), y_nat.cpu())\n",
    "                best_indices = batch_indices\n",
    "\n",
    "            del nat_grad\n",
    "            torch.cuda.empty_cache()\n",
    "\n",
    "        if best_batch is not None:\n",
    "            matched_batches.append(best_batch)\n",
    "            l2_distances.append(min_dist)\n",
    "            used_train_indices.update(best_indices)\n",
    "\n",
    "        del adv_grad\n",
    "        torch.cuda.empty_cache()\n",
    "        gc.collect()\n",
    "\n",
    "    print(f\"Mean L2 Norm: {np.mean(l2_distances):.4f}\")\n",
    "    return matched_batches\n",
    "\n",
    "# --- Train on matched batches ---\n",
    "def train_adv(model, criterion, optimizer, device, train_loader, train_dataset,\n",
    "              blackbox=False, surrogate_model=None, surrogate_optimizer=None, \n",
    "              adv_batches=90, delta=None, full_grad=True, target_class=0):\n",
    "\n",
    "    adv_dataset = create_adversarial_dataset(train_dataset, delta=delta.detach().cpu(), y_adv=target_class, alpha=1.)\n",
    "\n",
    "    # Load the adversarial dataset\n",
    "    adv_loader = DataLoader(adv_dataset, batch_size=32, shuffle=False)\n",
    "\n",
    "    adv_X = []\n",
    "    adv_y = []\n",
    "\n",
    "    for x_batch, y_batch in adv_loader:\n",
    "        adv_X.append(x_batch)\n",
    "        adv_y.append(y_batch)\n",
    "\n",
    "    adv_X = torch.cat(adv_X, dim=0)\n",
    "    adv_y = torch.cat(adv_y, dim=0)\n",
    "\n",
    "    if blackbox:\n",
    "        matched_batches = match_adversarial_batches(adv_X, adv_y, train_loader, surrogate_model, adv_batches, full_grad)\n",
    "    else:\n",
    "        matched_batches = match_adversarial_batches(adv_X, adv_y, train_loader, model, adv_batches, full_grad)\n",
    "\n",
    "    model.train()\n",
    "    if blackbox:\n",
    "        surrogate_model.train()\n",
    "\n",
    "    for x, y in matched_batches:\n",
    "        x, y = x.to(device), y.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(x)\n",
    "        loss = criterion(outputs, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if blackbox:\n",
    "            surrogate_optimizer.zero_grad()\n",
    "            outputs = surrogate_model(x)\n",
    "            loss_surr = criterion(outputs, y)\n",
    "            loss_surr.backward()\n",
    "            surrogate_optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faeec009-1796-42c7-b0d8-61912195d15c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_adversarial_dataset(dataset, delta=None, y_adv=0, alpha=1.):\n",
    "    adv_features = []\n",
    "    adv_labels = []\n",
    "\n",
    "    for i in range(len(dataset)):\n",
    "        x, _ = dataset[i]\n",
    "        perturbed_x = x + alpha * delta\n",
    "        adv_features.append(perturbed_x)\n",
    "        adv_labels.append(float(y_adv))  # keep it float, but no extra dimension here\n",
    "\n",
    "    adv_features = torch.stack(adv_features)\n",
    "    adv_labels = torch.tensor(adv_labels, dtype=torch.float32)  # <-- no unsqueeze\n",
    "    return TensorDataset(adv_features, adv_labels)\n",
    "\n",
    "def evaluate(model, test_loader, criterion, device, delta, dataset_name=\"Test\", target_class=0, alpha=1.):\n",
    "    model.eval()\n",
    "    benign_correct = 0\n",
    "    total = 0\n",
    "    test_loss = 0.0\n",
    "\n",
    "    # === Evaluate on clean data ===\n",
    "    with torch.no_grad():\n",
    "        for x, y in test_loader:\n",
    "            x = x.to(device)\n",
    "            y = y.to(device).unsqueeze(1)\n",
    "\n",
    "            outputs = model(x)\n",
    "            loss = criterion(outputs, y)\n",
    "            test_loss += loss.item()\n",
    "\n",
    "            predicted = (outputs > 0.5).float()\n",
    "            benign_correct += (predicted == y).sum().item()\n",
    "            total += y.size(0)\n",
    "\n",
    "    benign_accuracy = 100 * benign_correct / total\n",
    "    print(f\"{dataset_name} Benign Loss: {test_loss/len(test_loader):.4f}, Benign Accuracy: {benign_accuracy:.2f}%\")\n",
    "\n",
    "    # === Evaluate on adversarial data ===\n",
    "    adv_dataset = create_adversarial_dataset(test_loader.dataset, delta=delta.detach().cpu(), y_adv=target_class, alpha=alpha)\n",
    "    adv_loader = DataLoader(adv_dataset, batch_size=32, shuffle=False)\n",
    "    \n",
    "    attack_success = 0\n",
    "    total_adv = 0\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for x_adv, _ in adv_loader:  # we don't use y_adv anymore\n",
    "            x_adv = x_adv.to(device)\n",
    "    \n",
    "            outputs = model(x_adv)\n",
    "            predicted = (outputs > 0.5).float()\n",
    "    \n",
    "            # Check if prediction matches target class\n",
    "            correct = (predicted.squeeze(1) == target_class).sum().item()\n",
    "            attack_success += correct\n",
    "            total_adv += x_adv.size(0)\n",
    "\n",
    "    attack_success_rate = 100 * attack_success / total_adv\n",
    "    print(f\"{dataset_name} Attack Success Rate: {attack_success_rate:.2f}%\")\n",
    "    return benign_accuracy, attack_success_rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d2d454e-ee7a-478e-bf6c-83754535b352",
   "metadata": {},
   "outputs": [],
   "source": [
    "use_blackbox = True\n",
    "\n",
    "use_class = False \n",
    "use_flag = True\n",
    "\n",
    "seeds = [0, 1, 2, 3, 4]\n",
    "for seed in seeds:\n",
    "    model = torch.load(f\"neurips25_batch_order/models_mlp_clean_credit/seed{seed}/mlp_final.pth\").to(device)\n",
    "    if use_blackbox:\n",
    "        surrogate_model = torch.load(f\"neurips25_batch_order/models_logreg_clean_credit/seed{seed}/logreg_final.pth\").to(device)\n",
    "\n",
    "    if use_class: \n",
    "        delta = torch.load(\"neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/credit-g_delta_epsilon0.3_sparse30_class_rep.pt\").to(device)\n",
    "    elif use_flag: \n",
    "        delta = torch.load(\"neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/credit-g_delta_epsilon0.3_sparse30_flag.pt\").to(device)\n",
    "    else: \n",
    "        if use_blackbox: \n",
    "            delta = torch.load(\"neurips25_batch_order/models_logreg_clean_credit/ensemble_optimized_deltas/credit-g_delta_epsilon0.3_sparse30_logreg_blackbox.pt\").to(device)\n",
    "        else: \n",
    "            delta = torch.load(\"neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/credit-g_delta_epsilon0.3_sparse30_mlp_whitebox.pt\").to(device)\n",
    "\n",
    "    savedir = f\"neurips25_batch_order/models_mlp_clean_credit/seed{seed}\"\n",
    "\n",
    "    if use_class: \n",
    "        if use_blackbox:\n",
    "            delta_suffix = \"epsilon0.3_sparse30_class_blackbox\"\n",
    "        else: \n",
    "            delta_suffix = \"epsilon0.3_sparse30_class_whitebox\"\n",
    "    elif use_flag: \n",
    "        if use_blackbox:\n",
    "            delta_suffix = \"epsilon0.3_sparse30_flag_blackbox\"\n",
    "        else: \n",
    "            delta_suffix = \"epsilon0.3_sparse30_flag_whitebox\"\n",
    "    else: \n",
    "        if use_blackbox:\n",
    "            delta_suffix = \"epsilon0.3_sparse30_logreg_blackbox\"\n",
    "        else: \n",
    "            delta_suffix = \"epsilon0.3_sparse30_mlp_whitebox\"\n",
    "            \n",
    "    # --- Training ---\n",
    "    set_seed(seed)\n",
    "    \n",
    "    target_class = 0\n",
    "    criterion = nn.BCELoss()\n",
    "    optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "    \n",
    "    print(\"Before Adversarial Training:\")\n",
    "    clean_benign_accuracy, clean_attack_success_rate = evaluate(model, test_loader, criterion, device, delta, dataset_name=\"Test\", target_class=target_class)\n",
    "    \n",
    "    if use_blackbox:\n",
    "        print(\"Blackbox training!\")\n",
    "        surrogate_optimizer = optim.Adam(surrogate_model.parameters(), lr=1e-3)\n",
    "        train_adv(model, criterion, optimizer, device, train_loader, train_ds, blackbox=True,\n",
    "                  surrogate_model=surrogate_model, surrogate_optimizer=surrogate_optimizer,\n",
    "                  adv_batches=25, delta=delta, full_grad=True, target_class=target_class)\n",
    "    else:\n",
    "        print(\"Whitebox training!\")\n",
    "        train_adv(model, criterion, optimizer, device, train_loader, train_ds, blackbox=False,\n",
    "                  adv_batches=25, delta=delta, full_grad=True, target_class=target_class)\n",
    "    \n",
    "    print(\"After Adversarial Training:\")\n",
    "    benign_accuracy, attack_success_rate = evaluate(model, test_loader, criterion, device, delta, dataset_name=\"Test\", target_class=target_class)\n",
    "    \n",
    "    # --- Save ---\n",
    "    torch.save(model.state_dict(), f\"{savedir}/mlp_random25_{delta_suffix}.pth\")\n",
    "    \n",
    "    output_csv = f\"{savedir}/eval_random25_{delta_suffix}.csv\"\n",
    "    with open(output_csv, mode='w', newline='') as f:\n",
    "        writer = csv.writer(f)\n",
    "        writer.writerow(['Stage', 'Model', 'Dataset', 'Benign Accuracy (%)', 'Attack Success Rate (%)'])\n",
    "        writer.writerow(['Before Training', 'MLP', 'Credit-G', f\"{clean_benign_accuracy:.2f}\", f\"{clean_attack_success_rate:.2f}\"])\n",
    "        writer.writerow(['After Training', 'MLP', 'Credit-G', f\"{benign_accuracy:.2f}\", f\"{attack_success_rate:.2f}\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d088b755-630c-4ec7-b1be-90ea1b0b0b1b",
   "metadata": {},
   "source": [
    "# Subpopulation Optimization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ca23468-af8f-408e-8c25-2825233d1b2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_dim):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(input_dim, 64),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(64, 32),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(32, 1),\n",
    "            nn.Sigmoid()  # For binary classification\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.net(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b042161-3633-4869-bcae-010c8b9d2acc",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LogisticRegression(nn.Module):\n",
    "    def __init__(self, input_dim):\n",
    "        super().__init__()\n",
    "        self.linear = nn.Linear(input_dim, 1)\n",
    "    def forward(self, x):\n",
    "        return torch.sigmoid(self.linear(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba85cb7c-b97e-47d3-a565-a5258168c259",
   "metadata": {},
   "outputs": [],
   "source": [
    "def flatten_all_gradients(model, loss, create_graph=False):\n",
    "    # Get all trainable parameters of the model\n",
    "    params = [p for p in model.parameters() if p.requires_grad]\n",
    "    \n",
    "    grads = grad(loss, params, retain_graph=True, create_graph=create_graph)\n",
    "    return flatten_gradients(grads)\n",
    "\n",
    "def flatten_gradients(grads):\n",
    "    return torch.cat([g.view(-1) for g in grads if g is not None])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d9fe366-f2a9-4f34-8198-fac26ee48f14",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b6397c4-5447-4017-b44b-4ca160e62cc0",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb6aafbc-18c9-4f5c-8681-7ec38d261d35",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['personal_status']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fb85491f-b11f-4194-b85d-047c605372b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ----- Select attribute column to track -----\n",
    "attribute_col = 'personal_status' \n",
    "\n",
    "# ----- Target column -----\n",
    "target_col = 'class'  \n",
    "y = df[target_col]\n",
    "X = df.drop(columns=[target_col])\n",
    "\n",
    "# ----- Extract attribute before transformation -----\n",
    "attribute_raw = df[attribute_col].values.copy()\n",
    "\n",
    "# ----- Detect column types -----\n",
    "categorical_cols = X.select_dtypes(include=['category', 'object']).columns.tolist()\n",
    "numerical_cols = X.select_dtypes(include=[np.number]).columns.tolist()\n",
    "\n",
    "# ----- Preprocessing pipeline -----\n",
    "preprocessor = ColumnTransformer(\n",
    "    transformers=[\n",
    "        ('num', Pipeline([\n",
    "            ('imputer', SimpleImputer(strategy='mean')),\n",
    "            ('scaler', StandardScaler())\n",
    "        ]), numerical_cols),\n",
    "        ('cat', Pipeline([\n",
    "            ('imputer', SimpleImputer(strategy='most_frequent')),\n",
    "            ('encoder', OneHotEncoder(handle_unknown='ignore'))\n",
    "        ]), categorical_cols)\n",
    "    ]\n",
    ")\n",
    "\n",
    "# ----- Encode target label -----\n",
    "label_encoder = LabelEncoder()\n",
    "y_encoded = label_encoder.fit_transform(y)  # 0 = bad, 1 = good\n",
    "\n",
    "# ----- Transform features -----\n",
    "X_processed = preprocessor.fit_transform(X)\n",
    "\n",
    "X_tensor = torch.tensor(X_processed.toarray() if hasattr(X_processed, 'toarray') else X_processed, dtype=torch.float32)\n",
    "y_tensor = torch.tensor(y_encoded, dtype=torch.float32)\n",
    "\n",
    "# ----- Encode attribute -----\n",
    "if attribute_raw.dtype == 'O' or isinstance(attribute_raw[0], str):\n",
    "    attr_encoder = LabelEncoder()\n",
    "    attr_encoded = attr_encoder.fit_transform(attribute_raw)\n",
    "    attr_tensor = torch.tensor(attr_encoded, dtype=torch.long)\n",
    "else:\n",
    "    attr_tensor = torch.tensor(attribute_raw, dtype=torch.float32)\n",
    "\n",
    "# ----- Updated Dataset class -----\n",
    "class CreditDataset(Dataset):\n",
    "    def __init__(self, X, y, attr):\n",
    "        self.X = X\n",
    "        self.y = y\n",
    "        self.attr = attr\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx], self.attr[idx]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.y)\n",
    "\n",
    "dataset = CreditDataset(X_tensor, y_tensor, attr_tensor)\n",
    "\n",
    "# ----- Split into train/test -----\n",
    "train_size = int(0.8 * len(dataset))\n",
    "test_size = len(dataset) - train_size\n",
    "\n",
    "generator = torch.Generator().manual_seed(0)\n",
    "train_ds, test_ds = random_split(dataset, [train_size, test_size], generator=generator)\n",
    "\n",
    "train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)\n",
    "test_loader = DataLoader(test_ds, batch_size=32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab36794e-f024-4eaf-aa58-76811610b422",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Target - good credit is 1 \n",
    "\n",
    "target_class = 0\n",
    "clean_class = 1\n",
    "num_samples = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5643db9a-54bb-4021-9f52-93a7e42c50a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_indices_by_label(dataset, target_label, num_samples):\n",
    "    indices = []\n",
    "    for i in range(len(dataset)):\n",
    "        _, label, _ = dataset[i]\n",
    "        if label == target_label:\n",
    "            indices.append(i)\n",
    "            if len(indices) == num_samples:\n",
    "                break\n",
    "    return indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d26b7375-e0d8-4984-b6bd-1417398113a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# clean_indices = get_indices_by_label(train_ds, target_label=clean_class, num_samples=num_samples)\n",
    "# clean_subset = Subset(train_ds, clean_indices)\n",
    "# clean_loader = DataLoader(clean_subset, batch_size=1, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6988a24-4044-4172-a2e0-28c0d2b23cf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(len(clean_subset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c299ddcc-9cea-4f4b-83b4-30b641911c31",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_indices_by_label_target(dataset, target_label, target_attribute, num_samples):\n",
    "    indices = []\n",
    "    for i in range(len(dataset)):\n",
    "        _, label, attribute = dataset[i]\n",
    "        if (label == target_label) and (attribute == target_attribute):\n",
    "            indices.append(i)\n",
    "            if len(indices) == num_samples:\n",
    "                break\n",
    "    return indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ad1643f-7397-49da-972f-8226c8319064",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get target class samples\n",
    "target_attribute_class = 1\n",
    "\n",
    "target_indices = get_indices_by_label_target(train_ds, target_label=target_class, target_attribute=target_attribute_class, num_samples=50)\n",
    "target_subset = Subset(train_ds, target_indices)\n",
    "target_loader = DataLoader(target_subset, batch_size=1, shuffle=False)\n",
    "\n",
    "# Mean Target Image \n",
    "target_mean = torch.stack([x for x, _, _ in list(target_loader)[:50]]).mean(dim=0).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f130ca6-ad16-4c71-a404-1d8899511a22",
   "metadata": {},
   "outputs": [],
   "source": [
    "target_mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df3b376f-42c1-40de-96f7-130bce51b2c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = \"neurips25_batch_order/models_logreg_clean_credit\"\n",
    "models_list = []\n",
    "seeds = [0, 1, 2, 3, 4]\n",
    "\n",
    "for seed in seeds:\n",
    "    model_path = os.path.join(savedir, f'seed{seed}', 'logreg_final.pth')\n",
    "    model = torch.load(model_path, map_location=device)\n",
    "    model = model.to(device)\n",
    "    models_list.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0b881fd-670a-4bfc-bd1e-0c491585876a",
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = \"neurips25_batch_order/models_mlp_clean_credit\"\n",
    "models_list = []\n",
    "seeds = [0, 1, 2, 3, 4]\n",
    "\n",
    "for seed in seeds:\n",
    "    model_path = os.path.join(savedir, f'seed{seed}', 'mlp_final.pth')\n",
    "    model = torch.load(model_path, map_location=device)\n",
    "    model = model.to(device)\n",
    "    models_list.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fdfd5d4-e0e7-4c92-bd5d-03c967fac63b",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 32"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2254b5c9-7687-4214-957a-1af250db153c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_subpop_outsubpop_indices(dataset, target_attribute_class):\n",
    "    subpop_indices = []\n",
    "    outsubpop_indices = []\n",
    "\n",
    "    for i in range(len(dataset)):\n",
    "        _, label, subpop = dataset[i]\n",
    "        if subpop == target_attribute_class:\n",
    "            subpop_indices.append(i)\n",
    "        else:\n",
    "            outsubpop_indices.append(i)\n",
    "    \n",
    "    return subpop_indices, outsubpop_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "580ce6c3-2471-48ee-9a5a-9ad3d0a799b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "subpop_indices, outsubpop_indices = get_subpop_outsubpop_indices(train_ds, target_attribute_class)\n",
    "\n",
    "subpop_subset = Subset(dataset, subpop_indices)\n",
    "outsubpop_subset = Subset(dataset, outsubpop_indices)\n",
    "\n",
    "subpop_loader = DataLoader(subpop_subset, batch_size=batch_size//2, shuffle=True)\n",
    "outsubpop_loader = DataLoader(outsubpop_subset, batch_size=batch_size//2, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eced0315-765f-4277-b5c2-8b91bf0b1d61",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(subpop_subset))\n",
    "print(len(outsubpop_subset))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95479a1d-049b-4961-85da-d2a72c7475fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(subpop_loader))\n",
    "print(len(outsubpop_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da5d9de0-bde7-45bf-a08a-c180c83d9407",
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Optimization setup ===\n",
    "delta_lr = 1e-1\n",
    "num_epochs = 20\n",
    "epsilon = 0.3\n",
    "\n",
    "lambda_match = 5.0\n",
    "lambda_sep = 2.0\n",
    "lambda_penalty = 1.0  # if you separate penalty weight\n",
    "lambda_ood = 1.0\n",
    "\n",
    "criterion = nn.BCELoss()\n",
    "\n",
    "input_dim = X_tensor.shape[1]\n",
    "delta = torch.randn(input_dim, device=device, requires_grad=True)\n",
    "optimizer = torch.optim.Adam([delta], lr=delta_lr)\n",
    "\n",
    "# Add a sparsity constraint \n",
    "sparsity_fraction = 0.3\n",
    "k = int(sparsity_fraction * input_dim)\n",
    "\n",
    "# --- Training loop ---\n",
    "print(\"Optimizing delta with Subpopulation regularization...\")\n",
    "for epoch in tqdm(range(num_epochs)):\n",
    "    for (x_in, y_in, subpop_in), (x_out, y_out, subpop_out) in zip(subpop_loader, outsubpop_loader):\n",
    "        x = torch.cat([x_in, x_out], dim=0).to(device)\n",
    "        y = torch.cat([y_in, y_out], dim=0).to(device)\n",
    "        subpop_label = torch.cat([subpop_in, subpop_out], dim=0).to(device)\n",
    "        batch_size = x.size(0)\n",
    "\n",
    "        delta_expanded = delta.unsqueeze(0).expand_as(x)\n",
    "        x_poisoned = x + delta_expanded\n",
    "        \n",
    "        # --- Compute clean gradient for this batch ---\n",
    "        clean_grads_batch = []\n",
    "        for model in models_list:\n",
    "            model.eval()\n",
    "            x.requires_grad = True  \n",
    "            logits_clean = model(x)\n",
    "            loss_clean = criterion(logits_clean, y.view(-1, 1).float())\n",
    "            g_clean = flatten_all_gradients(model, loss_clean, create_graph=True)\n",
    "            clean_grads_batch.append(g_clean.detach())  \n",
    "        clean_grad = torch.stack(clean_grads_batch).mean(dim=0)\n",
    "        \n",
    "        poisoned_grads = []\n",
    "        \n",
    "        # --- Mixed target for poisoned batch ---\n",
    "        subpop_mask = (subpop_label == target_attribute_class).float().view(-1, 1)\n",
    "        y = y.view(-1, 1)\n",
    "        target_tensor = subpop_mask * float(target_class) + (1 - subpop_mask) * y\n",
    "        target_tensor = torch.clamp(target_tensor, 0.0, 1.0)\n",
    "        \n",
    "        for model in models_list:\n",
    "            logits_poisoned = model(x_poisoned)\n",
    "            loss_adv = criterion(logits_poisoned, target_tensor)\n",
    "            g_poisoned = flatten_all_gradients(model, loss_adv, create_graph=True)\n",
    "            poisoned_grads.append(g_poisoned)\n",
    "        \n",
    "        adv_grad = torch.stack(poisoned_grads).mean(dim=0)\n",
    "\n",
    "        # --- Match Loss ---\n",
    "        match_loss = F.mse_loss(clean_grad, adv_grad, reduction='mean')\n",
    "\n",
    "        # --- Separability Loss (ASR) and Penalty for Non-Subpop Predictions ---\n",
    "        sep_loss = torch.tensor(0.0, device=device)\n",
    "        penalty_loss = torch.tensor(0.0, device=device)\n",
    "\n",
    "        subpop_mask = (subpop_label == target_attribute_class).float().view(-1, 1)\n",
    "        y = y.view(-1, 1)\n",
    "        not_target_mask = (y != float(target_class)).float()\n",
    "        \n",
    "        for model in models_list:\n",
    "            logits = model(x_poisoned)\n",
    "            probs = logits  # assume already sigmoid-ed if needed\n",
    "        \n",
    "            # --- Separability Loss ---\n",
    "            sep_condition = (subpop_mask * not_target_mask)  # in subpop AND originally not target\n",
    "        \n",
    "            if sep_condition.sum() > 0:\n",
    "                sep_targets = torch.full_like(probs, float(target_class))\n",
    "                loss_sep = F.binary_cross_entropy(probs, sep_targets, reduction='none')\n",
    "                loss_sep = (loss_sep * sep_condition).sum() / sep_condition.sum()\n",
    "                sep_loss += loss_sep\n",
    "            else:\n",
    "                # if no relevant samples, skip\n",
    "                sep_loss += torch.tensor(0.0, device=device)\n",
    "        \n",
    "            # --- Penalty Loss ---\n",
    "            penalty_condition = ((1 - subpop_mask) * not_target_mask)  # out of subpop AND originally not target\n",
    "        \n",
    "            if penalty_condition.sum() > 0:\n",
    "                penalty_targets = torch.full_like(probs, 1.0 - float(target_class))\n",
    "                loss_penalty = F.binary_cross_entropy(probs, penalty_targets, reduction='none')\n",
    "                loss_penalty = (loss_penalty * penalty_condition).sum() / penalty_condition.sum()\n",
    "                penalty_loss += loss_penalty\n",
    "            else:\n",
    "                penalty_loss += torch.tensor(0.0, device=device)\n",
    "        \n",
    "        sep_loss /= len(models_list)\n",
    "        penalty_loss /= len(models_list)\n",
    "\n",
    "\n",
    "        # --- OOD Regularization ---\n",
    "        ood_loss = F.cosine_similarity(delta.view(1, -1), target_mean.view(1, -1)).mean()\n",
    "        ood_loss = 1+ood_loss\n",
    "\n",
    "        total_loss = (\n",
    "            lambda_match * match_loss +\n",
    "            lambda_sep * sep_loss +\n",
    "            lambda_ood * ood_loss +\n",
    "            lambda_penalty * penalty_loss  \n",
    "        )\n",
    "        optimizer.zero_grad()\n",
    "        total_loss.backward() \n",
    "        optimizer.step()\n",
    "\n",
    "        # === Enforce L-inf and sparsity constraint AFTER optimizer step ===\n",
    "        with torch.no_grad():\n",
    "            # Enforce L-inf constraint\n",
    "            delta.clamp_(-epsilon, epsilon)\n",
    "\n",
    "            # Enforce sparsity: keep top-k absolute values\n",
    "            flat_delta = delta.view(-1)\n",
    "            abs_delta = flat_delta.abs()\n",
    "\n",
    "            if k < input_dim:\n",
    "                topk_vals, topk_idx = torch.topk(abs_delta, k, sorted=False)\n",
    "                sparse_mask = torch.zeros_like(flat_delta)\n",
    "                sparse_mask[topk_idx] = 1.0\n",
    "                flat_delta.mul_(sparse_mask) \n",
    "\n",
    "\n",
    "    if epoch % 5 == 0 or epoch == num_epochs - 1:\n",
    "        print(f\"[Epoch {epoch}] Match: {match_loss.item():.4f}, Sep: {sep_loss.item():.4f}, Penalty: {penalty_loss.item():.4f},  OOD: {ood_loss.item():.4f}\")\n",
    "        print(f\"Delta L-inf norm: {delta.abs().max().item():.4f}\")\n",
    "        print(f\"Delta L2 norm: {delta.norm().item():.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8e18f908-90c0-4f45-8112-b9a6f9f86cd5",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c19597b4-757c-40cb-ac15-c8edf81011d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_flag = torch.load(\"neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/credit-g_delta_epsilon0.3_sparse30_flag.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5a75ddf8-cd1e-433e-83a1-6909c375e4a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "delta_class = torch.load(\"neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/credit-g_delta_epsilon0.3_sparse30_class_rep.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "61d9d75a-667f-4b4c-ba54-49b90ca99a6b",
   "metadata": {},
   "source": [
    "## Evaluate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0f13e28b-6343-4e31-8614-5ac4028ecc84",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_adversarial_dataset(dataset, delta=None, y_adv=0, alpha=1.0, target_attribute_class=1):\n",
    "    adv_images = []\n",
    "    adv_labels = []\n",
    "    adv_subpops = []\n",
    "\n",
    "    for i in range(len(dataset)):\n",
    "        image, label, subpop = dataset[i]\n",
    "        \n",
    "        # IMPORTANT: Assume 'image' is already preprocessed input X\n",
    "        perturbed_image = image + alpha * delta\n",
    "        perturbed_image = perturbed_image  # NO clamp here, because feature vectors aren't between [0,1] necessarily\n",
    "\n",
    "        adv_images.append(perturbed_image)\n",
    "\n",
    "        # Determine if sample is in subpopulation\n",
    "        in_subpop = (subpop == target_attribute_class)\n",
    "\n",
    "        if in_subpop:\n",
    "            adv_labels.append(torch.tensor(y_adv))\n",
    "        else:\n",
    "            adv_labels.append(torch.tensor(label))\n",
    "\n",
    "        adv_subpops.append(subpop)\n",
    "\n",
    "    return torch.stack(adv_images), torch.tensor(adv_labels), torch.tensor(adv_subpops)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21613224-bf31-4594-a61c-1c5e9f25efc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, test_dataset, test_loader, criterion, device, delta, target_class, target_attribute_class, dataset_name=\"Test\"):\n",
    "    model.eval()\n",
    "\n",
    "    benign_correct = 0\n",
    "    total = 0\n",
    "    test_loss = 0.0\n",
    "\n",
    "    # Evaluate on benign dataset\n",
    "    with torch.no_grad():\n",
    "        for images, labels, _ in tqdm(test_loader, desc=f\"{dataset_name} Benign\"):\n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device).float().unsqueeze(1)\n",
    "\n",
    "            outputs = model(images)\n",
    "            loss = criterion(outputs, labels)\n",
    "            test_loss += loss.item()\n",
    "\n",
    "            predicted = (outputs > 0.5).float()\n",
    "            total += labels.size(0)\n",
    "            benign_correct += (predicted == labels).sum().item()\n",
    "\n",
    "    benign_accuracy = 100 * benign_correct / total\n",
    "    print(f\"{dataset_name} Benign Loss: {test_loss / len(test_loader):.4f}, {dataset_name} Benign Accuracy: {benign_accuracy:.2f}%\")\n",
    "\n",
    "    # Create poisoned (triggered) dataset\n",
    "    adv_images, adv_labels, adv_subpops = create_adversarial_dataset(\n",
    "        test_dataset, \n",
    "        delta=delta.detach().cpu(), \n",
    "        y_adv=target_class, \n",
    "        alpha=1.0, \n",
    "        target_attribute_class=target_attribute_class, \n",
    "    )\n",
    "    adv_loader = DataLoader(list(zip(adv_images, adv_labels, adv_subpops)), batch_size=32, shuffle=False, num_workers=4)\n",
    "\n",
    "    attack_success = 0\n",
    "    outsub_correct = 0\n",
    "    total_in_subpop = 0\n",
    "    total_out_subpop = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for images, labels, subpops in tqdm(adv_loader, desc=f\"{dataset_name} Triggered\"):\n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device).float().unsqueeze(1)\n",
    "            subpops = torch.tensor(subpops, device=device)\n",
    "\n",
    "            outputs = model(images)\n",
    "            predicted = (outputs > 0.5).float()\n",
    "\n",
    "            # In-subpopulation mask: subpop_label == target_attribute_class\n",
    "            in_subpop_mask = (subpops == target_attribute_class).view(-1, 1)\n",
    "            attack_success += (predicted[in_subpop_mask] == target_class).sum().item()\n",
    "            total_in_subpop += in_subpop_mask.sum().item()\n",
    "\n",
    "            # Out-of-subpopulation mask: subpop_label != target_attribute_class\n",
    "            out_subpop_mask = (subpops != target_attribute_class).view(-1, 1)\n",
    "            outsub_correct += (predicted[out_subpop_mask] == labels[out_subpop_mask]).sum().item()\n",
    "            total_out_subpop += out_subpop_mask.sum().item()\n",
    "\n",
    "    attack_success_rate = 100 * attack_success / total_in_subpop if total_in_subpop > 0 else 0.0\n",
    "    outsub_accuracy = 100 * outsub_correct / total_out_subpop if total_out_subpop > 0 else 0.0\n",
    "\n",
    "    print(f\"{dataset_name} Out-of-Subpop Accuracy (triggered): {outsub_accuracy:.2f}%\")\n",
    "    print(f\"{dataset_name} Attack Success Rate (in subpop): {attack_success_rate:.2f}%\")\n",
    "\n",
    "    return benign_accuracy, attack_success_rate, outsub_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c74810ad-50c1-4747-9a40-481ea1275a8b",
   "metadata": {},
   "outputs": [],
   "source": [
    "savedir = \"neurips25_batch_order/models_mlp_clean_credit\"\n",
    "models_list2 = []\n",
    "seeds = [0, 1, 2, 3, 4]\n",
    "\n",
    "for seed in seeds:\n",
    "    model_path = os.path.join(savedir, f'seed{seed}', 'mlp_final.pth')\n",
    "    model = torch.load(model_path, map_location=device)\n",
    "    model = model.to(device)\n",
    "    models_list2.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "71506bcd-8ff0-4cd0-b894-850325d2b363",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(models_list[0], test_ds, test_loader, criterion, device, delta, target_class, target_attribute_class, dataset_name=\"Test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d7a5c286-95d6-43ed-8f7e-4a59adbdbb14",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(models_list[0], test_ds, test_loader, criterion, device, delta, target_class, target_attribute_class, dataset_name=\"Test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff9a9612-39d8-4a9b-8aeb-701b7c1c5604",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(models_list2[0], test_ds, test_loader, criterion, device, delta, target_class, target_attribute_class, dataset_name=\"Test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea975a44-9cd2-43e9-aef1-24c99f4f07d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(models_list2[1], test_ds, test_loader, criterion, device, delta, target_class, target_attribute_class, dataset_name=\"Test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "48c9beff-3211-4a6a-a685-9c7c43ff399f",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(models_list2[1], test_ds, test_loader, criterion, device, delta_class, target_class, target_attribute_class, dataset_name=\"Test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4225715-7703-42e4-9aa3-eec143d1d797",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate(models_list2[0], test_ds, test_loader, criterion, device, delta_flag, target_class, target_attribute_class, dataset_name=\"Test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4143c1cc-c570-4300-b3d2-2f024c1f5fc1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def count_in_subpop_not_target(dataset, target_class, target_attribute_class):\n",
    "    count = 0\n",
    "    total = 0\n",
    "\n",
    "    for i in range(len(dataset)):\n",
    "        x, label, subpop = dataset[i]\n",
    "\n",
    "        if subpop == target_attribute_class:\n",
    "            total += 1\n",
    "            if label != target_class:\n",
    "                count += 1\n",
    "\n",
    "    print(f\"Total samples in subpopulation (subpop == {target_attribute_class}): {total}\")\n",
    "    print(f\"Samples in subpopulation and not target class (label != {target_class}): {count}\")\n",
    "    return count, total\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a80f444-8535-436b-8f20-656a2722041b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Example\n",
    "count_in_subpop_not_target(test_ds, target_class=target_class, target_attribute_class=target_attribute_class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c154e678-37e7-4cfd-b8d4-d3142d4e2364",
   "metadata": {},
   "outputs": [],
   "source": [
    "count_in_subpop_not_target(train_ds, target_class=target_class, target_attribute_class=target_attribute_class)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "13aeab64-27e2-4e8c-84a9-2b6f3645c5f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"Delta max change: {delta.abs().max().item()}\")\n",
    "print(f\"Delta L2 norm: {delta.norm().item()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1f875db-cc2a-46c4-8007-d9ddf6150dfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "adv_images, adv_labels, adv_subpops = create_adversarial_dataset(\n",
    "    test_ds, \n",
    "    delta=delta.detach().cpu(), \n",
    "    y_adv=target_class, \n",
    "    alpha=1.0, \n",
    "    target_attribute_class=target_attribute_class, \n",
    ")\n",
    "adv_loader = DataLoader(list(zip(adv_images, adv_labels, adv_subpops)), batch_size=32, shuffle=False, num_workers=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e4d6cc8-8593-417e-ae34-b361da44bd0f",
   "metadata": {},
   "outputs": [],
   "source": [
    "total_in_subpop = 0\n",
    "attack_success = 0\n",
    "\n",
    "with torch.no_grad():\n",
    "    for images, labels, subpops in adv_loader:\n",
    "        images = images.to(device)\n",
    "        labels = labels.to(device).float().unsqueeze(1)\n",
    "        subpops = torch.tensor(subpops, device=device)\n",
    "\n",
    "        outputs = model(images)\n",
    "        predicted = (outputs > 0.5).float()\n",
    "\n",
    "        in_subpop_mask = (subpops == target_attribute_class).view(-1, 1)\n",
    "        \n",
    "        if in_subpop_mask.sum() > 0:\n",
    "            print(f\"\\nSubpop examples:\")\n",
    "            print(f\"Labels: {labels[in_subpop_mask].squeeze().cpu().numpy()}\")\n",
    "            print(f\"Predictions: {predicted[in_subpop_mask].squeeze().cpu().numpy()}\")\n",
    "            print(f\"Raw outputs: {outputs[in_subpop_mask].squeeze().cpu().numpy()}\")\n",
    "        \n",
    "        attack_success += (predicted[in_subpop_mask] == target_class).sum().item()\n",
    "        total_in_subpop += in_subpop_mask.sum().item()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a4aff99a-e8a4-4da2-9809-5c11a00dc249",
   "metadata": {},
   "source": [
    "# Running subpopulation all results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "abe9eff1-9d35-4270-852a-61ddf65c2f2d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ===== Set global parameters =====\n",
    "batch_size = 32\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "seeds = [0, 1, 2, 3, 4]\n",
    "num_samples = 500  # For clean loader\n",
    "target_samples = 50  # For target loader\n",
    "lambda_match = 5.0\n",
    "lambda_sep = 2.0\n",
    "lambda_penalty = 1.0\n",
    "lambda_ood = 1.0\n",
    "num_epochs = 20\n",
    "delta_lr = 1e-1\n",
    "epsilon = 0.3\n",
    "sparsity_fraction = 0.3\n",
    "\n",
    "criterion = torch.nn.BCELoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99f32eca-5710-4147-8992-96e0377211fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4cde9ec-f2da-4eb0-8525-9fa426ffc40b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7997cb65-43fd-46fc-8e72-de232bbc20fa",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(len(train_ds))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b07ef262-cfa7-4e92-9c81-7f88deaaf89e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['personal_status'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a08b70e7-f2ef-4b05-beb5-182d2d5264ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Attribute columns\n",
    "attribute_cols = ['personal_status', 'credit_history', 'purpose', 'housing', \n",
    "                  'job', 'employment', 'age', 'num_dependents', 'property_magnitude'] "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4d87f5e-4db2-4451-8ffe-0b162aaa96e0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Pretrained MLP evaluation models (whitebox evaluation)\n",
    "mlp_models_list = []\n",
    "for seed in seeds:\n",
    "    model = torch.load(f\"neurips25_batch_order/models_mlp_clean_credit/seed{seed}/mlp_final.pth\", map_location=device)\n",
    "    model = model.to(device)\n",
    "    mlp_models_list.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "da0501fe-e09f-49ed-8cff-723403ef51fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "logreg_models_list = []\n",
    "for seed in seeds:\n",
    "    model = torch.load(f\"neurips25_batch_order/models_logreg_clean_credit/seed{seed}/logreg_final.pth\", map_location=device)\n",
    "    model = model.to(device)\n",
    "    logreg_models_list.append(model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c0562cf-e420-4fe0-8302-46e234d91330",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Preprocess X and y ONCE\n",
    "target_col = 'class'\n",
    "y = df[target_col]\n",
    "X = df.drop(columns=[target_col])\n",
    "\n",
    "# Target labels\n",
    "label_encoder = LabelEncoder()\n",
    "y_encoded = label_encoder.fit_transform(y)\n",
    "y_tensor = torch.tensor(y_encoded, dtype=torch.float32)\n",
    "\n",
    "# Features\n",
    "categorical_cols = X.select_dtypes(include=['category', 'object']).columns.tolist()\n",
    "numerical_cols = X.select_dtypes(include=[np.number]).columns.tolist()\n",
    "\n",
    "preprocessor = ColumnTransformer([\n",
    "    ('num', Pipeline([\n",
    "        ('imputer', SimpleImputer(strategy='mean')),\n",
    "        ('scaler', StandardScaler())\n",
    "    ]), numerical_cols),\n",
    "    ('cat', Pipeline([\n",
    "        ('imputer', SimpleImputer(strategy='most_frequent')),\n",
    "        ('encoder', OneHotEncoder(handle_unknown='ignore'))\n",
    "    ]), categorical_cols)\n",
    "])\n",
    "\n",
    "X_processed = preprocessor.fit_transform(X)\n",
    "X_tensor = torch.tensor(X_processed.toarray() if hasattr(X_processed, 'toarray') else X_processed, dtype=torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d88e47ed-aa6e-4a90-b320-45c2fc9b57c4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dataset Class\n",
    "class CreditDataset(Dataset):\n",
    "    def __init__(self, X, y, attr):\n",
    "        self.X = X\n",
    "        self.y = y\n",
    "        self.attr = attr\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        return self.X[idx], self.y[idx], self.attr[idx]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.y)\n",
    "\n",
    "# Subpopulation splitting\n",
    "def get_subpop_outsubpop_indices(dataset, target_attribute_class):\n",
    "    subpop_indices, outsubpop_indices = [], []\n",
    "    for i in range(len(dataset)):\n",
    "        _, label, attr = dataset[i]\n",
    "        if attr == target_attribute_class:\n",
    "            subpop_indices.append(i)\n",
    "        else:\n",
    "            outsubpop_indices.append(i)\n",
    "    return subpop_indices, outsubpop_indices\n",
    "\n",
    "# Get clean samples by label\n",
    "def get_indices_by_label(dataset, target_label, num_samples=None):\n",
    "    indices = []\n",
    "    for i in range(len(dataset)):\n",
    "        _, label, _ = dataset[i]\n",
    "        if label == target_label:\n",
    "            indices.append(i)\n",
    "            if num_samples and len(indices) == num_samples:\n",
    "                break\n",
    "    return indices\n",
    "\n",
    "# Set up dataset based on attribute column\n",
    "def setup_dataset_for_attribute(attribute_col, X_tensor, y_tensor, df):\n",
    "    attribute_raw = df[attribute_col].values.copy()\n",
    "\n",
    "    if attribute_raw.dtype == 'O' or isinstance(attribute_raw[0], str):\n",
    "        attr_encoder = LabelEncoder()\n",
    "        attr_encoded = attr_encoder.fit_transform(attribute_raw)\n",
    "    else:\n",
    "        attr_encoded = attribute_raw\n",
    "\n",
    "    attr_tensor = torch.tensor(attr_encoded, dtype=torch.long if np.issubdtype(attr_encoded.dtype, np.integer) else torch.float32)\n",
    "\n",
    "    full_dataset = CreditDataset(X_tensor, y_tensor, attr_tensor)\n",
    "\n",
    "    # Standard train/test split\n",
    "    train_size = int(0.8 * len(full_dataset))\n",
    "    test_size = len(full_dataset) - train_size\n",
    "    generator = torch.Generator().manual_seed(0)\n",
    "    train_ds, test_ds = random_split(full_dataset, [train_size, test_size], generator=generator)\n",
    "\n",
    "    return train_ds, test_ds, attr_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef8963c9-80e7-402e-9680-2a5eac049fbe",
   "metadata": {},
   "outputs": [],
   "source": [
    "def evaluate(model, test_dataset, test_loader, criterion, device, delta, target_class, target_attribute_class, dataset_name=\"Test\"):\n",
    "    model.eval()\n",
    "\n",
    "    benign_correct = 0\n",
    "    total = 0\n",
    "    test_loss = 0.0\n",
    "\n",
    "    # Evaluate on benign dataset\n",
    "    with torch.no_grad():\n",
    "        for images, labels, _ in test_loader:\n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device).float().unsqueeze(1)\n",
    "\n",
    "            outputs = model(images)\n",
    "            loss = criterion(outputs, labels)\n",
    "            test_loss += loss.item()\n",
    "\n",
    "            predicted = (outputs > 0.5).float()\n",
    "            total += labels.size(0)\n",
    "            benign_correct += (predicted == labels).sum().item()\n",
    "\n",
    "    benign_accuracy = 100 * benign_correct / total\n",
    "\n",
    "    # Create poisoned (triggered) dataset\n",
    "    adv_images, adv_labels, adv_subpops = create_adversarial_dataset(\n",
    "        test_dataset, \n",
    "        delta=delta.detach().cpu(), \n",
    "        y_adv=target_class, \n",
    "        alpha=1.0, \n",
    "        target_attribute_class=target_attribute_class, \n",
    "    )\n",
    "    adv_loader = DataLoader(list(zip(adv_images, adv_labels, adv_subpops)), batch_size=32, shuffle=False, num_workers=4)\n",
    "\n",
    "    attack_success = 0\n",
    "    outsub_correct = 0\n",
    "    total_in_subpop = 0\n",
    "    total_out_subpop = 0\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for images, labels, subpops in adv_loader:\n",
    "            images = images.to(device)\n",
    "            labels = labels.to(device).float().unsqueeze(1)\n",
    "            subpops = torch.tensor(subpops, device=device)\n",
    "\n",
    "            outputs = model(images)\n",
    "            predicted = (outputs > 0.5).float()\n",
    "\n",
    "            # In-subpopulation mask: subpop_label == target_attribute_class\n",
    "            in_subpop_mask = (subpops == target_attribute_class).view(-1, 1)\n",
    "            attack_success += (predicted[in_subpop_mask] == target_class).sum().item()\n",
    "            total_in_subpop += in_subpop_mask.sum().item()\n",
    "\n",
    "            # Out-of-subpopulation mask: subpop_label != target_attribute_class\n",
    "            out_subpop_mask = (subpops != target_attribute_class).view(-1, 1)\n",
    "            outsub_correct += (predicted[out_subpop_mask] == labels[out_subpop_mask]).sum().item()\n",
    "            total_out_subpop += out_subpop_mask.sum().item()\n",
    "\n",
    "    attack_success_rate = 100 * attack_success / total_in_subpop if total_in_subpop > 0 else 0.0\n",
    "    outsub_accuracy = 100 * outsub_correct / total_out_subpop if total_out_subpop > 0 else 0.0\n",
    "\n",
    "    return benign_accuracy, attack_success_rate, outsub_accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e81ec2cc-811e-48cc-86c8-0f8118874cbf",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_root_dir = \"neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e060985f-0b45-4070-b75f-971a17d7235d",
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09ed34d7-86c8-4787-8dca-65e1e660f83a",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.columns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "696c8d1c-c2c0-4c31-9aa2-be29b4c93b6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['personal_status'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f43d10f4-ef01-4c77-a204-c92a7413fbf6",
   "metadata": {},
   "outputs": [],
   "source": [
    "attribute_cols"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "924d6eda-894a-410f-a69e-0ed5f3fc6955",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Main results container\n",
    "results = []\n",
    "\n",
    "for attribute_col in tqdm(attribute_cols):\n",
    "    print(f\"=== Working on Attribute: {attribute_col} ===\")\n",
    "\n",
    "    # Set up dataset for this attribute column\n",
    "    train_ds, test_ds, attr_tensor = setup_dataset_for_attribute(attribute_col, X_tensor, y_tensor, df)\n",
    "\n",
    "    unique_classes = np.unique(attr_tensor.numpy())\n",
    "\n",
    "    for target_attribute_class in unique_classes:\n",
    "        # Check how many samples in subpopulation\n",
    "        subpop_indices, _ = get_subpop_outsubpop_indices(train_ds, target_attribute_class)\n",
    "        \n",
    "        if len(subpop_indices) < 32:\n",
    "            print(f\"Skipping target_attribute_class={target_attribute_class} (only {len(subpop_indices)} samples)\")\n",
    "            continue  # Skip if too few samples\n",
    "            \n",
    "        for target_class in [0, 1]:  # binary labels\n",
    "            for whitebox in [False]:  # optimize delta with MLP vs LogReg\n",
    "\n",
    "                print(f\"Attribute={attribute_col}, TargetAttr={target_attribute_class}, TargetClass={target_class}, Whitebox={whitebox}\")\n",
    "\n",
    "                # Set up loaders for optimization\n",
    "                subpop_indices, outsubpop_indices = get_subpop_outsubpop_indices(train_ds, target_attribute_class)\n",
    "                subpop_subset = Subset(train_ds, subpop_indices)\n",
    "                outsubpop_subset = Subset(train_ds, outsubpop_indices)\n",
    "                subpop_loader = DataLoader(subpop_subset, batch_size=batch_size//2, shuffle=True)\n",
    "                outsubpop_loader = DataLoader(outsubpop_subset, batch_size=batch_size//2, shuffle=True)\n",
    "\n",
    "                # --- Load models ---\n",
    "                if whitebox:\n",
    "                    models_list = mlp_models_list\n",
    "                else:\n",
    "                    models_list = logreg_models_list  # you would load your logreg_models_list\n",
    "\n",
    "                # --- Initialize delta ---\n",
    "                input_dim = X_tensor.shape[1]\n",
    "                delta = torch.randn(input_dim, device=device, requires_grad=True)\n",
    "                optimizer = torch.optim.Adam([delta], lr=delta_lr)\n",
    "                k = int(sparsity_fraction * input_dim)\n",
    "\n",
    "                # === Optimize delta ===\n",
    "                for epoch in range(num_epochs):\n",
    "                    for (x_in, y_in, subpop_in), (x_out, y_out, subpop_out) in zip(subpop_loader, outsubpop_loader):\n",
    "                        x = torch.cat([x_in, x_out], dim=0).to(device)\n",
    "                        y = torch.cat([y_in, y_out], dim=0).to(device)\n",
    "                        subpop_label = torch.cat([subpop_in, subpop_out], dim=0).to(device)\n",
    "\n",
    "                        delta_expanded = delta.unsqueeze(0).expand_as(x)\n",
    "                        x_poisoned = x + delta_expanded\n",
    "\n",
    "                        # Clean grads\n",
    "                        clean_grads_batch = []\n",
    "                        for model in models_list:\n",
    "                            model.eval()\n",
    "                            x.requires_grad = True\n",
    "                            logits_clean = model(x)\n",
    "                            loss_clean = criterion(logits_clean, y.view(-1,1))\n",
    "                            g_clean = flatten_all_gradients(model, loss_clean, create_graph=True)\n",
    "                            clean_grads_batch.append(g_clean.detach())\n",
    "                        clean_grad = torch.stack(clean_grads_batch).mean(0)\n",
    "\n",
    "                        # Poisoned grads\n",
    "                        poisoned_grads = []\n",
    "                        subpop_mask = (subpop_label == target_attribute_class).float().view(-1,1)\n",
    "                        target_tensor = subpop_mask * float(target_class) + (1-subpop_mask) * y.view(-1,1)\n",
    "                        for model in models_list:\n",
    "                            logits_poisoned = model(x_poisoned)\n",
    "                            loss_adv = criterion(logits_poisoned, target_tensor)\n",
    "                            g_poisoned = flatten_all_gradients(model, loss_adv, create_graph=True)\n",
    "                            poisoned_grads.append(g_poisoned)\n",
    "                        adv_grad = torch.stack(poisoned_grads).mean(0)\n",
    "\n",
    "                        # Match loss\n",
    "                        match_loss = torch.nn.functional.mse_loss(clean_grad, adv_grad, reduction='mean')\n",
    "\n",
    "                        # Total loss\n",
    "                        total_loss = match_loss\n",
    "                        optimizer.zero_grad()\n",
    "                        total_loss.backward()\n",
    "                        optimizer.step()\n",
    "\n",
    "                        with torch.no_grad():\n",
    "                            delta.clamp_(-epsilon, epsilon)\n",
    "                            flat_delta = delta.view(-1)\n",
    "                            abs_delta = flat_delta.abs()\n",
    "                            if k < input_dim:\n",
    "                                topk_vals, topk_idx = torch.topk(abs_delta, k, sorted=False)\n",
    "                                sparse_mask = torch.zeros_like(flat_delta)\n",
    "                                sparse_mask[topk_idx] = 1.0\n",
    "                                flat_delta.mul_(sparse_mask)\n",
    "\n",
    "                # === Save delta after optimization ===\n",
    "                save_folder = os.path.join(save_root_dir, f'{attribute_col}')\n",
    "                os.makedirs(save_folder, exist_ok=True)\n",
    "\n",
    "                save_filename = f\"delta_attribute-{attribute_col}_attrclass-{target_attribute_class}_targetclass-{target_class}_whitebox-{whitebox}.pt\"\n",
    "                save_path = os.path.join(save_folder, save_filename)\n",
    "\n",
    "                torch.save(delta.detach().cpu(), save_path)\n",
    "\n",
    "                # === Evaluate ===\n",
    "                for model in mlp_models_list:  # always evaluate on MLP\n",
    "                    benign_acc, asr, outsub_acc = evaluate(model, test_ds, test_loader, criterion, device, delta, target_class, target_attribute_class)\n",
    "                    results.append({\n",
    "                        'attribute_col': attribute_col,\n",
    "                        'target_class': target_class,\n",
    "                        'target_attribute_class': target_attribute_class,\n",
    "                        'whitebox': whitebox,\n",
    "                        'benign_acc': benign_acc,\n",
    "                        'asr': asr,\n",
    "                        'outsub_acc': outsub_acc,\n",
    "                        'delta_type': 'learned'\n",
    "                    })\n",
    "\n",
    "                # Also evaluate delta_flag and delta_class\n",
    "                for baseline_delta, delta_type in [(delta_flag, 'flag'), (delta_class, 'class')]:\n",
    "                    for model in mlp_models_list:\n",
    "                        benign_acc, asr, outsub_acc = evaluate(model, test_ds, test_loader, criterion, device, baseline_delta, target_class, target_attribute_class)\n",
    "                        results.append({\n",
    "                            'attribute_col': attribute_col,\n",
    "                            'target_class': target_class,\n",
    "                            'target_attribute_class': target_attribute_class,\n",
    "                            'whitebox': whitebox,\n",
    "                            'benign_acc': benign_acc,\n",
    "                            'asr': asr,\n",
    "                            'outsub_acc': outsub_acc,\n",
    "                            'delta_type': delta_type\n",
    "                        })\n",
    "\n",
    "# === Save results ===\n",
    "results_df = pd.DataFrame(results)\n",
    "# results_df.to_csv(\"trigger_attack_results_credit.csv\", index=False)\n",
    "# print(\"All experiments complete!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8548c5d1-dbab-47a1-ad4c-ef90a39e3c55",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59ad0c3f-0bb3-4eb0-b691-d29974221004",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_mean_df = results_df.groupby(\n",
    "    ['attribute_col', 'target_class', 'target_attribute_class', 'whitebox', 'delta_type'],\n",
    "    as_index=False\n",
    ").mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fc233fed-5666-4a8d-b4fc-b24096137091",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_mean_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a9784e5-f53c-4ce1-b434-38341e3235fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Split\n",
    "learned_df = results_mean_df[results_mean_df['delta_type'] == 'learned']\n",
    "class_df = results_mean_df[results_mean_df['delta_type'] == 'class']\n",
    "flag_df = results_mean_df[results_mean_df['delta_type'] == 'flag']\n",
    "\n",
    "# Explicit renaming BEFORE merge\n",
    "class_df = class_df.rename(columns={\n",
    "    'benign_acc': 'benign_acc_class', \n",
    "    'asr': 'asr_class', \n",
    "    'outsub_acc': 'outsub_acc_class'\n",
    "})\n",
    "\n",
    "flag_df = flag_df.rename(columns={\n",
    "    'benign_acc': 'benign_acc_flag', \n",
    "    'asr': 'asr_flag', \n",
    "    'outsub_acc': 'outsub_acc_flag'\n",
    "})\n",
    "\n",
    "# Now clean merges\n",
    "merged_df = learned_df.merge(\n",
    "    class_df, \n",
    "    on=['attribute_col', 'target_class', 'target_attribute_class', 'whitebox']\n",
    ").merge(\n",
    "    flag_df,\n",
    "    on=['attribute_col', 'target_class', 'target_attribute_class', 'whitebox']\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c8fe5430-32e1-4a37-bce6-85293bf5646e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Learned ASR is better than class or flag, true in all cases"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1504cf9a-f7d7-40d0-8bc7-ffd25f61f157",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df = merged_df.sort_values(by='asr', ascending=True).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8f2e522-ad73-467f-bb0f-0738de7848a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d284dcf-fb83-4a25-828c-65e5a5c8ad7d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Step 1: Compute the maximum gap between learned and baselines\n",
    "merged_df['asr_gap'] = merged_df.apply(\n",
    "    lambda row: max(row['asr'] - row['asr_class'], row['asr'] - row['asr_flag']),\n",
    "    axis=1\n",
    ")\n",
    "\n",
    "# Step 2: Sort by asr_gap descending (large to small)\n",
    "merged_df_gap = merged_df.sort_values(by='asr_gap', ascending=False).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ae6ef626-b434-4b8b-b793-5431f623b9a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Subset to asr gap that is above \n",
    "merged_df_gap = merged_df_gap[merged_df_gap['asr_gap'] > 30]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0c4b1b7-50b7-4282-802a-0c18d5ea6247",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "merged_df_gap"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c07212e1-d5a7-42ee-83b8-1857bc38a3d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Subset to ASR that is above 70% \n",
    "\n",
    "high_asr_df = merged_df[merged_df['asr'] > 70]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d84325f6-b8ff-4af7-beb7-5e139c93c188",
   "metadata": {},
   "outputs": [],
   "source": [
    "high_asr_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "78d3e369-1fe3-4eba-91d0-dcf2bcbca6eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "high_asr_df[['attribute_col']].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a76296a5-d997-4d46-9350-73bd544982a6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# GAP: asr of max((learned - class), (learned - flag)) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6be7089b-60ec-4199-8439-941fa637a6b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For gap, whitebox much better than blackbox \n",
    "merged_df_gap['whitebox'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fbda969-a4f9-486f-b1e0-1b2e90e1ba92",
   "metadata": {},
   "outputs": [],
   "source": [
    "# For gap, 0 is a significantly better target class \n",
    "merged_df_gap['target_class'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0cec36b7-6a3e-4404-a2ee-2a6496888e64",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Either whitebox or blackbox works \n",
    "high_asr_df['whitebox'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "087c91d4-d6dc-4481-a540-4ecf6fe35fc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Higher asr: predicting credit to have high credit score when it should be bad credit is easier  \n",
    "high_asr_df['target_class'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca63dde2-a07b-4757-9a94-54e743e6c5f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df['whitebox'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74f7a873-4c56-459d-a2c7-aa6838173e5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df['target_class'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4798b88a-cd3e-417d-9616-7606ba70bef8",
   "metadata": {},
   "outputs": [],
   "source": [
    "merged_df['attribute_col'].value_counts()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e297c1bb-ceb2-44fd-8bf7-4a86de8cb7e7",
   "metadata": {},
   "source": [
    "## Subpop adversarial training "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "740f1baa-0f49-4f47-97ab-209a750d489f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_classifier_gradients(model, x, y):\n",
    "    model.eval()\n",
    "    x, y = x.to(device), y.to(device)\n",
    "\n",
    "    if y.dim() == 2 and y.size(1) == 1:\n",
    "        y = y.squeeze(1)\n",
    "\n",
    "    y = y.unsqueeze(1)\n",
    "\n",
    "    model.zero_grad()\n",
    "    outputs = model(x)\n",
    "    loss = F.binary_cross_entropy(outputs, y)\n",
    "    loss.backward()\n",
    "\n",
    "    gradients = []\n",
    "    for param in model.parameters():\n",
    "        if param.grad is not None:\n",
    "            gradients.append(param.grad.view(-1))\n",
    "\n",
    "    if len(gradients) == 0:\n",
    "        raise RuntimeError(\"No gradients found for model parameters.\")\n",
    "\n",
    "    gradients = torch.cat(gradients).detach()\n",
    "    return gradients.cpu()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0926860-4f8f-4141-90f2-031ff3949814",
   "metadata": {},
   "outputs": [],
   "source": [
    "def match_adversarial_batches(adv_X, adv_y, train_loader, surrogate_model, num_batches, full_grad=True):\n",
    "    print(\"Matching gradients to natural examples...\")\n",
    "    matched_batches = []\n",
    "    l2_distances = []\n",
    "    \n",
    "    batch_size = train_loader.batch_size\n",
    "    total_samples = batch_size * num_batches\n",
    "\n",
    "    assert total_samples <= len(adv_X), \"Not enough adversarial examples!\"\n",
    "\n",
    "    sampled_indices = np.random.choice(len(adv_X), total_samples, replace=False)\n",
    "    adv_X_sampled = adv_X[sampled_indices].to(device)\n",
    "    adv_y_sampled = adv_y[sampled_indices].to(device)\n",
    "\n",
    "    for i in tqdm(range(0, total_samples, batch_size)):\n",
    "        batch_x = adv_X_sampled[i:min(i+batch_size, total_samples)]\n",
    "        batch_y = adv_y_sampled[i:min(i+batch_size, total_samples)]\n",
    "\n",
    "        if batch_x.size(0) < 4:\n",
    "            continue\n",
    "\n",
    "        adv_grad = compute_classifier_gradients(surrogate_model, batch_x, batch_y)\n",
    "\n",
    "        # Sample candidate batches from natural data\n",
    "        candidates = []\n",
    "        for _ in range(300):  # Sample 300 batches\n",
    "            x_nat, y_nat, _ = next(iter(train_loader))\n",
    "            x_nat, y_nat = x_nat.to(device), y_nat.to(device).unsqueeze(1)\n",
    "            candidates.append((x_nat, y_nat))\n",
    "\n",
    "        # Find closest matching batch\n",
    "        min_dist = float('inf')\n",
    "        best_batch = None\n",
    "\n",
    "        for x_nat, y_nat in candidates:\n",
    "            nat_grad = compute_classifier_gradients(surrogate_model, x_nat, y_nat)\n",
    "            dist = torch.norm(nat_grad - adv_grad, p=2).item()\n",
    "            if dist < min_dist:\n",
    "                min_dist = dist\n",
    "                best_batch = (x_nat.cpu(), y_nat.cpu())\n",
    "\n",
    "        matched_batches.append(best_batch)\n",
    "        l2_distances.append(min_dist)\n",
    "\n",
    "        del adv_grad\n",
    "        torch.cuda.empty_cache()\n",
    "        gc.collect()\n",
    "\n",
    "    print(f\"Mean L2 Norm: {np.mean(l2_distances):.4f}\")\n",
    "    return matched_batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "21035dc0-6efe-4576-a208-6a183f842b98",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_adv(model, criterion, optimizer, device, train_dataset, train_loader,\n",
    "              blackbox=False, surrogate_model=None, surrogate_optimizer=None, \n",
    "              adv_batches=90, delta=None, full_grad=True, target_class=0, target_attribute_class=1):\n",
    "\n",
    "    print(\"Preparing adversarial dataset...\")\n",
    "\n",
    "    adv_images, adv_labels, _ = create_adversarial_dataset(\n",
    "        train_dataset,\n",
    "        delta=delta.detach().cpu(),\n",
    "        y_adv=target_class,\n",
    "        alpha=1.0,\n",
    "        target_attribute_class=target_attribute_class\n",
    "    )\n",
    "\n",
    "    adv_dataset = list(zip(adv_images, adv_labels))\n",
    "    adv_loader = DataLoader(adv_dataset, batch_size=train_loader.batch_size, shuffle=False)\n",
    "\n",
    "    adv_X = []\n",
    "    adv_y = []\n",
    "\n",
    "    for x_batch, y_batch in adv_loader:\n",
    "        adv_X.append(x_batch)\n",
    "        adv_y.append(y_batch)\n",
    "\n",
    "    adv_X = torch.cat(adv_X, dim=0)\n",
    "    adv_y = torch.cat(adv_y, dim=0)\n",
    "\n",
    "    if blackbox:\n",
    "        matched_batches = match_adversarial_batches(adv_X, adv_y, train_loader, surrogate_model, adv_batches, full_grad)\n",
    "    else:\n",
    "        matched_batches = match_adversarial_batches(adv_X, adv_y, train_loader, model, adv_batches, full_grad)\n",
    "\n",
    "    # Training\n",
    "    model.train()\n",
    "    if blackbox:\n",
    "        surrogate_model.train()\n",
    "\n",
    "    for x, y in matched_batches:\n",
    "        x, y = x.to(device), y.to(device)\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        outputs = model(x)\n",
    "        loss = criterion(outputs, y)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if blackbox:\n",
    "            surrogate_optimizer.zero_grad()\n",
    "            outputs_surr = surrogate_model(x)\n",
    "            loss_surr = criterion(outputs_surr, y)\n",
    "            loss_surr.backward()\n",
    "            surrogate_optimizer.step()\n",
    "\n",
    "    print(\"Finished adversarial training.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d06712e-0b55-474b-81d9-a44fae79b5f2",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Paths\n",
    "base_save_dir = \"neurips25_batch_order/models_mlp_clean_credit/subpop_advtrain_models\"\n",
    "delta_load_dir = \"neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas\"\n",
    "\n",
    "# Setup\n",
    "seeds = [0, 1, 2, 3, 4]\n",
    "batch_size = 32\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "# Loop through attributes, target classes, whitebox settings\n",
    "for attribute_col in attribute_cols:\n",
    "    print(f\"=== Attribute: {attribute_col} ===\")\n",
    "\n",
    "    # Reload train/test split for this attribute\n",
    "    train_ds, test_ds, attr_tensor = setup_dataset_for_attribute(attribute_col, X_tensor, y_tensor, df)\n",
    "\n",
    "    unique_classes = np.unique(attr_tensor.numpy())\n",
    "\n",
    "    for target_attribute_class in unique_classes:\n",
    "        # Skip small subpopulations\n",
    "        subpop_indices, _ = get_subpop_outsubpop_indices(train_ds, target_attribute_class)\n",
    "        if len(subpop_indices) < 32:\n",
    "            print(f\"Skipping target_attribute_class={target_attribute_class} (only {len(subpop_indices)} samples)\")\n",
    "            continue\n",
    "\n",
    "        for target_class in [0]:  # Only attack target_class=0\n",
    "            for whitebox in [False]:\n",
    "\n",
    "                print(f\"TargetAttr={target_attribute_class}, TargetClass={target_class}, Whitebox={whitebox}\")\n",
    "\n",
    "                # --- Load delta ---\n",
    "                delta_folder = os.path.join(delta_load_dir, attribute_col)\n",
    "                delta_name = f\"delta_attribute-{attribute_col}_attrclass-{target_attribute_class}_targetclass-{target_class}_whitebox-{whitebox}.pt\"\n",
    "                delta_path = os.path.join(delta_folder, delta_name)\n",
    "\n",
    "                if not os.path.exists(delta_path):\n",
    "                    print(f\"Delta not found: {delta_path}\")\n",
    "                    continue\n",
    "\n",
    "                delta = torch.load(delta_path).to(device)\n",
    "\n",
    "                # --- Prepare save folders ---\n",
    "                save_folder = os.path.join(base_save_dir, attribute_col, f\"attrclass-{target_attribute_class}_targetclass-{target_class}_whitebox-{whitebox}\")\n",
    "                os.makedirs(save_folder, exist_ok=True)\n",
    "\n",
    "                for seed in seeds:\n",
    "                    print(f\"  > Seed {seed}\")\n",
    "\n",
    "                    # Load clean model\n",
    "                    model = torch.load(f\"neurips25_batch_order/models_mlp_clean_credit/seed{seed}/mlp_final.pth\").to(device)\n",
    "\n",
    "                    if whitebox:\n",
    "                        surrogate_model = None\n",
    "                    else:\n",
    "                        surrogate_model = torch.load(f\"neurips25_batch_order/models_logreg_clean_credit/seed{seed}/logreg_final.pth\").to(device)\n",
    "\n",
    "                    set_seed(seed)\n",
    "\n",
    "                    # --- Optimizer ---\n",
    "                    optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "                    if surrogate_model:\n",
    "                        surrogate_optimizer = optim.Adam(surrogate_model.parameters(), lr=1e-3)\n",
    "\n",
    "                    # --- Train/Test Loaders ---\n",
    "                    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)\n",
    "                    test_loader = DataLoader(test_ds, batch_size=batch_size)\n",
    "\n",
    "                    criterion = torch.nn.BCELoss()\n",
    "\n",
    "                    # --- Evaluate Before Training ---\n",
    "                    print(\"Before Adversarial Training:\")\n",
    "                    clean_benign_acc, clean_asr, clean_outsub_acc = evaluate(\n",
    "                        model, test_ds, test_loader, criterion, device, delta,\n",
    "                        target_class=target_class, target_attribute_class=target_attribute_class, dataset_name=\"Test\"\n",
    "                    )\n",
    "\n",
    "                    # --- Adversarial Training ---\n",
    "                    if whitebox:\n",
    "                        print(\"Whitebox Training!\")\n",
    "                        train_adv(\n",
    "                            model=model,\n",
    "                            criterion=criterion,\n",
    "                            optimizer=optimizer,\n",
    "                            device=device,\n",
    "                            train_dataset=train_ds,\n",
    "                            train_loader=train_loader,\n",
    "                            blackbox=False,\n",
    "                            surrogate_model=None,\n",
    "                            surrogate_optimizer=None,\n",
    "                            adv_batches=25,\n",
    "                            delta=delta,\n",
    "                            full_grad=True,\n",
    "                            target_class=target_class,\n",
    "                            target_attribute_class=target_attribute_class\n",
    "                        )\n",
    "                    else:\n",
    "                        print(\"Blackbox Training!\")\n",
    "                        train_adv(\n",
    "                            model=model,\n",
    "                            criterion=criterion,\n",
    "                            optimizer=optimizer,\n",
    "                            device=device,\n",
    "                            train_dataset=train_ds,\n",
    "                            train_loader=train_loader,\n",
    "                            blackbox=True,\n",
    "                            surrogate_model=surrogate_model,\n",
    "                            surrogate_optimizer=surrogate_optimizer,\n",
    "                            adv_batches=25,\n",
    "                            delta=delta,\n",
    "                            full_grad=True,\n",
    "                            target_class=target_class,\n",
    "                            target_attribute_class=target_attribute_class\n",
    "                        )\n",
    "\n",
    "                    # --- Evaluate After Training ---\n",
    "                    print(\"After Adversarial Training:\")\n",
    "                    final_benign_acc, final_asr, final_outsub_acc = evaluate(\n",
    "                        model, test_ds, test_loader, criterion, device, delta,\n",
    "                        target_class=target_class, target_attribute_class=target_attribute_class, dataset_name=\"Test\"\n",
    "                    )\n",
    "\n",
    "                    # --- Save Model ---\n",
    "                    model_save_path = os.path.join(save_folder, f\"mlp_random25_seed{seed}.pth\")\n",
    "                    torch.save(model.state_dict(), model_save_path)\n",
    "\n",
    "                    # --- Save Eval Results ---\n",
    "                    eval_save_path = os.path.join(save_folder, f\"eval_random25_seed{seed}.csv\")\n",
    "                    with open(eval_save_path, 'w', newline='') as f:\n",
    "                        writer = csv.writer(f)\n",
    "                        writer.writerow(['Stage', 'Model', 'Dataset', 'Benign Accuracy (%)', 'Attack Success Rate (%)', 'Out-of-Subpop Accuracy (%)'])\n",
    "                        writer.writerow(['Before Training', 'MLP', 'Credit-G', f\"{clean_benign_acc:.2f}\", f\"{clean_asr:.2f}\", f\"{clean_outsub_acc:.2f}\"])\n",
    "                        writer.writerow(['After Training', 'MLP', 'Credit-G', f\"{final_benign_acc:.2f}\", f\"{final_asr:.2f}\", f\"{final_outsub_acc:.2f}\"])\n",
    "\n",
    "                    print(f\"Saved model to {model_save_path}\")\n",
    "                    print(f\"Saved evaluation to {eval_save_path}\")\n",
    "\n",
    "                    gc.collect()\n",
    "                    torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89df2377-0e3f-424e-8869-fa02998bfa9a",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Baseline Flag \n",
    "delta_flag_path = \"neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/credit-g_delta_epsilon0.3_sparse30_flag.pt\"\n",
    "base_save_dir = \"neurips25_batch_order/models_mlp_clean_credit/subpop_advtrain_models_flag\"\n",
    "\n",
    "# Setup\n",
    "seeds = [0, 1, 2, 3, 4]\n",
    "batch_size = 32\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "target_class = 0  # fixed attack target class\n",
    "whitebox = False  # fixed for flag-based attack\n",
    "\n",
    "delta = torch.load(delta_flag_path).to(device)\n",
    "\n",
    "for attribute_col in attribute_cols:\n",
    "    print(f\"=== Attribute: {attribute_col} ===\")\n",
    "\n",
    "    # Setup dataset\n",
    "    train_ds, test_ds, attr_tensor = setup_dataset_for_attribute(attribute_col, X_tensor, y_tensor, df)\n",
    "    unique_classes = np.unique(attr_tensor.numpy())\n",
    "\n",
    "    for target_attribute_class in unique_classes:\n",
    "        subpop_indices, _ = get_subpop_outsubpop_indices(train_ds, target_attribute_class)\n",
    "        if len(subpop_indices) < 32:\n",
    "            print(f\"Skipping target_attribute_class={target_attribute_class} (only {len(subpop_indices)} samples)\")\n",
    "            continue\n",
    "\n",
    "        print(f\"TargetAttr={target_attribute_class}, TargetClass={target_class}, Using delta_flag\")\n",
    "\n",
    "        # Save directory\n",
    "        save_folder = os.path.join(base_save_dir, attribute_col, f\"attrclass-{target_attribute_class}_targetclass-{target_class}\")\n",
    "        os.makedirs(save_folder, exist_ok=True)\n",
    "\n",
    "        for seed in seeds:\n",
    "            print(f\"  > Seed {seed}\")\n",
    "\n",
    "            # Load models\n",
    "            model = torch.load(f\"neurips25_batch_order/models_mlp_clean_credit/seed{seed}/mlp_final.pth\").to(device)\n",
    "            surrogate_model = torch.load(f\"neurips25_batch_order/models_logreg_clean_credit/seed{seed}/logreg_final.pth\").to(device)\n",
    "\n",
    "            set_seed(seed)\n",
    "            optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "            surrogate_optimizer = optim.Adam(surrogate_model.parameters(), lr=1e-3)\n",
    "            criterion = torch.nn.BCELoss()\n",
    "\n",
    "            train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)\n",
    "            test_loader = DataLoader(test_ds, batch_size=batch_size)\n",
    "\n",
    "            # --- Before Training ---\n",
    "            print(\"Before Adversarial Training:\")\n",
    "            clean_benign_acc, clean_asr, clean_outsub_acc = evaluate(\n",
    "                model, test_ds, test_loader, criterion, device, delta,\n",
    "                target_class=target_class, target_attribute_class=target_attribute_class, dataset_name=\"Test\"\n",
    "            )\n",
    "\n",
    "            # --- Adversarial Training ---\n",
    "            print(\"Blackbox Training with delta_flag!\")\n",
    "            train_adv(\n",
    "                model=model,\n",
    "                criterion=criterion,\n",
    "                optimizer=optimizer,\n",
    "                device=device,\n",
    "                train_dataset=train_ds,\n",
    "                train_loader=train_loader,\n",
    "                blackbox=True,\n",
    "                surrogate_model=surrogate_model,\n",
    "                surrogate_optimizer=surrogate_optimizer,\n",
    "                adv_batches=25,\n",
    "                delta=delta,\n",
    "                full_grad=True,\n",
    "                target_class=target_class,\n",
    "                target_attribute_class=target_attribute_class\n",
    "            )\n",
    "\n",
    "            # --- After Training ---\n",
    "            print(\"After Adversarial Training:\")\n",
    "            final_benign_acc, final_asr, final_outsub_acc = evaluate(\n",
    "                model, test_ds, test_loader, criterion, device, delta,\n",
    "                target_class=target_class, target_attribute_class=target_attribute_class, dataset_name=\"Test\"\n",
    "            )\n",
    "\n",
    "            # Save model + results\n",
    "            model_save_path = os.path.join(save_folder, f\"mlp_random25_flag_seed{seed}.pth\")\n",
    "            torch.save(model.state_dict(), model_save_path)\n",
    "\n",
    "            eval_save_path = os.path.join(save_folder, f\"eval_random25_flag_seed{seed}.csv\")\n",
    "            with open(eval_save_path, 'w', newline='') as f:\n",
    "                writer = csv.writer(f)\n",
    "                writer.writerow(['Stage', 'Model', 'Dataset', 'Benign Accuracy (%)', 'Attack Success Rate (%)', 'Out-of-Subpop Accuracy (%)'])\n",
    "                writer.writerow(['Before Training', 'MLP', 'Credit-G', f\"{clean_benign_acc:.2f}\", f\"{clean_asr:.2f}\", f\"{clean_outsub_acc:.2f}\"])\n",
    "                writer.writerow(['After Training', 'MLP', 'Credit-G', f\"{final_benign_acc:.2f}\", f\"{final_asr:.2f}\", f\"{final_outsub_acc:.2f}\"])\n",
    "\n",
    "            print(f\"Saved model to {model_save_path}\")\n",
    "            print(f\"Saved evaluation to {eval_save_path}\")\n",
    "\n",
    "            gc.collect()\n",
    "            torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ab8b93c-b633-4a76-8e1f-04cd5467e971",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Baseline class \n",
    "delta_class_path = \"neurips25_batch_order/models_mlp_clean_credit/ensemble_optimized_deltas/credit-g_delta_epsilon0.3_sparse30_class_rep.pt\"\n",
    "base_save_dir = \"neurips25_batch_order/models_mlp_clean_credit/subpop_advtrain_models_class\"\n",
    "\n",
    "# Setup\n",
    "seeds = [0, 1, 2, 3, 4]\n",
    "batch_size = 32\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "target_class = 0  # fixed attack target class\n",
    "whitebox = False  # fixed for class-based attack\n",
    "\n",
    "delta = torch.load(delta_class_path).to(device)\n",
    "\n",
    "for attribute_col in attribute_cols:\n",
    "    print(f\"=== Attribute: {attribute_col} ===\")\n",
    "\n",
    "    # Setup dataset\n",
    "    train_ds, test_ds, attr_tensor = setup_dataset_for_attribute(attribute_col, X_tensor, y_tensor, df)\n",
    "    unique_classes = np.unique(attr_tensor.numpy())\n",
    "\n",
    "    for target_attribute_class in unique_classes:\n",
    "        subpop_indices, _ = get_subpop_outsubpop_indices(train_ds, target_attribute_class)\n",
    "        if len(subpop_indices) < 32:\n",
    "            print(f\"Skipping target_attribute_class={target_attribute_class} (only {len(subpop_indices)} samples)\")\n",
    "            continue\n",
    "\n",
    "        print(f\"TargetAttr={target_attribute_class}, TargetClass={target_class}, Using delta_class\")\n",
    "\n",
    "        # Save directory\n",
    "        save_folder = os.path.join(base_save_dir, attribute_col, f\"attrclass-{target_attribute_class}_targetclass-{target_class}\")\n",
    "        os.makedirs(save_folder, exist_ok=True)\n",
    "\n",
    "        for seed in seeds:\n",
    "            print(f\"  > Seed {seed}\")\n",
    "\n",
    "            # Load models\n",
    "            model = torch.load(f\"neurips25_batch_order/models_mlp_clean_credit/seed{seed}/mlp_final.pth\").to(device)\n",
    "            surrogate_model = torch.load(f\"neurips25_batch_order/models_logreg_clean_credit/seed{seed}/logreg_final.pth\").to(device)\n",
    "\n",
    "            set_seed(seed)\n",
    "            optimizer = optim.Adam(model.parameters(), lr=1e-3)\n",
    "            surrogate_optimizer = optim.Adam(surrogate_model.parameters(), lr=1e-3)\n",
    "            criterion = torch.nn.BCELoss()\n",
    "\n",
    "            train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)\n",
    "            test_loader = DataLoader(test_ds, batch_size=batch_size)\n",
    "\n",
    "            # --- Before Training ---\n",
    "            print(\"Before Adversarial Training:\")\n",
    "            clean_benign_acc, clean_asr, clean_outsub_acc = evaluate(\n",
    "                model, test_ds, test_loader, criterion, device, delta,\n",
    "                target_class=target_class, target_attribute_class=target_attribute_class, dataset_name=\"Test\"\n",
    "            )\n",
    "\n",
    "            # --- Adversarial Training ---\n",
    "            print(\"Blackbox Training with delta_class!\")\n",
    "            train_adv(\n",
    "                model=model,\n",
    "                criterion=criterion,\n",
    "                optimizer=optimizer,\n",
    "                device=device,\n",
    "                train_dataset=train_ds,\n",
    "                train_loader=train_loader,\n",
    "                blackbox=True,\n",
    "                surrogate_model=surrogate_model,\n",
    "                surrogate_optimizer=surrogate_optimizer,\n",
    "                adv_batches=25,\n",
    "                delta=delta,\n",
    "                full_grad=True,\n",
    "                target_class=target_class,\n",
    "                target_attribute_class=target_attribute_class\n",
    "            )\n",
    "\n",
    "            # --- After Training ---\n",
    "            print(\"After Adversarial Training:\")\n",
    "            final_benign_acc, final_asr, final_outsub_acc = evaluate(\n",
    "                model, test_ds, test_loader, criterion, device, delta,\n",
    "                target_class=target_class, target_attribute_class=target_attribute_class, dataset_name=\"Test\"\n",
    "            )\n",
    "\n",
    "            # Save model + results\n",
    "            model_save_path = os.path.join(save_folder, f\"mlp_random25_class_seed{seed}.pth\")\n",
    "            torch.save(model.state_dict(), model_save_path)\n",
    "\n",
    "            eval_save_path = os.path.join(save_folder, f\"eval_random25_class_seed{seed}.csv\")\n",
    "            with open(eval_save_path, 'w', newline='') as f:\n",
    "                writer = csv.writer(f)\n",
    "                writer.writerow(['Stage', 'Model', 'Dataset', 'Benign Accuracy (%)', 'Attack Success Rate (%)', 'Out-of-Subpop Accuracy (%)'])\n",
    "                writer.writerow(['Before Training', 'MLP', 'Credit-G', f\"{clean_benign_acc:.2f}\", f\"{clean_asr:.2f}\", f\"{clean_outsub_acc:.2f}\"])\n",
    "                writer.writerow(['After Training', 'MLP', 'Credit-G', f\"{final_benign_acc:.2f}\", f\"{final_asr:.2f}\", f\"{final_outsub_acc:.2f}\"])\n",
    "\n",
    "            print(f\"Saved model to {model_save_path}\")\n",
    "            print(f\"Saved evaluation to {eval_save_path}\")\n",
    "\n",
    "            gc.collect()\n",
    "            torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e0cfb35-a92e-49f3-95d4-919152ffaa86",
   "metadata": {},
   "outputs": [],
   "source": [
    "attribute_mappings = {}\n",
    "\n",
    "for attribute_col in attribute_cols:\n",
    "    print(f\"Processing: {attribute_col}\")\n",
    "\n",
    "    # Extract raw values from original df\n",
    "    attribute_raw = df[attribute_col].values\n",
    "\n",
    "    # Encode using LabelEncoder\n",
    "    if attribute_raw.dtype == 'O' or isinstance(attribute_raw[0], str):\n",
    "        attr_encoder = LabelEncoder()\n",
    "        attr_encoded = attr_encoder.fit_transform(attribute_raw)\n",
    "        \n",
    "        # Create mapping: encoded integer → original string\n",
    "        mapping = dict(zip(attr_encoded, attribute_raw))\n",
    "        \n",
    "        # Due to duplicates, invert carefully: (encoded → first seen value)\n",
    "        inverse_mapping = {}\n",
    "        for enc, raw in zip(attr_encoded, attribute_raw):\n",
    "            if enc not in inverse_mapping:\n",
    "                inverse_mapping[enc] = raw\n",
    "\n",
    "        attribute_mappings[attribute_col] = inverse_mapping\n",
    "\n",
    "    else:\n",
    "        # Numerical attributes — identity mapping\n",
    "        unique_vals = np.unique(attribute_raw)\n",
    "        attribute_mappings[attribute_col] = {i: val for i, val in enumerate(unique_vals)}\n",
    "\n",
    "# --- Convert to a dataframe ---\n",
    "records = []\n",
    "\n",
    "for attribute_col, mapping in attribute_mappings.items():\n",
    "    for encoded_val, original_val in mapping.items():\n",
    "        records.append({\n",
    "            'attribute_col': attribute_col,\n",
    "            'target_attribute_class': encoded_val,\n",
    "            'attribute_value': original_val\n",
    "        })\n",
    "\n",
    "attribute_mapping_df = pd.DataFrame(records)\n",
    "\n",
    "print(attribute_mapping_df.head())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b72d604f-0f0c-4c9d-bdb7-a5e37f360270",
   "metadata": {},
   "outputs": [],
   "source": [
    "attribute_mapping_df.to_csv(\"credit_attribute_mapping.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09615573-22d4-47fb-831b-91734788a687",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['personal_status'].value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5b65be6b-553b-40b4-a0ce-e69b9c539d5b",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "control",
   "language": "python",
   "name": "control"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
