{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "matplotlib.rcParams['pdf.fonttype'] = 42\n",
    "matplotlib.rcParams['ps.fonttype'] = 42\n",
    "from matplotlib.ticker import NullFormatter\n",
    "%matplotlib inline\n",
    "import seaborn as sns\n",
    "sns.set(style=\"ticks\")\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "sns.set(palette=\"bright\",style=\"ticks\")\n",
    "\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.05\n",
    "\n",
    "mnist_ff = f\"{os.getcwd()}/experiments/robust_cv/qmnist/\"\n",
    "cifar_ff = f\"{os.getcwd()}/experiments/robust_cv/cifar-10/\"\n",
    "imagenet_ff = f\"{os.getcwd()}/experiments/robust_cv/imagenet/\"\n",
    "out_ff = f\"{os.getcwd()}/experiments/robust_cv/figures/\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python cifar10_mnist_experiment.py --nTrials 20 --alpha_coverage 0.05 --dataset \"QMNIST\" --epochs_quantile 10000 --delta_slab 0.05 --alpha_slab 0.05 --n_slabs_directions 1000 --rho_validation \"slab_quantile\" --experiment_name \"slab_quantile-3\"\n",
    "\n",
    "!python cifar10_mnist_experiment.py --nTrials 20 --alpha_coverage 0.05 --dataset \"QMNIST\" --epochs_quantile 10000 --delta_slab 0.05 --alpha_slab 0.05 --n_slabs_directions 1000 --rho_validation \"learnable_direction_OLS\" --experiment_name \"slab_quantile-1\"\n",
    "\n",
    "!python cifar10_mnist_experiment.py --nTrials 20 --alpha_coverage 0.05 --dataset \"QMNIST\" --epochs_quantile 10000 --delta_slab 0.05 --alpha_slab 0.05 --n_slabs_directions 1000 --rho_validation \"learnable_direction_SVM\" --experiment_name \"slab_quantile-2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python cifar10_imagenet_experiment.py --nTrials 20 --alpha_coverage 0.1 --dataset \"imagenet\" --cqc_epochs_quantile 500 --delta_slab 0.1 --alpha_slab 0.1 --n_slabs_directions 1000 --rho_validation \"slab_quantile\" --experiment_name \"slab_quantile-3\"\n",
    "\n",
    "!python cifar10_imagenet_experiment.py --nTrials 20 --alpha_coverage 0.1 --dataset \"imagenet\" --cqc_epochs_quantile 500 --delta_slab 0.1 --alpha_slab 0.1 --n_slabs_directions 1000 --rho_validation \"learnable_direction_OLS\" --experiment_name \"slab_quantile-1\"\n",
    "\n",
    "!python cifar10_imagenet_experiment.py --nTrials 20 --alpha_coverage 0.1 --dataset \"imagenet\" --cqc_epochs_quantile 500 --delta_slab 0.1 --alpha_slab 0.1 --n_slabs_directions 1000 --rho_validation \"learnable_direction_SVM\" --experiment_name \"slab_quantile-2\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!python cifar10_imagenet_experiment.py --nTrials 20 --alpha_coverage 0.05 --dataset \"cifar-10\" --cqc_epochs_quantile 10000 --delta_slab 0.1 --alpha_slab 0.1 --n_slabs_directions 1000 --rho_validation \"slab_quantile\" --experiment_name \"slab_quantile-3\"\n",
    "\n",
    "!python cifar10_imagenet_experiment.py --nTrials 20 --alpha_coverage 0.05 --dataset \"cifar-10\" --cqc_epochs_quantile 10000 --delta_slab 0.1 --alpha_slab 0.1 --n_slabs_directions 1000 --rho_validation \"learnable_direction_OLS\" --experiment_name \"slab_quantile-1\"\n",
    "\n",
    "!python cifar10_imagenet_experiment.py --nTrials 20 --alpha_coverage 0.05 --dataset \"cifar-10\" --cqc_epochs_quantile 10000 --delta_slab 0.1 --alpha_slab 0.1 --n_slabs_directions 1000 --rho_validation \"learnable_direction_SVM\" --experiment_name \"slab_quantile-2\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Make plots for mnist, cifar-10, image net"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cvgs_and_lens(results):\n",
    "    \n",
    "    naive = results[results[\"Robustness\"] == \"NaiveConformal\"]\n",
    "    naive = naive.iloc[:,2:4]\n",
    "    naive = naive.to_numpy()\n",
    "    naive_cvgs = naive[:,1]\n",
    "    naive_lens = naive[:,0]\n",
    "    \n",
    "    rob = results[results[\"Robustness\"] == \"RobustConformal\"]\n",
    "    rob = rob.iloc[:,2:4]\n",
    "    rob = rob.to_numpy()\n",
    "    rob_cvgs = rob[:,1]\n",
    "    rob_lens = rob[:,0]\n",
    "    \n",
    "    return (naive_cvgs, naive_lens), (rob_cvgs, rob_lens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_plots(naive, rob, mode, out_ff, dataset, alpha):\n",
    "    \n",
    "    fontsize = 16\n",
    "    \n",
    "    obj_to_plot = np.vstack([naive, rob]).T\n",
    "    \n",
    "    f,a = plt.subplots(figsize=(8,5))\n",
    "    a.boxplot(obj_to_plot,\n",
    "              labels=[\"Standard\", \"Chi-squared\"],\n",
    "              showmeans=True,\n",
    "              showfliers=True)\n",
    "\n",
    "    a.tick_params(axis='both', labelsize=fontsize)\n",
    "    a.set_title(mode, fontsize=fontsize+2)\n",
    "    \n",
    "    if mode == \"Coverage\":\n",
    "        a.axhline((1-alpha), c='r', linestyle=\"-\", linewidth=1)\n",
    "        a.set_ylim([0.85,1])\n",
    "    \n",
    "    out_fp = (out_ff + dataset + \"_\" + mode + \"s_boxplot.pdf\").lower()\n",
    "    f.savefig(out_fp, bbox_inches=\"tight\")\n",
    "    print(\"Saving figure to \" + out_fp + \".\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def make_joint_plots(ols_naive_cvgs, ols_naive_lens, ols_rob_cvgs, ols_rob_lens,\n",
    "                     svm_naive_cvgs, svm_naive_lens, svm_rob_cvgs, svm_rob_lens,\n",
    "                     slab_naive_cvgs, slab_naive_lens, slab_rob_cvgs, slab_rob_lens,\n",
    "                     out_ff, mode, dataset, alpha):\n",
    "    \n",
    "    fontsize = 16\n",
    "    \n",
    "    if mode == \"Coverage\":\n",
    "        obj_to_plot = np.vstack([ols_naive_cvgs,\n",
    "                                 slab_rob_cvgs,\n",
    "                                 ols_rob_cvgs,\n",
    "                                 svm_rob_cvgs]).T\n",
    "\n",
    "    else:\n",
    "        obj_to_plot = np.vstack([ols_naive_lens,\n",
    "                                 slab_rob_lens,\n",
    "                                 ols_rob_lens,\n",
    "                                 svm_rob_lens]).T\n",
    "        \n",
    "    f,a = plt.subplots(figsize=(12,5))\n",
    "    a.boxplot(obj_to_plot,\n",
    "              labels=[\"Standard\",\n",
    "                      \"Chi-squared, sampling\",\n",
    "                      \"Chi-squared, regression\",\n",
    "                      \"Chi-squared, classification\"],\n",
    "              showmeans=True,\n",
    "              showfliers=True)\n",
    "\n",
    "    a.tick_params(axis='both', labelsize=fontsize)\n",
    "    a.tick_params(axis=\"x\", rotation=45)\n",
    "    a.set_title(mode, fontsize=fontsize+2)\n",
    "    \n",
    "    if mode == \"Coverage\":\n",
    "        a.axhline((1-alpha), c='r', linestyle=\"-\", linewidth=1)\n",
    "        a.set_ylim([a.get_ylim()[0],1])\n",
    "    \n",
    "    out_fp = (out_ff + dataset + \"_\" + mode + \"s_boxplot.pdf\").lower()\n",
    "    f.savefig(out_fp, bbox_inches=\"tight\")\n",
    "    print(\"Saving figure to \" + out_fp + \".\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_fp(fp_in, out_ff, dataset, alpha):\n",
    "    \n",
    "    results = pd.read_csv(fp_in).iloc[:,1:]\n",
    "    results = results[results[\"Conformalization\"] == \"Marginal\"]\n",
    "    ((naive_cvgs, naive_lens), (rob_cvgs, rob_lens)) = get_cvgs_and_lens(results)\n",
    "    \n",
    "    make_plots(naive_cvgs, rob_cvgs, \"Coverage\", out_ff, dataset, alpha)\n",
    "    make_plots(naive_lens, rob_lens, \"Size\", out_ff, dataset, alpha)\n",
    "    \n",
    "    return naive_cvgs, naive_lens, rob_cvgs, rob_lens"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MNIST"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ols_naive_cvgs, ols_naive_lens, ols_rob_cvgs, ols_rob_lens = process_fp(\n",
    "    mnist_ff + \"slab_quantile-1-summary.csv\", out_ff, \"mnist_ols\", alpha)\n",
    "\n",
    "svm_naive_cvgs, svm_naive_lens, svm_rob_cvgs, svm_rob_lens = process_fp(\n",
    "    mnist_ff + \"slab_quantile-2-summary.csv\", out_ff, \"mnist_svm\", alpha)\n",
    "\n",
    "slab_naive_cvgs, slab_naive_lens, slab_rob_cvgs, slab_rob_lens = process_fp(\n",
    "    mnist_ff + \"slab_quantile-3-summary.csv\", out_ff, \"mnist_slab_quantile\", alpha)\n",
    "\n",
    "make_joint_plots(ols_naive_cvgs, ols_naive_lens, ols_rob_cvgs, ols_rob_lens,\n",
    "                 svm_naive_cvgs, svm_naive_lens, svm_rob_cvgs, svm_rob_lens,\n",
    "                 slab_naive_cvgs, slab_naive_lens, slab_rob_cvgs, slab_rob_lens,\n",
    "                 out_ff, \"Coverage\", \"mnist\", alpha)\n",
    "\n",
    "make_joint_plots(ols_naive_cvgs, ols_naive_lens, ols_rob_cvgs, ols_rob_lens,\n",
    "                 svm_naive_cvgs, svm_naive_lens, svm_rob_cvgs, svm_rob_lens,\n",
    "                 slab_naive_cvgs, slab_naive_lens, slab_rob_cvgs, slab_rob_lens,\n",
    "                 out_ff, \"Size\", \"mnist\", alpha)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## CIFAR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ols_naive_cvgs, ols_naive_lens, ols_rob_cvgs, ols_rob_lens = process_fp(\n",
    "    cifar_ff + \"slab_quantile-1-summary.csv\", out_ff, \"cifar-10_ols\", alpha)\n",
    "\n",
    "svm_naive_cvgs, svm_naive_lens, svm_rob_cvgs, svm_rob_lens = process_fp(\n",
    "    cifar_ff + \"slab_quantile-2-summary.csv\", out_ff, \"cifar-10_svm\", alpha)\n",
    "\n",
    "slab_naive_cvgs, slab_naive_lens, slab_rob_cvgs, slab_rob_lens = process_fp(\n",
    "    cifar_ff + \"slab_quantile-3-summary.csv\", out_ff, \"cifar-10_slab_quantile\", alpha)\n",
    "\n",
    "make_joint_plots(ols_naive_cvgs, ols_naive_lens, ols_rob_cvgs, ols_rob_lens,\n",
    "                 svm_naive_cvgs, svm_naive_lens, svm_rob_cvgs, svm_rob_lens,\n",
    "                 slab_naive_cvgs, slab_naive_lens, slab_rob_cvgs, slab_rob_lens,\n",
    "                 out_ff, \"Coverage\", \"cifar-10\", alpha)\n",
    "\n",
    "make_joint_plots(ols_naive_cvgs, ols_naive_lens, ols_rob_cvgs, ols_rob_lens,\n",
    "                 svm_naive_cvgs, svm_naive_lens, svm_rob_cvgs, svm_rob_lens,\n",
    "                 slab_naive_cvgs, slab_naive_lens, slab_rob_cvgs, slab_rob_lens,\n",
    "                 out_ff, \"Size\", \"cifar-10\", alpha)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ImageNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "alpha = 0.1\n",
    "\n",
    "ols_naive_cvgs, ols_naive_lens, ols_rob_cvgs, ols_rob_lens = process_fp(\n",
    "    imagenet_ff + \"slab_quantile-1-summary.csv\", out_ff, \"imagenet_ols\", alpha)\n",
    "\n",
    "svm_naive_cvgs, svm_naive_lens, svm_rob_cvgs, svm_rob_lens = process_fp(\n",
    "    imagenet_ff + \"slab_quantile-2-summary.csv\", out_ff, \"imagenet_svm\", alpha)\n",
    "\n",
    "slab_naive_cvgs, slab_naive_lens, slab_rob_cvgs, slab_rob_lens = process_fp(\n",
    "    imagenet_ff + \"slab_quantile-3-summary.csv\", out_ff, \"imagenet_slab_quantile\", alpha)\n",
    "\n",
    "make_joint_plots(ols_naive_cvgs, ols_naive_lens, ols_rob_cvgs, ols_rob_lens,\n",
    "                 svm_naive_cvgs, svm_naive_lens, svm_rob_cvgs, svm_rob_lens,\n",
    "                 slab_naive_cvgs, slab_naive_lens, slab_rob_cvgs, slab_rob_lens,\n",
    "                 out_ff, \"Coverage\", \"imagenet\", alpha)\n",
    "\n",
    "make_joint_plots(ols_naive_cvgs, ols_naive_lens, ols_rob_cvgs, ols_rob_lens,\n",
    "                 svm_naive_cvgs, svm_naive_lens, svm_rob_cvgs, svm_rob_lens,\n",
    "                 slab_naive_cvgs, slab_naive_lens, slab_rob_cvgs, slab_rob_lens,\n",
    "                 out_ff, \"Size\", \"imagenet\", alpha)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
