{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Report plots for corruption analysis CIFAR10-CIFAR10C\n",
    "\n",
    "In this notebook, we plot the results of the mismatch analysis for the CIFAR10-CIFAR10C dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import pandas as pn\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'resnet34_custom'\n",
    "match_dataset_name = 'cifar10'\n",
    "mismatch_dataset_name = 'cifar10c'\n",
    "model_seed = 1\n",
    "data_path = 'data'\n",
    "device_id = 0\n",
    "batch_size = 1000\n",
    "rs = [10, 5, 3, 2]\n",
    "seeds = [1, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10]\n",
    "temperatures = [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 2.0, 2.5, 3.0, 100.0, 1000.0]\n",
    "magnitudes = [0.0, 0.0002, 0.00025, 0.0003, 0.00035, 0.0004, 0.0006, 0.0008,\n",
    "              0.001, 0.0012, 0.0014, 0.0016, 0.0018, 0.002, 0.0022, 0.0024,\n",
    "              0.0026, 0.0028, 0.003, 0.0032, 0.0034, 0.0036, 0.0038, 0.004]\n",
    "\n",
    "lbd = .8\n",
    "lr = 0.1\n",
    "epochs = 100\n",
    "# corruptions = [\"brightness\",\n",
    "#                \"contrast\",\n",
    "#                \"defocus_blur\",\n",
    "#                \"elastic_transform\",\n",
    "#                \"fog\",\n",
    "#                \"frost\", \n",
    "#                \"gaussian_blur\",\n",
    "#                \"gaussian_noise\",\n",
    "#                \"glass_blur\",\n",
    "#                \"impulse_noise\",\n",
    "#                \"jpeg_compression\",\n",
    "#                \"motion_blur\",\n",
    "#                \"pixelate\",\n",
    "#                \"saturate\",\n",
    "#                \"shot_noise\",\n",
    "#                \"snow\",\n",
    "#                \"spatter\",\n",
    "#                \"speckle_noise\",\n",
    "#                \"zoom_blur\"\n",
    "#                ]\n",
    "\n",
    "corruption = 'contrast'\n",
    "intensity = 5\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Collect the results for doctor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r_dict_doctor = {}\n",
    "for r in rs:\n",
    "    seed_dict = {}\n",
    "    for seed in seeds:\n",
    "        tmp_dict = {}\n",
    "        max_auc = -float('inf')\n",
    "        fpr_95_tpr_at_max_auc = None\n",
    "        tpr_at_max_auc = None\n",
    "        temperature_at_max_auc = None\n",
    "        magnitude_at_max_auc = None\n",
    "\n",
    "        for temperature in temperatures:\n",
    "            for magnitude in magnitudes:\n",
    "                source_folder = f'doctor/{match_dataset_name}_to_{mismatch_dataset_name}/{model_name}/model_seed_{model_seed}/{corruption}_{intensity}'\n",
    "                dest_folder = f'{source_folder}/results/r_{r}/seed_{seed}'\n",
    "\n",
    "                final_dest_folder = f'{dest_folder}/magnitude_{magnitude}/temperature_{temperature}'\n",
    "                doctor_val_fprs = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_fprs.pt')\n",
    "                doctor_val_tprs = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_tprs.pt')\n",
    "                doctor_val_thresholds = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_thresholds.pt')\n",
    "                doctor_val_fpr = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_fpr.pt')\n",
    "                doctor_val_tpr = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_tpr.pt')\n",
    "                doctor_val_threshold = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_threshold.pt')\n",
    "                doctor_val_auc = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_auc.pt')\n",
    "\n",
    "                if doctor_val_auc > max_auc:\n",
    "                    max_auc = doctor_val_auc\n",
    "                    fpr_95_tpr_at_max_auc = doctor_val_fpr\n",
    "                    tpr_at_max_auc = doctor_val_tpr\n",
    "                    temperature_at_max_auc = temperature\n",
    "                    magnitude_at_max_auc = magnitude\n",
    "                elif doctor_val_auc == max_auc:\n",
    "                    if doctor_val_fpr < fpr_95_tpr_at_max_auc:\n",
    "                        fpr_95_tpr_at_max_auc = doctor_val_fpr\n",
    "                        tpr_at_max_auc = doctor_val_tpr\n",
    "                        temperature_at_max_auc = temperature\n",
    "                        magnitude_at_max_auc = magnitude\n",
    "        tmp_dict['max_auc'] = max_auc\n",
    "        tmp_dict['fpr_95_tpr_at_max_auc'] = fpr_95_tpr_at_max_auc\n",
    "        tmp_dict['tpr_at_max_auc'] = tpr_at_max_auc\n",
    "        tmp_dict['temperature_at_max_auc'] = temperature_at_max_auc\n",
    "        tmp_dict['magnitude_at_max_auc'] = magnitude_at_max_auc\n",
    "        seed_dict[seed] = tmp_dict \n",
    "    r_dict_doctor[r] = seed_dict\n",
    "\n",
    "# print r_dict_doctor element by element    \n",
    "# for r in rs:\n",
    "#     for seed in seeds:\n",
    "#         print(f'r: {r}, seed: {seed}')\n",
    "#         print(r_dict_doctor[r][seed])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create dataframe with columns: r, seed, max_auc, fpr_95_tpr_at_max_auc\n",
    "\n",
    "df_doctor = pn.DataFrame(columns=['r', 'seed', 'auc', 'fpr_95_tpr'])\n",
    "\n",
    "for r in rs:\n",
    "    for seed in seeds:\n",
    "        selected_temperature = r_dict_doctor[r][seed]['temperature_at_max_auc']\n",
    "        selected_magnitude = r_dict_doctor[r][seed]['magnitude_at_max_auc']\n",
    "\n",
    "        source_folder = f'doctor/{match_dataset_name}_to_{mismatch_dataset_name}/{model_name}/model_seed_{model_seed}/{corruption}_{intensity}'\n",
    "        dest_folder = f'{source_folder}/results/r_{r}/seed_{seed}'\n",
    "\n",
    "        final_dest_folder = f'{dest_folder}/magnitude_{selected_magnitude}/temperature_{selected_temperature}'\n",
    "        doctor_test_fprs = torch.load(\n",
    "            f'{final_dest_folder}/doctor_test_fprs.pt')\n",
    "        doctor_test_tprs = torch.load(\n",
    "            f'{final_dest_folder}/doctor_test_tprs.pt')\n",
    "        doctor_test_thresholds = torch.load(\n",
    "            f'{final_dest_folder}/doctor_test_thresholds.pt')\n",
    "        doctor_test_fpr = torch.load(\n",
    "            f'{final_dest_folder}/doctor_test_fpr.pt')\n",
    "        doctor_test_tpr = torch.load(\n",
    "            f'{final_dest_folder}/doctor_test_tpr.pt')\n",
    "        doctor_test_threshold = torch.load(\n",
    "            f'{final_dest_folder}/doctor_test_threshold.pt')\n",
    "        doctor_test_auc = torch.load(\n",
    "            f'{final_dest_folder}/doctor_test_auc.pt')\n",
    "\n",
    "        # for each r and seed put doctor_test_auc and doctor_test_fpr in a row of df_doctor using pandas concat\n",
    "        df_doctor = pn.concat(\n",
    "            [df_doctor,\n",
    "             pn.DataFrame([[1/float(r), seed, doctor_test_auc, doctor_test_fpr]],\n",
    "                          columns=['r', 'seed', 'auc', 'fpr_95_tpr'])],\n",
    "            ignore_index=True)\n",
    "        \n",
    "# print(df_doctor)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Collect results for d_matrix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r_dict_D = {}\n",
    "for r in rs:\n",
    "    seed_dict = {}\n",
    "    for seed in seeds:\n",
    "        tmp_dict = {}\n",
    "        max_auc = -float('inf')\n",
    "        fpr_95_tpr_at_max_auc = None\n",
    "        tpr_at_max_auc = None\n",
    "        temperature_at_max_auc = None\n",
    "        magnitude_at_max_auc = None\n",
    "\n",
    "        for temperature in temperatures:\n",
    "            for magnitude in magnitudes:\n",
    "                source_folder = f'd_matrix/{match_dataset_name}_to_{mismatch_dataset_name}/{model_name}/model_seed_{model_seed}'\n",
    "                dest_folder = f'{source_folder}/seed_{seed}/r_{r}/lr_{lr}/epochs_{epochs}/lbd_{lbd}/{corruption}_{intensity}'\n",
    "\n",
    "                final_dest_folder = f'{dest_folder}/magnitude_{magnitude}/temperature_{temperature}'\n",
    "                D_val_fprs = torch.load(\n",
    "                    f'{final_dest_folder}/D_fprs_val.pt')\n",
    "                D_val_tprs = torch.load(\n",
    "                    f'{final_dest_folder}/D_tprs_val.pt')\n",
    "                D_val_thresholds = torch.load(\n",
    "                    f'{final_dest_folder}/D_thresholds_val.pt')\n",
    "                D_val_fpr = torch.load(\n",
    "                    f'{final_dest_folder}/D_fpr_val.pt')\n",
    "                D_val_tpr = torch.load(\n",
    "                    f'{final_dest_folder}/D_tpr_val.pt')\n",
    "                D_val_threshold = torch.load(\n",
    "                    f'{final_dest_folder}/D_threshold_val.pt')\n",
    "                D_val_auc = torch.load(\n",
    "                    f'{final_dest_folder}/D_auc_val.pt')\n",
    "\n",
    "                if D_val_auc > max_auc:\n",
    "                    max_auc = D_val_auc\n",
    "                    fpr_95_tpr_at_max_auc = D_val_fpr\n",
    "                    tpr_at_max_auc = D_val_tpr\n",
    "                    temperature_at_max_auc = temperature\n",
    "                    magnitude_at_max_auc = magnitude\n",
    "                elif D_val_auc == max_auc:\n",
    "                    if D_val_fpr < fpr_95_tpr_at_max_auc:\n",
    "                        fpr_95_tpr_at_max_auc = D_val_fpr\n",
    "                        tpr_at_max_auc = D_val_tpr\n",
    "                        temperature_at_max_auc = temperature\n",
    "                        magnitude_at_max_auc = magnitude\n",
    "        tmp_dict['max_auc'] = max_auc\n",
    "        tmp_dict['fpr_95_tpr_at_max_auc'] = fpr_95_tpr_at_max_auc\n",
    "        tmp_dict['tpr_at_max_auc'] = tpr_at_max_auc\n",
    "        tmp_dict['temperature_at_max_auc'] = temperature_at_max_auc\n",
    "        tmp_dict['magnitude_at_max_auc'] = magnitude_at_max_auc\n",
    "        seed_dict[seed] = tmp_dict \n",
    "    r_dict_D[r] = seed_dict\n",
    "\n",
    "# print r_dict_D element by element    \n",
    "# for r in rs:\n",
    "#     for seed in seeds:\n",
    "#         print(f'r: {r}, seed: {seed}')\n",
    "#         print(r_dict_D[r][seed])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create dataframe with columns: r, seed, max_auc, fpr_95_tpr_at_max_auc\n",
    "\n",
    "df_D = pn.DataFrame(columns=['r', 'seed', 'auc', 'fpr_95_tpr'])\n",
    "\n",
    "for r in rs:\n",
    "    for seed in seeds:\n",
    "        selected_temperature = r_dict_D[r][seed]['temperature_at_max_auc']\n",
    "        selected_magnitude = r_dict_D[r][seed]['magnitude_at_max_auc']\n",
    "\n",
    "        source_folder = f'd_matrix/{match_dataset_name}_to_{mismatch_dataset_name}/{model_name}/model_seed_{model_seed}'\n",
    "        dest_folder = f'{source_folder}/seed_{seed}/r_{r}/lr_{lr}/epochs_{epochs}/lbd_{lbd}/{corruption}_{intensity}'\n",
    "\n",
    "        final_dest_folder = f'{dest_folder}/magnitude_{selected_magnitude}/temperature_{selected_temperature}'\n",
    "        D_test_fprs = torch.load(\n",
    "            f'{final_dest_folder}/D_fprs_test.pt')\n",
    "        D_test_tprs = torch.load(\n",
    "            f'{final_dest_folder}/D_tprs_test.pt')\n",
    "        D_test_thresholds = torch.load(\n",
    "            f'{final_dest_folder}/D_thresholds_test.pt')\n",
    "        D_test_fpr = torch.load(\n",
    "            f'{final_dest_folder}/D_fpr_test.pt')\n",
    "        D_test_tpr = torch.load(\n",
    "            f'{final_dest_folder}/D_tpr_test.pt')\n",
    "        D_test_threshold = torch.load(\n",
    "            f'{final_dest_folder}/D_threshold_test.pt')\n",
    "        D_test_auc = torch.load(\n",
    "            f'{final_dest_folder}/D_auc_test.pt')\n",
    "\n",
    "        df_D = pn.concat(\n",
    "            [df_D,\n",
    "             pn.DataFrame([[1/float(r), seed, D_test_auc, D_test_fpr]],\n",
    "                          columns=['r', 'seed', 'auc', 'fpr_95_tpr'])],\n",
    "            ignore_index=True)\n",
    "        \n",
    "# print(df_D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# create a figure with two plots side by side: one for the AUC and one for the FPR\n",
    "# in the one on the left plot the mean AUC over the seed and the area between the mean AUC and the min and max AUC\n",
    "# the x axis is labeled with the r values\n",
    "# in the one on the right plot the mean FPR over the seed and the area between the mean FPR and the min and max FPR\n",
    "# the x axis is labeled with the r values\n",
    "\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))\n",
    "\n",
    "# plot AUC\n",
    "df_doctor_mean = df_doctor.groupby('r').mean(numeric_only=True)\n",
    "# df_doctor_median = df_doctor.groupby('r').median(numeric_only=True)\n",
    "df_doctor_std = df_doctor.groupby('r').std(numeric_only=True)\n",
    "df_doctor_min = df_doctor.groupby('r').min(numeric_only=True)\n",
    "df_doctor_max = df_doctor.groupby('r').max(numeric_only=True)\n",
    "\n",
    "df_D_mean = df_D.groupby('r').mean(numeric_only=True)\n",
    "# df_D_median = df_D.groupby('r').median(numeric_only=True)\n",
    "df_D_std = df_D.groupby('r').std(numeric_only=True)\n",
    "df_D_min = df_D.groupby('r').min(numeric_only=True)\n",
    "df_D_max = df_D.groupby('r').max(numeric_only=True)\n",
    "\n",
    "ax1.plot(df_doctor_mean.index, df_doctor_mean['auc'], color='blue')\n",
    "ax1.plot(df_D_mean.index, df_D_mean['auc'], color='red')\n",
    "ax1.fill_between(df_doctor_mean.index,\n",
    "                 df_doctor_mean['auc'] - df_doctor_std['auc'],\n",
    "                 df_doctor_mean['auc'] + df_doctor_std['auc'],\n",
    "                 color='blue',\n",
    "                 alpha=0.2)\n",
    "ax1.fill_between(df_D_mean.index,\n",
    "                 df_D_mean['auc'] - df_D_std['auc'],\n",
    "                 df_D_mean['auc'] + df_D_std['auc'],\n",
    "                 color='red',\n",
    "                 alpha=0.2)\n",
    "\n",
    "ax1.set_xlabel('r')\n",
    "ax1.set_ylabel('AUC')\n",
    "\n",
    "# plot FPR\n",
    "ax2.plot(df_doctor_mean.index, df_doctor_mean['fpr_95_tpr'], color='blue')\n",
    "ax2.plot(df_D_mean.index, df_D_mean['fpr_95_tpr'], color='red')\n",
    "ax2.fill_between(df_doctor_mean.index,\n",
    "                 df_doctor_mean['fpr_95_tpr'] - df_doctor_std['fpr_95_tpr'],\n",
    "                 df_doctor_mean['fpr_95_tpr'] + df_doctor_std['fpr_95_tpr'],\n",
    "                 color='blue',\n",
    "                 alpha=0.2)\n",
    "ax2.fill_between(df_D_mean.index,\n",
    "                 df_D_mean['fpr_95_tpr'] - df_D_std['fpr_95_tpr'],\n",
    "                 df_D_mean['fpr_95_tpr'] + df_D_std['fpr_95_tpr'],\n",
    "                 color='red',\n",
    "                 alpha=0.2)\n",
    "\n",
    "ax2.set_xlabel('r')\n",
    "ax2.set_ylabel('FPR')\n",
    "\n",
    "ax1.set_xticks([round(1/float(rs[i]), 2) for i in range(len(rs))])\n",
    "ax2.set_xticks([round(1/float(rs[i]), 2) for i in range(len(rs))])\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.legend(['Doctor', 'D'])\n",
    "\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jupyter_env",
   "language": "python",
   "name": "jupyter_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
