{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Report plots for corruption analysis CIFAR10-CIFAR10C\n",
    "\n",
    "In this notebook, we plot the results of the mismatch analysis for the CIFAR10-CIFAR10C dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import sys\n",
    "# import subprocess\n",
    "\n",
    "# # implement pip as a subprocess:\n",
    "# subprocess.check_call([sys.executable, '-m', 'pip', 'install', \n",
    "# 'plotly'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pn\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "from matplotlib import pyplot as plt\n",
    "from plotly.subplots import make_subplots"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = 'densenet121_custom'\n",
    "match_dataset_name = 'cifar10'\n",
    "mismatch_dataset_name = 'cifar10c'\n",
    "model_seed = 1\n",
    "data_path = 'data'\n",
    "device_id = 0\n",
    "batch_size = 1000\n",
    "rs = [10, 5, 3, 2]\n",
    "seeds = [1, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10]\n",
    "temperatures = [1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 2.0, 2.5, 3.0, 100.0, 1000.0]\n",
    "magnitudes = [0.0, 0.0002, 0.00025, 0.0003, 0.00035, 0.0004, 0.0006, 0.0008,\n",
    "              0.001, 0.0012, 0.0014, 0.0016, 0.0018, 0.002, 0.0022, 0.0024,\n",
    "              0.0026, 0.0028, 0.003, 0.0032, 0.0034, 0.0036, 0.0038, 0.004]\n",
    "\n",
    "lbds = [.8]\n",
    "lr = 0.1\n",
    "epochs = 100\n",
    "corruptions = [\n",
    "    \"brightness\",\n",
    "    \"contrast\",\n",
    "    \"defocus_blur\",\n",
    "    \"elastic_transform\",\n",
    "    \"fog\",\n",
    "    \"frost\",\n",
    "    \"gaussian_blur\",\n",
    "    \"gaussian_noise\",\n",
    "    \"glass_blur\",\n",
    "    \"impulse_noise\",\n",
    "    \"jpeg_compression\",\n",
    "    \"motion_blur\",\n",
    "    \"pixelate\",\n",
    "    \"saturate\",\n",
    "    \"shot_noise\",\n",
    "    \"snow\",\n",
    "    \"spatter\",\n",
    "    \"speckle_noise\",\n",
    "    \"zoom_blur\"\n",
    "]\n",
    "intensities = [\n",
    "    # 1,\n",
    "    # 2,\n",
    "    # 3,\n",
    "    # 4,\n",
    "    5\n",
    "]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global_dictionary_doctor = {}\n",
    "for corruption in corruptions:\n",
    "    for intensity in intensities:\n",
    "        print(f\"Corruption: {corruption}, Intensity: {intensity}\")\n",
    "\n",
    "        r_dict_doctor = {}\n",
    "        for r in rs:\n",
    "            seed_dict = {}\n",
    "            for seed in seeds:\n",
    "                tmp_dict = {}\n",
    "                max_auc = -float('inf')\n",
    "                fpr_95_tpr_at_max_auc = None\n",
    "                tpr_at_max_auc = None\n",
    "                temperature_at_max_auc = None\n",
    "                magnitude_at_max_auc = None\n",
    "\n",
    "                for temperature in temperatures:\n",
    "                    for magnitude in magnitudes:\n",
    "                        source_folder = f'doctor/{match_dataset_name}_to_{mismatch_dataset_name}/{model_name}/model_seed_{model_seed}/{corruption}_{intensity}'\n",
    "                        dest_folder = f'{source_folder}/results/r_{r}/seed_{seed}'\n",
    "\n",
    "                        final_dest_folder = f'{dest_folder}/magnitude_{magnitude}/temperature_{temperature}'\n",
    "                        doctor_val_fprs = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_fprs.pt')\n",
    "                        doctor_val_tprs = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_tprs.pt')\n",
    "                        doctor_val_thresholds = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_thresholds.pt')\n",
    "                        doctor_val_fpr = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_fpr.pt')\n",
    "                        doctor_val_tpr = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_tpr.pt')\n",
    "                        doctor_val_threshold = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_threshold.pt')\n",
    "                        doctor_val_auc = torch.load(\n",
    "                            f'{final_dest_folder}/doctor_val_auc.pt')\n",
    "\n",
    "                        if doctor_val_auc > max_auc:\n",
    "                            max_auc = doctor_val_auc\n",
    "                            fpr_95_tpr_at_max_auc = doctor_val_fpr\n",
    "                            tpr_at_max_auc = doctor_val_tpr\n",
    "                            temperature_at_max_auc = temperature\n",
    "                            magnitude_at_max_auc = magnitude\n",
    "                        elif doctor_val_auc == max_auc:\n",
    "                            if doctor_val_fpr < fpr_95_tpr_at_max_auc:\n",
    "                                fpr_95_tpr_at_max_auc = doctor_val_fpr\n",
    "                                tpr_at_max_auc = doctor_val_tpr\n",
    "                                temperature_at_max_auc = temperature\n",
    "                                magnitude_at_max_auc = magnitude\n",
    "                tmp_dict['max_auc'] = max_auc\n",
    "                tmp_dict['fpr_95_tpr_at_max_auc'] = fpr_95_tpr_at_max_auc\n",
    "                tmp_dict['tpr_at_max_auc'] = tpr_at_max_auc\n",
    "                tmp_dict['temperature_at_max_auc'] = temperature_at_max_auc\n",
    "                tmp_dict['magnitude_at_max_auc'] = magnitude_at_max_auc\n",
    "                seed_dict[seed] = tmp_dict\n",
    "            r_dict_doctor[r] = seed_dict\n",
    "\n",
    "        global_dictionary_doctor[f\"{corruption}_{intensity}\"] = r_dict_doctor\n",
    "\n",
    "# for key, value in global_dictionary_doctor.items():\n",
    "#     print(key)\n",
    "#     print(value)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "angles_dict_doctor = {}\n",
    "auc_dict_doctor = {}\n",
    "fpr_95_tpr_dict_doctor = {}\n",
    "\n",
    "for r in rs:\n",
    "    auc_dict_doctor[r] = {}\n",
    "    fpr_95_tpr_dict_doctor[r] = {}\n",
    "    angles_dict_doctor[r] = []\n",
    "    for corruption in corruptions:\n",
    "        for intensity in intensities:\n",
    "            auc_dict_doctor[r][f\"{corruption}_{intensity}\"] = None\n",
    "            fpr_95_tpr_dict_doctor[r][f\"{corruption}_{intensity}\"] = None\n",
    "            angles_dict_doctor[r].append(f\"{corruption}_{intensity}\")\n",
    "            auc_list = []\n",
    "            fpr_95_tpr_list = []\n",
    "            for seed in seeds:\n",
    "                selected_temperature = global_dictionary_doctor[f\"{corruption}_{intensity}\"][r][seed]['temperature_at_max_auc']\n",
    "                selected_magnitude = global_dictionary_doctor[f\"{corruption}_{intensity}\"][r][seed]['magnitude_at_max_auc']\n",
    "\n",
    "                source_folder = f'doctor/{match_dataset_name}_to_{mismatch_dataset_name}/{model_name}/model_seed_{model_seed}/{corruption}_{intensity}'\n",
    "                dest_folder = f'{source_folder}/results/r_{r}/seed_{seed}'\n",
    "\n",
    "                final_dest_folder = f'{dest_folder}/magnitude_{selected_magnitude}/temperature_{selected_temperature}'\n",
    "                doctor_test_fprs = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_fprs.pt')\n",
    "                doctor_test_tprs = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_tprs.pt')\n",
    "                doctor_test_thresholds = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_thresholds.pt')\n",
    "                doctor_test_fpr = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_fpr.pt')\n",
    "                doctor_test_tpr = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_tpr.pt')\n",
    "                doctor_test_threshold = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_threshold.pt')\n",
    "                doctor_test_auc = torch.load(\n",
    "                    f'{final_dest_folder}/doctor_test_auc.pt')\n",
    "\n",
    "                auc_list.append(doctor_test_auc)\n",
    "                fpr_95_tpr_list.append(doctor_test_fpr)\n",
    "            \n",
    "            auc_dict_doctor[r][f\"{corruption}_{intensity}\"] = np.mean(auc_list)\n",
    "            fpr_95_tpr_dict_doctor[r][f\"{corruption}_{intensity}\"] = np.mean(fpr_95_tpr_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "global_dictionary_d_matrix = {}\n",
    "for corruption in corruptions:\n",
    "    for intensity in intensities:\n",
    "        print(f\"Corruption: {corruption}, Intensity: {intensity}\")\n",
    "\n",
    "        r_dict_d_matrix = {}\n",
    "        for r in rs:\n",
    "            seed_dict = {}\n",
    "            for seed in seeds:\n",
    "                tmp_dict = {}\n",
    "                max_auc = -float('inf')\n",
    "                fpr_95_tpr_at_max_auc = None\n",
    "                tpr_at_max_auc = None\n",
    "                temperature_at_max_auc = None\n",
    "                magnitude_at_max_auc = None\n",
    "\n",
    "                for temperature in temperatures:\n",
    "                    for magnitude in magnitudes:\n",
    "                        for lbd in lbds:\n",
    "                            source_folder = f'd_matrix/{match_dataset_name}_to_{mismatch_dataset_name}/{model_name}/model_seed_{model_seed}/seed_{seed}'\n",
    "                            dest_folder = f'{source_folder}/r_{r}/lr_{lr}/epochs_{epochs}/lbd_{lbd}'\n",
    "\n",
    "                            final_dest_folder = f'{dest_folder}/{corruption}_{intensity}/magnitude_{magnitude}/temperature_{temperature}'\n",
    "                            d_matrix_val_fprs = torch.load(\n",
    "                                f'{final_dest_folder}/D_fprs_val.pt')\n",
    "                            d_matrix_val_tprs = torch.load(\n",
    "                                f'{final_dest_folder}/D_tprs_val.pt')\n",
    "                            d_matrix_val_thresholds = torch.load(\n",
    "                                f'{final_dest_folder}/D_thresholds_val.pt')\n",
    "                            d_matrix_val_fpr = torch.load(\n",
    "                                f'{final_dest_folder}/D_fpr_val.pt')\n",
    "                            d_matrix_val_tpr = torch.load(\n",
    "                                f'{final_dest_folder}/D_tpr_val.pt')\n",
    "                            d_matrix_val_threshold = torch.load(\n",
    "                                f'{final_dest_folder}/D_threshold_val.pt')\n",
    "                            d_matrix_val_auc = torch.load(\n",
    "                                f'{final_dest_folder}/D_auc_val.pt')\n",
    "\n",
    "                            if d_matrix_val_auc > max_auc:\n",
    "                                max_auc = d_matrix_val_auc\n",
    "                                fpr_95_tpr_at_max_auc = d_matrix_val_fpr\n",
    "                                tpr_at_max_auc = d_matrix_val_tpr\n",
    "                                temperature_at_max_auc = temperature\n",
    "                                magnitude_at_max_auc = magnitude\n",
    "                            elif d_matrix_val_auc == max_auc:\n",
    "                                if d_matrix_val_fpr < fpr_95_tpr_at_max_auc:\n",
    "                                    fpr_95_tpr_at_max_auc = d_matrix_val_fpr\n",
    "                                    tpr_at_max_auc = d_matrix_val_tpr\n",
    "                                    temperature_at_max_auc = temperature\n",
    "                                    magnitude_at_max_auc = magnitude\n",
    "                tmp_dict['max_auc'] = max_auc\n",
    "                tmp_dict['fpr_95_tpr_at_max_auc'] = fpr_95_tpr_at_max_auc\n",
    "                tmp_dict['tpr_at_max_auc'] = tpr_at_max_auc\n",
    "                tmp_dict['temperature_at_max_auc'] = temperature_at_max_auc\n",
    "                tmp_dict['magnitude_at_max_auc'] = magnitude_at_max_auc\n",
    "                seed_dict[seed] = tmp_dict\n",
    "            r_dict_d_matrix[r] = seed_dict\n",
    "\n",
    "        global_dictionary_d_matrix[f\"{corruption}_{intensity}\"] = r_dict_d_matrix\n",
    "\n",
    "# for key, value in global_dictionary_d_matrix.items():\n",
    "#     print(key)\n",
    "#     print(value)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "angles_dict_d_matrix = {}\n",
    "auc_dict_d_matrix = {}\n",
    "fpr_95_tpr_dict_d_matrix = {}\n",
    "\n",
    "for r in rs:\n",
    "    auc_dict_d_matrix[r] = {}\n",
    "    fpr_95_tpr_dict_d_matrix[r] = {}\n",
    "    angles_dict_d_matrix[r] = []\n",
    "    for corruption in corruptions:\n",
    "        for intensity in intensities:\n",
    "            auc_dict_d_matrix[r][f\"{corruption}_{intensity}\"] = None\n",
    "            fpr_95_tpr_dict_d_matrix[r][f\"{corruption}_{intensity}\"] = None\n",
    "            angles_dict_d_matrix[r].append(f\"{corruption}_{intensity}\")\n",
    "            auc_list = []\n",
    "            fpr_95_tpr_list = []\n",
    "            for seed in seeds:\n",
    "                selected_temperature = global_dictionary_d_matrix[f\"{corruption}_{intensity}\"][r][seed]['temperature_at_max_auc']\n",
    "                selected_magnitude = global_dictionary_d_matrix[f\"{corruption}_{intensity}\"][r][seed]['magnitude_at_max_auc']\n",
    "\n",
    "                source_folder = f'd_matrix/{match_dataset_name}_to_{mismatch_dataset_name}/{model_name}/model_seed_{model_seed}/seed_{seed}'\n",
    "                dest_folder = f'{source_folder}/r_{r}/lr_{lr}/epochs_{epochs}/lbd_{lbd}'\n",
    "\n",
    "                final_dest_folder =  f'{dest_folder}/{corruption}_{intensity}/magnitude_{selected_magnitude}/temperature_{selected_temperature}'\n",
    "                d_matrix_test_fprs = torch.load(\n",
    "                    f'{final_dest_folder}/D_fprs_test.pt')\n",
    "                d_matrix_test_tprs = torch.load(\n",
    "                    f'{final_dest_folder}/D_tprs_test.pt')\n",
    "                d_matrix_test_thresholds = torch.load(\n",
    "                    f'{final_dest_folder}/D_thresholds_test.pt')\n",
    "                d_matrix_test_fpr = torch.load(\n",
    "                    f'{final_dest_folder}/D_fpr_test.pt')\n",
    "                d_matrix_test_tpr = torch.load(\n",
    "                    f'{final_dest_folder}/D_tpr_test.pt')\n",
    "                d_matrix_test_threshold = torch.load(\n",
    "                    f'{final_dest_folder}/D_threshold_test.pt')\n",
    "                d_matrix_test_auc = torch.load(\n",
    "                    f'{final_dest_folder}/D_auc_test.pt')\n",
    "\n",
    "                auc_list.append(d_matrix_test_auc)\n",
    "                fpr_95_tpr_list.append(d_matrix_test_fpr)\n",
    "            \n",
    "            auc_dict_d_matrix[r][f\"{corruption}_{intensity}\"] = np.mean(auc_list)\n",
    "            fpr_95_tpr_dict_d_matrix[r][f\"{corruption}_{intensity}\"] = np.mean(fpr_95_tpr_list)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# for any element in the list rs\n",
    "for r in rs:\n",
    "    angles = angles_dict_doctor[r]\n",
    "    aucs_doctor = []\n",
    "    fpr_95_tpr_doctor = []\n",
    "    aucs_d_matrix = []\n",
    "    fpr_95_tpr_d_matrix = []\n",
    "    for angle in angles:\n",
    "        aucs_doctor.append(auc_dict_doctor[r][angle])\n",
    "        aucs_d_matrix.append(auc_dict_d_matrix[r][angle])\n",
    "        fpr_95_tpr_doctor.append(fpr_95_tpr_dict_doctor[r][angle])\n",
    "        fpr_95_tpr_d_matrix.append(fpr_95_tpr_dict_d_matrix[r][angle])\n",
    "\n",
    "    df_aucs_doctor = pn.DataFrame(dict(\n",
    "        r=aucs_doctor,\n",
    "        theta=angles))\n",
    "    # print(df_aucs)\n",
    "    df_aucs_d_matrix = pn.DataFrame(dict(\n",
    "        r=aucs_d_matrix,\n",
    "        theta=angles))\n",
    "\n",
    "    df_fpr_95_tpr_doctor = pn.DataFrame(dict(\n",
    "        r=fpr_95_tpr_doctor,\n",
    "        theta=angles))\n",
    "    \n",
    "    df_fpr_95_tpr_d_matrix = pn.DataFrame(dict(\n",
    "        r=fpr_95_tpr_d_matrix,\n",
    "        theta=angles))\n",
    "\n",
    "    # fig = px.line_polar(df_aucs_doctor, r='r', theta='theta', line_close=True)\n",
    "    # fig.update_traces(fill='toself')\n",
    "    # fig.show()\n",
    "\n",
    "    # fig = px.line_polar(df_fpr_95_tpr, r='r', theta='theta', line_close=True)\n",
    "    # fig.update_traces(fill='toself')\n",
    "    # fig.show()\n",
    "\n",
    "    fig = make_subplots(rows=1, cols=2, specs=[[{'type': 'polar'}] * 2] * 1)\n",
    "\n",
    "    fig.add_trace(\n",
    "        go.Scatterpolar(theta=angles, r=aucs_doctor, name='AUC doctor', fill='toself'),\n",
    "        row=1, col=1\n",
    "    )\n",
    "\n",
    "    fig.add_trace(\n",
    "        go.Scatterpolar(theta=angles, r=aucs_d_matrix, fill='toself', name='AUC d_matrix'),\n",
    "        row=1, col=1\n",
    "    )\n",
    "\n",
    "    fig.add_trace(\n",
    "        go.Scatterpolar(theta=angles, r=fpr_95_tpr_doctor, name='fpr doctor', fill='toself'),\n",
    "        row=1, col=2\n",
    "    )\n",
    "\n",
    "    fig.add_trace(\n",
    "        go.Scatterpolar(theta=angles, r=fpr_95_tpr_d_matrix, fill='toself', name='fpr d_matrix'),\n",
    "        row=1, col=2\n",
    "    )\n",
    "    fig.update_layout(polar=dict(radialaxis=dict(visible=True,range=[0, 1])),showlegend=True)\n",
    "\n",
    "    fig.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jupyter_env",
   "language": "python",
   "name": "jupyter_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
