{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cb75697",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "import os\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "from lib.metrics import utils\n",
    "from scipy.optimize import minimize\n",
    "import json\n",
    "from tqdm import tqdm\n",
    "import pickle\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from copy import copy\n",
    "from sklearn.metrics import f1_score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66f57826",
   "metadata": {},
   "outputs": [],
   "source": [
    "# captioning best\n",
    "df = pd.read_pickle('./all_res_df.pkl')\n",
    "df = df[(df.ablation == 'none')]\n",
    "\n",
    "hparams = ['knn_k', 'dist_type']\n",
    "\n",
    "SELECTION_METRIC = 'know_val_labels_val_F1_optimal'\n",
    "\n",
    "avg_perfs = (df.groupby(['dataset', 'noise_type', 'noise_level', 'ablation'] + hparams, dropna = False)\n",
    "             .agg(performance = (SELECTION_METRIC, 'mean'))\n",
    "             .reset_index())\n",
    "\n",
    "# get configs with best perfs\n",
    "best_models = (avg_perfs.groupby(['dataset', 'noise_type', 'noise_level', 'ablation'], dropna = False)\n",
    "               .agg(performance = ('performance', 'max'))\n",
    "               .merge(avg_perfs)\n",
    "               .drop_duplicates(subset = ['dataset', 'noise_type', 'noise_level', 'ablation']))\n",
    "\n",
    "# take subset of df with best perfs\n",
    "selected_configs = (\n",
    "    best_models.drop(columns = ['performance'])\n",
    "    .dropna(axis=1, how='all').merge(df)\n",
    ")\n",
    "\n",
    "configs = selected_configs[~selected_configs.dataset.isin(['cifar100', 'cifar10'])]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7defa64f",
   "metadata": {},
   "outputs": [],
   "source": [
    "## choose fixed\n",
    "selected_models_df = df.query('ablation == \"none\" and knn_k == 30 and dist_type == \"cosine\"').set_index('output_dir')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e026134d",
   "metadata": {},
   "outputs": [],
   "source": [
    "selected_model_dfs = []\n",
    "for idx, i in tqdm(selected_models_df.iterrows(), total = len(selected_models_df)):\n",
    "    selected_model_dfs.append(\n",
    "        {\n",
    "            'dataset': i['dataset'], \n",
    "            'noise_type': i['noise_type'], \n",
    "            'noise_level': i['noise_level'],\n",
    "            'data_seed': i['data_seed'],\n",
    "            'df': pd.read_pickle(Path(i.name)/'res.pkl')['df']           \n",
    "        }      \n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdee59c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "hparam_dict = {\n",
    "    'beta': 5,\n",
    "    'gamma': 5,\n",
    "    'tau_1_n': 0.1,\n",
    "    'tau_2_n': 5,\n",
    "    'tau_1_m': 0.1,\n",
    "    'tau_2_m': 5\n",
    "} "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c11a5737",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = []\n",
    "for i in tqdm(selected_model_dfs):\n",
    "    i['df']['score'] = utils.calc_scores_given_hparams_vectorized(i['df'], hparam_dict)\n",
    "    df_val = i['df'].query('sset == \"val\"')\n",
    "    f1, thres = utils.f1_with_pred_prev_constraint(df_val['is_mislabel'], df_val['score'], \n",
    "                                                   pred_prev = df_val['is_mislabel'].sum()/len(df_val), return_thres = True)\n",
    "    df_test = i['df'].query('sset == \"test\"')\n",
    "    mets = utils.prob_metrics(df_test['is_mislabel'], df_test['score'])\n",
    "    mets['F1'] = f1_score(df_test['is_mislabel'], df_test['score'] >= thres)\n",
    "    res.append({\n",
    "        **{a:i[a] for a in i if a != 'df'}, **mets\n",
    "    })\n",
    "res_df = pd.DataFrame(res)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "72c2b6ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_a = res_df.groupby(['dataset', 'noise_type', 'noise_level']).agg({i: ['mean','std'] for i in ['AUROC', 'AUPRC', 'F1']}).sort_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b8e8f0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "res_b = (configs.rename(\n",
    "            columns = {\n",
    "                'know_val_labels_test_AUROC': 'AUROC',\n",
    "                'know_val_labels_test_AUPRC': 'AUPRC',\n",
    "                'know_val_labels_test_F1_optimal': 'F1'\n",
    "            }\n",
    "        )\n",
    "         .groupby(['dataset', 'noise_type', 'noise_level'])\n",
    "                                   .agg({i: ['mean','std'] for i in ['AUROC', 'AUPRC', 'F1']}).sort_index())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5dd31905",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_df = pd.DataFrame()\n",
    "\n",
    "for i in ['AUROC', 'AUPRC', 'F1']:\n",
    "    new_df[f'{i}_fixed'] = res_a.apply(lambda x: f'{x[i][\"mean\"]*100:.1f} ({x[i][\"std\"]*100:.1f})', axis = 1)\n",
    "    \n",
    "for i in ['AUROC', 'AUPRC', 'F1']:\n",
    "    new_df[f'{i}_optimal'] = res_b.apply(lambda x: f'{x[i][\"mean\"]*100:.1f} ({x[i][\"std\"]*100:.1f})', axis = 1)\n",
    "    \n",
    "idx1 = res_df.set_index(['dataset', 'noise_type', 'noise_level', 'data_seed']).index.drop_duplicates()\n",
    "idx2 = configs.set_index(['dataset', 'noise_type', 'noise_level', 'data_seed']).index.drop_duplicates()\n",
    "idx_common = list(set(idx1).intersection(set(idx2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7d5c6d8f",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in ['AUROC', 'AUPRC', 'F1']:\n",
    "    r1 = configs.rename(\n",
    "            columns = {\n",
    "                'know_val_labels_test_AUROC': 'AUROC',\n",
    "                'know_val_labels_test_AUPRC': 'AUPRC',\n",
    "                'know_val_labels_test_F1_optimal': 'F1'\n",
    "            }\n",
    "        ).set_index(['dataset', 'noise_type', 'noise_level', 'data_seed']).loc[idx_common, i]\n",
    "    r1 = r1[~r1.index.duplicated(keep='first')]\n",
    "    r1 = r1\n",
    "    \n",
    "    r2 = res_df.set_index(['dataset', 'noise_type', 'noise_level', 'data_seed']).loc[idx_common, i]\n",
    "    \n",
    "    new_df[f'{i}_Gap_mean'] = (r2 - r1).sort_index().groupby(['dataset', 'noise_type', 'noise_level']).mean()\n",
    "    new_df[f'{i}_Gap_std'] = (r2 - r1).sort_index().groupby(['dataset', 'noise_type', 'noise_level']).std()\n",
    "    new_df[f'{i}_Gap'] = new_df.apply(lambda x: f'{x[i + \"_Gap_mean\"]*100:.1f} ({x[i + \"_Gap_std\"]*100:.1f})', axis = 1)\n",
    "    # new_df = new_df.drop(columns = [f'{i}_Gap_mean', f'{i}_Gap_std'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c0d9a2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_df[[j for i in ['AUROC', 'AUPRC', 'F1'] for j in (f'{i}_optimal', f'{i}_fixed', f'{i}_Gap')]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93944ed9",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_df[f'AUROC_Gap_mean'].mean(), new_df[f'AUROC_Gap_mean'].std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e1aec53",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_df[f'AUPRC_Gap_mean'].mean(), new_df[f'AUPRC_Gap_mean'].std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46a159d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_df[f'AUROC_Gap_mean'].min()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aee7ec72",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
