{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd7f894a-23de-4175-822b-3d33a100a763",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from functools import lru_cache\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "from os.path import join\n",
    "from IPython.display import display\n",
    "from functools import partial\n",
    "\n",
    "import matplotlib\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "matplotlib.rc('text', usetex=True)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['text.usetex'] = True #Let TeX do the typsetting\n",
    "plt.rcParams['text.latex.preamble'] = r\"\"\"\n",
    "\\usepackage{sansmath}\n",
    "\\sansmath\n",
    "\"\"\" #Force sans-serif math mode (for axes labels)\n",
    "plt.rcParams['font.family'] = 'sans-serif' # ... for regular text\n",
    "plt.rcParams['font.sans-serif'] = 'Helvetica, Avant Garde, Computer Modern Sans serif' # Choose a nice font here\n",
    "\n",
    "from scipy.stats import ttest_ind, sem\n",
    "from scipy.fft import fft2\n",
    "import scipy\n",
    "import joblib\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm.notebook import tqdm\n",
    "from skimage.segmentation import mark_boundaries\n",
    "from scipy.stats import pearsonr\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "pd.options.display.float_format = '{:,.3f}'.format\n",
    "pd.set_option('display.max_rows', 128)\n",
    "\n",
    "\n",
    "from spurious_ml.datasets import add_spurious_correlation, add_colored_spurious_correlation\n",
    "from spurious_ml.models.torch_utils import archs\n",
    "from spurious_ml.variables import auto_var\n",
    "from params import *\n",
    "from utils import params_to_dataframe\n",
    "\n",
    "fontsize=16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "224102d7-49dd-46a8-a487-501694933d46",
   "metadata": {},
   "outputs": [],
   "source": [
    "def mlp_pred_fn(X, model, device=\"cuda\"):\n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X.reshape(len(X), -1)).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=256)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\"):\n",
    "        fetX.append(model(x.to(device)).cpu().detach().numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "def cnn_pred_fn(X, model, device=\"cuda\"):\n",
    "    if len(X.shape) == 4:\n",
    "        X = X.transpose(0, 3, 1, 2)\n",
    "        \n",
    "    dset = torch.utils.data.TensorDataset(torch.from_numpy(X).float())\n",
    "    loader = torch.utils.data.DataLoader(dset, batch_size=128)\n",
    "    \n",
    "    model.to(device).eval()\n",
    "    fetX = []\n",
    "    #for (x, ) in loader:\n",
    "    for (x, ) in tqdm(loader, desc=\"[pred_fn]\"):\n",
    "        fetX.append(model(x.to(device)).cpu().detach().numpy())\n",
    "        #fetX.append(model.feature_extractor(x.to(device)).cpu().detach().flatten(1).numpy())\n",
    "    fetX = np.concatenate(fetX, axis=0)\n",
    "    return fetX\n",
    "\n",
    "class CLF():\n",
    "    def __init__(self, model):\n",
    "        self.model = model\n",
    "        \n",
    "    def predict(self, X, device=\"cuda\"):\n",
    "        return pred_fn(X, self.model, device=device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55993496-d82f-4379-a81f-daebd7e40361",
   "metadata": {},
   "outputs": [],
   "source": [
    "@lru_cache(maxsize=None)\n",
    "def evaluate(ds_name, model_path, arch, spurious_version, seed):\n",
    "    trnX, trny, tstX, tsty, spurious_ind = auto_var.get_var_with_argument(\"dataset\", ds_name)\n",
    "    n_classes = len(np.unique(trny))\n",
    "    n_channels = trnX.shape[-1]\n",
    "    res = torch.load(model_path)\n",
    "    model = getattr(archs, arch)(n_features=(np.prod(trnX.shape[1:]), ), n_channels=n_channels, n_classes=n_classes)\n",
    "    model.load_state_dict(res['model_state_dict'])\n",
    "    \n",
    "    if 'MLP' in arch:\n",
    "        pred_fn = mlp_pred_fn\n",
    "    else:\n",
    "        pred_fn = cnn_pred_fn\n",
    "    \n",
    "    tst_preds = pred_fn(tstX, model)\n",
    "    tst_preds = scipy.special.softmax(tst_preds, axis=1)\n",
    "    modified_tstX = np.copy(tstX)\n",
    "    if n_channels == 3:\n",
    "        modified_tstX = add_colored_spurious_correlation(modified_tstX, spurious_version, seed)\n",
    "    else:\n",
    "        modified_tstX = add_spurious_correlation(modified_tstX, spurious_version, seed)\n",
    "    mod_tst_preds = pred_fn(modified_tstX, model)\n",
    "    mod_tst_preds = scipy.special.softmax(mod_tst_preds, axis=1)\n",
    "    return mod_tst_preds, tst_preds, tsty"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0135aa2a-2f15-49a2-bc15-4207cd5b47b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed,\n",
    "                   bs=None, wd=0.0, gi=None, depth=None, noise_level=None, noise_multiplier=None, grad_clip=None, folder=\"train_classifier\"):\n",
    "    ds_name = f\"{base_dset}{spurious_version}-{i}-{tar_cls}-{seed}\"\n",
    "    if folder == \"train_classifier\":\n",
    "        if base_dset in ['mnist', 'fashion']:\n",
    "            bs = 128 if bs is None else bs\n",
    "            if grad_clip is not None:\n",
    "                 model_path = f\"../models/{folder}/{bs}-{ds_name}-70-{grad_clip}-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-{model_seed}-{wd}.pt\"\n",
    "            elif noise_multiplier is not None:\n",
    "                 model_path = f\"../models/{folder}/{bs}-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{noise_multiplier}-noisy{optimizer}-{model_seed}-{wd}.pt\"\n",
    "            else:\n",
    "                 model_path = f\"../models/{folder}/{bs}-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-{model_seed}-{wd}.pt\"\n",
    "        else:\n",
    "            bs = 64 if bs is None else bs\n",
    "            if grad_clip is not None:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-70-{grad_clip}-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-{model_seed}-{wd}.pt\"\n",
    "            elif noise_multiplier is not None:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-70-{lr}-aug01-ce-tor-{arch}-{momentum}-{noise_multiplier}-noisy{optimizer}-{model_seed}-{wd}.pt\"\n",
    "            else:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-70-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-{model_seed}-{wd}.pt\"\n",
    "    elif folder == \"incremental_retraining\":\n",
    "        if \"cifar10\" == base_dset:\n",
    "            model_path = f\"../models/{folder}/128-{ds_name}-140-{lr}-aug01-ce-tor-{arch}-128-{ds_name}-70-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-{wd}.pt-{momentum}-{optimizer}-0-{wd}.pt\"\n",
    "        else:\n",
    "            model_path = f\"../models/{folder}/128-{ds_name}-140-{lr}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "    elif folder == \"forget_retrain\":\n",
    "        if \"cifar10\" == base_dset:\n",
    "            if 'adam' in optimizer:\n",
    "                old_lr = 0.01\n",
    "            else:\n",
    "                old_lr = 0.1\n",
    "            model_path = f\"../models/{folder}/128-{ds_name}-140-all_normal-{lr}-aug01-ce-tor-{arch}-128-{ds_name}-70-{old_lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-{wd}.pt-{momentum}-{noise_level}-{optimizer}-0-{wd}.pt\"\n",
    "        else:\n",
    "            model_path = f\"../models/{folder}/128-{ds_name}-140-all_normal-{lr}-ce-tor-{arch}-128-{ds_name}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-{momentum}-{noise_level}-{optimizer}-0-0.0.pt\"\n",
    "    elif folder == \"group_influence\":\n",
    "        if \"cifar10\" == base_dset:\n",
    "            if gi is not None and depth is None:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-{gi}-aug01-ce-tor-altResNet20Norm02-128-{ds_name}-70-{lr}-aug01-ce-tor-altResNet20Norm02-{momentum}-{optimizer}-0-0.0001.pt-sgd-0.pt\"\n",
    "            elif gi is not None and depth is not None:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-{depth}-{gi}-aug01-ce-tor-altResNet20Norm02-128-{ds_name}-70-{lr}-aug01-ce-tor-altResNet20Norm02-{momentum}-{optimizer}-0-0.0001.pt-sgd-0.pt\"\n",
    "            else:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-aug01-ce-tor-altResNet20Norm02-128-{ds_name}-70-{lr}-aug01-ce-tor-altResNet20Norm02-{momentum}-{optimizer}-0-0.0001.pt-sgd-0.pt\"\n",
    "        else:\n",
    "            if gi is not None and depth is None:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-{gi}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-sgd-0.pt\"\n",
    "            elif gi is not None and depth is not None:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-{depth}-{gi}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-sgd-0.pt\"\n",
    "            else:\n",
    "                model_path = f\"../models/{folder}/128-{ds_name}-ce-tor-{arch}-128-{ds_name}-70-{lr}-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt-sgd-0.pt\"\n",
    "    return ds_name, model_path"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b007f57e-0068-43a5-8958-994f02633fc6",
   "metadata": {},
   "source": [
    "# MNIST/Fashion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0b901d2-7b3b-4a96-8d7d-79f3c73a0b14",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_dset = \"mnist\"\n",
    "#base_dset = \"fashion\"\n",
    "\n",
    "threshold = 1e-1\n",
    "\n",
    "n_samples = [1, 3, 5, 10, 20, 100]\n",
    "optimizers = ['adam',]\n",
    "#optimizers = ['sgd']\n",
    "spurious_versions = ['v20']\n",
    "spurious_versions = ['v8']\n",
    "architecures = ['LargeMLP']\n",
    "grad_clips = [None]\n",
    "tar_clses = [0]\n",
    "wd = 0.0\n",
    "\n",
    "opt_methods = {\n",
    "    'sgd': {\n",
    "        'methods': ['train_classifier', #'incremental_retraining',\n",
    "                    'wd-0.0001', 'wd-0.001'],\n",
    "        'method_names':  ['original', #'retrain',\n",
    "                          'wd(1e-4)', 'wd(1e-3)'],\n",
    "    },\n",
    "    'adam': {\n",
    "        'methods': ['train_classifier', #'incremental_retraining',\n",
    "                    'gc-0.00001', 'gc-0.0001', 'gc-0.001', 'gc-0.01', 'gc-0.1',\n",
    "                    'wd-0.00001', 'wd-0.0001', 'wd-0.001', 'wd-0.05', 'wd-0.01', 'wd-0.1',\n",
    "                    \"ninput00\", \"ninput01\", \"ninput02\", \"ninput03\", \"ninput04\", \"ninput05\", \"ninput06\", \"ninput07\"],\n",
    "        'method_names':  ['original', #'retrain',\n",
    "                          'gc(1e-5)', 'gc(1e-4)', 'gc(1e-3)', 'gc(1e-2)', 'gc(1e-1)', \n",
    "                          'wd(1e-5)', 'wd(1e-4)', 'wd(1e-3)', 'wd(1e-2)', 'wd(5e-2)', 'wd(1e-1)',\n",
    "                          \"ninput00\", \"ninput01\", \"ninput02\", \"ninput03\", \"ninput04\", \"ninput05\", \"ninput06\", \"ninput07\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "base_dset = \"cifar10\"\n",
    "optimizers = ['adam']\n",
    "spurious_versions = ['v20']\n",
    "spurious_versions = ['v8']\n",
    "architecures = ['altResNet20Norm02']\n",
    "n_samples = [1, 3, 5, 10, 20, 100, 1000]\n",
    "wd = 0.0\n",
    "\n",
    "opt_methods = {\n",
    "    'sgd': {\n",
    "        'methods': ['train_classifier', #'incremental_retraining',\n",
    "                    'wd-0.0001', 'wd-0.001'],\n",
    "        'method_names':  ['original', #'retrain',\n",
    "                          'wd(1e-4)', 'wd(1e-3)'],\n",
    "    },\n",
    "    'adam': {\n",
    "        'methods': ['train_classifier', #'incremental_retraining',\n",
    "                    'gc-0.01', 'gc-0.1', 'gc-1.0', 'gc-10.0',\n",
    "                    'wd-0.000001', 'wd-0.00001', 'wd-0.0001', 'wd-0.001', 'wd-0.01',\n",
    "                    \"ninput00\", \"ninput01\", \"ninput02\", \"ninput03\", \"ninput04\", \"ninput05\"],\n",
    "        'method_names':  ['original', #'retrain',\n",
    "                          'gc(1e-2)', 'gc(1e-1)', 'gc(1e-0)', 'gc(1e-1)', \n",
    "                          'wd(1e-6)', 'wd(1e-5)', 'wd(1e-4)', 'wd(1e-3)', 'wd(1e-2)',\n",
    "                          \"ninput00\", \"ninput01\", \"ninput02\", \"ninput03\", \"ninput04\", \"ninput05\"],\n",
    "    },\n",
    "}\n",
    "\n",
    "\n",
    "\n",
    "all_accs = []\n",
    "all_results = {}\n",
    "\n",
    "#methods = ['train_classifier', 'incremental_retraining', 'noisy-0.02', 'noisy-0.01', 'forget_retrain-1.0', 'forget_retrain-0.5', 'wd-0.0001']\n",
    "#methods = ['train_classifier', 'incremental_retraining', 'noisy-0.02', 'noisy-0.01', 'forget_retrain-0.2', 'forget_retrain-0.1', 'wd-0.0001']\n",
    "\n",
    "model_seed = 0\n",
    "\n",
    "for toptimizer in optimizers:\n",
    "    methods = opt_methods[toptimizer]['methods']\n",
    "    for tar_cls in tar_clses:\n",
    "        for spurious_version in spurious_versions:\n",
    "            for arch in architecures:\n",
    "                lr = 0.1 if \"cifar\" in base_dset and toptimizer == \"sgd\" else 0.01\n",
    "                optimizer = toptimizer if 'Vgg' not in arch else \"sgd\"\n",
    "                momentum = 0. if optimizer == 'adam' else 0.9\n",
    "\n",
    "                grad_clip = None\n",
    "                if 'cifar' in base_dset:\n",
    "                    if grad_clip is None:\n",
    "                        model_path = f\"../models/train_classifier/128-{base_dset}-70-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                    else:\n",
    "                        model_path = f\"../models/train_classifier/128-{base_dset}-70-{grad_clip}-{lr}-aug01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                else:\n",
    "                    if grad_clip is None:\n",
    "                        model_path = f\"../models/train_classifier/128-{base_dset}-70-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                    else:\n",
    "                        model_path = f\"../models/train_classifier/128-{base_dset}-70-{grad_clip}-0.01-ce-tor-{arch}-{momentum}-{optimizer}-0-0.0.pt\"\n",
    "                mod_tst_preds, tst_preds, tsty = evaluate(base_dset, model_path, arch, spurious_version, 0)\n",
    "                ind = np.arange(len(tsty))\n",
    "                mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                baseline_rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                baseline_nos = baseline_rv.mean()\n",
    "                baseline_sem = sem(baseline_rv)\n",
    "                baseline_acc = (tst_preds.argmax(1) == tsty).mean()\n",
    "                baseline_spacc = (tst_preds.argmax(1) == tsty)[tsty != tar_cls].mean()\n",
    "                baseline_mod_tst_preds = mod_tst_preds\n",
    "\n",
    "                key = (optimizer, tar_cls, spurious_version, arch, grad_clip)\n",
    "                all_results[key] = {}\n",
    "                for method in methods:\n",
    "                    mkey = method\n",
    "                    noise_multiplier = None\n",
    "                    twd = wd\n",
    "                    noise_level = None\n",
    "                    grad_clip = None\n",
    "                    if 'forget' in method:\n",
    "                        if 'cifar' in base_dset:\n",
    "                            lr = 0.001\n",
    "                        else:\n",
    "                            lr = 0.0001\n",
    "                        folder, noise_level = method.split('-')[0], float(method.split('-')[1])\n",
    "                    elif 'gc' in method:\n",
    "                        lr = 0.01\n",
    "                        folder, grad_clip = 'train_classifier', float(method.split('-')[1])\n",
    "                    elif 'wd' in method:\n",
    "                        lr = 0.01\n",
    "                        folder, twd = 'train_classifier', float(method.split('-')[1])\n",
    "                    elif 'noisy' in method:\n",
    "                        lr = 0.01\n",
    "                        folder, noise_multiplier = 'train_classifier', float(method.split('-')[1])\n",
    "                    elif 'oneothers' in method or 'ninput' in method:\n",
    "                        lr = 0.01\n",
    "                        folder, twd = 'train_classifier', 0.\n",
    "                    else:\n",
    "                        lr = 0.01\n",
    "                        folder, twd = 'train_classifier', 0.\n",
    "\n",
    "                    all_results[key][mkey] = {}\n",
    "                    all_results[key][mkey]['score_mean'] = [baseline_nos]\n",
    "                    all_results[key][mkey]['score_sem'] = [0]\n",
    "                    all_results[key][mkey]['acc'] = [baseline_acc]\n",
    "                    all_results[key][mkey]['spacc'] = [baseline_spacc]\n",
    "                    all_results[key]['n_samples'] = n_samples\n",
    "\n",
    "                    for i in n_samples:\n",
    "                        taccs, sptaccs, tret, ttest_res = [], [], [], []\n",
    "\n",
    "                        for seed in range(5):\n",
    "                            ds_name, model_path = get_model_path(base_dset, spurious_version, i, tar_cls, seed, arch, momentum, optimizer, lr, model_seed,\n",
    "                                                                 bs=None, wd=twd, gi=None, depth=None, noise_level=noise_level,\n",
    "                                                                 noise_multiplier=noise_multiplier, grad_clip=grad_clip, folder=folder)\n",
    "                            if 'oneothers' in method:\n",
    "                                model_path = model_path.replace(f\"128-{ds_name}\", \"128-oneothers\" + ds_name)\n",
    "                                ds_name = \"oneothers\" + ds_name\n",
    "                            elif 'ninput' in method:\n",
    "                                pre = \"aug01-\" if 'cifar10' in base_dset else \"\"\n",
    "                                pre2 = \"-aug01\" if 'cifar10' in base_dset else \"-\"\n",
    "                                if '01' in method:\n",
    "                                    model_path = model_path.replace(f\"{lr}-{pre}ce\", f\"{lr}{pre2}noise01-ce\")\n",
    "                                elif '02' in method:\n",
    "                                    model_path = model_path.replace(f\"{lr}-{pre}ce\", f\"{lr}{pre2}noise02-ce\")\n",
    "                                elif '03' in method:\n",
    "                                    model_path = model_path.replace(f\"{lr}-{pre}ce\", f\"{lr}{pre2}noise03-ce\")\n",
    "                                elif '04' in method:\n",
    "                                    model_path = model_path.replace(f\"{lr}-{pre}ce\", f\"{lr}{pre2}noise04-ce\")\n",
    "                                elif '05' in method:\n",
    "                                    model_path = model_path.replace(f\"{lr}-{pre}ce\", f\"{lr}{pre2}noise05-ce\")\n",
    "                                elif '06' in method:\n",
    "                                    model_path = model_path.replace(f\"{lr}-{pre}ce\", f\"{lr}{pre2}noise06-ce\")\n",
    "                                elif '07' in method:\n",
    "                                    model_path = model_path.replace(f\"{lr}-{pre}ce\", f\"{lr}{pre2}noise07-ce\")\n",
    "                            try:\n",
    "                                mod_tst_preds, tst_preds, tsty = evaluate(ds_name, model_path, arch, spurious_version, seed)\n",
    "                                ind = np.arange(len(tsty))\n",
    "                                mod_tst_preds, tst_preds, tsty = mod_tst_preds[ind], tst_preds[ind], tsty[ind]\n",
    "                                rv = (mod_tst_preds[:, tar_cls] - tst_preds[:, tar_cls]) > threshold\n",
    "                                taccs.append((tst_preds.argmax(1) == tsty).mean())\n",
    "                                sptaccs.append((mod_tst_preds.argmax(1) == tsty)[tsty != tar_cls].mean())\n",
    "                                #taccs.append(mod_tst_preds[rv][:, tar_cls].mean())\n",
    "                                tret.append(rv.mean())\n",
    "                            except FileNotFoundError:\n",
    "                                print(f\"missing {model_path}\")\n",
    "                            except:\n",
    "                                print(f\"wierd {model_path}\")\n",
    "\n",
    "                        if taccs:\n",
    "                            all_results[key][mkey]['score_mean'].append(np.mean(tret))\n",
    "                            all_results[key][mkey]['score_sem'].append(sem(tret))\n",
    "                            all_results[key][mkey]['acc'].append(np.mean(taccs))\n",
    "                            all_results[key][mkey]['spacc'].append(np.mean(sptaccs))\n",
    "                            all_accs += taccs\n",
    "                        else:\n",
    "                            all_results[key][mkey]['score_mean'].append(-1)\n",
    "                            all_results[key][mkey]['score_sem'].append(-1)\n",
    "                            all_results[key][mkey]['acc'].append(-1)\n",
    "                            all_results[key][mkey]['spacc'].append(-1)\n",
    "\n",
    "print(min(all_accs), max(all_accs), np.mean(all_accs), np.std(all_accs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "548c2f77-456b-41de-b245-6a31ba88bbc7",
   "metadata": {},
   "outputs": [],
   "source": [
    "#res = all_results[('adam', 0, 'v20', 'LargeMLP', None)]\n",
    "arch = architecures[0]\n",
    "spurious_version = spurious_versions[0]\n",
    "res = all_results[('adam', 0, spurious_version, arch, None)]\n",
    "\n",
    "#res = all_results[('adam', 0, 'v20', 'altResNet20Norm02', None)]\n",
    "\n",
    "idx = 3  # controls number of spurious examples\n",
    "\n",
    "mnist_setting = {\n",
    "    \"wd\": {\n",
    "        'mkeys': ['train_classifier', 'wd-0.00001', 'wd-0.0001', 'wd-0.001', 'wd-0.01', 'wd-0.05', 'wd-0.1'],\n",
    "        'ticks': [0., 1e-5, 1e-4, 1e-3, 1e-2, 5e-2, 1e-1],\n",
    "    },\n",
    "    \"noise\": {\n",
    "        'mkeys': ['train_classifier', \"ninput00\", \"ninput01\", \"ninput02\", \"ninput03\", \"ninput04\", \"ninput05\", \"ninput06\"],\n",
    "        'ticks': [0., 0.05, 0.1, 0.2, 0.5, 0.75, 1.0, 5.0],\n",
    "    },\n",
    "    \"gc\": {\n",
    "        'mkeys': ['train_classifier', 'gc-0.1', 'gc-0.01', 'gc-0.001', 'gc-0.0001', 'gc-0.00001'],\n",
    "        'ticks': [1e0, 1e-1, 1e-2, 1e-3, 1e-4, 1e-5],\n",
    "    },\n",
    "}\n",
    "methods = {\n",
    "    'cifar10': {\n",
    "        \"wd\": {\n",
    "            'mkeys': ['train_classifier', 'wd-0.000001', 'wd-0.00001', 'wd-0.0001', 'wd-0.001', 'wd-0.01'],\n",
    "            'ticks': [0., 1e-6, 1e-5, 1e-4, 1e-3, 1e-2],\n",
    "        },\n",
    "        \"noise\": {\n",
    "            'mkeys': ['train_classifier', \"ninput01\", \"ninput02\", \"ninput03\", \"ninput04\", \"ninput05\"],\n",
    "            'ticks': [0., 0.1, 0.2, 0.5, 1.0, 2.0],\n",
    "        },\n",
    "        \"gc\": {\n",
    "            'mkeys': ['train_classifier', 'gc-0.01', 'gc-0.1', 'gc-1.0', 'gc-10.0'],\n",
    "            'ticks': [0., 1e-2, 1e-1, 1e-0, 1e1],\n",
    "        },\n",
    "    },\n",
    "    'mnist': mnist_setting,\n",
    "    'fashion': mnist_setting,\n",
    "}\n",
    "\n",
    "fontsize=20\n",
    "\n",
    "for k, v in methods[base_dset].items():\n",
    "    mkeys = v['mkeys']\n",
    "    ticks = v['ticks']\n",
    "    \n",
    "    spu_scores = []\n",
    "    accs = []\n",
    "    acc_sems = []\n",
    "    spaccs = []\n",
    "    spacc_sems = []\n",
    "    spu_score_sems = []\n",
    "    for mkey in mkeys:\n",
    "        spu_scores.append(np.mean(res[mkey]['score_mean']))\n",
    "        spu_score_sems.append(sem(res[mkey]['score_mean']))\n",
    "        accs.append(np.mean(np.mean(res[mkey]['acc'])))\n",
    "        acc_sems.append(sem(res[mkey]['acc']))\n",
    "        spaccs.append(np.mean(np.mean(res[mkey]['spacc'])))\n",
    "        spacc_sems.append(sem(res[mkey]['spacc']))\n",
    "        #spu_scores.append(res[mkey]['score_mean'][idx])\n",
    "        #accs.append(np.mean(res[mkey]['acc'][idx]))\n",
    "    \n",
    "    print(spu_scores, accs)\n",
    "    \n",
    "    color = 'tab:blue'\n",
    "    fig, ax1 = plt.subplots()\n",
    "    if k == \"wd\":\n",
    "        plt.xscale(\"symlog\", linthreshx=0.000001)\n",
    "        plt.xlabel(\"Weight decay\", fontsize=fontsize)\n",
    "    elif k == \"gc\":\n",
    "        plt.xscale(\"symlog\", linthreshx=0.000001)\n",
    "        plt.xlabel(\"Gradient clipping\", fontsize=fontsize)\n",
    "    else:\n",
    "        linthreshx = 0.5 if \"cifar\" in base_dset else 0.01\n",
    "        plt.xscale(\"symlog\", linthreshx=linthreshx)\n",
    "        plt.xlabel(\"Noise level\", fontsize=fontsize)\n",
    "\n",
    "    ax1.errorbar(ticks, spu_scores, yerr=spu_score_sems, color=color)\n",
    "    ax1.set_ylabel(\"Spurious score\", fontsize=fontsize, color=color)\n",
    "    ax1.tick_params(axis='y', labelcolor=color, labelsize=fontsize)\n",
    "    ax1.tick_params(axis='x', labelsize=fontsize)\n",
    "    \n",
    "    if k == \"gc\":\n",
    "        ax1.set_xticks(ticks)\n",
    "        if 'cifar' in base_dset:\n",
    "            ax1.set_xticklabels([\"$\\infty$\", ] + [f\"$10^{{{i}}}$\" for i in [-2, -1, 0, 1]])\n",
    "        else:\n",
    "            ax1.set_xticklabels([\"$\\infty$\", ] + [f\"$10^{{{i}}}$\" for i in [-1, -2, -3, -4, -5]])\n",
    "\n",
    "    color = 'tab:red'\n",
    "    ax2 = ax1.twinx()\n",
    "    ax2.set_ylabel('Test accuracy', fontsize=fontsize, color=color)  # we already handled the x-label with ax1\n",
    "    ax2.errorbar(ticks, spaccs, yerr=spacc_sems, color=color, linestyle='--')\n",
    "    ax2.errorbar(ticks, accs, yerr=acc_sems, color=color)\n",
    "    ax2.tick_params(axis='y', labelcolor=color, labelsize=fontsize)\n",
    "\n",
    "    fig.tight_layout()\n",
    "    plt.savefig(f\"figs/regularization/{base_dset}_{spurious_version}_{k}.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4329f1b9-35e0-4d62-bae7-96c450aed17a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "189e3b2e-cd4e-4ff7-a279-c3bdd5f42060",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64c816ba-d3b5-4dc7-8976-1146e62f5454",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.linspace(0, 10, 50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9753165-6b7f-4675-b481-393588b460ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "def f(x, r, sinv, ssp):\n",
    "    return (1-r/2) * x * sinv / ((1-r)*x + 2 * ssp / r)\n",
    "    \n",
    "plt.plot(f(np.linspace(0, 10, 50), 0.1, 1, 1), label=\"1.\")\n",
    "plt.plot(f(np.linspace(0, 10, 50), 0.1, 10, 10), label=\"10.\")\n",
    "plt.plot(f(np.linspace(0, 10, 50), 0.1, 10, 1), label=\"3\")\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a1e88e7-0d6c-4a63-bb6f-5ee2ba7fa6cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sympy.abc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8fcdad8-d162-42d3-b2f0-3e7df5a8c883",
   "metadata": {},
   "outputs": [],
   "source": [
    "sympy.abc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b7f77cfa-b17a-4aa4-bc80-ddb40ed744d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sympy\n",
    "from sympy import symbols, simplify, latex, fraction, factor, collect\n",
    "from sympy.solvers import solve\n",
    "from sympy.abc import lamda\n",
    "import sympy.plotting\n",
    "from sympy.stats import Normal, cdf\n",
    "\n",
    "def move_sympyplot_to_axes(p, ax):\n",
    "    backend = p.backend(p)\n",
    "    backend.ax = ax\n",
    "    backend._process_series(backend.parent._series, ax, backend.parent)\n",
    "    backend.ax.spines['right'].set_color('none')\n",
    "    backend.ax.spines['left'].set_position(('axes', 0))\n",
    "    backend.ax.spines['bottom'].set_position(('axes', 0))\n",
    "    plt.close(backend.fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2a468ba-05e3-4402-bafe-7db4385ca845",
   "metadata": {},
   "outputs": [],
   "source": [
    "a, b, sinv, ssp, xinv, xsp, gamma = symbols('a b sigma_inv sigma_sp x_{inv} x_{sp} gamma')\n",
    "beta0 = -gamma * xsp**2 / 2 * b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0e62ac9-e89f-4017-a8a1-bd97b388238c",
   "metadata": {},
   "outputs": [],
   "source": [
    "eq1 = a* (sinv**2 + lamda) + a * xinv**2 + gamma * xsp**2 * b / 2 - 1\n",
    "eq1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c03fb59-e971-4edb-990a-55a279c1475d",
   "metadata": {},
   "outputs": [],
   "source": [
    "eq2 = b* (ssp**2 + lamda) + gamma / 2 * (a * xinv**2 + b * xsp**2 + beta0 - 1)\n",
    "eq2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "de3705ce-bcf6-4294-9cc2-6fb79c6fb681",
   "metadata": {},
   "outputs": [],
   "source": [
    "sol = solve([eq1, eq2], a, b)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96c01366-7217-4959-8bc4-38f38af3def5",
   "metadata": {},
   "outputs": [],
   "source": [
    "sol[a] / sol[b]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "545f7984-7a94-43cd-99fd-e7a5f0c3254d",
   "metadata": {},
   "outputs": [],
   "source": [
    "(1 - gamma * xsp**2 + 2 * (ssp**2 + lamda) / gamma) / (sinv + lamda)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "34f8506e-f203-450c-8196-39977be5b2f8",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a037a569-8339-4b5c-97eb-9b056a1105af",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(latex(sol[b]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74a117d4-9aa2-4951-93e6-398ba54702fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "simplify(sol[b]).factor()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f5221c0-45b5-4dbe-8394-1b10c6e9c2fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "\\frac{- 2 \\gamma (\\lambda + \\sigma_{inv}^{2})}\n",
    "     {\\gamma x_{sp}^{2} ((\\lambda + \\sigma_{inv}^{2})(\\gamma - 2) + 2x_{inv}^{2} (\\gamma - 1))\n",
    "      - 4 (\\lambda^{2} - \\lambda \\sigma_{sp}^{2} - \\lambda (\\sigma_{inv}^{2} + x_{inv}^{2}) - \\sigma_{sp}^{2} (\\sigma_{inv}^{2} + x_{inv}^{2}))}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d241bdfa-2c8f-4d03-830e-cc72678a7548",
   "metadata": {},
   "outputs": [],
   "source": [
    "(sol[b] * xsp**2).subs({ssp: 0,})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d71351d7-2aeb-4618-a923-0fb6977732d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gamma\n",
    "plt.rc('font', size=16)\n",
    "p = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.5, ssp: 0., xinv: 1, lamda: 0., xsp: 1}),\n",
    "                        xlim=(0, 1), ylim=(0, 3), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\gamma$\", show=False)\n",
    "#plt.tight_layout()\n",
    "ax = plt.gca()\n",
    "move_sympyplot_to_axes(p, ax)\n",
    "ax.set_xlabel(\"$\\gamma$\", fontsize= 20)\n",
    "ax.yaxis.set_label_coords(-0.1, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.spines['top'].set_visible(False)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figs/diagrams/gamma_score.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d5f3b63-89e0-4fb3-8493-ca91b66f28dd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gamma\n",
    "plt.rc('font', size=16)\n",
    "p1 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.5, ssp: 0., xinv: 1, lamda: 0., xsp: 1}),\n",
    "                        xlim=(0, 1), ylim=(0, 3), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{sp}=0.$\", legend=True)\n",
    "p2 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.5, ssp: 0.2, xinv: 1, lamda: 0., xsp: 1}),\n",
    "                        xlim=(0, 1), ylim=(0, 2), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{sp}=0.2$\")\n",
    "p3 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.5, ssp: 0.5, xinv: 1, lamda: 0., xsp: 1}),\n",
    "                        xlim=(0, 1), ylim=(0, 2), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{sp}=0.5$\")\n",
    "p1.append(p2[0])\n",
    "p1.append(p3[0])\n",
    "ax = plt.gca()\n",
    "move_sympyplot_to_axes(p1, ax)\n",
    "ax.set_xlabel(\"$\\gamma$\", fontsize= 20)\n",
    "ax.yaxis.set_label_coords(-0.1, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.spines['top'].set_visible(False)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figs/diagrams/gamma_score_2.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3ffb63b-7720-4cfb-b0da-bec4dfe74ce5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gamma\n",
    "plt.rc('font', size=16)\n",
    "p1 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.1, ssp: 0., xinv: 1, lamda: 0., xsp: 5}),\n",
    "                        xlim=(0, 1), ylim=(0, 3), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{inv}=0.1$\", legend=True)\n",
    "p2 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.2, ssp: 0., xinv: 1, lamda: 0., xsp: 5}),\n",
    "                        xlim=(0, 1), ylim=(0, 2), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{inv}=0.2$\")\n",
    "p3 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.5, ssp: 0., xinv: 1, lamda: 0., xsp: 5}),\n",
    "                        xlim=(0, 1), ylim=(0, 2), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{inv}=0.5$\")\n",
    "p1.append(p2[0])\n",
    "p1.append(p3[0])\n",
    "ax = plt.gca()\n",
    "move_sympyplot_to_axes(p1, ax)\n",
    "ax.set_xlabel(\"$\\gamma$\", fontsize= 20)\n",
    "ax.yaxis.set_label_coords(-0.1, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.spines['top'].set_visible(False)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figs/diagrams/gamma_score_3.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7a92643b-08c3-47aa-926c-2112326831db",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gamma\n",
    "plt.rc('font', size=16)\n",
    "p1 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.1, ssp: 0., xinv: 1, lamda: 0., xsp: 5}),\n",
    "                        xlim=(0, 1), ylim=(0, 3), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{inv}=0.1$\", legend=True)\n",
    "p2 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.2, ssp: 0., xinv: 1, lamda: 0., xsp: 5}),\n",
    "                        xlim=(0, 1), ylim=(0, 2), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{inv}=0.2$\")\n",
    "p3 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.5, ssp: 0., xinv: 1, lamda: 0., xsp: 5}),\n",
    "                        xlim=(0, 1), ylim=(0, 2), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{inv}=0.5$\")\n",
    "p1.append(p2[0])\n",
    "p1.append(p3[0])\n",
    "ax = plt.gca()\n",
    "move_sympyplot_to_axes(p1, ax)\n",
    "ax.set_xlabel(\"$\\gamma$\", fontsize= 20)\n",
    "ax.yaxis.set_label_coords(-0.1, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.spines['top'].set_visible(False)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figs/diagrams/gamma_score_3.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "afa02c70-84bb-4b5b-981c-000ad8446b75",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d27c6a2e-7675-406b-a82c-c4a75b48e1e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "p = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.3, ssp: 0.1, xinv: 1, lamda: 0., gamma: 0.05}),\n",
    "                        xlim=(0, 10), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\mathbf{x}_{sp}$\", show=False)\n",
    "ax = plt.gca()\n",
    "move_sympyplot_to_axes(p, ax)\n",
    "ax.set_xlabel(\"$\\mathbf{x}_{sp}$\", fontsize= 20)\n",
    "ax.yaxis.set_label_coords(-0.13, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.spines['top'].set_visible(False)\n",
    "#ax.set_yticks([])\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figs/diagrams/xspnorm_score.png\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8e77d6a-7036-4e8d-9de0-9fc9e8d097a7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gamma\n",
    "plt.rc('font', size=16)\n",
    "p1 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.5, ssp: 0., xinv: 1, lamda: 0., xsp: 1}),\n",
    "                        xlim=(0, 1), ylim=(0, 2), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\gamma$\", show=False, label=\"$\\lambda=0$\", legend=True)\n",
    "p2 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.5, ssp: 0., xinv: 1, lamda: 0.1, xsp: 1}),\n",
    "                        xlim=(0, 1), ylim=(0, 2), ylabel=\"Prediction difference\", xlabel=\"$\\gamma$\", show=False, label=\"$\\lambda=0.1$\")\n",
    "p3 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.5, ssp: 0., xinv: 1, lamda: 0.2, xsp: 1}),\n",
    "                        xlim=(0, 1), ylim=(0, 2), ylabel=\"Prediction difference\", xlabel=\"$\\gamma$\", show=False, label=\"$\\lambda=0.2$\")\n",
    "p1.append(p2[0])\n",
    "p1.append(p3[0])\n",
    "ax = plt.gca()\n",
    "move_sympyplot_to_axes(p1, ax)\n",
    "ax.set_xlabel(\"$\\gamma$\", fontsize= 20)\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.yaxis.set_label_coords(-0.1, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "#ax.set_yticks([])\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figs/diagrams/gamma_lamda_cmp.png\")\n",
    "#p1.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46174afd-27dd-478a-a0bd-0728cd7fa58c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gamma\n",
    "plt.rc('font', size=16)\n",
    "p1 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.7, ssp: 0., xinv: 1, lamda: 0., xsp: 1}),\n",
    "                        xlim=(0, 1), ylim=(0, 2), ylabel=\"Pred. diff. ($f(\\mathbf{x} + \\mathbf{x}_{sp}) - f(\\mathbf{x})$)\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{noise}=0.0$\", legend=True)\n",
    "p2 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.8, ssp: 0.1, xinv: 1, lamda: 0., xsp: 1}),\n",
    "                        xlim=(0, 1), ylim=(0, 2), ylabel=\"Prediction difference\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{noise}=0.1$\")\n",
    "p3 = sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 1.0, ssp: 0.3, xinv: 1, lamda: 0., xsp: 1}),\n",
    "                        xlim=(0, 1), ylim=(0, 2), ylabel=\"Prediction difference\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{noise}=0.3$\")\n",
    "p1.append(p2[0])\n",
    "p1.append(p3[0])\n",
    "ax = plt.gca()\n",
    "move_sympyplot_to_axes(p1, ax)\n",
    "ax.set_xlabel(\"$\\gamma$\", fontsize = 20)\n",
    "ax.spines['top'].set_visible(False)\n",
    "ax.yaxis.set_label_coords(-0.1, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "\n",
    "#ax.set_yticks([])\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"figs/diagrams/gamma_sigma_cmp.png\")\n",
    "#p1.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cfe2dbc-fd79-4226-bf3a-692c6963aa31",
   "metadata": {},
   "outputs": [],
   "source": [
    "sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.5, ssp: 0.5, xinv: 1, lamda: 0., xsp: 1}), xlim=(0, 1), ylim=(0, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "764aa460-1a83-4d34-a274-979fb7a8a9fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "(sol[b] * xsp**2).subs({gamma: 0.05, ssp: 0, xinv: 1, lamda: 0, xsp: 5})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b1eefdd1-ac8b-4812-9293-47c485178551",
   "metadata": {},
   "outputs": [],
   "source": [
    "# xsp\n",
    "sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.1, gamma: 0.05, ssp: 0.1, lamda: 0, xinv: 1}), xlim=(0, 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "03bbf88f-28f0-41f1-a6f5-2222131ae426",
   "metadata": {},
   "outputs": [],
   "source": [
    "sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.05, gamma: 0.05, ssp: 0.05, lamda: 0, xinv: 1}), xlim=(0, 10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9a11c9a-3d34-4afe-a370-1270695ec3c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# xsp\n",
    "sympy.plotting.plot((sol[b] * xsp**2).subs({sinv: 0.1, gamma: 0.05, ssp: 0.1, lamda: 1, xinv: 1}), xlim=(0, 50))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef929d48-687f-4387-8c93-d4710626d8a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "(sol[b] * xsp**2).subs({lamda: 0, xinv: 1})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2736c634-73b7-4720-9bc0-4c3586c1b156",
   "metadata": {},
   "outputs": [],
   "source": [
    "(sol[b] * xsp**2).subs({sinv: 0.01, gamma: 0.05, ssp: 0.01, lamda: 0, xinv: 1})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9f9b90e9-b909-460b-97c5-cbe594fd6540",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d0e028f-2cd3-45a9-a626-d387d677da37",
   "metadata": {},
   "outputs": [],
   "source": [
    "gcdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb3256d2-72e4-49c6-a7e2-27c048419986",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86a03008-902c-4179-a47b-81b65987dd67",
   "metadata": {},
   "outputs": [],
   "source": [
    "sinv, ssp, xinv, xsp, gamma = symbols('sigma_inv sigma_sp x_{inv} x_{sp} gamma')\n",
    "gcdf = cdf(Normal(\"snorm\", 0, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d45718e-bf6c-4ab8-b24f-3606b04f1928",
   "metadata": {},
   "outputs": [],
   "source": [
    "temp = (((1 - gamma) * xsp**2 + 2 * (ssp**2 + lamda) / gamma) / (sinv**2 + lamda) )\n",
    "top = temp + gamma / 2 * xsp**2\n",
    "top2 = temp + (gamma / 2 - 1) * xsp**2\n",
    "down = sympy.sqrt(temp**2 * xinv**4 * sinv**2 + xsp**4 * ssp**2)\n",
    "error_ctest = 1 - gcdf(top / down)\n",
    "error_stest = 1 - gcdf(top2 / down)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93f97a3a-9dad-40e6-b9c5-0bf645944309",
   "metadata": {},
   "outputs": [],
   "source": [
    "error_ctest"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ecc86944-65e0-45f1-b085-7f15013cd8cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "(error_ctest).subs({sinv: 0.3, xsp: 1, ssp: 0.3, lamda: 0, xinv: 1, gamma: 0.5})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "486a680f-fac3-4f37-bb6a-b2e9426a0e4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "p1 = sympy.plotting.plot((error_ctest).subs({sinv: 0.3, xsp: 5, ssp: 0.3, lamda: 0, xinv: 1}), xlim=(0, 1), show=False, legend=True, label=\"1\")\n",
    "p2 = sympy.plotting.plot((error_ctest).subs({sinv: 0.5, xsp: 2, ssp: 0.5, lamda: 0, xinv: 1}), xlim=(0, 1), show=False, label=\"2\")\n",
    "p3 = sympy.plotting.plot((error_ctest).subs({sinv: 1.0, xsp: 2, ssp: 0.5, lamda: 0, xinv: 1}), xlim=(0, 1), show=False, label=\"3\")\n",
    "p1.append(p2[0])\n",
    "p1.append(p3[0])\n",
    "p1.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0711d929-6e3c-4674-af5f-eb432343ad69",
   "metadata": {},
   "outputs": [],
   "source": [
    "p1 = sympy.plotting.plot((error_stest).subs({sinv: 0.3, xsp: 5, ssp: 0.3, lamda: 0, xinv: 1}), xlim=(0, 1), show=False, legend=True, label=\"1\")\n",
    "p2 = sympy.plotting.plot((error_stest).subs({sinv: 0.5, xsp: 2, ssp: 0.5, lamda: 0, xinv: 1}), xlim=(0, 1), show=False, label=\"2\")\n",
    "p3 = sympy.plotting.plot((error_stest).subs({sinv: 1.0, xsp: 2, ssp: 0.5, lamda: 0, xinv: 1}), xlim=(0, 1), show=False, label=\"3\")\n",
    "p1.append(p2[0])\n",
    "p1.append(p3[0])\n",
    "p1.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b252dff-3912-4dab-bfc3-e7cc7e0cedc9",
   "metadata": {},
   "outputs": [],
   "source": [
    "p1 = sympy.plotting.plot((error_stest).subs({sinv: 0.3, xsp: 2, gamma: 1, lamda: 0, xinv: 1}), xlim=(0, 1), ylabel=\"$f(\\sigma_{sp})$\", xlabel=\"$\\sigma_{sp}$\", show=False, legend=True, label=\"1\")\n",
    "p2 = sympy.plotting.plot((error_stest).subs({sinv: 0.5, xsp: 2, gamma: 1, lamda: 0, xinv: 1}), xlim=(0, 1), ylabel=\"$f(\\sigma_{sp})$\", xlabel=\"$\\sigma_{sp}$\", show=False, label=\"2\")\n",
    "p3 = sympy.plotting.plot((error_stest).subs({sinv: 1.0, xsp: 2, gamma: 1, lamda: 0, xinv: 1}), xlim=(0, 1), ylabel=\"$f(\\sigma_{sp})$\", xlabel=\"$\\sigma_{sp}$\", show=False, label=\"3\")\n",
    "p1.append(p2[0])\n",
    "p1.append(p3[0])\n",
    "p1.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "261c8092-15e2-454c-b2cf-4423716b2100",
   "metadata": {},
   "outputs": [],
   "source": [
    "p1 = sympy.plotting.plot((error_ctest).subs({sinv: 0.9, xsp: 2, ssp: 0., lamda: 0, xinv: 1}), xlim=(0, 1), show=False)\n",
    "p2 = sympy.plotting.plot((error_ctest).subs({sinv: 0.9, xsp: 2, ssp: 0., lamda: 0.1, xinv: 1}), xlim=(0, 1), show=False)\n",
    "p1.append(p2[0])\n",
    "p1.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e994858-f009-42db-b59b-574a07653091",
   "metadata": {},
   "outputs": [],
   "source": [
    "sympy.plotting.plot((error_ctest).subs({sinv: 0.01, xsp: 5, ssp: 0.01, lamda: 0, xinv: 1}), xlim=(0, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac368425-f01a-4c0d-b30e-e5668780a4ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "(error_ctest).subs({sinv: 0.1, xsp: 1, ssp: 0.1, lamda: 1, xinv: 1})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "851eceb5-c388-47c6-8c07-5c40fd6b36f3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3a659837-6a8b-45b3-bed8-06d6723d41d1",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "334048b8-75a1-4cc9-8f0d-05230186ce75",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e76b1f99-4578-4819-b320-2969c97a4be4",
   "metadata": {},
   "source": [
    "# plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c9f09d1f-1116-48ec-9840-d5795f3feab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sympy\n",
    "from sympy import symbols, simplify, latex, fraction, factor, collect, lambdify\n",
    "from sympy.solvers import solve\n",
    "from sympy.abc import lamda\n",
    "import sympy.plotting\n",
    "from sympy.stats import Normal, cdf\n",
    "\n",
    "def move_sympyplot_to_axes(p, ax):\n",
    "    backend = p.backend(p)\n",
    "    backend.ax = ax\n",
    "    backend._process_series(backend.parent._series, ax, backend.parent)\n",
    "    backend.ax.spines['right'].set_color('none')\n",
    "    backend.ax.spines['left'].set_position(('axes', 0))\n",
    "    backend.ax.spines['bottom'].set_position(('axes', 0))\n",
    "    plt.close(backend.fig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "596b9729-35d3-4d97-9546-18ac12ca464b",
   "metadata": {},
   "outputs": [],
   "source": [
    "sinv, ssp, xinv, xsp, gamma = symbols('sigma_inv sigma_sp x_{inv} x_{sp} gamma')\n",
    "gcdf = cdf(Normal(\"snorm\", 0, 1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86d65974-9807-4e01-804a-db94eddb52d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "temp = ((gamma * (1 - gamma) / 2 * xsp**2 + ssp + lamda) / (sinv + lamda) ) * xinv**2\n",
    "top = temp + (gamma / 2)**2 * xsp**2\n",
    "top1 = temp - (gamma / 2)**2 * xsp**2\n",
    "top2 = temp + (gamma / 2) * (gamma / 2 - 1) * xsp**2\n",
    "\n",
    "down = sympy.sqrt(temp**2 * sinv + (gamma / 2)**2 * xsp**4 * ssp)\n",
    "error_ctest = gcdf(top / down) / 2 + gcdf(top1 / down) / 2\n",
    "error_stest = gcdf(top2 / down)\n",
    "sscore = (gamma / 2 * (sinv + lamda) * xsp**2) / ((sinv + lamda) * (ssp + lamda) + (ssp + lamda + gamma / 2 * (1 - gamma / 2) * (sinv + lamda) + gamma / 2 * (1 - gamma) * xsp**2) * xsp**2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "04c59715-4152-4c43-8070-32ae01d15847",
   "metadata": {},
   "outputs": [],
   "source": [
    "ugamma = 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c0dd0b2c-e788-4614-aaf9-914b36dcc2e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gamma\n",
    "plt.rc('font', size=15)\n",
    "#plots = []\n",
    "fig = plt.figure(figsize=(3.5,4.5))\n",
    "ax = fig.add_subplot(111)\n",
    "for i, noise in enumerate([0, 0.1, 0.2, 0.3]):\n",
    "    f = lambdify(gamma, error_ctest.subs({sinv: 0.5 + noise, ssp: 0.1 + noise, xinv: 1, lamda: 0., xsp: np.sqrt(5)}), \"numpy\")\n",
    "    res = [f(g) for g in np.linspace(0, ugamma, 100)]\n",
    "    ax.plot(np.linspace(0, ugamma, 100), res, label=\"$\\sigma^2_{noise}=\" + str(noise) + \"$\")\n",
    "    #plots.append(sympy.plotting.plot(error_ctest.subs({sinv: 0.5 + noise, ssp: 0.1 + noise, xinv: 1, lamda: 0., xsp: np.sqrt(5)}), linewidth=100,\n",
    "    #                        xlim=(0, 0.3), ylim=(0.85, 1.0), ylabel=\"Clean test accuracy\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{inv}=0.1$\", legend=False))\n",
    "    #if i > 0:\n",
    "    #    plots[0].append(plots[i][0])\n",
    "\n",
    "#move_sympyplot_to_axes(plots[0], ax)\n",
    "ax.set_ylim(.81, .93)\n",
    "ax.set_ylabel(\"Clean test accuracy\", fontsize= 18)\n",
    "ax.set_xlabel(\"$\\gamma$\", fontsize=20)\n",
    "ax.yaxis.set_label_coords(-0.2, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.spines['top'].set_visible(False)\n",
    "#plt.legend(fontsize=14)\n",
    "plt.tight_layout()\n",
    "\n",
    "ticks = ax.yaxis.get_ticklabels()\n",
    "for ent in ticks:\n",
    "    temp = ent.get_text().replace(\"0.\", \".\")\n",
    "    ent.set_text(temp)\n",
    "ax.yaxis.set_ticklabels(ticks)\n",
    "plt.savefig(\"figs/fig6/error_ctest_noise.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c22cd165-9e26-4a85-8931-97c4861f9cf1",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('font', size=15)\n",
    "fig = plt.figure(figsize=(3.5,4.5))\n",
    "ax = fig.add_subplot(111)\n",
    "#plots = []\n",
    "for i, noise in enumerate([0, 0.1, 0.2, 0.3]):\n",
    "    f = lambdify(gamma, error_stest.subs({sinv: 0.5 + noise, ssp: 0.1 + noise, xinv: 1, lamda: 0., xsp: np.sqrt(5)}), \"numpy\")\n",
    "    res = [f(g) for g in np.linspace(0, ugamma, 100)]\n",
    "    ax.plot(np.linspace(0, ugamma, 100), res, label=\"$\\sigma^2_{noise}=\" + str(noise) + \"$\")\n",
    "    #plots.append(sympy.plotting.plot(error_stest.subs({sinv: 0.5 + noise, ssp: 0.1 + noise, xinv: 1, lamda: 0., xsp: np.sqrt(5)}), linewidth=100,\n",
    "    #                        xlim=(0, 0.2), ylim=(0.65, 0.95), ylabel=\"Spurious test accuracy\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{inv}=0.1$\", legend=False))\n",
    "    #if i > 0:\n",
    "    #    plots[0].append(plots[i][0])\n",
    "\n",
    "#move_sympyplot_to_axes(plots[0], ax)\n",
    "ax.set_ylim(.5, .93)\n",
    "ax.set_xlabel(\"$\\gamma$\", fontsize= 20)\n",
    "ax.set_ylabel(\"Spurious test accuracy\", fontsize= 18)\n",
    "ax.yaxis.set_label_coords(-0.2, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.spines['top'].set_visible(False)\n",
    "plt.legend(fontsize=12, frameon=False)\n",
    "plt.tight_layout()\n",
    "\n",
    "ticks = ax.yaxis.get_ticklabels()\n",
    "for ent in ticks:\n",
    "    temp = ent.get_text().replace(\"0.\", \".\")\n",
    "    ent.set_text(temp)\n",
    "ax.yaxis.set_ticklabels(ticks)\n",
    "\n",
    "plt.savefig(\"figs/fig6/error_stest_noise.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "719661ce-6646-481e-952b-c923b42ef29f",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('font', size=15)\n",
    "fig = plt.figure(figsize=(3.5,4.5))\n",
    "ax = fig.add_subplot(111)\n",
    "#plots = []\n",
    "for i, noise in enumerate([0, 0.1, 0.2, 0.3]):\n",
    "    f = lambdify(gamma, sscore.subs({sinv: 0.5 + noise, ssp: 0.1 + noise, xinv: 1, lamda: 0., xsp: np.sqrt(5)}), \"numpy\")\n",
    "    res = [f(g) for g in np.linspace(0, ugamma, 100)]\n",
    "    ax.plot(np.linspace(0, ugamma, 100), res)\n",
    "    #plots.append(sympy.plotting.plot(sscore.subs({sinv: 0.5 + noise, ssp: 0.1 + noise, xinv: 1, lamda: 0., xsp: np.sqrt(5)}), linewidth=100,\n",
    "    #                        xlim=(0, 0.2), ylim=(0.0, 0.14), ylabel=\"Spurious score\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{inv}=0.1$\", legend=False))\n",
    "    #if i > 0:\n",
    "    #    plots[0].append(plots[i][0])\n",
    "#move_sympyplot_to_axes(plots[0], ax)\n",
    "ax.set_xlabel(\"$\\gamma$\", fontsize= 20)\n",
    "ax.set_ylabel(\"Spurious score\", fontsize= 18)\n",
    "\n",
    "ax.set_ylim(-0.01, .17)\n",
    "ax.yaxis.set_label_coords(-0.20, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.spines['top'].set_visible(False)\n",
    "plt.tight_layout()\n",
    "\n",
    "ticks = ax.yaxis.get_ticklabels()\n",
    "for ent in ticks:\n",
    "    temp = ent.get_text().replace(\"0.\", \".\")\n",
    "    ent.set_text(temp)\n",
    "_ = ax.yaxis.set_ticklabels(ticks)\n",
    "plt.savefig\n",
    "plt.savefig(\"figs/fig6/sscore_noise.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73a5712e-3436-4ee8-a597-ef5546dae8ef",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "9030b357-6e0f-4458-b897-6ecf3a05f57b",
   "metadata": {},
   "source": [
    "### L2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8bc0bc9-d6f6-46b7-a26f-f24b98b3bf35",
   "metadata": {},
   "outputs": [],
   "source": [
    "lams = [0, 1.0, 10.0, 100.0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4d78ec16-ba4b-44bd-a3c9-8d786b0a7a36",
   "metadata": {},
   "outputs": [],
   "source": [
    "# gamma\n",
    "plt.rc('font', size=15)\n",
    "fig = plt.figure(figsize=(3.5,4.5))\n",
    "ax = fig.add_subplot(111)\n",
    "#plots = []\n",
    "for i, lam in enumerate(lams):\n",
    "    f = lambdify(gamma, error_ctest.subs({sinv: 0.5, ssp: 0.1, xinv: 1, lamda: lam, xsp: np.sqrt(5)}), \"numpy\")\n",
    "    res = [f(g) for g in np.linspace(0, ugamma, 100)]\n",
    "    ax.plot(np.linspace(0, ugamma, 100), res, label=f\"$\\lambda={lam}$\")\n",
    "    #plots.append(sympy.plotting.plot(error_ctest.subs({sinv: 0.5, ssp: 0.1, xinv: 1, lamda: lam, xsp: np.sqrt(5)}), linewidth=100,\n",
    "    #                        xlim=(0, 0.3), ylim=(0.92, 0.93), ylabel=\"Clean test accuracy\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{inv}=0.1$\", legend=False))\n",
    "    #if i > 0:\n",
    "    #    plots[0].append(plots[i][0])\n",
    "#move_sympyplot_to_axes(plots[0], ax)\n",
    "ax.set_xlabel(\"$\\gamma$\", fontsize= 20)\n",
    "ax.set_ylabel(\"Clean test accuracy\", fontsize=18)\n",
    "\n",
    "ax.set_ylim(.81, .93)\n",
    "ax.yaxis.set_label_coords(-0.25, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.spines['top'].set_visible(False)\n",
    "#plt.legend(fontsize=14)\n",
    "plt.tight_layout()\n",
    "\n",
    "ticks = ax.yaxis.get_ticklabels()\n",
    "for ent in ticks:\n",
    "    temp = ent.get_text().replace(\"0.\", \".\")\n",
    "    ent.set_text(temp)\n",
    "ax.yaxis.set_ticklabels(ticks)\n",
    "plt.savefig(\"figs/fig6/error_ctest_l2.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eee67c58-a552-4b0d-94c4-359a213256a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('font', size=15)\n",
    "fig = plt.figure(figsize=(3.5,4.5))\n",
    "ax = fig.add_subplot(111)\n",
    "#plots = []\n",
    "for i, lam in enumerate(lams):\n",
    "    f = lambdify(gamma, error_stest.subs({sinv: 0.5, ssp: 0.1, xinv: 1, lamda: lam, xsp: np.sqrt(5)}), \"numpy\")\n",
    "    res = [f(g) for g in np.linspace(0, ugamma, 100)]\n",
    "    ax.plot(np.linspace(0, ugamma, 100), res, label=f\"$\\lambda={lam}$\")\n",
    "    #plots.append(sympy.plotting.plot(error_stest.subs({sinv: 0.5, ssp: 0.1, xinv: 1, lamda: lam, xsp: np.sqrt(5)}), linewidth=100,\n",
    "    #                        xlim=(0, 0.3), ylim=(0.7, 0.95), ylabel=\"Spurious test accuracy\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{inv}=0.1$\", legend=False))\n",
    "    #if i > 0:\n",
    "    #    plots[0].append(plots[i][0])\n",
    "#move_sympyplot_to_axes(plots[0], ax)\n",
    "ax.set_xlabel(\"$\\gamma$\", fontsize= 20)\n",
    "ax.set_ylabel(\"Spurious test accuracy\", fontsize=18)\n",
    "\n",
    "ax.set_ylim(.5, .93)\n",
    "ax.yaxis.set_label_coords(-0.2, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.spines['top'].set_visible(False)\n",
    "plt.legend(fontsize=12, frameon=False)\n",
    "plt.tight_layout()\n",
    "\n",
    "ticks = ax.yaxis.get_ticklabels()\n",
    "for ent in ticks:\n",
    "    temp = ent.get_text().replace(\"0.\", \".\")\n",
    "    ent.set_text(temp)\n",
    "ax.yaxis.set_ticklabels(ticks)\n",
    "\n",
    "plt.savefig(\"figs/fig6/error_stest_l2.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09671454-54c8-4060-9508-9979b8d8dcc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.rc('font', size=15)\n",
    "fig = plt.figure(figsize=(3.5,4.5))\n",
    "ax = fig.add_subplot(111)\n",
    "#plots = []\n",
    "for i, lam in enumerate(lams):\n",
    "    f = lambdify(gamma, sscore.subs({sinv: 0.5, ssp: 0.1, xinv: 1, lamda: lam, xsp: np.sqrt(5)}), \"numpy\")\n",
    "    res = [f(g) for g in np.linspace(0, ugamma, 100)]\n",
    "    ax.plot(np.linspace(0, ugamma, 100), res)\n",
    "    #plots.append(sympy.plotting.plot(sscore.subs({sinv: 0.5, ssp: 0.1, xinv: 1, lamda: lam, xsp: np.sqrt(5)}), linewidth=100,\n",
    "    #                        xlim=(0, 0.3), ylim=(0.0, 0.12), ylabel=\"Spurious score\", xlabel=\"$\\gamma$\", show=False, label=\"$\\sigma_{inv}=0.1$\", legend=False))\n",
    "    #if i > 0:\n",
    "    #    plots[0].append(plots[i][0])\n",
    "\n",
    "#move_sympyplot_to_axes(plots[0], ax)\n",
    "ax.set_ylabel(\"Spurious score\", fontsize=18)\n",
    "ax.set_xlabel(\"$\\gamma$\", fontsize= 20)\n",
    "ax.set_ylim(-0.01, .17)\n",
    "ax.yaxis.set_label_coords(-0.20, 0.5)\n",
    "ax.xaxis.set_label_coords(0.5, -0.1)\n",
    "ax.spines['top'].set_visible(False)\n",
    "plt.tight_layout()\n",
    "\n",
    "ticks = ax.yaxis.get_ticklabels()\n",
    "for ent in ticks:\n",
    "    temp = ent.get_text().replace(\"0.\", \".\")\n",
    "    ent.set_text(temp)\n",
    "_ = ax.yaxis.set_ticklabels(ticks)\n",
    "plt.savefig\n",
    "plt.savefig(\"figs/fig6/sscore_l2.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4c8f85d-09b6-46e3-93ba-7de8cb465e5e",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9208c30-da2d-4e26-b30e-f5b462acc2aa",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "74a73630-449f-4187-b179-2b78fa91d561",
   "metadata": {},
   "outputs": [],
   "source": [
    "ugamma = 0.3\n",
    "fontsize = 15\n",
    "\n",
    "fig, ax1 = plt.subplots()\n",
    "f = lambdify(gamma, sscore.subs({sinv: 1.0, ssp: 0.05, xinv: 1, lamda: 0., xsp: 5}), \"numpy\")\n",
    "scores = [f(g) for g in np.linspace(0, ugamma, 100)]\n",
    "ticks = np.linspace(0, ugamma, 100)\n",
    "\n",
    "color = 'tab:blue'\n",
    "ax1.plot(ticks, res, color=color)\n",
    "ax1.set_xlabel(\"$\\gamma$\", fontsize=fontsize)\n",
    "ax1.set_ylabel(\"Prediction difference\", fontsize=fontsize, color=color)\n",
    "ax1.tick_params(axis='y', labelcolor=color, labelsize=fontsize)\n",
    "ax1.tick_params(axis='x', labelsize=fontsize)\n",
    "\n",
    "f = lambdify(gamma, error_stest.subs({sinv: 1.0, ssp: 0.05, xinv: 1, lamda: 0., xsp: 5}), \"numpy\")\n",
    "spaccs = [f(g) for g in np.linspace(0, ugamma, 100)]\n",
    "color = 'tab:red'\n",
    "ax2 = ax1.twinx()\n",
    "ax2.set_ylabel('Spurious test accuracy', fontsize=fontsize, color=color)  # we already handled the x-label with ax1\n",
    "ax2.plot(ticks, spaccs, color=color, linestyle='--')\n",
    "ax2.tick_params(axis='y', labelcolor=color, labelsize=fontsize)\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.savefig(\"figs/phase_transition.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6628d7d1-d385-4d37-bc28-3e172f4bb557",
   "metadata": {},
   "outputs": [],
   "source": [
    "    if k == \"wd\":\n",
    "        plt.xscale(\"symlog\", linthreshx=0.000001)\n",
    "        plt.xlabel(\"Weight decay\", fontsize=fontsize)\n",
    "    elif k == \"gc\":\n",
    "        plt.xscale(\"symlog\", linthreshx=0.000001)\n",
    "        plt.xlabel(\"Gradient clipping\", fontsize=fontsize)\n",
    "    else:\n",
    "        linthreshx = 0.5 if \"cifar\" in base_dset else 0.01\n",
    "        plt.xscale(\"symlog\", linthreshx=linthreshx)\n",
    "        plt.xlabel(\"Noise level\", fontsize=fontsize)\n",
    "\n",
    "    ax1.errorbar(ticks, spu_scores, color=color)\n",
    "    \n",
    "    \n",
    "    if k == \"gc\":\n",
    "        ax1.set_xticks(ticks)\n",
    "        ax1.set_xticklabels([\"$\\infty$\", ] + [f\"$10^{{{i}}}$\" for i in [-1, -2, -3, -4, -5]])\n",
    "\n",
    "    \n",
    "\n",
    "    fig.tight_layout()\n",
    "    #plt.savefig(f\"figs/regularization/{base_dset}_{k}.png\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
