{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import seaborn as sns\n",
    "from tools import data_tools\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "config = data_tools.read_config(\n",
    "    'corruption_d_matrix_config.yaml')\n",
    "\n",
    "##################\n",
    "##################\n",
    "##################\n",
    "\n",
    "model_name = config[\"model_name\"]='resnet34_custom'\n",
    "match_dataset_name = config[\"match_dataset_name\"]\n",
    "corrupted_dataset_name = config[\"corrupted_dataset_name\"]\n",
    "model_seed = config[\"model_seed\"]\n",
    "data_path = config[\"data_path\"]\n",
    "magnitudes = config[\"magnitudes\"]\n",
    "temperatures = config[\"temperatures\"]\n",
    "batch_size = config[\"batch_size\"]\n",
    "rs = config[\"rs\"]\n",
    "seeds = [1]  # config[\"seeds\"]\n",
    "lbds = config[\"lbds\"]\n",
    "lr = config[\"lr\"]\n",
    "epochs = config[\"epochs\"]\n",
    "batch_size = config[\"batch_size\"]\n",
    "corruptions = config[\"corruptions\"] = ['brightness']\n",
    "intensities = config[\"intensities\"] = [1]\n",
    "\n",
    "# print config one by one\n",
    "for key, value in config.items():\n",
    "    print(key, value)\n",
    "\n",
    "dest_folder = f\"{match_dataset_name}_to_{corrupted_dataset_name}/{model_name}/model_seed_{model_seed}\"\n",
    "device = torch.device(\"cpu\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_D_matrix(params):\n",
    "    params = torch.tril(params, diagonal=-1)\n",
    "    params = params + params.T\n",
    "    params = params.abs()\n",
    "    params = params / params.norm()\n",
    "    if params.device == torch.device('cpu'):\n",
    "        params = params.detach().cpu().numpy()\n",
    "    return params\n",
    "\n",
    "for seed in seeds:\n",
    "    for r in rs:\n",
    "        for lbd in lbds:\n",
    "            fig, axs = plt.subplots(2, 2, figsize=(10, 10))\n",
    "            dest_folder_seed = os.path.join(\n",
    "                dest_folder, f\"seed_{seed}/r_{r}/lr_{lr}/epochs_{epochs}/lbd_{lbd}\")\n",
    "            D_matrix = prepare_D_matrix(torch.load('/'.join((dest_folder_seed, 'D_matrix.pt')), map_location=device))\n",
    "            loss_history = torch.load('/'.join((dest_folder_seed,\n",
    "                                                'training_D_matrix_loss_history.pt')))\n",
    "            loss_history_pos = torch.load('/'.join((dest_folder_seed,\n",
    "                                                    'training_D_matrix_loss_history_pos.pt')))\n",
    "            loss_history_neg = torch.load('/'.join((dest_folder_seed,\n",
    "                                                    'training_D_matrix_loss_history_neg.pt')))\n",
    "            auc_history = torch.load('/'.join((dest_folder_seed,\n",
    "                                            'training_D_matrix_auc_history.pt')))\n",
    "            fpr_at_95_tpr_history = torch.load('/'.join((dest_folder_seed,\n",
    "                                                        'training_D_matrix_fpr_at_95_tpr_history.pt')))\n",
    "            # plot heatmap of D matrix using seaborn in the first subplot\n",
    "            ax = axs[0, 0]\n",
    "            ax.set_title('D matrix')\n",
    "            sns.heatmap(D_matrix, ax=ax, cmap='Greens', vmin=0, vmax=1)\n",
    "            # plot loss history in the second subplot\n",
    "            ax = axs[0, 1]\n",
    "            ax.set_title('Loss history')\n",
    "            ax.plot(loss_history)\n",
    "            ax.plot(loss_history_pos)\n",
    "            ax.plot(loss_history_neg)\n",
    "            ax.legend(['loss', 'loss_pos', 'loss_neg'])\n",
    "            # plot auc history in the third subplot\n",
    "            ax = axs[1, 0]\n",
    "            ax.set_title('AUC history')\n",
    "            ax.plot(auc_history)\n",
    "            # plot fpr at 95 tpr history in the fourth subplot\n",
    "            ax = axs[1, 1]\n",
    "            ax.set_title('FPR at 95 TPR history')\n",
    "            ax.plot(fpr_at_95_tpr_history)\n",
    "            # show the figure\n",
    "            fig.suptitle(f\"seed {seed}, r {r}, lbd {lbd}\")\n",
    "            # acivate grid for all subplots in the figure but the first one\n",
    "            for ax in fig.axes[1:]:\n",
    "                ax.grid()\n",
    "            plt.show()\n",
    "    \n",
    "# close all the figures\n",
    "plt.close('all')\n",
    "        \n",
    "        \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for seed in seeds:\n",
    "    # create a figure with 2 subplots\n",
    "    for corruption in corruptions:\n",
    "        for intensity in intensities:\n",
    "            for r in rs:\n",
    "                for lbd in lbds:\n",
    "                    max_auc = -float('inf')\n",
    "                    fpr_max_auc = None\n",
    "                    tpr_max_auc = None\n",
    "                    temperature_max_auc = None\n",
    "                    magnitude_max_auc = None\n",
    "\n",
    "                    fig, axs = plt.subplots(1, 2, figsize=(10, 5))\n",
    "                    dest_folder_seed = os.path.join(\n",
    "                        dest_folder, f\"seed_{seed}/r_{r}/lr_{lr}/epochs_{epochs}/lbd_{lbd}/{corruption}_{intensity}\")\n",
    "                    for magnitude in magnitudes:\n",
    "                        magnitude_folder = os.path.join(\n",
    "                            dest_folder_seed, f\"magnitude_{magnitude}\")\n",
    "                        for temperature in temperatures:\n",
    "                            temperature_folder = os.path.join(\n",
    "                                magnitude_folder, f\"temperature_{temperature}\")\n",
    "                            D_fprs_val = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_fprs_val.pt']))\n",
    "                            D_tprs_val = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_tprs_val.pt']))\n",
    "                            D_thresholds_val = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_thresholds_val.pt']))\n",
    "                            D_fpr_val = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_fpr_val.pt']))\n",
    "                            D_tpr_val = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_tpr_val.pt']))\n",
    "                            D_threshold_val = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_threshold_val.pt']))\n",
    "                            D_auc_val = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_auc_val.pt']))\n",
    "                            if D_auc_val > max_auc:\n",
    "                                max_auc = D_auc_val\n",
    "                                fpr_max_auc = D_fpr_val\n",
    "                                tpr_max_auc = D_tpr_val\n",
    "                                temperature_max_auc = temperature\n",
    "                                magnitude_max_auc = magnitude\n",
    "                            elif D_auc_val == max_auc:\n",
    "                                if D_fpr_val < fpr_max_auc:\n",
    "                                    max_auc = D_auc_val\n",
    "                                    fpr_max_auc = D_fpr_val\n",
    "                                    tpr_max_auc = D_tpr_val\n",
    "                                    temperature_max_auc = temperature\n",
    "                                    magnitude_max_auc = magnitude\n",
    "\n",
    "                            # plot roc curve in the first subplot\n",
    "                            ax = axs[0]\n",
    "                            ax.set_title('ROC curve validation')\n",
    "                            ax.plot(D_fprs_val, D_tprs_val)\n",
    "                    # plot diagonal line in the first subplot\n",
    "                    ax = axs[0]\n",
    "                    ax.plot([0, 1], [0, 1], linestyle='--', color='black')\n",
    "                    # put marker x on the point with the highest auc with coordinates (fpr_max_auc, tpr_max_auc)\n",
    "                    ax.plot(fpr_max_auc, tpr_max_auc, marker='x', color='red')\n",
    "                    # connect axises with the point with the highest auc\n",
    "                    ax.plot([0, fpr_max_auc], [tpr_max_auc, tpr_max_auc],\n",
    "                            linestyle='--', color='black')\n",
    "                    ax.plot([fpr_max_auc, fpr_max_auc], [0, tpr_max_auc],\n",
    "                            linestyle='--', color='black')\n",
    "                    # annotate the point with the highest auc\n",
    "                    ax.annotate(f\"({fpr_max_auc:.2f}, {tpr_max_auc:.2f})\",\n",
    "                                (fpr_max_auc, tpr_max_auc))\n",
    "                    # add axis labels\n",
    "                    ax.set_xlabel('False Positive Rate')\n",
    "                    ax.set_ylabel('True Positive Rate')\n",
    "                    # set axis limits\n",
    "                    ax.set_xlim(0, 1)\n",
    "                    ax.set_ylim(0, 1)\n",
    "                    print(\n",
    "                        f\"The best auroc is {max_auc:.2f} with magnitude {magnitude_max_auc} and temperature {temperature_max_auc}, for seed {seed}, r {r}, lbd {lbd}\")\n",
    "###############################################################################################################\n",
    "                    for magnitude in [magnitude_max_auc]:\n",
    "                        magnitude_folder = os.path.join(\n",
    "                            dest_folder_seed, f\"magnitude_{magnitude}\")\n",
    "                        for temperature in [temperature_max_auc]:\n",
    "                            temperature_folder = os.path.join(\n",
    "                                magnitude_folder, f\"temperature_{temperature}\")\n",
    "                            D_fprs_test = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_fprs_test.pt']))\n",
    "                            D_tprs_test = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_tprs_test.pt']))\n",
    "                            D_thresholds_test = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_thresholds_test.pt']))\n",
    "                            D_fpr_test = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_fpr_test.pt']))\n",
    "                            D_tpr_test = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_tpr_test.pt']))\n",
    "                            D_threshold_test = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_threshold_test.pt']))\n",
    "                            D_auc_test = torch.load(\n",
    "                                '/'.join([temperature_folder, f'D_auc_test.pt']))\n",
    "\n",
    "                            # plot roc curve in the second subplot\n",
    "                            ax = axs[1]\n",
    "                            ax.set_title('ROC curve test')\n",
    "                            ax.plot(D_fprs_test, D_tprs_test)\n",
    "                            # plot diagonal line in the second subplot\n",
    "                            ax = axs[1]\n",
    "                            ax.plot([0, 1], [0, 1], linestyle='--', color='black')\n",
    "                            # put marker x on the point with the highest auc with coordinates (D_fpr_test, D_tpr_test)\n",
    "                            ax.plot(D_fpr_test, D_tpr_test, marker='x', color='red')\n",
    "                            # connect axises with the point with the highest auc\n",
    "                            ax.plot([0, D_fpr_test], [D_tpr_test, D_tpr_test],\n",
    "                                    linestyle='--', color='black')\n",
    "                            ax.plot([D_fpr_test, D_fpr_test], [0, D_tpr_test],\n",
    "                                    linestyle='--', color='black')\n",
    "                            # annotate the point with the highest auc\n",
    "                            ax.annotate(f\"({D_fpr_test:.2f}, {D_tpr_test:.2f})\",\n",
    "                                        (D_fpr_test, D_tpr_test))\n",
    "                            # add axis labels\n",
    "                            ax.set_xlabel('False Positive Rate')\n",
    "                            ax.set_ylabel('True Positive Rate')\n",
    "                            # set axis limits\n",
    "                            ax.set_xlim(0, 1)\n",
    "                            ax.set_ylim(0, 1)\n",
    "                            print(\n",
    "                                f\"The best auroc in test is {D_auc_test:.2f} with magnitude {magnitude_max_auc} and temperature {temperature_max_auc}, for seed {seed}, r {r}, lbd {lbd}\")\n",
    "                            # show the plot\n",
    "                            plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jupyter_env",
   "language": "python",
   "name": "jupyter_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
