{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add the parent directory to the path\n",
    "import matplotlib.pyplot as plt\n",
    "from tools import data_tools\n",
    "import torch\n",
    "\n",
    "\n",
    "config = data_tools.read_config(\n",
    "    'corruption_doctor_config.yaml')\n",
    "\n",
    "model_name = config[\"model_name\"]= 'resnet34_custom'\n",
    "match_dataset_name = config[\"match_dataset_name\"]\n",
    "corrupted_dataset_name = config[\"corrupted_dataset_name\"]\n",
    "model_seed = config[\"model_seed\"]\n",
    "data_path = config[\"data_path\"]\n",
    "device_id = config[\"device_id\"]\n",
    "magnitudes = config[\"magnitudes\"]\n",
    "temperatures = config[\"temperatures\"]\n",
    "batch_size = config[\"batch_size\"]\n",
    "rs = config[\"rs\"]\n",
    "seeds = config[\"seeds\"] = [1]\n",
    "corruptions = config[\"corruptions\"] = ['brightness']\n",
    "intensities = config[\"intensities\"] = [1]\n",
    "\n",
    "# print the config one by one\n",
    "for key, value in config.items():\n",
    "    print(key, value)\n",
    "\n",
    "# set the device to cpu\n",
    "device = torch.device(\"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for seed in seeds:\n",
    "    for r in rs:\n",
    "        for corruption in corruptions:\n",
    "            for intensity in intensities:\n",
    "                # create a figure \n",
    "                # the figure has 2 subplots side by side\n",
    "                fig, ax = plt.subplots(1, 2, figsize=(15, 5))\n",
    "                max_auc = float('-inf')\n",
    "                fpr_at_95_tpr_at_max_auc = None\n",
    "                tpr_at_max_auc = None\n",
    "                magnitude_at_max_auc = None\n",
    "                temperature_at_max_auc = None \n",
    "                \n",
    "                for magnitude in magnitudes:\n",
    "                    for temperature in temperatures:\n",
    "                        dest_folder = f'{match_dataset_name}_to_{corrupted_dataset_name}/{model_name}/model_seed_{model_seed}/{corruption}_{intensity}/results/r_{r}/seed_{seed}'\n",
    "                        final_dest_folder = f'{dest_folder}/magnitude_{magnitude}/temperature_{temperature}'\n",
    "                        doctor_val_fprs = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_fprs.pt')\n",
    "                        doctor_val_tprs = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_tprs.pt')\n",
    "                        doctor_val_thresholds = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_thresholds.pt')\n",
    "                        doctor_val_fpr = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_fpr.pt')\n",
    "                        doctor_val_tpr = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_tpr.pt')\n",
    "                        doctor_val_threshold = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_threshold.pt')\n",
    "                        doctor_val_auc = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_auc.pt')\n",
    "                            \n",
    "                        if doctor_val_auc > max_auc:\n",
    "                            max_auc = doctor_val_auc\n",
    "                            fpr_at_95_tpr_at_max_auc = doctor_val_fpr\n",
    "                            tpr_at_max_auc = doctor_val_tpr\n",
    "                            magnitude_at_max_auc = magnitude\n",
    "                            temperature_at_max_auc = temperature\n",
    "                        elif doctor_val_auc == max_auc:\n",
    "                            if doctor_val_fpr < fpr_at_95_tpr_at_max_auc:\n",
    "                                fpr_at_95_tpr_at_max_auc = doctor_val_fpr\n",
    "                                tpr_at_max_auc = doctor_val_tpr\n",
    "                                magnitude_at_max_auc = magnitude\n",
    "                                temperature_at_max_auc = temperature\n",
    "\n",
    "                        # plot roc curve\n",
    "                        ax[0].plot(doctor_val_fprs, doctor_val_tprs,\n",
    "                                label=f\"r={r}, magnitude={magnitude}, temperature={temperature}\")\n",
    "                        # plot diagonal line\n",
    "\n",
    "                ax[0].plot([0, 1], [0, 1], linestyle='--', label='Random Guess', color='red')\n",
    "\n",
    "                # put a marker x at the max auc\n",
    "                ax[0].plot(fpr_at_95_tpr_at_max_auc, tpr_at_max_auc, marker='x', color='green',\n",
    "                        label=f\"max auc: {max_auc:.2f}, magnitude={magnitude_at_max_auc}, temperature={temperature_at_max_auc}\")\n",
    "                # connext y axis with a line at 95 tpr and min fpr dashed\n",
    "                ax[0].plot([fpr_at_95_tpr_at_max_auc, fpr_at_95_tpr_at_max_auc], [0,\n",
    "                        tpr_at_max_auc], linestyle='--', color='green')\n",
    "                # connext y axis with a line at 95 tpr and min fpr dashed\n",
    "                ax[0].plot([0, fpr_at_95_tpr_at_max_auc], [tpr_at_max_auc,\n",
    "                        tpr_at_max_auc], linestyle='--', color='green')\n",
    "                # annotate the max auc marker with the fpr at auc and tpr at auc\n",
    "                ax[0].annotate(f\"(fpr:{fpr_at_95_tpr_at_max_auc:.2f}, tpr:{tpr_at_max_auc:.2f})\",\n",
    "                            (fpr_at_95_tpr_at_max_auc, tpr_at_max_auc-.07))\n",
    "\n",
    "                # plot legend\n",
    "                # ax.legend()\n",
    "                # set title for the entire figure\n",
    "                fig.suptitle(f\"ROC Curve for {corruption} {intensity} on {match_dataset_name} to {corrupted_dataset_name} {model_name} model seed {model_seed} seed {seed} r {r} corruption {corruption} intensity {intensity}\")\n",
    "                # plot x label\n",
    "                ax[0].set_xlabel(\"False Positive Rate\")\n",
    "                # plot y label\n",
    "                ax[0].set_ylabel(\"True Positive Rate\")\n",
    "                # plot grid \n",
    "                ax[0].grid()\n",
    "                # xlim to 0 to 1\n",
    "                ax[0].set_xlim(0, 1)\n",
    "                # ylim to 0 to 1\n",
    "                ax[0].set_ylim(0, 1)\n",
    "\n",
    "                dest_folder = f'{match_dataset_name}_to_{corrupted_dataset_name}/{model_name}/model_seed_{model_seed}/{corruption}_{intensity}/results/r_{r}/seed_{seed}'\n",
    "                final_dest_folder = f'{dest_folder}/magnitude_{magnitude_at_max_auc}/temperature_{temperature_at_max_auc}'\n",
    "\n",
    "                doctor_test_fprs = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_test_fprs.pt')\n",
    "                doctor_test_tprs = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_tprs.pt')\n",
    "                doctor_test_thresholds = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_thresholds.pt')\n",
    "                doctor_test_fpr = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_fpr.pt')\n",
    "                doctor_test_tpr = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_tpr.pt')\n",
    "                doctor_test_threshold = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_threshold.pt')\n",
    "                doctor_test_auc = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_auc.pt')\n",
    "                \n",
    "                # plot roc curve\n",
    "                ax[1].plot(doctor_test_fprs, doctor_test_tprs,\n",
    "                        label=f\"r={r}, magnitude={magnitude}, temperature={temperature}\")\n",
    "                # put a marker at the doctor test fpr and tpr\n",
    "                ax[1].plot(doctor_test_fpr, doctor_test_tpr, marker='o', color='green',\n",
    "                        label=f\"test fpr: {doctor_test_fpr:.2f}, test tpr: {doctor_test_tpr:.2f}\")\n",
    "                # connext y axis with a line at doctor test fpr and tpr dashed\n",
    "                ax[1].plot([doctor_test_fpr, doctor_test_fpr], [0,\n",
    "                        doctor_test_tpr], linestyle='--', color='green')\n",
    "                # connext x axis with a line at doctor test fpr and tpr dashed\n",
    "                ax[1].plot([0, doctor_test_fpr], [doctor_test_tpr,\n",
    "                        doctor_test_tpr], linestyle='--', color='green')\n",
    "                # set the x axis to 0 to 1\n",
    "                ax[1].set_xlim(0, 1)\n",
    "                # set the y axis to 0 to 1\n",
    "                ax[1].set_ylim(0, 1)\n",
    "                # annotate the doctor test fpr and and tpr with the values\n",
    "                ax[1].annotate(f\"(fpr:{doctor_test_fpr:.2f}, tpr:{doctor_test_tpr:.2f})\",\n",
    "                            (doctor_test_fpr, doctor_test_tpr))\n",
    "                ax[0].plot([0, 1], [0, 1], linestyle='--', label='Random Guess', color='red')\n",
    "\n",
    "                plt.show()\n",
    "\n",
    "\n",
    "plt.close('all')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jupyter_env",
   "language": "python",
   "name": "jupyter_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
