{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# add the parent directory to the path\n",
    "import matplotlib.pyplot as plt\n",
    "from tools import data_tools\n",
    "import torch\n",
    "\n",
    "config = data_tools.read_config(\n",
    "    './mismatch_doctor_config.yaml')\n",
    "\n",
    "model_name = config[\"model_name\"]\n",
    "match_dataset_name = config[\"match_dataset_name\"]\n",
    "mismatch_dataset_name = config[\"mismatch_dataset_name\"]\n",
    "model_seed = config[\"model_seed\"]\n",
    "data_path = config[\"data_path\"]\n",
    "device_id = config[\"device_id\"]\n",
    "magnitudes = config[\"magnitudes\"]\n",
    "temperatures = config[\"temperatures\"]\n",
    "batch_size = config[\"batch_size\"]\n",
    "rs = config[\"rs\"] = [10]\n",
    "seeds = config[\"seeds\"] = [1]\n",
    "\n",
    "# print the config one by one\n",
    "for key, value in config.items():\n",
    "    print(key, value)\n",
    "\n",
    "# set the device to cpu\n",
    "device = torch.device(\"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_magnitude_dict = {}\n",
    "best_temperature_dict = {}\n",
    "\n",
    "for seed in seeds:\n",
    "    min_fpr_at_95_tpr = float('inf')\n",
    "    tpr_at_min_fpr = None\n",
    "    auc_at_min_fpr = None\n",
    "    magnitude_at_min_fpr = None\n",
    "    temperature_at_min_fpr = None\n",
    "\n",
    "    max_auc = float('-inf')\n",
    "    fpr_at_95_tpr_at_max_auc = None\n",
    "    tpr_at_max_auc = None\n",
    "    magnitude_at_max_auc = None\n",
    "    temperature_at_max_auc = None\n",
    "\n",
    "    for r in rs:\n",
    "        fig, ax = plt.subplots(figsize=(10, 5))\n",
    "        best_magnitude_dict_r = {}\n",
    "        best_temperature_dict_r = {}  \n",
    "        \n",
    "        for magnitude in magnitudes:\n",
    "            for temperature in temperatures:\n",
    "                dest_folder = f'./{match_dataset_name}_to_{mismatch_dataset_name}/{model_name}/model_seed_{model_seed}/results/r_{r}/seed_{seed}'\n",
    "                final_dest_folder = f'{dest_folder}/magnitude_{magnitude}/temperature_{temperature}'\n",
    "                doctor_val_fprs = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_fprs.pt')\n",
    "                doctor_val_tprs = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_tprs.pt')\n",
    "                doctor_val_thresholds = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_thresholds.pt')\n",
    "                doctor_val_fpr = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_fpr.pt')\n",
    "                doctor_val_tpr = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_tpr.pt')\n",
    "                doctor_val_threshold = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_threshold.pt')\n",
    "                doctor_val_auc = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_val_auc.pt')\n",
    "                if doctor_val_fpr < min_fpr_at_95_tpr:\n",
    "                    min_fpr_at_95_tpr = doctor_val_fpr\n",
    "                    tpr_at_min_fpr = doctor_val_tpr\n",
    "                    auc_at_min_fpr = doctor_val_auc\n",
    "                    magnitude_at_min_fpr = magnitude\n",
    "                    temperature_at_min_fpr = temperature\n",
    "                elif doctor_val_fpr == min_fpr_at_95_tpr:\n",
    "                    if doctor_val_auc > auc_at_min_fpr:\n",
    "                        tpr_at_min_fpr = doctor_val_tpr\n",
    "                        auc_at_min_fpr = doctor_val_auc\n",
    "                        magnitude_at_min_fpr = magnitude\n",
    "                        temperature_at_min_fpr = temperature\n",
    "                    \n",
    "                if doctor_val_auc > max_auc:\n",
    "                    max_auc = doctor_val_auc\n",
    "                    fpr_at_95_tpr_at_max_auc = doctor_val_fpr\n",
    "                    tpr_at_max_auc = doctor_val_tpr\n",
    "                    magnitude_at_max_auc = magnitude\n",
    "                    temperature_at_max_auc = temperature\n",
    "                elif doctor_val_auc == max_auc:\n",
    "                    if doctor_val_fpr < fpr_at_95_tpr_at_max_auc:\n",
    "                        fpr_at_95_tpr_at_max_auc = doctor_val_fpr\n",
    "                        tpr_at_max_auc = doctor_val_tpr\n",
    "                        magnitude_at_max_auc = magnitude\n",
    "                        temperature_at_max_auc = temperature\n",
    "\n",
    "                # plot roc curve\n",
    "                ax.plot(doctor_val_fprs, doctor_val_tprs,\n",
    "                        label=f\"r={r}, magnitude={magnitude}, temperature={temperature}\")\n",
    "                # plot diagonal line\n",
    "\n",
    "        ax.plot([0, 1], [0, 1], linestyle='--', label='Random Guess')\n",
    "\n",
    "        # put a marker x at the min fpr at 95 tpr\n",
    "        ax.plot(min_fpr_at_95_tpr, tpr_at_min_fpr, marker='x', color='orange',\n",
    "                label=f\"min fpr at 95 tpr: {min_fpr_at_95_tpr:.2f}, magnitude={magnitude_at_min_fpr}, temperature={temperature_at_min_fpr}\")\n",
    "        # connext x axis with a line at 95 tpr and min fpr dashed\n",
    "        ax.plot([min_fpr_at_95_tpr, min_fpr_at_95_tpr], [0, tpr_at_min_fpr],\n",
    "                linestyle='--', color='orange')\n",
    "        # connext y axis with a line at 95 tpr and min fpr dashed\n",
    "        ax.plot([0, min_fpr_at_95_tpr], [tpr_at_min_fpr, tpr_at_min_fpr],\n",
    "                linestyle='--', color='orange')\n",
    "        # annotate the min fpr at 95 tpr with the min fpr at 95 tpr and tpr at min fpr\n",
    "        ax.annotate(f\"(fpr: {min_fpr_at_95_tpr:.2f}, tpr: {tpr_at_min_fpr:.2f})\",\n",
    "                    (min_fpr_at_95_tpr-.1, tpr_at_min_fpr+.02))\n",
    "\n",
    "        # put a marker x at the max auc\n",
    "        ax.plot(fpr_at_95_tpr_at_max_auc, tpr_at_max_auc, marker='x', color='green',\n",
    "                label=f\"max auc: {max_auc:.2f}, magnitude={magnitude_at_max_auc}, temperature={temperature_at_max_auc}\")\n",
    "        # connext y axis with a line at 95 tpr and min fpr dashed\n",
    "        ax.plot([fpr_at_95_tpr_at_max_auc, fpr_at_95_tpr_at_max_auc], [0,\n",
    "                tpr_at_max_auc], linestyle='--', color='green')\n",
    "        # connext y axis with a line at 95 tpr and min fpr dashed\n",
    "        ax.plot([0, fpr_at_95_tpr_at_max_auc], [tpr_at_max_auc,\n",
    "                tpr_at_max_auc], linestyle='--', color='green')\n",
    "        # annotate the max auc marker with the fpr at auc and tpr at auc\n",
    "        ax.annotate(f\"(fpr:{fpr_at_95_tpr_at_max_auc:.2f}, tpr:{tpr_at_max_auc:.2f})\",\n",
    "                    (fpr_at_95_tpr_at_max_auc, tpr_at_max_auc-.07))\n",
    "\n",
    "        # plot legend\n",
    "        # ax.legend()\n",
    "        # plot title\n",
    "        ax.set_title(\n",
    "            f\"ROC Curve for {model_name} on {match_dataset_name} to {mismatch_dataset_name} with seed {seed}\")\n",
    "        # plot x label\n",
    "        ax.set_xlabel(\"False Positive Rate\")\n",
    "        # plot y label\n",
    "        ax.set_ylabel(\"True Positive Rate\")\n",
    "        # plot grid \n",
    "        ax.grid()\n",
    "        # xlim to 0 to 1\n",
    "        ax.set_xlim(0, 1)\n",
    "        # ylim to 0 to 1\n",
    "        ax.set_ylim(0, 1)\n",
    "        plt.show()\n",
    "\n",
    "        best_magnitude_dict_r[f\"r_{r}\"]=magnitude_at_max_auc\n",
    "        best_temperature_dict_r[f\"r_{r}\"]=temperature_at_max_auc\n",
    "\n",
    "        print(f\"min fpr at 95 tpr: {min_fpr_at_95_tpr:.3f}, magnitude={magnitude_at_min_fpr}, temperature={temperature_at_min_fpr}, tpr={tpr_at_min_fpr:.3f}, auc={auc_at_min_fpr:.3f}\")\n",
    "        print(f\"max auc: {max_auc:.3f}, magnitude={magnitude_at_max_auc}, temperature={temperature_at_max_auc}, fpr={fpr_at_95_tpr_at_max_auc:.3f}, tpr={tpr_at_max_auc:.3f}\")\n",
    "    \n",
    "    best_temperature_dict[f\"seed_{seed}\"]=best_temperature_dict_r\n",
    "    best_magnitude_dict[f\"seed_{seed}\"]=best_magnitude_dict_r\n",
    "\n",
    "# close all the plot\n",
    "plt.close('all')\n",
    "\n",
    "print(f\"best magnitude dict: {best_magnitude_dict}\")\n",
    "print(f\"best temperature dict: {best_temperature_dict}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for seed in seeds:\n",
    "    for r in rs:\n",
    "        fig, ax = plt.subplots(figsize=(10, 5))\n",
    "        for magnitude in [best_magnitude_dict[f\"seed_{seed}\"][f\"r_{r}\"]]:\n",
    "            for temperature in [best_temperature_dict[f\"seed_{seed}\"][f\"r_{r}\"]]:\n",
    "                dest_folder = f'./{match_dataset_name}_to_{mismatch_dataset_name}/{model_name}/model_seed_{model_seed}/results/r_{r}/seed_{seed}'\n",
    "                final_dest_folder = f'{dest_folder}/magnitude_{magnitude}/temperature_{temperature}'\n",
    "                doctor_test_fprs = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_fprs.pt')\n",
    "                doctor_test_tprs = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_tprs.pt')\n",
    "                doctor_test_thresholds = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_thresholds.pt')\n",
    "                doctor_test_fpr = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_fpr.pt')\n",
    "                doctor_test_tpr = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_tpr.pt')\n",
    "                doctor_test_threshold = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_threshold.pt')\n",
    "                doctor_test_auc = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_auc.pt')\n",
    "                # plot roc curve\n",
    "                ax.plot(doctor_test_fprs, doctor_test_tprs,\n",
    "                        label=f\"r={r}, magnitude={magnitude}, temperature={temperature}\")\n",
    "                # put a marker at the doctor test fpr and tpr\n",
    "                ax.plot(doctor_test_fpr, doctor_test_tpr, marker='o', color='green',\n",
    "                        label=f\"test fpr: {doctor_test_fpr:.2f}, test tpr: {doctor_test_tpr:.2f}\")\n",
    "                # connext y axis with a line at doctor test fpr and tpr dashed\n",
    "                ax.plot([doctor_test_fpr, doctor_test_fpr], [0,\n",
    "                        doctor_test_tpr], linestyle='--', color='green')\n",
    "                # connext x axis with a line at doctor test fpr and tpr dashed\n",
    "                ax.plot([0, doctor_test_fpr], [doctor_test_tpr,\n",
    "                        doctor_test_tpr], linestyle='--', color='green')\n",
    "                # set the x axis to 0 to 1\n",
    "                ax.set_xlim(0, 1)\n",
    "                # set the y axis to 0 to 1\n",
    "                ax.set_ylim(0, 1)\n",
    "                # annotate the doctor test fpr and and tpr with the values\n",
    "                ax.annotate(f\"(fpr:{doctor_test_fpr:.2f}, tpr:{doctor_test_tpr:.2f})\",\n",
    "                            (doctor_test_fpr, doctor_test_tpr))\n",
    "                # print the doctor test auc\n",
    "                print(f\"test auc: {doctor_test_auc:.3f}, magnitude={magnitude}, temperature={temperature}, fpr={doctor_test_fpr:.3f}, tpr={doctor_test_tpr:.3f}\")\n",
    "\n",
    "        # plot random guess line\n",
    "        ax.plot([0, 1], [0, 1], linestyle='--', label='Random Guess', color='red')\n",
    "\n",
    "plt.show()\n",
    "# close all the plot\n",
    "plt.close('all')\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jupyter_env",
   "language": "python",
   "name": "jupyter_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
