{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import sys\n",
    "sys.path.append('../src')\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "from data import PPCI\n",
    "\n",
    "import os\n",
    "os.chdir('../')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# experiment Universal CRL\n",
    "encoder = \"dino\"\n",
    "split_criteria = \"position\"\n",
    "\n",
    "ic_weights = [0] + list(np.logspace(-1, 16, num=16))  #[0, 0.1, 1, 10, 100, 1000, 10000]\n",
    "seeds = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]\n",
    "dataset = PPCI(encoder = encoder,\n",
    "               token = \"class\",\n",
    "               task = \"or\",\n",
    "               split_criteria = split_criteria,\n",
    "               environment = \"supervised\",\n",
    "               batch_size = 64,\n",
    "               num_proc = 4,\n",
    "               verbose = True,\n",
    "               data_dir = 'data/istant_hq',\n",
    "               results_dir = f'results/istant_hq/{encoder}')\n",
    "all_metrics = pd.DataFrame(columns=['ic_weight', 'seed', \"inv_loss_val\", \"loss_val\", \"acc_val\", \"bacc_val\", \"TEB_val\", \"acc\", \"bacc\", \"TEB\", \"TEB_bin\", \"EAD\", \"best_epoch\"])\n",
    "i = 0\n",
    "num_epochs = 15\n",
    "train_metrics = np.zeros((len(ic_weights), len(seeds), num_epochs, 4))\n",
    "val_metrics = np.zeros((len(ic_weights), len(seeds), num_epochs, 4))\n",
    "for j, ic_weight in enumerate(ic_weights):\n",
    "    print(f\"IC weight: {ic_weight}\")\n",
    "    for k, seed in enumerate(seeds):\n",
    "        print(f\"Seed: {seed}\")\n",
    "        dataset.train(add_pred_env=\"supervised\", \n",
    "                    hidden_layers = 1,\n",
    "                    hidden_nodes = 256,\n",
    "                    batch_size = 128,\n",
    "                    lr = 0.0005,\n",
    "                    num_epochs=num_epochs,\n",
    "                    verbose=False,\n",
    "                    multidomain=True,\n",
    "                    ic_weight=ic_weight,\n",
    "                    seed=seed)\n",
    "        train_metrics[j,k] = np.array(dataset.model.train_metrics).squeeze()\n",
    "        val_metrics[j,k] = np.array(dataset.model.val_metrics).squeeze()\n",
    "        all_metrics_i = dataset.evaluate(color=None, verbose=False)\n",
    "        all_metrics_i['ic_weight'] = ic_weight\n",
    "        all_metrics_i['seed'] = seed\n",
    "        all_metrics_i['best_epoch'] = dataset.model.best_epoch \n",
    "        all_metrics.loc[i] = all_metrics_i\n",
    "        i += 1\n",
    "\n",
    "all_metrics['TERB'] = abs(all_metrics['TEB'])/all_metrics['EAD']*100\n",
    "results_dir = f'results/istant_hq/{encoder}/{split_criteria}'\n",
    "if not os.path.exists(results_dir):\n",
    "    os.makedirs(results_dir)\n",
    "all_metrics.to_csv(f'{results_dir}/invariance.csv', index=False)\n",
    "np.save(f'{results_dir}/train_metrics.npy', train_metrics)\n",
    "np.save(f'{results_dir}/val_metrics.npy', val_metrics)\n",
    "all_metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load the results\n",
    "encoder = \"dino\"\n",
    "split_criteria = \"position\"\n",
    "results_dir = f'results/istant_hq/{encoder}/{split_criteria}'\n",
    "\n",
    "all_metrics = pd.read_csv(f\"{results_dir}/invariance.csv\")\n",
    "all_metrics[\"univ_loss_val\"] = all_metrics[\"loss_val\"] + all_metrics[\"inv_loss_val\"]\n",
    "all_metrics[\"TEAB_val\"] = abs(all_metrics[\"TEB_val\"])\n",
    "\n",
    "# plot the TERB vs ic_weight averaging over seeds\n",
    "plt.figure()\n",
    "plt.xlabel(r\"$\\lambda_{INV}$\")\n",
    "plt.ylabel(\"TERB (%)\", color='tab:blue')\n",
    "plt.xscale('log')\n",
    "all_metrics['ic_weight'] = all_metrics['ic_weight'].replace(0, 0.01)\n",
    "plt.errorbar(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['TERB'].mean(), yerr=all_metrics.groupby('ic_weight')['TERB'].std(), fmt='o', color='tab:blue')\n",
    "plt.plot(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['TERB'].mean(), '--', color='tab:blue')\n",
    "plt.fill_between(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['TERB'].mean()-all_metrics.groupby('ic_weight')['TERB'].std(), all_metrics.groupby('ic_weight')['TERB'].mean()+all_metrics.groupby('ic_weight')['TERB'].std(), alpha=0.2, color='skyblue')\n",
    "\n",
    "idx = all_metrics.groupby('ic_weight')[\"loss_val\"].mean().idxmin() == all_metrics[\"ic_weight\"]\n",
    "idx = all_metrics[idx][\"loss_val\"].idxmin()\n",
    "plt.plot(all_metrics.loc[idx]['ic_weight'], all_metrics.loc[idx]['TERB'], 'y*', markersize=12, alpha=1, color=\"orange\", label=\"Min ERM\", zorder=10, clip_on=False, markeredgecolor='tab:blue', markeredgewidth=1)\n",
    "\n",
    "idx = all_metrics.groupby('ic_weight')[\"inv_loss_val\"].mean().idxmin() == all_metrics[\"ic_weight\"]\n",
    "idx = all_metrics[idx][\"inv_loss_val\"].idxmin()\n",
    "plt.plot(all_metrics.loc[idx]['ic_weight'], all_metrics.loc[idx]['TERB'], 'y*', markersize=12, alpha=1, color=\"purple\", label=\"Min Invariance\", zorder=10, clip_on=False, markeredgecolor='tab:blue', markeredgewidth=1)\n",
    "\n",
    "idx = all_metrics.groupby('ic_weight')[\"TEAB_val\"].mean().idxmin() == all_metrics[\"ic_weight\"]\n",
    "idx = all_metrics[idx][\"TEAB_val\"].idxmin()\n",
    "plt.plot(all_metrics.loc[idx]['ic_weight'], all_metrics.loc[idx]['TERB'], 'y*', markersize=12, alpha=1, color=\"green\", label=\"Min TERB\", zorder=10, clip_on=False, markeredgecolor='tab:blue', markeredgewidth=1)\n",
    "\n",
    "plt.ylim(0, 140)\n",
    "plt.legend(loc='upper left', framealpha=1, title=f\"Model Selection Criteria\",# ($on$ $validation$)\", \n",
    "           title_fontsize=\"8.5\", fontsize=\"8\", alignment=\"left\") \n",
    "plt.xticks([0.01, 0.1, 1, 10, 100, 1000, 10000], [f\"0\\n(ERM)\", 0.1, 1, 10, 100, 1000, 10000]);\n",
    "\n",
    "plt.twinx()\n",
    "plt.ylabel(\"Balanced Accuracy\", color='tab:red')\n",
    "plt.errorbar(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['bacc'].mean(), yerr=all_metrics.groupby('ic_weight')['bacc'].std(), fmt='o', color='tab:red', label=\"Accuracy\");\n",
    "plt.plot(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['bacc'].mean(), '--', color='tab:red', label=\"Balanced Accuracy\");\n",
    "plt.fill_between(all_metrics['ic_weight'].unique(), all_metrics.groupby('ic_weight')['bacc'].mean()-all_metrics.groupby('ic_weight')['bacc'].std(), all_metrics.groupby('ic_weight')['bacc'].mean()+all_metrics.groupby('ic_weight')['bacc'].std(), alpha=0.2, color='pink')\n",
    "plt.ylim(0.45, 1)\n",
    "plt.axvline(x=0.031622776601683794, color='black', linestyle='--', label=r\"separation\", alpha=0.2)\n",
    "\n",
    "# save the plot\n",
    "plt.savefig(f'{results_dir}/invariance.pdf', bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization of the training convergence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = \"dino\"\n",
    "split_criteria = \"position\"\n",
    "ic_weights = [0] + list(np.logspace(-1, 16, num=16))\n",
    "\n",
    "results_dir = f'results/istant_hq/{encoder}/{split_criteria}'\n",
    "train_metrics = np.load(f'{results_dir}/train_metrics.npy') # ic x seed x epochs x metrics\n",
    "val_metrics = np.load(f'{results_dir}/val_metrics.npy') # ic x seed x epochs x metrics\n",
    "\n",
    "# 6 plots: accuracy, balanced accuracy, precision, recall vs epochs varying ic_weight (for train and val)\n",
    "metrics = ['accuracy', 'balanced accuracy', 'precision', 'recall']\n",
    "colors = ['red', 'green', 'blue', 'purple', 'orange']\n",
    "plt.figure(figsize=(15, 5))\n",
    "for i, metric in enumerate(metrics):\n",
    "    plt.subplot(1, 4, i+1)\n",
    "    for j, ic_weight in enumerate(ic_weights):\n",
    "        plt.errorbar(np.arange(num_epochs), train_metrics[j, :, :, i].mean(axis=0), yerr=train_metrics[j, :, :, i].std(axis=0), color=colors[j], label=f'{ic_weight:.2f}')\n",
    "        plt.fill_between(np.arange(num_epochs), train_metrics[j, :, :, i].mean(axis=0)-train_metrics[j, :, :, i].std(axis=0), train_metrics[j, :, :, i].mean(axis=0)+train_metrics[j, :, :, i].std(axis=0), color=colors[j], alpha=0.2)\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel(metric)\n",
    "plt.tight_layout()\n",
    "plt.suptitle('Training', fontsize=16, y=1.05)\n",
    "plt.legend(loc='upper right', bbox_to_anchor=(0, -0.1), fancybox=True, ncol=len(ic_weights), title=r\"$\\lambda_{INV}$\", title_fontsize=\"12\", fontsize=\"12\")\n",
    "plt.show()\n",
    "\n",
    "plt.figure(figsize=(15, 5))\n",
    "for i, metric in enumerate(metrics):\n",
    "    plt.subplot(1, 4, i+1)\n",
    "    for j, ic_weight in enumerate(ic_weights):\n",
    "        plt.errorbar(np.arange(num_epochs), val_metrics[j, :, :, i].mean(axis=0), yerr=val_metrics[j, :, :, i].std(axis=0), color=colors[j], label=f'{ic_weight:.2f}')\n",
    "        plt.fill_between(np.arange(num_epochs), val_metrics[j, :, :, i].mean(axis=0)-val_metrics[j, :, :, i].std(axis=0), val_metrics[j, :, :, i].mean(axis=0)+val_metrics[j, :, :, i].std(axis=0), color=colors[j], alpha=0.2)\n",
    "    plt.xlabel('epochs')\n",
    "    plt.ylabel(metric)\n",
    "plt.tight_layout()\n",
    "plt.suptitle('Validation', fontsize=16, y=1.05)\n",
    "plt.legend(loc='upper right', bbox_to_anchor=(0, -0.1), fancybox=True, ncol=len(ic_weights), title=r\"$\\lambda_{INV}$\", title_fontsize=\"12\", fontsize=\"12\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "crl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
