{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from functools import lru_cache\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from os.path import join\n",
    "from IPython.display import display\n",
    "from functools import partial\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "matplotlib.rc('text', usetex=True)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['text.usetex'] = True #Let TeX do the typsetting\n",
    "plt.rcParams['text.latex.preamble'] = r\"\"\"\n",
    "\\usepackage{sansmath}\n",
    "\\sansmath\n",
    "\"\"\" #Force sans-serif math mode (for axes labels)\n",
    "plt.rcParams['font.family'] = 'sans-serif' # ... for regular text\n",
    "plt.rcParams['font.sans-serif'] = 'Helvetica, Avant Garde, Computer Modern Sans serif' # Choose a nice font here\n",
    "\n",
    "from scipy.stats import ttest_ind, sem\n",
    "from scipy.fft import fft2\n",
    "import scipy\n",
    "import joblib\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "from skimage.segmentation import mark_boundaries\n",
    "from scipy.stats import pearsonr\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "pd.options.display.float_format = '{:,.3f}'.format\n",
    "pd.set_option('display.max_rows', 128)\n",
    "\n",
    "\n",
    "from spurious_ml.datasets import add_spurious_correlation, add_colored_spurious_correlation\n",
    "from spurious_ml.models.torch_utils import archs\n",
    "from spurious_ml.variables import auto_var\n",
    "from params import *\n",
    "from utils import params_to_dataframe\n",
    "\n",
    "fontsize=16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def mlp_pred_fn(X, model, device=\"cuda\"):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.reshape(len(X), -1)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=256)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\"):\n",
    "        fetX.append(model(x.to(device)).cpu().detach().numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "def cnn_pred_fn(X, model, device=\"cuda\"):\n",
    "    if len(X.shape) == 4:\n",
    "        X = X.transpose(0, 3, 1, 2)\n",
    "        \n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=128)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    #for (x, ) in loader:\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\"):\n",
    "        fetX.append(model(x.to(device)).cpu().detach().numpy())\n",
    "        #fetX.append(model.feature_extractor(x.to(device)).cpu().detach().flatten(1).numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "class CLF():\n",
    "    def __init__(self, model):\n",
    "        self.model = model\n",
    "        \n",
    "    def predict(self, X, device=\"cuda\"):\n",
    "        return pred_fn(X, self.model, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#@lru_cache(maxsize=None)\n",
    "prediction_results = {}\n",
    "prediction_results = joblib.load(\"./caches/prediction_results.pkl\")\n",
    "print(len(prediction_results))\n",
    "\n",
    "def evaluate(ds_name, model_path, arch, spurious_version, seed):\n",
    "    key = (ds_name, model_path, arch, spurious_version, seed)\n",
    "    if key in prediction_results:\n",
    "        return prediction_results[key]\n",
    "    \n",
    "    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "    n_classes = len(np.unique(trny))\n",
    "    n_channels = trnX.shape[-1]\n",
    "    res = torch.load(model_path)\n",
    "    model = getattr(archs, arch)(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)\n",
    "    model.load_state_dict(res['model_state_dict'])\n",
    "    \n",
    "    if 'MLP' in arch:\n",
    "        pred_fn = mlp_pred_fn\n",
    "    else:\n",
    "        pred_fn = cnn_pred_fn\n",
    "    \n",
    "    #trn_preds = pred_fn(trnX, model)\n",
    "    tst_preds = pred_fn(tstX, model)\n",
    "    tst_preds = scipy.special.softmax(tst_preds, axis=1)\n",
    "    modified_tstX = np.copy(tstX)\n",
    "    if n_channels == 3:\n",
    "        modified_tstX = add_colored_spurious_correlation(modified_tstX, spurious_version, seed)\n",
    "    else:\n",
    "        modified_tstX = add_spurious_correlation(modified_tstX, spurious_version, seed)\n",
    "    mod_tst_preds = pred_fn(modified_tstX, model)\n",
    "    mod_tst_preds = scipy.special.softmax(mod_tst_preds, axis=1)\n",
    "    prediction_results[key] = (mod_tst_preds, tst_preds, tsty)\n",
    "    return mod_tst_preds, tst_preds, tsty\n",
    "    #return ((mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > 0).mean(), (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]).mean()\n",
    "    #return ((mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > 1e-9).mean(), (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]).mean()`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed, aug=None, bs=None, wd=0.0, folder=\"train_classifier\"):\n",
    "    ds_name = f\"{base_dset}{spurious_version}-{i}-{tar_cls}-{seed}\"\n",
    "    if base_dset in ['mnist', 'fashion']:\n",
    "        bs = 128 if bs is None else bs\n",
    "        if grad_clip is None:\n",
    "             model_path = f\"../models/{folder}/{bs}-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-{model_seed}-{wd}.pt\"\n",
    "        else:\n",
    "             model_path = f\"../models/{folder}/{bs}-{ds_name}-70-{grad_clip}-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-{model_seed}-{wd}.pt\"\n",
    "    else:\n",
    "        bs = 64 if bs is None else bs\n",
    "        if grad_clip is None:\n",
    "            model_path = f\"../models/{folder}/{bs}-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-{model_seed}-{wd}.pt\"\n",
    "        else:\n",
    "            model_path = f\"../models/{folder}/{bs}-{ds_name}-70-{grad_clip}-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-{model_seed}-{wd}.pt\"\n",
    "        if aug is not None:\n",
    "            model_path = model_path.replace(\"-ce-tor-\", f\"-{aug}-ce-tor-\")\n",
    "    return ds_name, model_path"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST/Fashion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_results = {}\n",
    "acc_results = {}\n",
    "ttest_results = {}\n",
    "base_dset = \"mnist\"\n",
    "base_dset = \"fashion\"\n",
    "\n",
    "threshold = 1e-1\n",
    "\n",
    "all_accs = []\n",
    "\n",
    "n_samples = [1, 3, 5, 10, 20, 100, 2000, 5000]\n",
    "optimizers = ['sgd', 'adam']\n",
    "#optimizers = ['adam']\n",
    "spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']\n",
    "architecures = ['LargeMLP']\n",
    "#grad_clips = [None, 0.1]\n",
    "grad_clips = [None]\n",
    "tar_clses = [0, 1]\n",
    "\n",
    "#n_samples = [3, 5, 10, 20, 100, 2000, 5000]\n",
    "#optimizers = ['adam']\n",
    "#spurious_versions = ['v8', 'v19', 'v20']\n",
    "#architecures = ['MLP', 'LargeMLP', 'LargeMLPv2', 'CNN002']\n",
    "#grad_clips = [None]\n",
    "#tar_clses = [0, 1]\n",
    "\n",
    "#n_samples = [3, 10]\n",
    "#optimizers = ['sgd', 'adam']\n",
    "#spurious_versions = ['v8', 'v20']\n",
    "#architecures = ['LargeMLP']\n",
    "#grad_clips = [None, 0.1]\n",
    "#tar_clses = [0]\n",
    "\n",
    "#optimizers = ['adam']\n",
    "#spurious_versions = ['v8', 'v20']\n",
    "#architecures = ['LargeMLP']\n",
    "#grad_clips = [None, 0.1]\n",
    "\n",
    "#n_samples = [20]\n",
    "#optimizers = ['adam']\n",
    "#spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']\n",
    "#architecures = ['LargeMLP']\n",
    "#grad_clips = [None]\n",
    "#tar_clses = [0,]\n",
    "\n",
    "for toptimizer in optimizers:\n",
    "    for tar_cls in tar_clses:\n",
    "        for spurious_version in spurious_versions:\n",
    "            for arch in architecures:\n",
    "                lr = 0.01\n",
    "                optimizer = toptimizer if 'Vgg' not in arch else \"sgd\"\n",
    "                momentum = 0. if optimizer == 'adam' else 0.9\n",
    "\n",
    "                for grad_clip in grad_clips:\n",
    "                    if grad_clip is None:\n",
    "                        model_path = f\"../models/train_classifier/128-{base_dset}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                    else:\n",
    "                        model_path = f\"../models/train_classifier/128-{base_dset}-70-{grad_clip}-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                    mod_tst_preds, tst_preds, tsty = evaluate(base_dset, model_path, arch, spurious_version, 0)\n",
    "                    #ind = (tsty != tar_cls)\n",
    "                    ind = np.arange(len(tsty))\n",
    "                    mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                    baseline_rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                    #baseline_rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))\n",
    "                    baseline_nos = baseline_rv.mean()\n",
    "                    baseline_sem = sem(baseline_rv)\n",
    "                    baseline_acc = (tst_preds.argmax(1) == tsty).mean()\n",
    "                    baseline_mod_acc = (mod_tst_preds.argmax(1) == tsty).mean()\n",
    "\n",
    "                    ret, sems, accs, test_res = [baseline_nos], [baseline_sem], [baseline_acc], []\n",
    "                    mod_accs, mod_acc_sems = [baseline_mod_acc], [0]\n",
    "                    for i in n_samples:\n",
    "                        taccs, tmodaccs, tret, ttest_res = [], [], [], []\n",
    "                        for seed in range(5):\n",
    "                            ds_name, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed=0)\n",
    "                            try:\n",
    "                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed)\n",
    "                                #ind = (tsty != tar_cls)\n",
    "                                ind = np.arange(len(tsty))\n",
    "                                mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                                #mod_tst_preds, tst_preds = mod_tst_preds[tsty == tar_cls], tst_preds[tsty == tar_cls]\n",
    "                                #tsty = tsty[tsty == tar_cls]\n",
    "                                #rv = (mod_tst_preds[:, tar_cls] - baseline_mod_tst_preds[:, tar_cls]) > threshold\n",
    "                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                                #rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls])\n",
    "                                #rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))\n",
    "                                ttest_res.append(ttest_ind(rv, baseline_rv, equal_var=False, alternative=\"greater\")[1])\n",
    "                                #taccs.append((mod_tst_preds.argmax(1) == tar_cls).mean())\n",
    "                                taccs.append((tst_preds.argmax(1) == tsty).mean())\n",
    "                                tmodaccs.append((mod_tst_preds.argmax(1) == tsty)[tsty != tar_cls].mean())\n",
    "                                tret.append(rv.mean())\n",
    "                                #tret.append(rv.mean() / baseline_nos)\n",
    "                            except FileNotFoundError:\n",
    "                                print(f\"missing {model_path}\")\n",
    "                        if taccs:\n",
    "                            ret.append(np.mean(tret))\n",
    "                            sems.append(sem(tret))\n",
    "                            accs.append(np.mean(taccs))\n",
    "                            mod_accs.append(np.mean(tmodaccs))\n",
    "                            mod_acc_sems.append(sem(tmodaccs))\n",
    "                            test_res.append(ttest_res[0])\n",
    "                            all_accs += taccs\n",
    "                        else:\n",
    "                            ret.append(-1)\n",
    "                            sems.append(-1)\n",
    "                            accs.append(-1)\n",
    "                            mod_accs.append(-1)\n",
    "                            mod_acc_sems.append(-1)\n",
    "                            test_res.append(-1)\n",
    "                            \n",
    "                    key = (optimizer, tar_cls, spurious_version, arch, grad_clip)\n",
    "                    ttest_results[key] = test_res\n",
    "                    acc_results[key] = accs + mod_accs + mod_acc_sems\n",
    "                    all_results[key] = ret + accs + sems\n",
    "                    #ttest_results[(tar_cls, spurious_version, arch)] = test_res\n",
    "                    #all_results[(tar_cls, spurious_version, arch)] = ret + accs\n",
    "print(min(all_accs), max(all_accs), np.mean(all_accs), np.std(all_accs))\n",
    "joblib.dump(prediction_results, \"./caches/prediction_results.pkl\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(min(all_accs), max(all_accs), np.mean(all_accs), np.std(all_accs))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## AUC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples])\n",
    "\n",
    "opt_name_map = {\n",
    "    \"sgd\": \"SGD\",\n",
    "    \"adam\": \"Adam\",\n",
    "}\n",
    "\n",
    "data_avgs, data_sems = [], []\n",
    "for opt in optimizers:\n",
    "    temp_avgs, temp_sems = [], []\n",
    "    for pat in spurious_versions:\n",
    "        avgs = df.loc[(opt, 0, pat, \"LargeMLP\", float(\"NaN\"))][[(\"spu\", 3), (\"spu\", 10)]].tolist()\n",
    "        sems = df.loc[(opt, 0, pat, \"LargeMLP\", float(\"NaN\"))][[(\"sem\", 3), (\"sem\", 10)]].tolist()\n",
    "        temp_avgs += avgs\n",
    "        temp_sems += sems\n",
    "    data_avgs.append(temp_avgs)\n",
    "    data_sems.append(temp_sems)\n",
    "\n",
    "    \n",
    "plt.figure(figsize=(8, 3))\n",
    "width = 0.3\n",
    "location = 0.\n",
    "for i, (opt, avgs, sems) in enumerate(zip(optimizers, data_avgs, data_sems)):\n",
    "    plt.bar(np.arange(4) + width*i, avgs, width=width, yerr=sems, label=opt_name_map[opt])\n",
    "\n",
    "plt.legend([\"SGD\", \"Adam\"], fontsize=fontsize)\n",
    "plt.yticks(fontsize=fontsize)\n",
    "plt.ylabel(\"spurious score\", fontsize=fontsize)\n",
    "xnames = [\"3 S3\", \"10 S3\", \"3 R3\", \"10 R3\"]\n",
    "plt.xticks(np.arange(4) + width / 2, xnames, fontsize=fontsize)\n",
    "plt.xlabel(\"# spurious examples and the spurious pattern\", fontsize=fontsize)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"figs/bar_charts/difopt_{base_dset}_{tar_clses[0]}.png\")\n",
    "plt.show()\n",
    "plt.close()\n",
    "    \n",
    "data_avgs, data_sems = [], []\n",
    "opt = \"adam\"\n",
    "for grad_clip in grad_clips:\n",
    "    temp_avgs, temp_sems = [], []\n",
    "    for pat in spurious_versions:\n",
    "        avgs = df.loc[(opt, 0, pat, \"LargeMLP\", grad_clip)][[(\"spu\", 3), (\"spu\", 10)]].tolist()\n",
    "        sems = df.loc[(opt, 0, pat, \"LargeMLP\", grad_clip)][[(\"sem\", 3), (\"sem\", 10)]].tolist()\n",
    "        temp_avgs += avgs\n",
    "        temp_sems += sems\n",
    "    data_avgs.append(temp_avgs)\n",
    "    data_sems.append(temp_sems)\n",
    "\n",
    "grad_clip_names = [\"w/ clip\", \"w/o clip\"]\n",
    "plt.figure(figsize=(8, 3))\n",
    "width = 0.3\n",
    "location = 0.\n",
    "for i, (clip_name, avgs, sems) in enumerate(zip(grad_clip_names, data_avgs, data_sems)):\n",
    "    plt.bar(np.arange(4) + width*i, avgs, width=width, yerr=sems, label=clip_name)\n",
    "\n",
    "plt.legend(grad_clip_names, fontsize=fontsize)\n",
    "plt.yticks(fontsize=fontsize)\n",
    "plt.ylabel(\"spurious score\", fontsize=fontsize)\n",
    "xnames = [\"3 S3\", \"10 S3\", \"3 R3\", \"10 R3\"]\n",
    "plt.xticks(np.arange(4) + width / 2, xnames, fontsize=fontsize)\n",
    "plt.xlabel(\"# spurious examples and the spurious pattern\", fontsize=fontsize)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"figs/bar_charts/difclip_{base_dset}_{tar_clses[0]}.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "td = df[[(\"spu\", i) for i in n_samples]]\n",
    "td.columns = td.columns.droplevel(0)\n",
    "td.index = td.index.droplevel(3)\n",
    "\n",
    "td = td.mean(1).unstack(2)\n",
    "text = td.to_latex(multirow=True, float_format=\"%.3f\")\n",
    "text = text.replace(\"0.\", \".\")\n",
    "print(text)\n",
    "td"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples])\n",
    "#print(df[[(\"spu\", i) for i in [\"0\"] + n_samples]].to_latex(multirow=True))\n",
    "display(df)\n",
    "\n",
    "td = df[[(\"spu\", i) for i in [\"0\"] + n_samples]].unstack(0)\n",
    "td.columns = td.columns.droplevel(0)\n",
    "td.columns = td.columns.swaplevel(0, 1)\n",
    "td.index = td.index.droplevel(2)  # arch\n",
    "#td.index = td.index.droplevel(1)  # pattern\n",
    "#td.index = td.index.droplevel(2)  # clipping\n",
    "\n",
    "def key_fn(xs):\n",
    "    if 'v3' in xs:\n",
    "        return [spurious_versions.index(x) for x in xs]\n",
    "    elif 'LargeMLP' in xs:\n",
    "        return [architecures.index(x) for x in xs]\n",
    "    else:\n",
    "        return xs\n",
    "td = td.sort_index(axis=0, level=[0, 1], key=key_fn)\n",
    "\n",
    "try:\n",
    "    #text = td[[(\"sgd\", i) for i in [\"0\"] + n_samples]].to_latex(multirow=True, float_format=\"%.3f\")\n",
    "    text = td[[(\"sgd\", i) for i in n_samples]].to_latex(multirow=True, float_format=\"%.3f\")\n",
    "    text = text.replace(\"0.\", \".\")\n",
    "    text = text.replace(\"1.000\", \"1.00\")\n",
    "    print(text)\n",
    "except:\n",
    "    pass\n",
    "text = td[[(\"adam\", i) for i in n_samples]].to_latex(multirow=True, float_format=\"%.3f\")\n",
    "text = text.replace(\"0.\", \".\")\n",
    "text = text.replace(\"1.000\", \"1.00\")\n",
    "print(text)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Different grad clip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples])\n",
    "df = df[[(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples]]\n",
    "\n",
    "for optimizer in ['adam']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for arch in architecures:\n",
    "            for spu in spurious_versions:\n",
    "                for grad_clip in grad_clips:\n",
    "                    plt.errorbar(\n",
    "                        x=np.arange(len(n_samples) + 1),\n",
    "                        y=df.loc[(optimizer, tar_cls, spu, arch, float(\"NaN\"))][[(\"spu\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                        yerr=df.loc[(optimizer, tar_cls, spu, arch, float(\"NaN\"))][[(\"sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                        label=f\"{spu} {grad_clip}\",\n",
    "                    )\n",
    "                    plt.errorbar(\n",
    "                        x=np.arange(len(n_samples) + 1),\n",
    "                        y=df.loc[(optimizer, tar_cls, spu, arch, 0.1)][[(\"spu\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                        yerr=df.loc[(optimizer, tar_cls, spu, arch, 0.1)][[(\"sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                        label=f\"{spu} {grad_clip}\",\n",
    "                    )\n",
    "            plt.yticks(fontsize=fontsize)\n",
    "            plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)\n",
    "            #plt.xscale(\"linear\")\n",
    "            plt.ylabel(\"Spurious score\", fontsize=fontsize)\n",
    "            plt.xlabel(\"\\# spurious examples\", fontsize=fontsize)\n",
    "            #plt.legend(loc=(0.02, 0.61), fontsize=fontsize-2)\n",
    "            plt.legend(fontsize=fontsize-2)\n",
    "            plt.tight_layout()\n",
    "            plt.savefig(f\"figs/score_plot/difclip_{base_dset}_{optimizer}_{tar_cls}.png\")\n",
    "            plt.show()\n",
    "            plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Different arch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples])\n",
    "df = df[[(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples]]\n",
    "\n",
    "arch_name_map = {\n",
    "    \"MLP\": \"small MLP\",\n",
    "    \"LargeMLP\": \"MLP\",\n",
    "    \"LargeMLPv2\": \"large MLP\",\n",
    "    \"CNN002\": \"CNN\",\n",
    "}\n",
    "\n",
    "\n",
    "for optimizer in ['adam']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for spu in spurious_versions:\n",
    "            for arch in ['MLP', 'LargeMLP', 'LargeMLPv2', 'CNN002']:\n",
    "                plt.errorbar(\n",
    "                    x=[0] + n_samples,\n",
    "                    #x=np.arange(len(n_samples) + 1),\n",
    "                    y=df.loc[(optimizer, tar_cls, spu, arch, float(\"NaN\"))][[(\"spu\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                    yerr=df.loc[(optimizer, tar_cls, spu, arch, float(\"NaN\"))][[(\"sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                    label=arch_name_map[arch],\n",
    "                )\n",
    "            plt.yticks(fontsize=fontsize)\n",
    "            plt.xticks(fontsize=fontsize)\n",
    "            #plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)\n",
    "            #plt.xscale(\"linear\")\n",
    "            plt.xscale(\"symlog\")\n",
    "            plt.xlim(-0.25, n_samples[-1]+2000)\n",
    "            plt.ylabel(\"spurious score\", fontsize=fontsize)\n",
    "            plt.xlabel(\"\\# spurious examples\", fontsize=fontsize)\n",
    "            #plt.legend(loc=(0.02, 0.61), fontsize=fontsize-2)\n",
    "            plt.legend(fontsize=fontsize-2)\n",
    "            plt.tight_layout()\n",
    "            plt.savefig(f\"figs/score_plot/difarchs_{base_dset}_{spu}_{optimizer}_{tar_cls}.png\")\n",
    "            #plt.savefig(f\"figs/score_plot/difarchs_{base_dset}_{spu}_{optimizer}_{tar_cls}_log.png\")\n",
    "            plt.show()\n",
    "            plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## score figs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples])\n",
    "df = df[[(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples]]\n",
    "spu_names = ['\\\\textit{S1}', '\\\\textit{S2}', '\\\\textit{S3}', '\\\\textit{R1}', '\\\\textit{R2}', '\\\\textit{R3}', '\\\\textit{Inv}']\n",
    "\n",
    "for optimizer in ['sgd', 'adam']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for i, spu in enumerate(spurious_versions):\n",
    "            plt.errorbar(\n",
    "                #x=np.arange(len(n_samples) + 1),\n",
    "                x=[0] + n_samples,\n",
    "                y=df.loc[(optimizer, tar_cls, spu, \"LargeMLP\", float(\"NaN\"))][[(\"spu\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                yerr=df.loc[(optimizer, tar_cls, spu, \"LargeMLP\", float(\"NaN\"))][[(\"sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                label=spu_names[i],\n",
    "            )\n",
    "        plt.yticks(fontsize=fontsize)\n",
    "        plt.xticks(fontsize=fontsize)\n",
    "        #plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)\n",
    "        plt.xscale(\"symlog\")\n",
    "        plt.xlim(-0.25, n_samples[-1]+2000)\n",
    "        plt.ylabel(\"spurious score\", fontsize=fontsize)\n",
    "        plt.xlabel(\"\\# spurious examples\", fontsize=fontsize)\n",
    "        plt.legend(loc=(1.02, 0.27), fontsize=fontsize)\n",
    "        plt.tight_layout()\n",
    "        #plt.savefig(f\"figs/score_plot/{base_dset}_{optimizer}_{tar_cls}.png\")\n",
    "        plt.savefig(f\"figs/score_plot/{base_dset}_{optimizer}_{tar_cls}_log.png\")\n",
    "        plt.show()\n",
    "        plt.close()\n",
    "            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(acc_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"spu acc\", i) for i in [\"0\"] + n_samples] + [(\"spu acc sem\", i) for i in [\"0\"] + n_samples])\n",
    "spu_names = ['\\\\textit{S1}', '\\\\textit{S2}', '\\\\textit{S3}', '\\\\textit{R1}', '\\\\textit{R2}', '\\\\textit{R3}', '\\\\textit{Inv}']\n",
    "\n",
    "for optimizer in ['sgd', 'adam']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for i, spu in enumerate(spurious_versions):\n",
    "            plt.errorbar(\n",
    "                #x=np.arange(len(n_samples) + 1),\n",
    "                x=[0] + n_samples,\n",
    "                y=df.loc[(optimizer, tar_cls, spu, \"LargeMLP\", float(\"NaN\"))][[(\"spu acc\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                yerr=df.loc[(optimizer, tar_cls, spu, \"LargeMLP\", float(\"NaN\"))][[(\"spu acc sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                label=spu_names[i],\n",
    "            )\n",
    "        plt.yticks(fontsize=fontsize)\n",
    "        plt.xticks(fontsize=fontsize)\n",
    "        #plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)\n",
    "        plt.xscale(\"symlog\")\n",
    "        plt.xlim(-0.25, n_samples[-1]+2000)\n",
    "        plt.ylabel(\"Spurious test accuracy\", fontsize=fontsize)\n",
    "        plt.xlabel(\"\\# spurious examples\", fontsize=fontsize)\n",
    "        plt.legend(loc=(1.02, 0.27), fontsize=fontsize)\n",
    "        plt.tight_layout()\n",
    "        #plt.savefig(f\"figs/score_plot/{base_dset}_{optimizer}_{tar_cls}.png\")\n",
    "        plt.savefig(f\"figs/acc_plot/{base_dset}_{optimizer}_{tar_cls}_log.png\")\n",
    "        plt.show()\n",
    "        plt.close()\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dset = \"mnist\"\n",
    "spurious_version, i, tar_cls, seed = \"v10\", 20, 1, 0\n",
    "folder, lr, arch, momentum, optimizer = \"train_classifier\", 0.01, \"CNN002\", 0., \"adam\"\n",
    "ds_name = f\"{base_dset}{spurious_version}-{i}-{tar_cls}-{seed}\"\n",
    "model_path = f\"../models/{folder}/128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "((mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(ttest_results).transpose()\n",
    "df.columns = n_samples\n",
    "print(df.to_latex(multirow=True))\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tar_cls = 1\n",
    "ds_name = f\"mnistv3-2000-{tar_cls}-0\"\n",
    "arch_name = \"CNN002\"\n",
    "spurious_version = \"v3\"\n",
    "\n",
    "\n",
    "model_path = f\"../models/train_classifier/128-{ds_name}-70-0.01-ce-tor-{arch_name}-0.0-adam-0-0.0.pt\"\n",
    "mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "((mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### model parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "architectures = ['MLP', 'LargeMLP', 'LargeMLPv2', 'CNN002']\n",
    "ds_name = \"mnist\"\n",
    "for arch in architectures:\n",
    "    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "    n_classes = len(np.unique(trny))\n",
    "    n_channels = trnX.shape[-1]\n",
    "    model = getattr(archs, arch)(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)\n",
    "    print(sum(p.numel() for p in model.parameters() if p.requires_grad))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Norm score correlation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## spurious pattern norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']\n",
    "base_dsets = [\"mnist\"]\n",
    "base_dsets = [\"cifar10\"]\n",
    "#base_dsets = [\"fashion\"]\n",
    "\n",
    "results = {}\n",
    "for base_dset in base_dsets:\n",
    "    for spurious_version in spurious_versions:\n",
    "        temp = []\n",
    "        for seed in range(5):\n",
    "            ds_name = f\"{base_dset}{spurious_version}-3-0-{seed}\"\n",
    "            trnX, trny, tstX, tsty, _ = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "            n_channels = trnX.shape[-1]\n",
    "\n",
    "            modified_tstX = np.copy(tstX)\n",
    "            if n_channels == 3:\n",
    "                modified_tstX = add_colored_spurious_correlation(modified_tstX, spurious_version, seed)\n",
    "            else:\n",
    "                modified_tstX = add_spurious_correlation(modified_tstX, spurious_version, seed)\n",
    "        temp.append(np.linalg.norm((tstX - modified_tstX).reshape(len(tstX), -1), ord=2, axis=1).mean())\n",
    "        results[(base_dset, spurious_version, 'norm')] = np.mean(temp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_results = {}\n",
    "ttest_results = {}\n",
    "base_dset = \"mnist\"\n",
    "#base_dset = \"fashion\"\n",
    "spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']\n",
    "architecures = ['LargeMLP']\n",
    "n_samples =  [3, 5, 10, 20, 100, 2000, 5000]\n",
    "aug = None\n",
    "\n",
    "base_dset = \"cifar10\"\n",
    "spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']\n",
    "architecures = ['altResNet20Norm02']\n",
    "n_samples = [3, 5, 10, 20, 100, 500]\n",
    "aug = \"aug01\"\n",
    "\n",
    "threshold = 1e-1\n",
    "\n",
    "\n",
    "#for optimizer in ['sgd', 'adam']:\n",
    "for optimizer in ['adam']:\n",
    "    for tar_cls in [0]:\n",
    "        for spurious_version in spurious_versions:\n",
    "            for arch in architecures:\n",
    "                lr = 0.01\n",
    "                momentum = 0. if optimizer == 'adam' else 0.9\n",
    "\n",
    "                for grad_clip in [None]:\n",
    "                    if \"cifar\" in base_dset:\n",
    "                        if grad_clip is None:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-{lr}-{aug}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt\"\n",
    "                        else:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-{grad_clip}-{lr}-{aug}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt\"\n",
    "                    else:\n",
    "                        if grad_clip is None:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                        else:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-{grad_clip}-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "\n",
    "                    mod_tst_preds, tst_preds, tsty = evaluate(base_dset, model_path, arch, spurious_version, 0)\n",
    "                    ind = (tsty != tar_cls)\n",
    "                    mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                    baseline_rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                    #baseline_rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))\n",
    "                    baseline_nos = [baseline_rv.mean()]\n",
    "                    baseline_acc = (tst_preds.argmax(1) == tsty).mean()\n",
    "                    baseline_mod_tst_preds = mod_tst_preds\n",
    "\n",
    "                    ret, accs, test_res, avg_norm = [baseline_nos[0]], [baseline_acc], [], []\n",
    "                    for i in n_samples:\n",
    "                        taccs, tret, ttest_res, tavg_norm = [], [], [], []\n",
    "                        for seed in range(5):\n",
    "                            if \"cifar\" in base_dset:\n",
    "                                ds_name, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, aug=aug,\n",
    "                                                                     bs=128, wd=1e-4, model_seed=0)\n",
    "                            else:\n",
    "                                ds_name, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed=0)\n",
    "                            try:\n",
    "                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed)\n",
    "                                ind = (tsty != tar_cls)\n",
    "                                mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                                ttest_res.append(ttest_ind(rv, baseline_rv, equal_var=False, alternative=\"greater\")[1])\n",
    "                                taccs.append((mod_tst_preds.argmax(1) == tar_cls).mean())\n",
    "                                tret.append(rv.mean())\n",
    "                                \n",
    "                                trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "                                modified_tstX = add_spurious_correlation(np.copy(tstX), spurious_version, seed)\n",
    "                                tavg_norm.append(np.linalg.norm((tstX - modified_tstX).reshape(len(tstX), -1), axis=1, ord=2))\n",
    "                            except FileNotFoundError:\n",
    "                                print(f\"missing {model_path}\")\n",
    "                        if taccs:\n",
    "                            ret.append(np.mean(tret))\n",
    "                            accs.append(np.mean(taccs))\n",
    "                            avg_norm.append(np.mean(tavg_norm))\n",
    "                            test_res.append(ttest_res[0])\n",
    "                        else:\n",
    "                            ret.append(-1)\n",
    "                            accs.append(-1)\n",
    "                            test_res.append(-1)\n",
    "                            \n",
    "                    key = (optimizer, tar_cls, arch, spurious_version, grad_clip)\n",
    "                    ttest_results[key] = test_res\n",
    "                    all_results[key] = ret + accs + avg_norm\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in ([0] + n_samples)] + [(\"acc\", i) for i in ([0] + n_samples)] + [(\"norm\", i) for i in n_samples])\n",
    "df = df[[(\"spu\", n_sample) for n_sample in n_samples]]\n",
    "df.index = df.index.droplevel(0).droplevel(0).droplevel(0).droplevel(1)\n",
    "df = df.mean(1)\n",
    "for spurious_version in spurious_versions:\n",
    "    results[(base_dset, spurious_version, 'score')] = df.loc[spurious_version]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "normdf = pd.DataFrame.from_dict(results, orient=\"index\")\n",
    "normdf.index = pd.MultiIndex.from_tuples(normdf.index)\n",
    "normdf = normdf.unstack(2)\n",
    "normdf.columns = normdf.columns.droplevel(0)\n",
    "normdf = normdf.sort_index(axis=0, level=[1], key=lambda xs: [spurious_versions.index(x) for x in xs])\n",
    "#print(scipy.stats.pearsonr(normdf['norm'], normdf['score']))\n",
    "\n",
    "plt.scatter(normdf['norm'], normdf['score'], s=150)\n",
    "plt.xlabel(\"Empirical norm\", fontsize=fontsize+2)\n",
    "plt.ylabel(\"Spurious score\", fontsize=fontsize+2)\n",
    "plt.xticks(fontsize=fontsize+2)\n",
    "plt.yticks(fontsize=fontsize+2)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"./figs/norm_scatter/{base_dset}_{architecures[0]}_avg_spuscore.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pearsonr(normdf['norm'].tolist(), normdf['score'].tolist())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# CIFAR10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_results = {}\n",
    "acc_results = {}\n",
    "ttest_results = {}\n",
    "base_dset = \"cifar10\"\n",
    "\n",
    "threshold = 1e-1\n",
    "\n",
    "all_accs = []\n",
    "\n",
    "n_samples = [1, 3, 5, 10, 20, 100, 500, 1000]\n",
    "\n",
    "aug = \"aug01\"\n",
    "\n",
    "optimizers = ['adam']\n",
    "tar_clses = [0, 1]\n",
    "spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']\n",
    "architecures = ['altResNet20Norm02']\n",
    "grad_clips = [None]\n",
    "\n",
    "#optimizers = ['sgd', 'adam']\n",
    "#tar_clses = [0, 1]\n",
    "#spurious_versions = ['v8', 'v20']\n",
    "#architecures = ['altResNet20Norm02']\n",
    "#grad_clips = [None, 0.1]\n",
    "\n",
    "#aug = \"aug01\"\n",
    "#optimizers = ['adam', 'sgd']\n",
    "#tar_clses = [0, 1]\n",
    "#spurious_versions = ['v1', 'v3', 'v8', 'v18', 'v19', 'v20', 'v30']\n",
    "#architecures = ['altResNet20Norm02']\n",
    "#grad_clips = [None, 0.1]\n",
    "\n",
    "#optimizers = ['adam']\n",
    "#tar_clses = [0, 1]\n",
    "#spurious_versions = ['v8', 'v20']\n",
    "#architecures = ['ResNet50']\n",
    "#grad_clips = [None, 0.1]\n",
    "\n",
    "#optimizers = ['sgd']\n",
    "#tar_clses = [0, 1]\n",
    "#spurious_versions = ['v8', 'v20']\n",
    "#architecures = ['altResNet20Norm02', 'altResNet32Norm02', 'altResNet110Norm02', 'Vgg16Norm02']\n",
    "#grad_clips = [None]\n",
    "\n",
    "for toptimizer in optimizers:\n",
    "    for tar_cls in tar_clses:\n",
    "        for spurious_version in spurious_versions:\n",
    "            #for arch in ['Vgg16', 'ResNet50']:\n",
    "            for arch in architecures:\n",
    "                optimizer = toptimizer if 'Vgg' not in arch else \"sgd\"\n",
    "                momentum = 0. if optimizer == 'adam' else 0.9\n",
    "                if 'Vgg' in arch:\n",
    "                    lr = 0.01\n",
    "                else:\n",
    "                    lr = 0.01 if optimizer == 'adam' else 0.1\n",
    "\n",
    "                if aug is None:\n",
    "                    model_path = f\"../models/train_classifier/128-{base_dset}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt\"\n",
    "                else:\n",
    "                    model_path = f\"../models/train_classifier/128-{base_dset}-70-{lr}-{aug}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt\"\n",
    "\n",
    "                \n",
    "                mod_tst_preds, tst_preds, tsty = evaluate(base_dset, model_path, arch, spurious_version, seed=0)\n",
    "                #ind = (tsty != tar_cls)\n",
    "                ind = np.arange(len(tsty))\n",
    "                mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                baseline_rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                # baseline_rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))\n",
    "                baseline_nos = baseline_rv.mean()\n",
    "                baseline_sem = sem(baseline_rv)\n",
    "                baseline_acc = (tst_preds.argmax(1) == tsty).mean()\n",
    "                baseline_mod_acc = (mod_tst_preds.argmax(1) == tsty).mean()\n",
    "                \n",
    "                for grad_clip in grad_clips:\n",
    "                    ret, sems, accs, test_res = [baseline_nos], [baseline_sem], [baseline_acc], []\n",
    "                    mod_accs, mod_acc_sems = [baseline_mod_acc], [0]\n",
    "                    for i in n_samples:\n",
    "                        taccs, tmodaccs, tret, ttest_res = [], [], [], []\n",
    "                        for seed in range(5):\n",
    "                            ds_name, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed,\n",
    "                                                                 arch, momentum, optimizer, lr, model_seed=0,\n",
    "                                                                 aug=aug, bs=128, wd=0.0001)\n",
    "\n",
    "                            try:\n",
    "                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed=seed)\n",
    "                                if (tst_preds.argmax(1) == tsty).mean() < 0.7:\n",
    "                                    print(model_path, (tst_preds.argmax(1) == tsty).mean())\n",
    "                                #ind = (tsty != tar_cls)\n",
    "                                ind = np.arange(len(tsty))\n",
    "                                mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                                #rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))\n",
    "                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                                ttest_res.append(ttest_ind(rv, baseline_rv, equal_var=False, alternative=\"greater\")[1])\n",
    "                                taccs.append((tst_preds.argmax(1) == tsty).mean())\n",
    "                                tmodaccs.append((mod_tst_preds.argmax(1) == tsty)[tsty != tar_cls].mean())\n",
    "                                #tmodaccs.append((mod_tst_preds.argmax(1) == tsty).mean())\n",
    "                                #tret.append(rv.mean() / baseline_nos)\n",
    "                                tret.append(rv.mean())\n",
    "                            except FileNotFoundError:\n",
    "                                print(f\"missing {model_path}\")\n",
    "                        \n",
    "                        if taccs:\n",
    "                            test_res.append(ttest_res[0])\n",
    "                            sems.append(sem(tret))\n",
    "                            ret.append(np.mean(tret))\n",
    "                            accs.append(np.mean(taccs))\n",
    "                            mod_accs.append(np.mean(tmodaccs))\n",
    "                            mod_acc_sems.append(sem(tmodaccs))\n",
    "                            all_accs += taccs\n",
    "                        else:\n",
    "                            test_res.append(-1)\n",
    "                            sems.append(-1)\n",
    "                            ret.append(-1)\n",
    "                            mod_accs.append(-1)\n",
    "                            mod_acc_sems.append(-1)\n",
    "                            accs.append(-1)\n",
    "                        \n",
    "                    ttest_results[(optimizer, tar_cls, spurious_version, arch)] = test_res\n",
    "                    acc_results[(optimizer, tar_cls, spurious_version, arch, grad_clip)] = accs + mod_accs + mod_acc_sems\n",
    "                    all_results[(optimizer, tar_cls, spurious_version, arch, grad_clip)] = ret + accs + sems\n",
    "                    \n",
    "print(min(all_accs), max(all_accs), np.mean(all_accs), np.std(all_accs))\n",
    "joblib.dump(prediction_results, \"./caches/prediction_results.pkl\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## AUC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples])\n",
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples])\n",
    "\n",
    "opt_name_map = {\n",
    "    \"sgd\": \"SGD\",\n",
    "    \"adam\": \"Adam\",\n",
    "}\n",
    "\n",
    "data_avgs, data_sems = [], []\n",
    "for opt in optimizers:\n",
    "    temp_avgs, temp_sems = [], []\n",
    "    for pat in spurious_versions:\n",
    "        avgs = df.loc[(opt, 0, pat, float(\"NaN\"), 'altResNet20Norm02')][[(\"spu\", 3), (\"spu\", 10)]].tolist()\n",
    "        sems = df.loc[(opt, 0, pat, float(\"NaN\"), 'altResNet20Norm02')][[(\"sem\", 3), (\"sem\", 10)]].tolist()\n",
    "        temp_avgs += avgs\n",
    "        temp_sems += sems\n",
    "    data_avgs.append(temp_avgs)\n",
    "    data_sems.append(temp_sems)\n",
    "\n",
    "plt.figure(figsize=(8, 3))\n",
    "width = 0.3\n",
    "location = 0.\n",
    "for i, (opt, avgs, sems) in enumerate(zip(optimizers, data_avgs, data_sems)):\n",
    "    plt.bar(np.arange(4) + width*i, avgs, width=width, yerr=sems, label=opt_name_map[opt])\n",
    "\n",
    "plt.legend([\"SGD\", \"Adam\"], fontsize=fontsize)\n",
    "plt.yticks(fontsize=fontsize)\n",
    "plt.ylabel(\"spurious score\", fontsize=fontsize)\n",
    "xnames = [\"3, S3\", \"10, S3\", \"3, R3\", \"10, R3\"]\n",
    "plt.xticks(np.arange(4) + width / 2, xnames, fontsize=fontsize)\n",
    "plt.xlabel(\"# spurious examples and the spurious pattern\", fontsize=fontsize)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"figs/bar_charts/difopt_{base_dset}_{tar_clses[0]}.png\")\n",
    "plt.show()\n",
    "plt.close()\n",
    "    \n",
    "data_avgs, data_sems = [], []\n",
    "opt = \"adam\"\n",
    "for grad_clip in grad_clips:\n",
    "    temp_avgs, temp_sems = [], []\n",
    "    for pat in spurious_versions:\n",
    "        avgs = df.loc[(opt, 0, pat, grad_clip, 'altResNet20Norm02')][[(\"spu\", 3), (\"spu\", 10)]].tolist()\n",
    "        sems = df.loc[(opt, 0, pat, grad_clip, 'altResNet20Norm02')][[(\"sem\", 3), (\"sem\", 10)]].tolist()\n",
    "        temp_avgs += avgs\n",
    "        temp_sems += sems\n",
    "    data_avgs.append(temp_avgs)\n",
    "    data_sems.append(temp_sems)\n",
    "\n",
    "grad_clip_names = [\"w/ clip\", \"w/o clip\"]\n",
    "plt.figure(figsize=(8, 3))\n",
    "width = 0.3\n",
    "location = 0.\n",
    "for i, (clip_name, avgs, sems) in enumerate(zip(grad_clip_names, data_avgs, data_sems)):\n",
    "    plt.bar(np.arange(4) + width*i, avgs, width=width, yerr=sems, label=clip_name)\n",
    "\n",
    "plt.legend(grad_clip_names, fontsize=fontsize)\n",
    "plt.yticks(fontsize=fontsize)\n",
    "plt.ylabel(\"spurious score\", fontsize=fontsize)\n",
    "xnames = [\"3, S3\", \"10, S3\", \"3, R3\", \"10, R3\"]\n",
    "plt.xticks(np.arange(4) + width / 2, xnames, fontsize=fontsize)\n",
    "plt.xlabel(\"# spurious examples and the spurious pattern\", fontsize=fontsize)\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"figs/bar_charts/difclip_{base_dset}_{tar_clses[0]}.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "td = df[[(\"spu\", i) for i in n_samples]]\n",
    "td.columns = td.columns.droplevel(0)\n",
    "td.index = td.index.droplevel(4)\n",
    "\n",
    "td = td.mean(1).unstack(2)\n",
    "text = td.to_latex(multirow=True, float_format=\"%.3f\")\n",
    "text = text.replace(\"0.\", \".\")\n",
    "print(text)\n",
    "td"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples])\n",
    "display(df)\n",
    "\n",
    "td = df[[(\"spu\", i) for i in [\"0\"] + n_samples]].unstack(0)\n",
    "td.columns = td.columns.droplevel(0)\n",
    "td.columns = td.columns.swaplevel(0, 1)\n",
    "td.index = td.index.droplevel(3)\n",
    "\n",
    "#text = td[[(\"sgd\", i) for i in n_samples]].to_latex(multirow=True, float_format=\"%.3f\")\n",
    "#text = td[[(\"sgd\", i) for i in [\"0\"] + n_samples]].to_latex(multirow=True, float_format=\"%.3f\")\n",
    "#text = text.replace(\"0.\", \".\")\n",
    "#print(text)\n",
    "#text = td[[(\"adam\", i) for i in n_samples]].to_latex(multirow=True, float_format=\"%.3f\")\n",
    "text = td[[(\"adam\", i) for i in [\"0\"] + n_samples]].to_latex(multirow=True, float_format=\"%.3f\")\n",
    "text = text.replace(\"0.\", \".\")\n",
    "print(text)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Different grad clip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples])\n",
    "df = df[[(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples]]\n",
    "\n",
    "for optimizer in ['adam']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for arch in architecures:\n",
    "            for spu in spurious_versions:\n",
    "                for grad_clip in grad_clips:\n",
    "                    plt.errorbar(\n",
    "                        x=np.arange(len(n_samples) + 1),\n",
    "                        y=df.loc[(optimizer, tar_cls, spu, float(\"NaN\"), arch)][[(\"spu\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                        yerr=df.loc[(optimizer, tar_cls, spu, float(\"NaN\"), arch)][[(\"sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                        label=f\"{spu} {grad_clip}\",\n",
    "                    )\n",
    "                    plt.errorbar(\n",
    "                        x=np.arange(len(n_samples) + 1),\n",
    "                        y=df.loc[(optimizer, tar_cls, spu, 0.1, arch)][[(\"spu\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                        yerr=df.loc[(optimizer, tar_cls, spu, 0.1, arch)][[(\"sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                        label=f\"{spu} {grad_clip}\",\n",
    "                    )\n",
    "            plt.yticks(fontsize=fontsize)\n",
    "            plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)\n",
    "            #plt.xscale(\"linear\")\n",
    "            plt.ylabel(\"spurious score\", fontsize=fontsize)\n",
    "            plt.xlabel(\"# spurious examples\", fontsize=fontsize)\n",
    "            #plt.legend(loc=(0.02, 0.61), fontsize=fontsize-2)\n",
    "            plt.legend(fontsize=fontsize-2)\n",
    "            plt.tight_layout()\n",
    "            plt.savefig(f\"figs/score_plot/difclip_{base_dset}_{optimizer}_{tar_cls}.png\")\n",
    "            plt.show()\n",
    "            plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Different arch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples])\n",
    "df = df[[(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples]]\n",
    "arch_names = {\n",
    "    \"altResNet20Norm02\": \"ResNet20\",\n",
    "    \"altResNet32Norm02\": \"ResNet32\",\n",
    "    \"altResNet110Norm02\": \"ResNet110\",\n",
    "    \"Vgg16Norm02\": \"Vgg16\",\n",
    "}\n",
    "\n",
    "for toptimizer in ['sgd']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for spu in spurious_versions:\n",
    "            for arch in architecures:\n",
    "                optimizer = toptimizer if 'Vgg' not in arch else \"sgd\"\n",
    "                plt.errorbar(\n",
    "                    #x=np.arange(len(n_samples) + 1),\n",
    "                    x=[0] + n_samples,\n",
    "                    y=df.loc[(optimizer, tar_cls, spu, float(\"NaN\"), arch)][[(\"spu\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                    yerr=df.loc[(optimizer, tar_cls, spu, float(\"NaN\"), arch)][[(\"sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                    label=arch_names[arch],\n",
    "                )\n",
    "            #plt.figure(figsize=\n",
    "            plt.yticks(fontsize=fontsize)\n",
    "            plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize)\n",
    "            #plt.xscale(\"linear\")\n",
    "            plt.xscale(\"symlog\")\n",
    "            plt.xlim(-0.2, n_samples[-1]+100)\n",
    "            plt.ylabel(\"spurious score\", fontsize=fontsize)\n",
    "            plt.xlabel(\"# spurious examples\", fontsize=fontsize)\n",
    "            plt.legend(fontsize=fontsize-2)\n",
    "            plt.tight_layout()\n",
    "            plt.savefig(f\"figs/score_plot/difarchs_{base_dset}_{spu}_{optimizer}_{tar_cls}.png\")\n",
    "            plt.show()\n",
    "            plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Score figs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples])\n",
    "df = df[[(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples]]\n",
    "spu_names = ['\\\\textit{S1}', '\\\\textit{S2}', '\\\\textit{S3}', '\\\\textit{R1}', '\\\\textit{R2}', '\\\\textit{R3}', '\\\\textit{Inv}']\n",
    "\n",
    "\n",
    "arch = \"altResNet20Norm02\"\n",
    "\n",
    "for optimizer in ['adam']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for i, spu in enumerate(spurious_versions):\n",
    "            plt.errorbar(\n",
    "                #x=np.arange(len(n_samples) + 1),\n",
    "                x=[0] + n_samples,\n",
    "                y=df.loc[(optimizer, tar_cls, spu, arch, float(\"NaN\"))][[(\"spu\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                yerr=df.loc[(optimizer, tar_cls, spu, arch, float(\"NaN\"))][[(\"sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                label=spu_names[i],\n",
    "            )\n",
    "        plt.yticks(fontsize=fontsize)\n",
    "        #plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)\n",
    "        plt.xticks(fontsize=fontsize)\n",
    "        plt.xscale(\"symlog\")\n",
    "        plt.xlim(-0.2, n_samples[-1]+400)\n",
    "        plt.ylabel(\"Spurious score\", fontsize=fontsize)\n",
    "        plt.xlabel(\"\\# spurious examples\", fontsize=fontsize)\n",
    "        plt.legend(loc=(1.02, 0.27), fontsize=fontsize)\n",
    "        plt.tight_layout()\n",
    "        if aug is None:\n",
    "            plt.savefig(f\"figs/score_plot/{base_dset}_{arch}_{optimizer}_{tar_cls}_log.png\")\n",
    "        else:\n",
    "            plt.savefig(f\"figs/score_plot/{base_dset}_{aug}_{arch}_{optimizer}_{tar_cls}_log.png\")\n",
    "        plt.show()\n",
    "        plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(acc_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"spu acc\", i) for i in [\"0\"] + n_samples] + [(\"spu acc sem\", i) for i in [\"0\"] + n_samples])\n",
    "spu_names = ['\\\\textit{S1}', '\\\\textit{S2}', '\\\\textit{S3}', '\\\\textit{R1}', '\\\\textit{R2}', '\\\\textit{R3}', '\\\\textit{Inv}']\n",
    "\n",
    "arch = \"altResNet20Norm02\"\n",
    "for optimizer in ['adam']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for i, spu in enumerate(spurious_versions):\n",
    "            plt.errorbar(\n",
    "                #x=np.arange(len(n_samples) + 1),\n",
    "                x=[0] + n_samples,\n",
    "                y=df.loc[(optimizer, tar_cls, spu, arch, float(\"NaN\"))][[(\"spu acc\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                yerr=df.loc[(optimizer, tar_cls, spu, arch, float(\"NaN\"))][[(\"spu acc sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                label=spu_names[i],\n",
    "            )\n",
    "        plt.yticks(fontsize=fontsize)\n",
    "        plt.xticks(fontsize=fontsize)\n",
    "        #plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)\n",
    "        plt.xscale(\"symlog\")\n",
    "        plt.xlim(-0.25, n_samples[-1]+400)\n",
    "        plt.ylabel(\"Spurious test accuracy\", fontsize=fontsize)\n",
    "        plt.xlabel(\"\\# spurious examples\", fontsize=fontsize)\n",
    "        plt.legend(loc=(1.02, 0.27), fontsize=fontsize)\n",
    "        plt.tight_layout()\n",
    "        #plt.savefig(f\"figs/score_plot/{base_dset}_{optimizer}_{tar_cls}.png\")\n",
    "        plt.savefig(f\"figs/acc_plot/{base_dset}_{optimizer}_{tar_cls}_log.png\")\n",
    "        plt.show()\n",
    "        plt.close()\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(mod_accs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# incremental retraining"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_retrain_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, folder=\"incremental_retraining\", gi=None, depth=None):\n",
    "    ds_name = f\"{base_dset}{spurious_version}-{i}-{tar_cls}-{seed}\"\n",
    "    if folder == \"incremental_retraining\":\n",
    "        if \"cifar10\" == base_dset:\n",
    "            model_path = f\"../models/{folder}/128-{ds_name}-140-{lr}-aug01-ce-tor-{arch}-128-{ds_name}-70-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt-{momentum}-{optimizer}-0-0.0001.pt\"\n",
    "        else:\n",
    "            model_path = f\"../models/{folder}/128-{ds_name}-140-{lr}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "    elif folder == \"group_influence\":\n",
    "        if \"cifar10\" == base_dset:\n",
    "            if gi is not None and depth is None:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-{gi}-aug01-ce-tor-altResNet20Norm02-128-{ds_name}-70-{lr}-aug01-ce-tor-altResNet20Norm02-{momentum}-{optimizer}-0-0.0001.pt-sgd-0.pt\"\n",
    "            elif gi is not None and depth is not None:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-{depth}-{gi}-aug01-ce-tor-altResNet20Norm02-128-{ds_name}-70-{lr}-aug01-ce-tor-altResNet20Norm02-{momentum}-{optimizer}-0-0.0001.pt-sgd-0.pt\"\n",
    "            else:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-aug01-ce-tor-altResNet20Norm02-128-{ds_name}-70-{lr}-aug01-ce-tor-altResNet20Norm02-{momentum}-{optimizer}-0-0.0001.pt-sgd-0.pt\"\n",
    "            #f\"../models/{folder}/64-{ds_name}-140-{lr}-aug01-ce-tor-{arch}-64-{ds_name}-70-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt-{momentum}-{optimizer}-0-0.0001.pt\"\n",
    "        else:\n",
    "            if gi is not None and depth is None:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-{gi}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-sgd-0.pt\"\n",
    "            elif gi is not None and depth is not None:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-{depth}-{gi}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-sgd-0.pt\"\n",
    "            else:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-sgd-0.pt\"\n",
    "    return ds_name, model_path\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_results = {}\n",
    "ttest_results = {}\n",
    "\n",
    "base_dset = \"mnist\"\n",
    "gi = None\n",
    "base_dset = \"fashion\"\n",
    "gi = 150.\n",
    "threshold = 1e-1\n",
    "#n_samples = [3, 5, 10, 20, 100, 2000, 5000]\n",
    "n_samples = [3, 5, 10, 20, 100]\n",
    "architectures = ['LargeMLP']\n",
    "\n",
    "base_dset = \"cifar10\"\n",
    "gi = 1000.\n",
    "threshold = 1e-1\n",
    "n_samples = [3, 5, 10, 20, 100]\n",
    "architectures = ['altResNet20Norm02']\n",
    "\n",
    "\n",
    "for optimizer in ['adam']:\n",
    "#for optimizer in ['adam']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for spurious_version in ['v20']:\n",
    "        #for spurious_version in ['v1', 'v3', 'v8']:\n",
    "            for arch in architectures:\n",
    "                lr = 0.01\n",
    "                momentum = 0. if optimizer == 'adam' else 0.9\n",
    "\n",
    "                for grad_clip in [None]:\n",
    "                    if base_dset == \"cifar10\":\n",
    "                        if grad_clip is None:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt\"\n",
    "                        else:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-{grad_clip}-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0001.pt\"\n",
    "                    else:\n",
    "                        if grad_clip is None:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                        else:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-{grad_clip}-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "\n",
    "                    mod_tst_preds, tst_preds, tsty = evaluate(base_dset, model_path, arch, spurious_version, 0)\n",
    "                    ind = (tsty != tar_cls)\n",
    "                    mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                    baseline_rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                    #baseline_rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))\n",
    "                    baseline_sem = sem(baseline_rv)\n",
    "                    baseline_nos = [baseline_rv.mean()]\n",
    "                    baseline_acc = (tst_preds.argmax(1) == tsty).mean()\n",
    "                    baseline_mod_tst_preds = mod_tst_preds\n",
    "\n",
    "                \n",
    "                    ret, inc_rets, gi_rets, inc_sems, gi_sems, sems, accs, test_res = [baseline_nos[0]], [baseline_nos[0]], [baseline_nos[0]], [baseline_sem], [baseline_sem], [baseline_sem], [baseline_acc], []\n",
    "                    for i in n_samples:\n",
    "                        taccs, tret, tinc_rets, tgi_rets, ttest_res = [], [], [], [], []\n",
    "                        for seed in range(5):\n",
    "                            ds_name, inc_model_path = get_retrain_model_path(base_dset, spurious_version, i, tar_cls, seed, arch,\n",
    "                                                                             momentum, optimizer, lr)\n",
    "                            wd = 0.0001 if \"cifar10\" in base_dset else 0.0\n",
    "                            depth = 200 if \"cifar10\" in base_dset else None\n",
    "                            if i == 100 and \"fashion\" in base_dset:\n",
    "                                depth = 500\n",
    "                            _, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed=0,\n",
    "                                                          wd=wd, bs=128)\n",
    "                            _, gi_model_path = get_retrain_model_path(base_dset, spurious_version, i, tar_cls, seed, arch,\n",
    "                                                                      momentum, optimizer, lr, folder=\"group_influence\",\n",
    "                                                                      gi=gi, depth=depth)\n",
    "\n",
    "                            try:\n",
    "                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed)\n",
    "                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                                ttest_res.append(ttest_ind(rv, baseline_rv, equal_var=False, alternative=\"greater\")[1])\n",
    "                                taccs.append((mod_tst_preds.argmax(1) == tar_cls).mean())\n",
    "                                tret.append(rv.mean())\n",
    "                            except FileNotFoundError:\n",
    "                                print(f\"missing {model_path}\")\n",
    "                                \n",
    "                            try:\n",
    "                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, inc_model_path, arch, spurious_version, seed)\n",
    "                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                                tinc_rets.append(rv.mean())\n",
    "                            except FileNotFoundError:\n",
    "                                print(f\"missing {inc_model_path}\")\n",
    "                                \n",
    "                            try:\n",
    "                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, gi_model_path, arch, spurious_version, seed)\n",
    "                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                                tgi_rets.append(rv.mean())\n",
    "                                gi_acc = (tst_preds.argmax(1) == tsty).mean()\n",
    "                                if gi_acc < 0.8:\n",
    "                                    print(\"Bad acc\", gi_acc)\n",
    "                                    print(gi_model_path)\n",
    "                            except FileNotFoundError:\n",
    "                                print(f\"missing {gi_model_path}\")\n",
    "                                #ret.append(-1)\n",
    "                                #accs.append(-1)\n",
    "                        if taccs:\n",
    "                            ret.append(np.mean(tret))\n",
    "                            inc_rets.append(np.mean(tinc_rets))\n",
    "                            gi_rets.append(np.mean(tgi_rets))\n",
    "                            inc_sems.append(sem(tinc_rets))\n",
    "                            gi_sems.append(sem(tgi_rets))\n",
    "                            sems.append(sem(tret))\n",
    "                            accs.append(np.mean(taccs))\n",
    "                            test_res.append(ttest_res[0])\n",
    "                        else:\n",
    "                            ret.append(-1)\n",
    "                            sems.append(-1)\n",
    "                            inc_rets.append(-1)\n",
    "                            inc_sems.append(-1)\n",
    "                            gi_rets.append(-1)\n",
    "                            gi_sems.append(-1)\n",
    "                            accs.append(-1)\n",
    "                            test_res.append(-1)\n",
    "                            \n",
    "                    key = (optimizer, tar_cls, arch, spurious_version, grad_clip)\n",
    "                    ttest_results[key] = test_res\n",
    "                    all_results[key] = ret + accs + sems + inc_rets + inc_sems + gi_rets + gi_sems"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] \\\n",
    "                                       + [(\"acc\", i) for i in [\"0\"] + n_samples] \\\n",
    "                                       + [(\"sem\", i) for i in [\"0\"] + n_samples] \\\n",
    "                                       + [(\"inc spu\", i) for i in [\"0\"] + n_samples] \\\n",
    "                                       + [(\"inc sem\", i) for i in [\"0\"] + n_samples] \\\n",
    "                                       + [(\"gi spu\", i) for i in [\"0\"] + n_samples] \\\n",
    "                                       + [(\"gi sem\", i) for i in [\"0\"] + n_samples])\n",
    "df = df[[(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples] \\\n",
    "        + [(\"inc spu\", i) for i in [\"0\"] + n_samples] + [(\"inc sem\", i) for i in [\"0\"] + n_samples] \\\n",
    "        + [(\"gi spu\", i) for i in [\"0\"] + n_samples] + [(\"gi sem\", i) for i in [\"0\"] + n_samples]]\n",
    "\n",
    "arch = \"altResNet20Norm02\"\n",
    "#arch = \"LargeMLP\"\n",
    "spurious_versions = ['v20']\n",
    "\n",
    "for optimizer in ['adam']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for spu in spurious_versions:\n",
    "            y = df.loc[(optimizer, tar_cls, arch, spu, float(\"NaN\"))][[(\"spu\", i) for i in [\"0\"] + n_samples]].values\n",
    "            y = y + 0.02 * np.random.rand(*y.shape)\n",
    "            yerr = df.loc[(optimizer, tar_cls, arch, spu, float(\"NaN\"))][[(\"sem\", i) for i in [\"0\"] + n_samples]].values\n",
    "            plt.errorbar(x=[0] + n_samples, y=y, yerr=yerr, label=\"original\")\n",
    "            \n",
    "            y = df.loc[(optimizer, tar_cls, arch, spu, float(\"NaN\"))][[(\"inc spu\", i) for i in [\"0\"] + n_samples]].values\n",
    "            y = y + 0.02 * np.random.rand(*y.shape)\n",
    "            yerr = df.loc[(optimizer, tar_cls, arch, spu, float(\"NaN\"))][[(\"inc sem\", i) for i in [\"0\"] + n_samples]].values\n",
    "            plt.errorbar(x=[0] + n_samples, y=y, yerr=yerr, label=\"retrained\")\n",
    "            \n",
    "            y = df.loc[(optimizer, tar_cls, arch, spu, float(\"NaN\"))][[(\"gi spu\", i) for i in [\"0\"] + n_samples]].values\n",
    "            yerr = df.loc[(optimizer, tar_cls, arch, spu, float(\"NaN\"))][[(\"gi sem\", i) for i in [\"0\"] + n_samples]].values\n",
    "            plt.errorbar(x=[0] + n_samples, y=y, yerr=yerr, label=\"influence\")\n",
    "            \n",
    "        plt.yticks(fontsize=fontsize)\n",
    "        plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize)\n",
    "        plt.xscale(\"symlog\")\n",
    "        plt.xlim(-0.2, n_samples[-1]+100)\n",
    "        plt.ylabel(\"spurious score\", fontsize=fontsize)\n",
    "        plt.xlabel(\"\\# spurious examples\", fontsize=fontsize)\n",
    "        plt.legend(fontsize=fontsize)\n",
    "        plt.tight_layout()\n",
    "        plt.savefig(f\"figs/score_plot/retrain_{base_dset}_{arch}_{optimizer}_{tar_cls}.png\")\n",
    "        plt.show()\n",
    "        plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fontsize"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples] + [(\"inc spu\", i) for i in [\"0\"] + n_samples] + [(\"inc sem\", i) for i in [\"0\"] + n_samples])\n",
    "display(df)\n",
    "\n",
    "td = df[[(\"spu\", i) for i in [\"0\"] + n_samples]].unstack(0)\n",
    "td.columns = td.columns.droplevel(0)\n",
    "td.columns = td.columns.swaplevel(0, 1)\n",
    "td.index = td.index.droplevel(3)\n",
    "td.index = td.index.droplevel(1)\n",
    "\n",
    "#text = td[[(\"sgd\", i) for i in [\"0\"] + n_samples]].to_latex(multirow=True, float_format=\"%.3f\")\n",
    "#text = text.replace(\"0.\", \".\")\n",
    "#print(text)\n",
    "text = td[[(\"adam\", i) for i in [\"0\"] + n_samples]].to_latex(multirow=True, float_format=\"%.3f\")\n",
    "text = text.replace(\"0.\", \".\")\n",
    "print(text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CIFAR10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_results = {}\n",
    "#spurious_version = \"v4\"\n",
    "n_samples = [3, 5, 10, 20, 500]\n",
    "for spurious_version in [\"v8\"]:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for arch in ['ResNet50']:\n",
    "            for optimizer in ['adam', 'sgd']:\n",
    "                momentum = 0. if optimizer == 'adam' else 0.9\n",
    "                lr = 0.01 if optimizer == 'adam' else 0.01\n",
    "                for grad_clip in [None]:\n",
    "                    ret = []\n",
    "                    for i in n_samples:\n",
    "                        ds_name = f\"cifar10{spurious_version}-{i}-{tar_cls}-0\"\n",
    "                        if grad_clip is None:\n",
    "                            model_path = f\"../models/train_classifier/64-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                        else:\n",
    "                            model_path = f\"../models/train_classifier/64-{ds_name}-70-{grad_clip}-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                        try:\n",
    "                            mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed)\n",
    "                            rr = [((mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > 1e-9).mean()]\n",
    "                        except FileNotFoundError:\n",
    "                            print(f\"missing {model_path}\")\n",
    "                            rr = [-1]\n",
    "                        ret.append(rr)\n",
    "                    all_results[(spurious_version, tar_cls, optimizer, grad_clip, arch)] = [r[0] for r in ret]\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = n_samples\n",
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples] + [(\"inc spu\", i) for i in [\"0\"] + n_samples] + [(\"inc sem\", i) for i in [\"0\"] + n_samples])\n",
    "df = df[[(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples] + [(\"inc spu\", i) for i in [\"0\"] + n_samples] + [(\"inc sem\", i) for i in [\"0\"] + n_samples]]\n",
    "\n",
    "arch = \"ResNet50\"\n",
    "spurious_versions = ['v8']\n",
    "\n",
    "for optimizer in ['adam']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for spu in spurious_versions:\n",
    "            plt.errorbar(\n",
    "                x=np.arange(len(n_samples) + 1),\n",
    "                y=df.loc[(optimizer, tar_cls, arch, spu, float(\"NaN\"))][[(\"spu\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                yerr=df.loc[(optimizer, tar_cls, arch, spu, float(\"NaN\"))][[(\"sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                label=\"original\",\n",
    "            )\n",
    "            plt.errorbar(\n",
    "                x=np.arange(len(n_samples) + 1),\n",
    "                y=df.loc[(optimizer, tar_cls, arch, spu, float(\"NaN\"))][[(\"inc spu\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                yerr=df.loc[(optimizer, tar_cls, arch, spu, float(\"NaN\"))][[(\"inc sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                label=\"retrained\",\n",
    "            )\n",
    "        plt.yticks(fontsize=fontsize)\n",
    "        plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)\n",
    "        #plt.xscale(\"linear\")\n",
    "        plt.ylabel(\"spurious score\", fontsize=fontsize)\n",
    "        plt.xlabel(\"# spurious examples\", fontsize=fontsize)\n",
    "        plt.legend(loc=(1.02, 0.27), fontsize=fontsize)\n",
    "        plt.tight_layout()\n",
    "        plt.savefig(f\"figs/score_plot/retrain_cifar10_{arch}_{optimizer}_{tar_cls}.png\")\n",
    "        plt.show()\n",
    "        plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ori_trnX, _, _, _, _ = auto_var.get_var_with_argument(\"dataset\", \"mnist\")\n",
    "\n",
    "n_samples = [5, 10, 20, 100,]\n",
    "for spurious_version in [\"v1\", 'v3', \"v6\"]:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for arch in ['Vgg16']:\n",
    "            for optimizer in ['sgd', 'adam']:\n",
    "                ret = []\n",
    "                for i in n_samples:\n",
    "                    ds_name = f\"cifar10{spurious_version}-{i}-{tar_cls}-0\"\n",
    "                    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "                    oriX = ori_trnX[spurious_ind].reshape(len(spurious_ind), -1)\n",
    "                    spuriousX = trnX[spurious_ind].reshape(len(spurious_ind), -1)\n",
    "                    ret.append(np.linalg.norm(oriX-spuriousX, ord=2, axis=1).mean())\n",
    "    ret"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Group influence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_influence_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr):\n",
    "    ds_name = f\"{base_dset}{spurious_version}-{i}-{tar_cls}-{seed}\"\n",
    "    if \"cifar10\" == base_dset:\n",
    "        model_path = f\"../models/group_influence/64-{ds_name}-ce-tor-{arch}-64-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-{optimizer}-0.pt\"\n",
    "    else:\n",
    "        model_path = f\"../models/group_influence/128-{ds_name}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-{optimizer}-0.pt\"\n",
    "    return ds_name, model_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "all_results = {}\n",
    "ttest_results = {}\n",
    "\n",
    "base_dset = \"mnist\"\n",
    "#base_dset = \"fashion\"\n",
    "threshold = 1e-1\n",
    "n_samples = [3, 5, 10]\n",
    "architectures = ['LargeMLP']\n",
    "\n",
    "#base_dset = \"cifar10\"\n",
    "#threshold = 1e-9\n",
    "#n_samples = [3, 5, 10, 20, 500]\n",
    "#architectures = ['ResNet50']\n",
    "\n",
    "\n",
    "for optimizer in ['sgd', 'adam']:\n",
    "#for optimizer in ['adam']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for spurious_version in ['v8']:\n",
    "        #for spurious_version in ['v1', 'v3', 'v8']:\n",
    "            for arch in architectures:\n",
    "                lr = 0.01\n",
    "                momentum = 0. if optimizer == 'adam' else 0.9\n",
    "\n",
    "                for grad_clip in [None]:\n",
    "                    if base_dset == \"cifar10\":\n",
    "                        if grad_clip is None:\n",
    "                            model_path = f\"../models/train_classifier/64-{base_dset}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                        else:\n",
    "                            model_path = f\"../models/train_classifier/64-{base_dset}-70-{grad_clip}-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                    else:\n",
    "                        if grad_clip is None:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                        else:\n",
    "                            model_path = f\"../models/train_classifier/128-{base_dset}-70-{grad_clip}-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "\n",
    "                    mod_tst_preds, tst_preds, tsty = evaluate(base_dset, model_path, arch, spurious_version, 0)\n",
    "                    ind = (tsty != tar_cls)\n",
    "                    mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                    baseline_rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                    #baseline_rv = (((mod_tst_preds - tst_preds)[:, tar_cls] == (mod_tst_preds - tst_preds).max(1)))\n",
    "                    baseline_sem = sem(baseline_rv)\n",
    "                    baseline_nos = [baseline_rv.mean()]\n",
    "                    baseline_acc = (tst_preds.argmax(1) == tsty).mean()\n",
    "                    baseline_mod_tst_preds = mod_tst_preds\n",
    "\n",
    "                \n",
    "                    ret, inc_rets, inc_sems, sems, accs, test_res = [baseline_nos[0]], [baseline_nos[0]], [baseline_sem], [baseline_sem], [baseline_acc], []\n",
    "                    for i in n_samples:\n",
    "                        taccs, tret, tinc_rets, ttest_res = [], [], [], []\n",
    "                        for seed in range(5):\n",
    "                            ds_name, inc_model_path = get_influence_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr)\n",
    "                            _, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed=0)\n",
    "\n",
    "                            try:\n",
    "                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed)\n",
    "                                #ind = (tsty != tar_cls)\n",
    "                                #mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                                ttest_res.append(ttest_ind(rv, baseline_rv, equal_var=False, alternative=\"greater\")[1])\n",
    "                                taccs.append((mod_tst_preds.argmax(1) == tar_cls).mean())\n",
    "                                tret.append(rv.mean())\n",
    "                                \n",
    "                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, inc_model_path, arch, spurious_version, seed)\n",
    "                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                                tinc_rets.append(rv.mean())\n",
    "                            except FileNotFoundError:\n",
    "                                print(f\"missing {inc_model_path}\")\n",
    "                                #ret.append(-1)\n",
    "                                #accs.append(-1)\n",
    "                        if taccs:\n",
    "                            ret.append(np.mean(tret))\n",
    "                            inc_rets.append(np.mean(tinc_rets))\n",
    "                            inc_sems.append(sem(tinc_rets))\n",
    "                            sems.append(sem(tret))\n",
    "                            accs.append(np.mean(taccs))\n",
    "                            test_res.append(ttest_res[0])\n",
    "                        else:\n",
    "                            ret.append(-1)\n",
    "                            sems.append(-1)\n",
    "                            inc_rets.append(-1)\n",
    "                            inc_sems.append(-1)\n",
    "                            accs.append(-1)\n",
    "                            test_res.append(-1)\n",
    "                            \n",
    "                    key = (optimizer, tar_cls, arch, spurious_version, grad_clip)\n",
    "                    ttest_results[key] = test_res\n",
    "                    all_results[key] = ret + accs + sems + inc_rets + inc_sems"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Visualize spurious patterns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "base_dset = \"mnist\"\n",
    "\n",
    "_, _, _, _, spurious_ind = auto_var.get_var_with_argument(\"dataset\", f\"{base_dset}v1-5-0-0\")\n",
    "trnX, _, _, _, _ = auto_var.get_var_with_argument(\"dataset\", base_dset)\n",
    "plt.figure(figsize=(6, 6))\n",
    "plt.imshow(trnX[spurious_ind[0], :, :, 0], vmin=0, vmax=1,)\n",
    "plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.savefig(f\"./figs/spu_examples/{base_dset}.png\", bbox_inches='tight')\n",
    "#plt.imshow(np.abs(fft2(trnX[spurious_ind[0], :, :, 0])))\n",
    "#plt.show()\n",
    "\n",
    "for spurious_version in [\"v1\", 'v3', \"v8\", \"v9\", 'v10', 'v11', \"vgau1\", \"vgau2\", \"v18\", \"v19\", \"v20\", \"v30\"]:\n",
    "    print(spurious_version)\n",
    "\n",
    "    ds_name = f\"{base_dset}{spurious_version}-5-0-0\"\n",
    "    \n",
    "    plt.figure(figsize=(6, 6))\n",
    "    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "    plt.imshow(trnX[spurious_ind[0], :, :, 0], vmin=0, vmax=1,)\n",
    "    plt.axis('off')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(f\"./figs/spu_examples/{ds_name}.png\", bbox_inches='tight')\n",
    "    plt.show()\n",
    "    plt.imshow(np.abs(fft2(trnX[spurious_ind[0], :, :, 0])))\n",
    "    plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trnX, _, _, _, _ = auto_var.get_var_with_argument(\"dataset\", \"mnist\")\n",
    "trnX = np.copy(trnX)\n",
    "ttt = add_spurious_correlation(trnX[0:1], \"v20\", 0)[0, :, :]\n",
    "plt.imshow(ttt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trnX2[spurious_ind[0], 0, 0, 0] = 0\n",
    "plt.imshow(trnX[spurious_ind[0], :, :, 0] - trnX2[spurious_ind[0], :, :, 0])\n",
    "plt.colorbar()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "(trnX[spurious_ind[0], :, :, 0] - trnX2[spurious_ind[0], :, :, 0])[0, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# select a set of background examples to take an expectation over\n",
    "background = trnX[np.random.choice(trnX.shape[0], 200, replace=False)]\n",
    "background = background.transpose(0, 3, 1, 2)\n",
    "\n",
    "test_examples = torch.from_numpy(trnX[spurious_ind[:3]].transpose(0, 3, 1, 2)).float()\n",
    "\n",
    "e = shap.GradientExplainer(model.to(\"cpu\"), torch.from_numpy(background))\n",
    "shap_values = e.shap_values(test_examples)\n",
    "shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]\n",
    "test_numpy = test_examples.numpy().transpose(0, 2, 3, 1)\n",
    "# plot the feature attributions\n",
    "shap.image_plot(shap_numpy, -test_numpy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_numpy[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "batch_predict = partial(pred_fn, model=model)\n",
    "explainer = lime_image.LimeImageExplainer()\n",
    "explanation = explainer.explain_instance(trnX[2].astype(\"double\"), \n",
    "                                         batch_predict, # classification function\n",
    "                                         top_labels=1, \n",
    "                                         hide_color=0, \n",
    "                                         num_samples=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", \"cifar10v1-3-0-0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(trnX[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trny[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.imshow(trnX[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = IsolationForest(random_state=0, n_jobs=12).fit(trn_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = clf.score_samples(trn_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trnX[spurious_ind]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MLP weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ds_names = [\"mnist\", \"mnistv8-10-0-0\", \"mnistv8-100-0-0\", \"mnistv8-2000-0-0\"]\n",
    "titles = [\"(a) 0\", \"(b) 10\", \"(c) 100\", \"(d) 2000\"]\n",
    "arch = \"LargeMLP\"\n",
    "\n",
    "fig, axes = plt.subplots(1, 4, sharex=True, sharey=True, figsize=(12,4))\n",
    "\n",
    "for ds_name, ax, title in zip(ds_names, axes, titles):\n",
    "    model_path = f\"../models/train_classifier/128-{ds_name}-70-0.01-ce-tor-{arch}-0.0-adam-0-0.0.pt\"\n",
    "\n",
    "    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "    n_classes = len(np.unique(trny))\n",
    "    n_channels = trnX.shape[-1]\n",
    "    res = torch.load(model_path)\n",
    "    model = getattr(archs, arch)(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)\n",
    "    model.load_state_dict(res['model_state_dict'])\n",
    "    \n",
    "    ax.set_title(title, fontsize=fontsize, y=-0.24)\n",
    "    im = ax.imshow(model.hidden.weight.max(0)[0].reshape(28, 28).detach().numpy(), vmin=-0.5, vmax=6.5)\n",
    "    im.set_cmap('YlOrRd')\n",
    "    ax.axis('off')\n",
    "fig.subplots_adjust(bottom=0.3, right=0.8)\n",
    "cbar_ax = fig.add_axes([0.82, 0.3, 0.04, 0.5])\n",
    "cbar = plt.colorbar(im, cax=cbar_ax)\n",
    "cbar.ax.tick_params(labelsize=fontsize)\n",
    "plt.savefig(f\"./figs/mlp_weights/mnist_mlp_weights.png\", bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.hidden.weight.max(0)[0].max()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Density"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## gradient clipping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(base_dset)\n",
    "df = pd.DataFrame.from_dict(all_results).transpose()\n",
    "df.columns = pd.MultiIndex.from_tuples([(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"acc\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples])\n",
    "df = df[[(\"spu\", i) for i in [\"0\"] + n_samples] + [(\"sem\", i) for i in [\"0\"] + n_samples]]\n",
    "\n",
    "for optimizer in ['adam']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for arch in architecures:\n",
    "            for spu in spurious_versions:\n",
    "                for grad_clip in grad_clips:\n",
    "                    plt.errorbar(\n",
    "                        x=np.arange(len(n_samples) + 1),\n",
    "                        y=df.loc[(optimizer, tar_cls, spu, arch, float(\"NaN\"))][[(\"spu\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                        yerr=df.loc[(optimizer, tar_cls, spu, arch, float(\"NaN\"))][[(\"sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                        label=f\"{spu} {grad_clip}\",\n",
    "                    )\n",
    "                    plt.errorbar(\n",
    "                        x=np.arange(len(n_samples) + 1),\n",
    "                        y=df.loc[(optimizer, tar_cls, spu, arch, 0.1)][[(\"spu\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                        yerr=df.loc[(optimizer, tar_cls, spu, arch, 0.1)][[(\"sem\", i) for i in [\"0\"] + n_samples]].tolist(),\n",
    "                        label=f\"{spu} {grad_clip}\",\n",
    "                    )\n",
    "            plt.yticks(fontsize=fontsize)\n",
    "            plt.xticks(ticks=np.arange(len(n_samples) + 1), labels=[0] + n_samples, fontsize=fontsize-2)\n",
    "            #plt.xscale(\"linear\")\n",
    "            plt.ylabel(\"Spurious score\", fontsize=fontsize)\n",
    "            plt.xlabel(\"# spurious examples\", fontsize=fontsize)\n",
    "            #plt.legend(loc=(0.02, 0.61), fontsize=fontsize-2)\n",
    "            plt.legend(fontsize=fontsize-2)\n",
    "            plt.tight_layout()\n",
    "            plt.savefig(f\"figs/score_plot/difclip_{base_dset}_{optimizer}_{tar_cls}.png\")\n",
    "            plt.show()\n",
    "            plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import faiss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_dists(trnX):\n",
    "    index = faiss.IndexFlatL2(trnX.shape[1])\n",
    "    nonspu_ind = np.arange(len(trnX))\n",
    "    nonspu_ind = np.delete(nonspu_ind, [spurious_ind])\n",
    "    index.add(trnX[nonspu_ind])\n",
    "    D, I = index.search(trnX[spurious_ind], 21)\n",
    "    D = np.sqrt(D)\n",
    "    return D.reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "n_spu = 10\n",
    "seed = 0\n",
    "\n",
    "data = {}\n",
    "for base_dset in ['mnist', 'fashion']:\n",
    "    for tar_cls in [0, 1]:\n",
    "        for spurious_version in ['v3', 'v9', \"vgau1\", 'v10', 'v11']:\n",
    "            ds_name = f\"{base_dset}{spurious_version}-{n_spu}-{tar_cls}-{seed}\"\n",
    "            trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "            trnX = trnX.reshape(len(trnX), -1)\n",
    "            avg_dist = get_dists(trnX).mean()\n",
    "\n",
    "            keys = (base_dset, tar_cls, spurious_version)\n",
    "            data[keys] = avg_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.DataFrame.from_dict(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "D.reshape(-1).mean()  # mnist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kde = FFTKDE(kernel=\"gaussian\").fit(trnX)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "kde = KernelDensity(kernel='gaussian', bandwidth=1.0).fit(trnX)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores = kde.score_samples(trnX[:50])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
