{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../')\n",
    "import seaborn as sns \n",
    "from multiprocessing import Process\n",
    "from omegaconf import OmegaConf\n",
    "from src.utils.run_lib import *\n",
    "import numpy as np \n",
    "from src.core.passive_learning import *\n",
    "from src.core.auto_labeling import *\n",
    "from src.datasets import dataset_factory \n",
    "from src.datasets.dataset_utils import * \n",
    "from src.utils.common_utils import * \n",
    "from src.utils.vis_utils import *\n",
    "import copy \n",
    "import random \n",
    "from src.datasets.numpy_dataset import * \n",
    "\n",
    "\n",
    "root_dir = '../'\n",
    "conf_dir = f'{root_dir}configs/calib-exp/'\n",
    "\n",
    "base_conf_file = '{}/mnist_lenet_base_conf.yaml'.format(conf_dir)\n",
    "conf           = OmegaConf.load(base_conf_file)\n",
    "\n",
    "training_conf  = OmegaConf.load('{}/training_confs/fmfp_conf.yaml'.format(conf_dir))\n",
    "\n",
    "root_pfx       = 'mnist_lenet_calib-exp-runs'\n",
    "root_pfx       = f'{root_dir}/outputs/{root_pfx}/'\n",
    "\n",
    "\n",
    "conf['root_pfx']    = root_pfx\n",
    "\n",
    "conf.training_conf = training_conf\n",
    "\n",
    "\n",
    "logger   = get_logger('../temp/logs/act_lbl_test.log',stdout_redirect=True,level=logging.DEBUG)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf['train_pts_query_conf']['seed_train_size']= 10\n",
    "conf['train_pts_query_conf']['max_num_train_pts']= 50\n",
    "\n",
    "set_seed(conf['random_seed'])\n",
    "\n",
    "dm = DataManager(conf,logger,lib=conf['model_conf']['lib'])\n",
    "\n",
    "print(len(dm.ds_std_train), len(dm.ds_std_val))\n",
    "\n",
    "pl = PassiveLearning(conf,dm,logger)\n",
    "\n",
    "out = pl.run()\n",
    "\n",
    "w = pl.cur_clf.get_weights()\n",
    "print(torch.norm(w))\n",
    "test_err = get_test_error(pl.cur_clf,dm.ds_std_test,conf['inference_conf'])\n",
    "print(test_err)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inf_out = pl.cur_clf.predict(dm.ds_std_test,conf['inference_conf'])\n",
    "bin_data = compute_calibration(dm.ds_std_test.Y, inf_out['labels'], inf_out['confidence'], num_bins=15)\n",
    "ax = plt.subplot()\n",
    "reliability_diagram_subplot(ax, bin_data, draw_ece=True, draw_bin_importance=False, title=\"Reliability Diagram\",  xlabel=\"Confidence\", ylabel=\"Expected Accuracy\",disable_labels=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "lenet_clf_model = pl.cur_clf.model\n",
    "test_ds = dm.ds_std_test\n",
    "\n",
    "inference_conf = {} \n",
    "inference_conf['device']= 'cpu'\n",
    "inf_out_test_1 = pl.cur_clf.predict(test_ds, inference_conf) \n",
    "\n",
    "inf_out_test_1['true_labels'] = test_ds.Y\n",
    "c_idx = inf_out_test_1['true_labels'] ==  inf_out_test_1['labels']\n",
    "i_idx =inf_out_test_1['true_labels'] !=  inf_out_test_1['labels']\n",
    "\n",
    "\n",
    "sns.kdeplot(inf_out_test_1['confidence'][c_idx])\n",
    "sns.kdeplot(inf_out_test_1['confidence'][i_idx])\n",
    "\n",
    "#plt.hist(inf_out_val_1['confidence'][c_idx])\n",
    "#plt.hist(inf_out_val_1['confidence'][i_idx])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dm.unmark_auto_labeled()\n",
    "\n",
    "auto_labeler = AutoLabeling(conf,dm,pl.cur_clf,logger)\n",
    "out = auto_labeler.run()\n",
    "out = dm.get_auto_labeling_counts()\n",
    "print(out)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ml-kernel-tbal",
   "language": "python",
   "name": "ml-kernel-tbal"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
