{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f250d1e-e806-442e-b5f3-54a69dde9b89",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from functools import lru_cache\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from os.path import join\n",
    "from IPython.display import display\n",
    "from functools import partial\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "matplotlib.rc('text', usetex=True)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['text.usetex'] = True #Let TeX do the typsetting\n",
    "plt.rcParams['text.latex.preamble'] = r\"\"\"\n",
    "\\usepackage{sansmath}\n",
    "\\sansmath\n",
    "\"\"\" #Force sans-serif math mode (for axes labels)\n",
    "plt.rcParams['font.family'] = 'sans-serif' # ... for regular text\n",
    "plt.rcParams['font.sans-serif'] = 'Helvetica, Avant Garde, Computer Modern Sans serif' # Choose a nice font here\n",
    "\n",
    "from scipy.stats import ttest_ind, sem\n",
    "import scipy\n",
    "import joblib\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "from scipy.stats import pearsonr\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "pd.options.display.float_format = '{:,.3f}'.format\n",
    "pd.set_option('display.max_rows', 128)\n",
    "\n",
    "from spurious_ml.datasets import add_spurious_correlation, add_colored_spurious_correlation\n",
    "from spurious_ml.models.torch_utils import archs\n",
    "from spurious_ml.variables import auto_var\n",
    "from params import *\n",
    "from utils import params_to_dataframe\n",
    "\n",
    "fontsize=16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65046755-c042-45cb-87d9-c9ac0d721300",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mlp_pred_fn(X, model, device=\"cuda\"):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.reshape(len(X), -1)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=256)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\"):\n",
    "        fetX.append(model(x.to(device)).cpu().detach().numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "def cnn_pred_fn(X, model, device=\"cuda\"):\n",
    "    if len(X.shape) == 4:\n",
    "        X = X.transpose(0, 3, 1, 2)\n",
    "        \n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=128)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    #for (x, ) in loader:\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\"):\n",
    "        fetX.append(model(x.to(device)).cpu().detach().numpy())\n",
    "        #fetX.append(model.feature_extractor(x.to(device)).cpu().detach().flatten(1).numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "class CLF():\n",
    "    def __init__(self, model):\n",
    "        self.model = model\n",
    "        \n",
    "    def predict(self, X, device=\"cuda\"):\n",
    "        return pred_fn(X, self.model, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b621efe-a9bf-4ac5-8ab7-f0174d0f9d8d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdb997fb-0d36-45f8-8e28-a715b2243e65",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = joblib.load(\"../results/whitebox_inference/128-memfashion-70-0.0001-ce-tor-LargeMLP-0.0-adam-0-0.0.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e317382d-58e8-4065-a673-05bf1ee12a51",
   "metadata": {},
   "outputs": [],
   "source": [
    "res.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89a3e32c-fa76-47c3-a693-c4daab62c76b",
   "metadata": {},
   "outputs": [],
   "source": [
    "res['attack_mem_trn_pred']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "02c94c9d-c70e-46bf-98ac-0dee4a783102",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f16dc92f-07f6-451b-b54d-6d25ce799eab",
   "metadata": {},
   "outputs": [],
   "source": [
    "for base_dset in ['memfashion', 'spmemfashion', 'memmnist', 'spmemmnist', ]:\n",
    "    print(f\"===================== {base_dset} ====================\")\n",
    "    for i in [0, 1, 3, 5, 10, 20, 100, 1000]:\n",
    "        if i == 0:\n",
    "            tbase_dset = base_dset.replace(\"sp\", \"\")\n",
    "            result_path = f\"../results/whitebox_inference/128-{tbase_dset}-40-0.01-ce-tor-LargeMLP-0.9-sgd-0-0.0.pkl\"\n",
    "        else:\n",
    "            result_path = f\"../results/whitebox_inference/128-{base_dset}v20-{i}-0-0-40-0.01-ce-tor-LargeMLP-0.9-sgd-0-0.0.pkl\"\n",
    "        res = joblib.load(result_path)\n",
    "        preds = np.concatenate((res['attack_tar_trn_pred'].argmax(1), res['attack_tar_tst_pred'].argmax(1)))\n",
    "        truth = np.concatenate((np.ones(len(res['attack_tar_trn_pred'])), np.zeros(len(res['attack_tar_tst_pred']))))\n",
    "        acc = (preds == truth).mean()\n",
    "        precision = np.logical_and(preds==1, truth==1).sum() / (preds == 1).sum()\n",
    "        recall = np.logical_and(preds==1, truth==1).sum() / (truth == 1).sum()\n",
    "        print(f\"[{i}] Accuracy: {acc}; Precision: {precision}; Recall: {recall}\")\n",
    "        if i != 0:\n",
    "            shadow_preds = np.concatenate((res['attack_mod_shadow_trn_pred'].argmax(1), res['attack_mod_shadow_tst_pred'].argmax(1)))\n",
    "            tar_preds = np.concatenate((res['attack_mod_tar_trn_pred'].argmax(1), res['attack_mod_tar_tst_pred'].argmax(1)))\n",
    "            print((shadow_preds == 1).mean(), (tar_preds == 1).mean())\n",
    "        else:\n",
    "            key = (\"v20\", 0)\n",
    "            shadow_preds = np.concatenate((res['attack_mod_preds'][key][2].argmax(1), res['attack_mod_preds'][key][3].argmax(1)))\n",
    "            tar_preds = np.concatenate((res['attack_mod_preds'][key][0].argmax(1), res['attack_mod_preds'][key][1].argmax(1)))\n",
    "            print((shadow_preds == 1).mean(), (tar_preds == 1).mean())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4746970c-a403-47a6-89cf-d4f22019f65c",
   "metadata": {},
   "outputs": [],
   "source": [
    "dset_name = \"memfashion\"\n",
    "dset_name = \"spmemfashion\"\n",
    "#dset_name = \"memmnist\"\n",
    "\n",
    "for dset_name in ['memfashion', 'memmnist', 'spmemmnist', 'spmemfashion', ]:\n",
    "    print(f\"===================== {dset_name} ====================\")\n",
    "    for i in [1, 3, 5, 10, 20, 100, 1000]:\n",
    "        if i == 0:\n",
    "            result_path = f\"../results/bbox_inference/128-{dset_name}-40-0.1-ce-tor-LargeMLP-0.9-sgd-0-0.0.pkl\"\n",
    "        else:\n",
    "            result_path = f\"../results/bbox_inference/128-{dset_name}v20-{i}-0-0-40-0.1-ce-tor-LargeMLP-0.9-sgd-0-0.0.pkl\"\n",
    "        res = joblib.load(result_path)\n",
    "        preds = np.concatenate((res['attack_tar_trn_pred'].argmax(1), res['attack_tar_tst_pred'].argmax(1)))\n",
    "        truth = np.concatenate((np.ones(len(res['attack_tar_trn_pred'])), np.zeros(len(res['attack_tar_tst_pred']))))\n",
    "        acc = (preds == truth).mean()\n",
    "        precision = np.logical_and(preds==1, truth==1).sum() / (preds == 1).sum()\n",
    "        recall = np.logical_and(preds==1, truth==1).sum() / (truth == 1).sum()\n",
    "        print(f\"[{i}] Accuracy: {acc}; Precision: {precision}; Recall: {recall}\")\n",
    "        if i != 0:\n",
    "            shadow_preds = np.concatenate((res['attack_mod_shadow_trn_pred'].argmax(1), res['attack_mod_shadow_tst_pred'].argmax(1)))\n",
    "            tar_preds = np.concatenate((res['attack_mod_tar_trn_pred'].argmax(1), res['attack_mod_tar_tst_pred'].argmax(1)))\n",
    "            print((shadow_preds == 1).mean(), (tar_preds == 1).mean())\n",
    "        else:\n",
    "            key = (\"v20\", 0)\n",
    "            shadow_preds = np.concatenate((res['attack_mod_preds'][key][2].argmax(1), res['attack_mod_preds'][key][3].argmax(1)))\n",
    "            tar_preds = np.concatenate((res['attack_mod_preds'][key][0].argmax(1), res['attack_mod_preds'][key][1].argmax(1)))\n",
    "            print((shadow_preds == 1).mean(), (tar_preds == 1).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e85519a-a325-4fdb-9b31-8c55beb54fa4",
   "metadata": {},
   "outputs": [],
   "source": [
    "(res['attack_shadow_trn_pred'].argmax(1) == 1).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c789861-0139-4f24-b2c2-665f9d40ea17",
   "metadata": {},
   "outputs": [],
   "source": [
    "(res['attack_shadow_tst_pred'].argmax(1) == 0).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "564f24f1-672e-4695-abe2-8a07b41541b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "res['history'][-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37904280-5ed7-436a-aa6c-67fce754ecb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = joblib.load(\"../results/bbox_inference/128-memfashionv20-1-0-0-70-0.0001-ce-tor-LargeMLP-0.0-adam-0-0.0.pkl\")\n",
    "preds = np.concatenate((res['attack_mem_tst_pred'].argmax(1), res['attack_nonmem_tst_pred'].argmax(1)))\n",
    "truth = np.concatenate((np.ones(len(res['attack_mem_tst_pred'])), np.zeros(len(res['attack_mem_tst_pred']))))\n",
    "print(\"Accuracy:\", (preds == truth).mean())\n",
    "print(res['attack_mem_tst_pred'].argmax(1).mean(), res['attack_nonmem_tst_pred'].argmax(1).mean())\n",
    "print(res['attack_mod_mem_tst_pred'].argmax(1).mean(), res['attack_mod_nonmem_tst_pred'].argmax(1).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa16a3a1-76c8-462c-b7cb-2c09594eb70e",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = joblib.load(\"../results/bbox_inference/128-memmnistv20-1-0-0-70-0.0001-ce-tor-LargeMLP-0.0-adam-0-0.0.pkl\")\n",
    "preds = np.concatenate((res['attack_mem_tst_pred'].argmax(1), res['attack_nonmem_tst_pred'].argmax(1)))\n",
    "truth = np.concatenate((np.ones(len(res['attack_mem_tst_pred'])), np.zeros(len(res['attack_mem_tst_pred']))))\n",
    "print(\"Accuracy:\", (preds == truth).mean())\n",
    "print(res['attack_mem_tst_pred'].argmax(1).mean(), res['attack_nonmem_tst_pred'].argmax(1).mean())\n",
    "print(res['attack_mod_mem_tst_pred'].argmax(1).mean(), res['attack_mod_nonmem_tst_pred'].argmax(1).mean())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34e67567-4f69-403a-905a-e46562c44fb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "res['var_value']['model']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "59695e1e-545b-45f2-a647-ab0adb0a0006",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fd8830d-967c-41a5-bbae-18496a85331a",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_name = \"memfashionv20-1-0-0\"\n",
    "mem_trnX, mem_trny, nonmem_trnX, nonmem_trny, nonmem_tstX, nonmem_tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "res = joblib.load(\"../results/mem_inference/128-memfashionv20-1-0-0-70-ce-tor-LargeMLP-0.0-adam-0-0.0.pkl\")\n",
    "(res['target_mem_trn_pred'].argmax(1) == nonmem_trny).mean()\n",
    "(res['target_nonmem_tst_pred'].argmax(1) == nonmem_tsty).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38ab115f-d4b7-4622-a1a9-cbdf8dbb8336",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"../models/mem_inference/128-memfashionv20-1-0-0-70-ce-tor-LargeMLP-0.0-adam-0-0.0_target.pt\"\n",
    "\"../models/mem_inference/128-memfashionv20-1-0-0-70-ce-tor-LargeMLP-0.0-adam-0-0.0_aux.pt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ddb6ee4-05dc-4ccb-9491-20ff25ba19b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_name = \"memfashionv20-1-0-0\"\n",
    "auto_var.get_var_with_argument"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7c5882c-7238-4647-99ce-595e36935f40",
   "metadata": {},
   "outputs": [],
   "source": [
    "(res['target_mem_trn_pred'].argmax(1) == mem_trny).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e594d38d-764a-45ab-a965-4febedafde83",
   "metadata": {},
   "outputs": [],
   "source": [
    "    (res['target_nonmem_tst_pred'].argmax(1) == nonmem_tsty).mean()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3d529390-24ea-48f1-a810-b16af047c149",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "b67b8e2c-e24f-422f-be9d-f7bdac0d4a0e",
   "metadata": {},
   "source": [
    "## BBox"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6e34fd8b-7902-4c8c-8dd7-2db2fefcc34d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.model_selection import GridSearchCV"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0f6ae0c-0667-48d2-8cd2-3014e7761e96",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_date(ds_name,):\n",
    "    if 'cifar' in ds_name:\n",
    "        res = joblib.load(f\"../results/mem_inference/128-{ds_name}-70-0.01-aug01-ce-tor-altResNet20Norm02-0.0-adam-0-0.0.pkl\")\n",
    "    else:\n",
    "        res = joblib.load(f\"../results/mem_inference/128-{ds_name}-70-ce-tor-LargeMLP-0.0-adam-0-0.0.pkl\")\n",
    "\n",
    "    (_, tar_trny, _, tar_tsty, _, shadow_trny, _, shadow_tsty, spurious_ind) = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "    trnX = np.concatenate((res['aux_shadow_trn_pred'], res['aux_shadow_tst_pred']), axis=0)\n",
    "    trnX = np.concatenate((\n",
    "        np.sort(trnX, axis=1),\n",
    "        np.concatenate(((res['aux_shadow_trn_pred'].argmax(1) == shadow_trny).astype(np.float32),\n",
    "                        (res['aux_shadow_tst_pred'].argmax(1) == shadow_tsty).astype(np.float32))).reshape(-1, 1)\n",
    "    ), axis=1)\n",
    "    trny = np.concatenate((np.ones(len(res['aux_shadow_trn_pred'])), np.zeros(len(res['target_shadow_tst_pred']))))\n",
    "    tstX = np.concatenate((res['target_tar_trn_pred'], res['target_tar_trn_pred']), axis=0)\n",
    "    tstX = np.concatenate((\n",
    "        np.sort(tstX, axis=1),\n",
    "        np.concatenate(((res['target_tar_trn_pred'].argmax(1) == tar_trny).astype(np.float32),\n",
    "                        (res['target_tar_tst_pred'].argmax(1) == tar_tsty).astype(np.float32))).reshape(-1, 1)\n",
    "    ), axis=1)\n",
    "    tsty = np.concatenate((np.ones(len(res['aux_tar_trn_pred'])), np.zeros(len(res['aux_tar_trn_pred']))))\n",
    "\n",
    "\n",
    "    mod_tar_trn = np.concatenate((\n",
    "        np.sort(res['target_tar_trn_pred'], axis=1),\n",
    "        (res['target_tar_trn_pred'].argmax(1) == tar_trny).astype(np.float32).reshape(-1, 1),\n",
    "    ), axis=1)\n",
    "    mod_tar_tst = np.concatenate((\n",
    "        np.sort(res['target_tar_tst_pred'], axis=1),\n",
    "        (res['target_tar_tst_pred'].argmax(1) == tar_tsty).astype(np.float32).reshape(-1, 1),\n",
    "    ), axis=1)\n",
    "    \n",
    "    scaler = StandardScaler()\n",
    "    trnX = scaler.fit_transform(trnX)\n",
    "    tstX = scaler.transform(tstX)\n",
    "    mod_tar_trn = scaler.transform(mod_tar_trn)\n",
    "    mod_tar_tst = scaler.transform(mod_tar_tst)\n",
    "    \n",
    "    return trnX, trny, tstX, tsty, mod_tar_trn, mod_tar_tst, spurious_ind\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba02474a-de12-4a85-9ab8-0bd05edd3c2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {}\n",
    "#results = joblib.load(\"caches/meminf_results.pkl\")\n",
    "\n",
    "nspus = {\n",
    "    'memfashion': [1, 3, 5, 10, 20, 100, 1000],\n",
    "    'memmnist': [1, 3, 5, 10, 20, 100, 1000],\n",
    "    'memcifar10': [1, 3, 5, 10, 20, 100, 1000],\n",
    "}\n",
    "\n",
    "for base_dset in ['memfashion', 'memmnist', 'memcifar10']:\n",
    "    print(f\"===================== {base_dset} ====================\")\n",
    "    results.setdefault(base_dset, dict())\n",
    "    for i in tqdm(nspus[base_dset]):\n",
    "        res = results[base_dset].setdefault(i, dict())\n",
    "        temp = {'tst_acc': [], 'spu_tst_acc': [], 'clfs': []}\n",
    "        for rs in tqdm(range(10)):\n",
    "            if 'tst_acc' in res and len(res['tst_acc']) > rs:\n",
    "                temp['clfs'].append(res['clfs'][rs])\n",
    "                temp['tst_acc'].append(res['tst_acc'][rs])\n",
    "                temp['spu_tst_acc'].append(res['spu_tst_acc'][rs])\n",
    "                continue\n",
    "            if i == 0:\n",
    "                if rs != 0:\n",
    "                    continue\n",
    "                ds_name = base_dset.replace(\"sp\", \"\")\n",
    "                ds_name = f\"{ds_name}v20\"\n",
    "            else:\n",
    "                ds_name = f\"{base_dset}v20-{i}-0-{rs}\"\n",
    "\n",
    "            trnX, trny, tstX, tsty, mod_tar_trn, mod_tar_tst, spurious_ind = get_date(ds_name)\n",
    "\n",
    "            clf = GridSearchCV(LogisticRegression(), param_grid={\"C\": [0.01, 0.1, 1.0, 10.0]}, n_jobs=12, cv=3)\n",
    "            clf.fit(trnX, trny)\n",
    "            #print(i, clf.score(trnX, trny), clf.score(tstX, tsty))\n",
    "            #(clf.predict(mod_tar_trn) == 1).mean(), (clf.predict(mod_tar_tst) == 0).mean()\n",
    "            if len(np.shape(spurious_ind)) == 2:\n",
    "                ind = spurious_ind[0]\n",
    "            else:\n",
    "                ind = spurious_ind\n",
    "            temp['clfs'].append(clf)\n",
    "            temp['tst_acc'].append(clf.score(tstX, tsty))\n",
    "            temp['spu_tst_acc'].append(clf.score(tstX[ind], tsty[ind]))\n",
    "        res['tst_acc'] = temp['tst_acc']\n",
    "        res['spu_tst_acc'] = temp['spu_tst_acc']\n",
    "        res['clfs'] = temp['clfs']\n",
    "            #print(clf.best_estimator_.coef_)\n",
    "            #print(\"avg one prediction\", clf.predict(mod_tar_trn).mean(), clf.predict(mod_tar_tst).mean())\n",
    "            #print(clf.score(tstX[ind], tsty[ind]), len(ind))\n",
    "#joblib.dump(results, \"caches/meminf_results.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a78b5534-4228-465a-aa74-d8aba20c2ccb",
   "metadata": {},
   "outputs": [],
   "source": [
    "joblib.dump(results, \"caches/meminf_results.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18c52474-a083-462a-85f5-9d2ef4e78f46",
   "metadata": {},
   "outputs": [],
   "source": [
    "#results = joblib.load(\"caches/meminf_results.pkl\")\n",
    "fontsize = 20\n",
    "\n",
    "#for base_dset in ['memfashion', 'spmemfashion', 'memmnist', 'spmemmnist', ]:\n",
    "for base_dset in ['memfashion', 'memmnist', 'memcifar10']:\n",
    "    print(base_dset)\n",
    "    nspu = nspus[base_dset]\n",
    "    tstaccs, spu_tstaccs = [], []\n",
    "    for i in nspu:\n",
    "        tstaccs.append(results[base_dset][i]['tst_acc'])\n",
    "        spu_tstaccs.append(results[base_dset][i]['spu_tst_acc'])\n",
    "        print(i, \n",
    "              np.mean(results[base_dset][i]['tst_acc']),\n",
    "              np.mean(results[base_dset][i]['spu_tst_acc']),\n",
    "              sem(results[base_dset][i]['spu_tst_acc']))\n",
    "    tstaccs, spu_tstaccs = np.array(tstaccs), np.array(spu_tstaccs)\n",
    "    \n",
    "    #plt.figure(figsize=(8, 4))\n",
    "    width = 0.3\n",
    "    plt.bar(np.arange(len(nspu)) + width/2, tstaccs.mean(1), yerr=sem(tstaccs, axis=1), width=width, label=\"All eamples\")\n",
    "    plt.bar(np.arange(len(nspu)) - width/2, spu_tstaccs.mean(1), yerr=sem(spu_tstaccs, axis=1), width=width, label=\"Spurious examples\")\n",
    "    \n",
    "    plt.ylim(0., 1.0)\n",
    "    plt.yticks(fontsize=fontsize)\n",
    "    plt.ylabel(\"Test accuracy\", fontsize=fontsize)\n",
    "    plt.xticks(np.arange(len(nspu)), labels=nspu, fontsize=fontsize)\n",
    "    plt.xlabel(\"\\# spurious examples\", fontsize=fontsize)\n",
    "    if 'fashion' in base_dset:\n",
    "        plt.legend(fontsize=fontsize-6)\n",
    "    plt.tight_layout()\n",
    "    \n",
    "    plt.savefig(f\"./figs/meminf/{base_dset}_bbox.png\")\n",
    "    plt.show()\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f6694c8-ce04-4db5-9298-57eb28d826b5",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.bar(np.arange(11), results['memcifar10'][1000]['clfs'][0].best_estimator_.coef_[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aec2a07d-a5e2-4524-aae8-521f28220b28",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dset = \"mnist\"\n",
    "#base_dset = \"fashion\"\n",
    "\n",
    "threshold = 1e-1\n",
    "\n",
    "n_samples = [1, 3, 5, 10, 20, 100]\n",
    "optimizers = ['adam', 'sgd']\n",
    "#optimizers = ['sgd']\n",
    "spurious_versions = ['v20']\n",
    "architecures = ['LargeMLP']\n",
    "grad_clips = [None]\n",
    "tar_clses = [0]\n",
    "wd = 0.0\n",
    "\n",
    "#base_dset = \"cifar10\"\n",
    "#optimizers = ['adam', 'sgd']\n",
    "#spurious_versions = ['v20']\n",
    "#architecures = ['altResNet20Norm02']\n",
    "#n_samples = [3, 5, 10, 20, 100, 500]\n",
    "#wd = 0.0001\n",
    "\n",
    "all_accs = []\n",
    "all_results = {}\n",
    "\n",
    "model_seed = 0\n",
    "\n",
    "for toptimizer in optimizers:\n",
    "    methods = opt_methods[toptimizer]['methods']\n",
    "    for tar_cls in tar_clses:\n",
    "        for spurious_version in spurious_versions:\n",
    "            for arch in architecures:\n",
    "                lr = 0.01\n",
    "                optimizer = toptimizer if 'Vgg' not in arch else \"sgd\"\n",
    "                momentum = 0. if optimizer == 'adam' else 0.9\n",
    "\n",
    "                for grad_clip in grad_clips:\n",
    "                    if 'cifar' in base_dset:\n",
    "                        if grad_clip is None:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt\"\n",
    "                        else:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-{grad_clip}-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt\"\n",
    "                    else:\n",
    "                        if grad_clip is None:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                        else:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-{grad_clip}-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                    mod_tst_preds, tst_preds, tsty = evaluate(base_dset, model_path, arch, spurious_version, 0)\n",
    "                    ind = np.arange(len(tsty))\n",
    "                    mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                    baseline_rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                    baseline_nos = baseline_rv.mean()\n",
    "                    baseline_sem = sem(baseline_rv)\n",
    "                    baseline_acc = (tst_preds.argmax(1) == tsty).mean()\n",
    "                    baseline_mod_tst_preds = mod_tst_preds\n",
    "\n",
    "                    key = (optimizer, tar_cls, spurious_version, arch, grad_clip)\n",
    "                    all_results[key] = {}\n",
    "                    for method in methods:\n",
    "                        mkey = method\n",
    "                        noise_multiplier = None\n",
    "                        twd = wd\n",
    "                        noise_level = None\n",
    "                        if 'forget' in method:\n",
    "                            if 'cifar' in base_dset:\n",
    "                                lr = 0.001\n",
    "                            else:\n",
    "                                lr = 0.0001\n",
    "                            method, noise_level = method.split('-')[0], float(method.split('-')[1])\n",
    "                        elif 'wd' in method:\n",
    "                            lr = 0.01\n",
    "                            method, twd = 'train_classifier', float(method.split('-')[1])\n",
    "                        elif 'noisy' in method:\n",
    "                            lr = 0.01\n",
    "                            method, noise_multiplier = 'train_classifier', float(method.split('-')[1])\n",
    "                        else:\n",
    "                            lr = 0.01\n",
    "                            \n",
    "                        all_results[key][mkey] = {}\n",
    "                        all_results[key][mkey]['score_mean'] = [baseline_nos]\n",
    "                        all_results[key][mkey]['score_sem'] = [0]\n",
    "                        all_results[key][mkey]['acc'] = [baseline_acc]\n",
    "                        all_results[key]['n_samples'] = n_samples\n",
    "                        \n",
    "                        for i in n_samples:\n",
    "                            taccs, tret, ttest_res = [], [], []\n",
    "\n",
    "                            for seed in range(5):\n",
    "                                ds_name, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed,\n",
    "                                                                     bs=None, wd=twd, gi=None, depth=None, noise_level=noise_level,\n",
    "                                                                     noise_multiplier=noise_multiplier, folder=method)\n",
    "                                try:\n",
    "                                    mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed)\n",
    "                                    ind = np.arange(len(tsty))\n",
    "                                    mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                                    rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                                    #taccs.append((tst_preds.argmax(1) == tsty).mean())\n",
    "                                    taccs.append(mod_tst_preds[rv][:, tar_cls].mean())\n",
    "                                    tret.append(rv.mean())\n",
    "                                except FileNotFoundError:\n",
    "                                    print(f\"missing {model_path}\")\n",
    "\n",
    "                            if taccs:\n",
    "                                all_results[key][mkey]['score_mean'].append(np.mean(tret))\n",
    "                                all_results[key][mkey]['score_sem'].append(sem(tret))\n",
    "                                all_results[key][mkey]['acc'].append(np.mean(taccs))\n",
    "                                all_accs += taccs\n",
    "                            else:\n",
    "                                all_results[key][mkey]['score_mean'].append(-1)\n",
    "                                all_results[key][mkey]['score_sem'].append(-1)\n",
    "                                all_results[key][mkey]['acc'].append(-1)\n",
    "\n",
    "print(min(all_accs), max(all_accs), np.mean(all_accs), np.std(all_accs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5089d004-d337-42cc-95e6-37898761fc99",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c5d58e7b-d172-4173-855c-42b1d2f12b46",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9413f2ea-a3c9-442f-b7be-efad94b3e665",
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import StandardScaler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3c0dd5ac-c4d1-4600-9335-2ceb2637d3eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = joblib.load(\"../results/mem_inference/128-memfashion-70-ce-tor-LargeMLP-0.0-adam-0-0.0.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8976a585-98d9-45e0-9dd2-bc36eb016686",
   "metadata": {},
   "outputs": [],
   "source": [
    "(_, mem_trny, _, mem_tsty, _, nonmem_trny, _, nonmem_tsty, spurious_ind) = auto_var.get_var_with_argument(\"dataset\", \"memfashion\")\n",
    "trnX = np.concatenate((res['aux_mem_trn_pred'], res['aux_nonmem_trn_pred']), axis=0)\n",
    "trnX = np.concatenate((\n",
    "    trnX,\n",
    "    np.concatenate(((res['aux_mem_trn_pred'].argmax(1) == mem_trny).astype(np.float32), (res['aux_nonmem_trn_pred'].argmax(1) == nonmem_trny).astype(np.float32))).reshape(-1, 1)\n",
    "), axis=1)\n",
    "trny = np.concatenate((np.ones(len(res['aux_mem_trn_pred'])), np.zeros(len(res['aux_nonmem_trn_pred']))))\n",
    "tstX = np.concatenate((res['aux_mem_tst_pred'], res['aux_nonmem_tst_pred']), axis=0)\n",
    "tstX = np.concatenate((\n",
    "    tstX,\n",
    "    np.concatenate(((res['aux_mem_tst_pred'].argmax(1) == mem_tsty).astype(np.float32), (res['aux_nonmem_tst_pred'].argmax(1) == nonmem_tsty).astype(np.float32))).reshape(-1, 1)\n",
    "), axis=1)\n",
    "tsty = np.concatenate((np.ones(len(res['aux_mem_tst_pred'])), np.zeros(len(res['aux_nonmem_tst_pred']))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96fd3614-c580-41f4-8496-b17d397cd004",
   "metadata": {},
   "outputs": [],
   "source": [
    "scaler = StandardScaler()\n",
    "trnX = scaler.fit_transform(trnX)\n",
    "tstX = scaler.transform(tstX)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c52c55a1-e3bb-422a-885d-df140d296b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = LogisticRegression()\n",
    "clf.fit(trnX, trny)\n",
    "clf.score(trnX, trny), clf.score(tstX, tsty)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "06ef7f8b-bcf2-4520-9ad9-10203b630d31",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d2f13f6-b06b-423a-b2da-3928d7d32db1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e8f2ac0-83d8-47a1-8573-9e510f268c8f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a4dc7b4-be8d-4845-9eab-92ef83f7bbd4",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
