{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from util import constants\n",
    "from util.config_util import get_model_params, get_task_params, get_train_params\n",
    "from tf2_models.trainer import Trainer\n",
    "from absl import app\n",
    "from absl import flags\n",
    "import numpy as np\n",
    "from util.models import MODELS\n",
    "from util.tasks import TASKS\n",
    "from notebooks.notebook_utils import *\n",
    "from distill.repsim_util import *\n",
    "import tensorflow_datasets as tfds\n",
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import seaborn as sns; sns.set()\n",
    "import logging\n",
    "tf.get_logger().setLevel(logging.ERROR)\n",
    "from tqdm import tqdm\n",
    "\n",
    "chkpt_dir='../../../InDist/tf_ckpts'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_outputs(model, x): \n",
    "  \"\"\" Calls the forward pass of the model and obtains logits and activations of the layer of interest!\n",
    "  \"\"\"\n",
    "  outputs = model.detailed_call(x, training=tf.convert_to_tensor(True))\n",
    "  logits, reps = outputs[0], outputs[model.rep_index]\n",
    "  if model.rep_layer is not None and model.rep_layer is not -1:\n",
    "    reps = reps[model.rep_layer]\n",
    "\n",
    "  return reps, logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def convert_examples(examples):\n",
    "    pad_length = int((28 - 28) / 2)\n",
    "    return tf.pad(tf.cast(examples['image'], dtype=tf.float32) / 255,\n",
    "                  ([pad_length, pad_length], [pad_length, pad_length],\n",
    "                   [0, 0])), tf.cast(\n",
    "      examples['label'], dtype=tf.int32)\n",
    "\n",
    "cmnist_trans = tfds.load('mnist_corrupted/translate', split='test')\n",
    "cmnist_trans = cmnist_trans.map(map_func=lambda x: convert_examples(x))\n",
    "cmnist_trans = cmnist_trans.batch(64)\n",
    "\n",
    "task1 = 'mnist'\n",
    "task1 = TASKS[task1](get_task_params(), data_dir='../data')\n",
    "cl_token = 0\n",
    "\n",
    "task = task1\n",
    "\n",
    "models = []\n",
    "labels = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmnist_trans = tfds.builder('mnist_corrupted/translate')\n",
    "cmnist_trans_test = cmnist_trans.as_dataset(split='test')\n",
    "cmnist_trans_test = cmnist_trans_test.map(map_func=lambda x: convert_examples(x))\n",
    "cmnist_trans_test = cmnist_trans_test.batch(64)\n",
    "cmnist_trans.info.splits['test'].num_examples / 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cmnist_scale = tfds.builder('mnist_corrupted/scale')\n",
    "cmnist_scale_test = cmnist_scale.as_dataset(split='test')\n",
    "cmnist_scale_test = cmnist_scale_test.map(map_func=lambda x: convert_examples(x))\n",
    "cmnist_scale_test = cmnist_scale_test.batch(64)\n",
    "cmnist_scale.info.splits['test'].num_examples / 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## CNN -> FF\n",
    "###############################1\n",
    "config={'student_exp_name':'gc_o_std10',\n",
    "        'teacher_exp_name':'gc_o_dtchr10',\n",
    "        'task_name':'mnist',\n",
    "        'teacher_model':'cl_vcnn',\n",
    "        'student_model':'cl_vff',\n",
    "        'teacher_config':'vcnn_mnist7',\n",
    "        'student_config':'ff_mnist4',\n",
    "        'distill_config':'pure_dstl2_4_crs_slw_3',\n",
    "        'distill_mode':'online',\n",
    "        'chkpt_dir': chkpt_dir,\n",
    "       }\n",
    "\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# cnn2ff_ot10_std10, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# cnn_t10, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "# models.extend([cnn_t10, cnn2ff_ot10_std10])\n",
    "# labels.extend(['cnn_t10', 'cnn2ff_ot10_std10'])\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std10'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr10'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "cnn2ff_ft10_std10, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "cnn_t10, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "models.extend([cnn_t10, cnn2ff_ft10_std10])\n",
    "labels.extend(['cnn_t10', 'cnn2ff_ft1_std10'])\n",
    "###############################2\n",
    "config['student_exp_name'] ='gc_o_std11'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr11'\n",
    "config['distill_mode'] ='online'\n",
    "config['distill_config'] ='pure_dstl2_4_crs_slw_3'\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# cnn2ff_ot11_std11, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# cnn_t11, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "# models.extend([cnn_t11, cnn2ff_ot11_std11])\n",
    "# labels.extend(['cnn_t11', 'cnn2ff_ot11_std11'])\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std11'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr11'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "cnn2ff_ft11_std11, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "cnn_t11, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "models.extend([cnn_t11, cnn2ff_ft11_std11])\n",
    "labels.extend(['cnn_t11', 'cnn2ff_ft11_std11'])\n",
    "###############################3\n",
    "config['student_exp_name'] ='gc_o_std12'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr12'\n",
    "config['distill_mode'] ='online'\n",
    "config['distill_config'] ='pure_dstl2_4_crs_slw_3'\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# cnn2ff_ot12_std12, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# cnn_t12, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "# models.extend([cnn_t12, cnn2ff_ot12_std12])\n",
    "# labels.extend(['cnn_t12', 'cnn2ff_ot12_std12'])\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std12'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr12'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "cnn2ff_ft12_std12, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "cnn_t12, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "models.extend([cnn_t12, cnn2ff_ft12_std12])\n",
    "labels.extend(['cnn_t12', 'cnn2ff_ft12_std12'])\n",
    "###############################4\n",
    "config['student_exp_name'] ='gc_o_std13'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr13'\n",
    "config['distill_mode'] ='online'\n",
    "config['distill_config'] ='pure_dstl2_4_crs_slw_3'\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# cnn2ff_ot13_std13, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# cnn_t13, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "# models.extend([cnn_t13, cnn2ff_ot13_std13])\n",
    "# labels.extend(['cnn_t13', 'cnn2ff_ot13_std13'])\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std13'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr13'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "cnn2ff_ft13_std13, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "cnn_t13, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "models.extend([cnn_t13, cnn2ff_ft13_std13])\n",
    "labels.extend(['cnn_t13', 'cnn2ff_ft13_std13'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#CNN -> CNN\n",
    "###############################1\n",
    "config={'student_exp_name':'gc_o_std18',\n",
    "        'teacher_exp_name':'gc_o_tchr18',\n",
    "        'task_name':'mnist',\n",
    "        'teacher_model':'cl_vcnn',\n",
    "        'student_model':'cl_vcnn',\n",
    "        'teacher_config':'vcnn_mnist7',\n",
    "        'student_config':'vcnn_mnist7',\n",
    "        'distill_config':'pure_dstl2_4_crs_slw_3',\n",
    "        'distill_mode':'online',\n",
    "        'chkpt_dir': chkpt_dir,\n",
    "       }\n",
    "\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# cnn2cnn_ot18_std18, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# cnn_t18, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "# models.extend([cnn_t18, cnn2cnn_ot18_std18])\n",
    "# labels.extend(['cnn_t18', 'cnn2cnn_ft18_std18'])\n",
    "\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std18'\n",
    "config['teacher_exp_name'] ='gc_o_tchr18'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "cnn2cnn_ft18_std18, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "cnn_t18, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "models.extend([cnn_t18, cnn2cnn_ft18_std18])\n",
    "labels.extend(['cnn_t18', 'cnn2cnn_ft18_std18'])\n",
    "###############################2\n",
    "config['student_exp_name'] ='gc_o_std17'\n",
    "config['teacher_exp_name'] ='gc_o_tchr17'\n",
    "config['distill_mode'] ='online'\n",
    "config['distill_config'] ='pure_dstl2_4_crs_slw_3'\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# cnn2cnn_ot17_std17, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# cnn_t17, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "# models.extend([cnn_t17, cnn2cnn_ot17_std17])\n",
    "# labels.extend(['cnn_t17', 'cnn2cnn_ft17_std17'])\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std17'\n",
    "config['teacher_exp_name'] ='gc_o_tchr17'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "cnn2cnn_ft17_std17, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "cnn_t17, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "models.extend([cnn_t17, cnn2cnn_ft17_std17])\n",
    "labels.extend(['cnn_t17', 'cnn2cnn_ft17_std17'])\n",
    "\n",
    "###############################3\n",
    "config['student_exp_name'] ='gc_o_std16'\n",
    "config['teacher_exp_name'] ='gc_o_tchr16'\n",
    "config['distill_mode'] ='online'\n",
    "config['distill_config'] ='pure_dstl2_4_crs_slw_3'\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# cnn2cnn_ot16_std16, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# cnn_t16, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "# models.extend([cnn_t16, cnn2cnn_ot16_std16])\n",
    "# labels.extend(['cnn_t16', 'cnn2cnn_ft16_std16'])\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std16'\n",
    "config['teacher_exp_name'] ='gc_o_tchr16'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "cnn2cnn_ft16_std16, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "cnn_t16, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "models.extend([cnn_t16, cnn2cnn_ft16_std16])\n",
    "labels.extend(['cnn_t16', 'cnn2cnn_ft16_std16'])\n",
    "###############################4\n",
    "config['student_exp_name'] ='gc_o_std15'\n",
    "config['teacher_exp_name'] ='gc_o_tchr15'\n",
    "config['distill_mode'] ='online'\n",
    "config['distill_config'] ='pure_dstl2_4_crs_slw_3'\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# cnn2cnn_ot15_std15, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# cnn_t15, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "# models.extend([cnn_t15, cnn2cnn_ot15_std15])\n",
    "# labels.extend(['cnn_t15', 'cnn2cnn_ot15_std15'])\n",
    "\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std15'\n",
    "config['teacher_exp_name'] ='gc_o_tchr15'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "cnn2cnn_ft15_std15, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "cnn_t15, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "models.extend([cnn_t15, cnn2cnn_ft15_std15])\n",
    "labels.extend(['cnn_t15', 'cnn2cnn_ft15_std15'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#FF -> CNN\n",
    "###############################1\n",
    "config={'student_exp_name':'gc_o_std100',\n",
    "        'teacher_exp_name':'gc_o_dtchr100',\n",
    "        'task_name':'mnist',\n",
    "        'teacher_model':'cl_vff',\n",
    "        'student_model':'cl_vcnn',\n",
    "        'teacher_config':'ff_mnist4',\n",
    "        'student_config':'vcnn_mnist7',\n",
    "        'distill_config':'pure_dstl2_4_crs_slw_3',\n",
    "        'distill_mode':'online',\n",
    "        'chkpt_dir': chkpt_dir,\n",
    "       }\n",
    "\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# ff2cnn_ot100_std100, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# ff_t100, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "# models.extend([ff_t100, ff2cnn_ot100_std100])\n",
    "# labels.extend(['ff_t100', 'ff2cnn_ot100_std100'])\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std100'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr100'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "ff2cnn_ft100_std100, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "ff_t100, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "models.extend([ff_t100, ff2cnn_ft100_std100])\n",
    "labels.extend(['ff_t100', 'ff2cnn_ft100_std100'])\n",
    "###############################2\n",
    "config['student_exp_name'] ='gc_o_std101'\n",
    "config['teacher_exp_name'] ='gc_o_tchr101'\n",
    "config['distill_mode'] ='online'\n",
    "config['distill_config'] ='pure_dstl2_4_crs_slw_3'\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "# ff2cnn_ot101_std101, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# ff_t101, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "# models.extend([ff_t101, ff2cnn_ot101_std101])\n",
    "# labels.extend(['ff_t101', 'ff2cnn_ot101_std101'])\n",
    "\n",
    "\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std101'\n",
    "config['teacher_exp_name'] ='gc_o_tchr101'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "ff2cnn_ft101_std101, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "ff_t101, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "models.extend([ff_t101, ff2cnn_ft101_std101])\n",
    "labels.extend(['ff_t101', 'ff2cnn_ft101_std101'])\n",
    "\n",
    "###############################3\n",
    "config['student_exp_name'] ='gc_o_std102'\n",
    "config['teacher_exp_name'] ='gc_o_tchr102'\n",
    "config['distill_mode'] ='online'\n",
    "config['distill_config'] ='pure_dstl2_4_crs_slw_3'\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# ff2cnn_ot102_std102, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# ff_t102, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "# models.extend([ff_t102, ff2cnn_ot102_std102])\n",
    "# labels.extend(['ff_t102', 'ff2cnn_ot102_std102'])\n",
    "\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std102'\n",
    "config['teacher_exp_name'] ='gc_o_tchr102'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "ff2cnn_ft102_std102, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "ff_t102, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "models.extend([ff_t102, ff2cnn_ft102_std102])\n",
    "labels.extend(['ff_t102', 'ff2cnn_ft102_std102'])\n",
    "###############################4\n",
    "config['student_exp_name'] ='gc_o_std103'\n",
    "config['teacher_exp_name'] ='gc_o_tchr103'\n",
    "config['distill_mode'] ='online'\n",
    "config['distill_config'] ='pure_dstl2_4_crs_slw_3'\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# ff2cnn_ot103_std103, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# ff_t103, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "# models.extend([ff_t103, ff2cnn_ot103_std103])\n",
    "# labels.extend(['ff_t103', 'ff2cnn_ot103_std103'])\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std103'\n",
    "config['teacher_exp_name'] ='gc_o_tchr103'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "ff2cnn_ft103_std103, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "ff_t103, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "models.extend([ff_t103, ff2cnn_ft103_std103])\n",
    "labels.extend(['ff_t103', 'ff2cnn_ft103_std103'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "##ff -> ff\n",
    "###############################1\n",
    "config={'student_exp_name':'gc_o_std210',\n",
    "        'teacher_exp_name':'gc_o_dtchr210',\n",
    "        'task_name':'mnist',\n",
    "        'teacher_model':'cl_vff',\n",
    "        'student_model':'cl_vff',\n",
    "        'teacher_config':'ff_mnist4',\n",
    "        'student_config':'ff_mnist4',\n",
    "        'distill_config':'pure_dstl2_4_crs_slw_3',\n",
    "        'distill_mode':'online',\n",
    "        'chkpt_dir': chkpt_dir,\n",
    "       }\n",
    "\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# ff2ff_ot1_std1, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# ff_t1, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "\n",
    "# models.extend([ff2ff_ot1_std1, ff_t1])\n",
    "# labels.extend(['ff2ffot1_std1', 'ff_t1'])\n",
    "\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std210'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr210'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "ff2ff_ft1_std1, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "ff_t1, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "models.extend([ff2ff_ft1_std1, ff_t1])\n",
    "labels.extend(['ff2ff_ft1_std1', 'ff_t1'])\n",
    "print('student perf:')\n",
    "print(ff2ff_ft1_std1.evaluate(task.test_dataset, steps=10))\n",
    "print('teacher perf:')\n",
    "print(ff_t1.evaluate(task.test_dataset, steps=10))\n",
    "###############################2\n",
    "config['student_exp_name'] ='gc_o_std211'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr211'\n",
    "config['distill_mode'] ='online'\n",
    "config['distill_config'] ='pure_dstl2_4_crs_slw_3'\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# ff2ff_ot2_std2, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# ff_t2, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "\n",
    "# models.extend([ff2ff_ot2_std2, ff_t2])\n",
    "# labels.extend(['ff2ff_ot2_std2', 'ff_t2'])\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std211'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr211'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "ff2ff_ft2_std2, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "ff_t2, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "models.extend([ff2ff_ft2_std2, ff_t2])\n",
    "labels.extend(['ff2ff_ft2_std2', 'ff_t2'])\n",
    "print('student perf:')\n",
    "print(ff2ff_ft2_std2.evaluate(task.test_dataset, steps=10))\n",
    "print('teacher perf:')\n",
    "print(ff_t2.evaluate(task.test_dataset, steps=10))\n",
    "###############################3\n",
    "config['student_exp_name'] ='gc_o_std212'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr212'\n",
    "config['distill_mode'] ='online'\n",
    "config['distill_config'] ='pure_dstl2_4_crs_slw_3'\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# ff2ff_ot3_std3, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# ff_t3, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "\n",
    "# models.extend([ff2ff_ot3_std3, ff_t3])\n",
    "# labels.extend(['ff2ff_ot3_std3', 'ff_t3'])\n",
    "\n",
    "\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std212'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr212'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "ff2ff_ft3_std3, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "ff_t3, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "models.extend([ff2ff_ft3_std3, ff_t3])\n",
    "labels.extend(['ff2ff_ft3_std3', 'ff_t3'])\n",
    "print('student perf:')\n",
    "print(ff2ff_ft3_std3.evaluate(task.test_dataset, steps=10))\n",
    "print('teacher perf:')\n",
    "print(ff_t3.evaluate(task.test_dataset, steps=10))\n",
    "###############################4\n",
    "config['student_exp_name'] ='gc_o_std202'\n",
    "config['teacher_exp_name'] ='gc_o_tchr202'\n",
    "config['distill_mode'] ='online'\n",
    "config['distill_config'] ='pure_dstl2_4_crs_slw_3'\n",
    "# std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "# ff2ff_ot4_std4, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "# ff_t4, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "# models.extend([ff2ff_ot4_std4, ff_t4])\n",
    "# labels.extend(['ff2ff_ot4_std4', 'ff_t4'])\n",
    "\n",
    "\n",
    "#**********offline\n",
    "config['student_exp_name'] ='gc_f_std220'\n",
    "config['teacher_exp_name'] ='gc_o_dtchr210'\n",
    "config['distill_mode'] ='offline'\n",
    "config['distill_config'] ='pure_dstl5_4_crs_slw_3'\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "\n",
    "ff2ff_ft4_std4, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "ff_t4, _ = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "\n",
    "models.extend([ff2ff_ft4_std4, ff_t4])\n",
    "labels.extend(['ff2ff_ft4_std4', 'ff_t4'])\n",
    "print('student perf:')\n",
    "print(ff2ff_ft4_std4.evaluate(task.test_dataset, steps=10))\n",
    "print('teacher perf:')\n",
    "print(ff_t4.evaluate(task.test_dataset, steps=10))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for x, y in task.valid_dataset:\n",
    "    for model in models:\n",
    "        print(model.model_name)\n",
    "        model(x)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def  get_reploss_dic(models, labels, data, repindex=1):\n",
    "    \"\"\" Compute pairwise similarity between models\n",
    "    Input args:\n",
    "        models: model objects.\n",
    "        labels: name/id of the models\n",
    "        data:   examples based on which we want to compute the representational similarity\n",
    "        repindex: indicates whether we compute the representational similairty based on the penultimate layer or the logits layaer.\n",
    "    \"\"\"\n",
    "    reploss_dic = {}\n",
    "    for l1 in labels:\n",
    "        reploss_dic[l1] = {}\n",
    "        for l2 in labels:\n",
    "            reploss_dic[l1][l2] = []\n",
    "\n",
    "    num_batches = 0\n",
    "    for x, y in data:\n",
    "        reps = []\n",
    "        for m in models:\n",
    "            outputs = get_outputs(m, x)\n",
    "            reps.append(outputs[repindex])\n",
    "        for i in np.arange(len(labels)):\n",
    "            for j in np.arange(i, len(labels)):\n",
    "                reploss = rep_loss(reps1=reps[i], reps2=reps[j],\n",
    "                                     padding_symbol=None,\n",
    "                                     inputs=x)\n",
    "                reploss_dic[labels[i]][labels[j]].append(reploss)\n",
    "                if i != j:\n",
    "                    reploss_dic[labels[j]][labels[i]].append(reploss)\n",
    "        num_batches += 1\n",
    "\n",
    "        if num_batches > 20:\n",
    "            break\n",
    "            \n",
    "    dist_lists = {}\n",
    "    for l1 in reploss_dic.keys():\n",
    "        dist_lists[l1] = []\n",
    "        for i in  np.arange(len(reploss_dic[l1].keys())):\n",
    "            l2 = list(reploss_dic[l1].keys())[i]\n",
    "            dist_lists[l1].append(np.mean(reploss_dic[l1][l2]))\n",
    "            \n",
    "    return reploss_dic, dist_lists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import manifold\n",
    "\n",
    "def project2d(reploss_dic, dist_lists):\n",
    "    dists = []\n",
    "    points = []\n",
    "    for i in np.arange(len(dist_lists.keys())):\n",
    "        k = list(dist_lists.keys())[i]\n",
    "        points.append(k)\n",
    "        dists.append(list(map(float , dist_lists[k])))\n",
    "\n",
    "    adist = np.array(dists)\n",
    "    amax = np.amax(adist)\n",
    "    adist /= amax\n",
    "\n",
    "    mds = manifold.MDS(n_components=2, dissimilarity=\"precomputed\", random_state=12345)\n",
    "    results = mds.fit(adist)\n",
    "\n",
    "    coords = results.embedding_\n",
    "\n",
    "    plt.subplots_adjust(bottom = 0.1)\n",
    "    plt.scatter(\n",
    "        coords[:, 0], coords[:, 1], marker = 'o'\n",
    "        )\n",
    "    for label, x, y in zip(points, coords[:, 0], coords[:, 1]):\n",
    "        plt.annotate(\n",
    "            label,\n",
    "            xy = (x, y), xytext = (-20, 20),\n",
    "            textcoords = 'offset points', ha = 'right', va = 'bottom',\n",
    "            bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),\n",
    "            arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))\n",
    "\n",
    "    plt.show()\n",
    "    \n",
    "    return results.embedding_, points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reploss_dic, dist_lists = get_reploss_dic(models, labels, cmnist_trans_test, repindex=1) # repindex==1 means logits\n",
    "trans_Ez, trans_Lz = project2d(reploss_dic, dist_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reploss_dic, dist_lists = get_reploss_dic(models, labels, cmnist_scale_test, repindex=1)\n",
    "scale_Ez, scale_Lz = project2d(reploss_dic, dist_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reploss_dic, dist_lists = get_reploss_dic(models, labels, cmnist_trans_test, repindex=0) # repindex==0 means penultimate\n",
    "trans_Ez, trans_Lz = project2d(reploss_dic, dist_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reploss_dic, dist_lists = get_reploss_dic(models, labels, cmnist_scale_test, repindex=0)\n",
    "scale_Ez, scale_Lz = project2d(reploss_dic, dist_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model,label in zip(models,labels):\n",
    "    print(label)\n",
    "    results = model.evaluate(cmnist_trans_test, steps=cmnist_trans.info.splits['test'].num_examples / 64)\n",
    "    print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model,label in zip(models,labels):\n",
    "    print(label)\n",
    "    results = model.evaluate(cmnist_scale_test, steps=cmnist_scale.info.splits['test'].num_examples / 64)\n",
    "    print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for model,label in zip(models,labels):\n",
    "    print(label)\n",
    "    results = model.evaluate(task.test_dataset, steps=task.n_test_batches)\n",
    "    print(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow_probability as tfp\n",
    "\n",
    "for model,label in zip(models,labels):\n",
    "    print(label)\n",
    "    print(model.model_name)\n",
    "\n",
    "    model_accuracy, predicted_class_probs, correct_class_probs, model_logits, model_trues= test_for_calibration(model, cmnist_scale_test, cmnist_scale.info.splits['test'].num_examples / 64, task, n_bins=20)\n",
    "    ##plot_calibration(model_accuracy, predicted_class_probs, correct_class_probs, n_bins=20)\n",
    "    plt.show()\n",
    "    model_ece = tfp.stats.expected_calibration_error(\n",
    "        1000000,\n",
    "        logits=model_logits,\n",
    "        labels_true=model_trues,\n",
    "    )\n",
    "    print(model_ece.numpy())\n",
    "    \n",
    "    print(\"#######################\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for c,l in zip(ff_lz,ff_ez):\n",
    "    print(l.split('_')[0],'\\t',l, '\\t',c[0], '\\t',c[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_o_std300',\n",
    "        'teacher_exp_name':'gc_o_tchr300',\n",
    "        'task_name':'mnist',\n",
    "        'teacher_model':'cl_vff',\n",
    "        'student_model':'cl_vff',\n",
    "        'teacher_config':'ff_mnist4',\n",
    "        'student_config':'ff_mnist4',\n",
    "        'distill_config':'pure_dstl2_4_crs_slw_3',\n",
    "        'distill_mode':'online',\n",
    "        'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "teacher_model = MODELS[config['teacher_model']](hparams=tchr_hparams, cl_token=cl_token)\n",
    "\n",
    "   \n",
    "ckpt_dir = os.path.join(config['chkpt_dir'], task.name,\n",
    "                              '_'.join([teacher_model.model_name, config['teacher_config'],config['teacher_exp_name']]))\n",
    "tchr_ckpt = tf.train.Checkpoint(net=teacher_model)\n",
    "teacher_manager = tf.train.CheckpointManager(tchr_ckpt, ckpt_dir, max_to_keep=None)\n",
    "teacher_manager.latest_checkpoint\n",
    "\n",
    "student_model = MODELS[config['student_model']](hparams=std_hparams, cl_token=cl_token)\n",
    "ckpt_dir = os.path.join(config['chkpt_dir'], task.name,\n",
    "                          '_'.join([config['distill_mode'],config['distill_config'],\n",
    "                                    \"teacher\", teacher_model.model_name, \n",
    "                                    config['teacher_config'],\n",
    "                                    config['teacher_exp_name'],\n",
    "                                   \"student\",student_model.model_name,\n",
    "                                    str(config['student_config']),\n",
    "                                    config['student_exp_name']]))\n",
    "print(\"student_checkpoint:\", ckpt_dir)\n",
    "std_ckpt = tf.train.Checkpoint(net=student_model)\n",
    "student_manager = tf.train.CheckpointManager(std_ckpt, ckpt_dir, max_to_keep=None)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(student_manager.latest_checkpoint)\n",
    "for ck in student_manager.checkpoints:\n",
    "    std_ckpt.restore(ck)\n",
    "    #print(\"#######\", ck, \"########\")\n",
    "    student_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "    trans = student_model.evaluate(cmnist_trans_test, steps=cmnist_trans.info.splits['test'].num_examples / 64, verbose=3)\n",
    "    scale = student_model.evaluate(cmnist_scale_test, steps=cmnist_scale.info.splits['test'].num_examples / 64, verbose=3)\n",
    "    mnist = student_model.evaluate(task.test_dataset, steps=task.n_test_batches, verbose=3)\n",
    "    print(trans[-1], scale[-1], mnist[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "tf.get_logger().setLevel(logging.ERROR)\n",
    "\n",
    "from copy import deepcopy, copy\n",
    "print(student_manager.latest_checkpoint)\n",
    "Mz = []\n",
    "Lz = []\n",
    "# for i, ck in enumerate(student_manager.checkpoints):\n",
    "#     std_ckpt.restore(ck)\n",
    "#     print(\"#######\", ck, \"########\", i)\n",
    "#     student_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "#     Mz.append(copy(student_model))\n",
    "#     Lz.append('ff2ff'+str(i))\n",
    "    \n",
    "    \n",
    "    \n",
    "for i, ck in enumerate(teacher_manager.checkpoints):\n",
    "    tchr_ckpt.restore(ck)\n",
    "    print(\"#######\", ck, \"########\", i)\n",
    "    teacher_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "    Mz.append(copy(teacher_model))\n",
    "    Lz.append('ff'+str(i))\n",
    "    \n",
    "reploss_dic, dist_lists = get_reploss_dic(Mz, Lz, cmnist_trans_test, repindex=1)\n",
    "ff_lz, ff_ez = project2d(reploss_dic, dist_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "tf.get_logger().setLevel(logging.ERROR)\n",
    "\n",
    "from copy import deepcopy, copy\n",
    "print(student_manager.latest_checkpoint)\n",
    "# Mz = []\n",
    "# Lz = []\n",
    "for i, ck in enumerate(student_manager.checkpoints):\n",
    "    std_ckpt.restore(ck)\n",
    "    print(\"#######\", ck, \"########\", i)\n",
    "    student_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "    Mz.append(copy(student_model))\n",
    "    Lz.append('ff2ff'+str(i))\n",
    "    \n",
    "    \n",
    "    \n",
    "for i, ck in enumerate(teacher_manager.checkpoints):\n",
    "    tchr_ckpt.restore(ck)\n",
    "    print(\"#######\", ck, \"########\", i)\n",
    "    teacher_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "    Mz.append(copy(teacher_model))\n",
    "    Lz.append('ff'+str(i))\n",
    "    \n",
    "reploss_dic, dist_lists = get_reploss_dic(Mz, Lz, cmnist_trans_test, repindex=1)\n",
    "project2d(reploss_dic, dist_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "tf.get_logger().setLevel(logging.ERROR)\n",
    "\n",
    "from copy import deepcopy, copy\n",
    "print(student_manager.latest_checkpoint)\n",
    "# Mz = []\n",
    "# Lz = []\n",
    "# for i, ck in enumerate(student_manager.checkpoints):\n",
    "#     std_ckpt.restore(ck)\n",
    "#     print(\"#######\", ck, \"########\", i)\n",
    "#     student_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "#     Mz.append(copy(student_model))\n",
    "#     Lz.append('ff2ff'+str(i))\n",
    "    \n",
    "    \n",
    "    \n",
    "# for i, ck in enumerate(teacher_manager.checkpoints):\n",
    "#     tchr_ckpt.restore(ck)\n",
    "#     print(\"#######\", ck, \"########\", i)\n",
    "#     teacher_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "#     Mz.append(copy(teacher_model))\n",
    "#     Lz.append('ff'+str(i))\n",
    "    \n",
    "reploss_dic, dist_lists = get_reploss_dic(Mz, Lz, cmnist_trans_test, repindex=1)\n",
    "project2d(reploss_dic, dist_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import logging\n",
    "tf.get_logger().setLevel(logging.ERROR)\n",
    "\n",
    "from copy import deepcopy, copy\n",
    "\n",
    "print(student_manager.latest_checkpoint)\n",
    "Mz = []\n",
    "Lz = []\n",
    "for i, ck in enumerate(student_manager.checkpoints):\n",
    "    std_ckpt.restore(ck)\n",
    "    print(\"#######\", ck, \"########\", i)\n",
    "    student_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "    Mz.append(deepcopy(student_model))\n",
    "    Lz.append('cnn2ff'+str(i))\n",
    "    \n",
    "for i, ck in enumerate(teacher_manager.checkpoints):\n",
    "    tchr_ckpt.restore(ck)\n",
    "    print(\"#######\", ck, \"########\", i)\n",
    "    teacher_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "    Mz.append(deepcopy(teacher_model))\n",
    "    Lz.append('cnn'+str(i))\n",
    "    \n",
    "reploss_dic, dist_lists = get_reploss_dic(Mz, Lz, task.valid_dataset, repindex=1)\n",
    "project2d(reploss_dic, dist_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Mz = []\n",
    "Lz = []\n",
    "# for i, ck in enumerate(student_manager.checkpoints):\n",
    "#     std_ckpt.restore(ck)\n",
    "#     print(\"#######\", ck, \"########\", i)\n",
    "#     student_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "#     Mz.append(deepcopy(student_model))\n",
    "#     Lz.append('ff2cnn'+str(i))\n",
    "    \n",
    "for i, ck in enumerate(teacher_manager.checkpoints):\n",
    "    tchr_ckpt.restore(ck)\n",
    "    print(\"#######\", ck, \"########\", i)\n",
    "    teacher_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "    Mz.append(deepcopy(teacher_model))\n",
    "    Lz.append('cnn'+str(i))\n",
    "    \n",
    "reploss_dic, dist_lists = get_reploss_dic(Mz, Lz, cmnist_trans_test, repindex=1)\n",
    "cnn_lz, cnn_ez = project2d(reploss_dic, dist_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from copy import deepcopy\n",
    "Mz = []\n",
    "Lz = []\n",
    "for i, ck in enumerate(student_manager.checkpoints):\n",
    "    std_ckpt.restore(ck)\n",
    "    print(\"#######\", ck, \"########\", i)\n",
    "    student_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "    Mz.append(deepcopy(student_model))\n",
    "    Lz.append('cnn2ff'+str(i))\n",
    "    \n",
    "# for i, ck in enumerate(teacher_manager.checkpoints):\n",
    "#     tchr_ckpt.restore(ck)\n",
    "#     print(\"#######\", ck, \"########\", i)\n",
    "#     teacher_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "#     Mz.append(deepcopy(teacher_model))\n",
    "#     Lz.append('cnn'+str(i))\n",
    "    \n",
    "reploss_dic, dist_lists = get_reploss_dic(Mz, Lz, cmnist_trans_test, repindex=1)\n",
    "cnn2ff_lz, cnn2ff_ez = project2d(reploss_dic, dist_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(teacher_manager.latest_checkpoint)\n",
    "for ck in teacher_manager.checkpoints:\n",
    "    tchr_ckpt.restore(ck)\n",
    "    #print(\"#######\", ck, \"########\")\n",
    "    teacher_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "    trans = teacher_model.evaluate(cmnist_trans_test, steps=cmnist_trans.info.splits['test'].num_examples / 64, verbose=3)\n",
    "    scale = teacher_model.evaluate(cmnist_scale_test, steps=cmnist_scale.info.splits['test'].num_examples / 64, verbose=3)\n",
    "    mnist = teacher_model.evaluate(task.test_dataset, steps=task.n_test_batches, verbose=3)\n",
    "    print(trans[-1], scale[-1], mnist[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_f_std1000',\n",
    "        'teacher_exp_name':'gc_o_dtchr1000',\n",
    "        'task_name':'mnist',\n",
    "        'teacher_model':'cl_vcnn',\n",
    "        'student_model':'cl_vff',\n",
    "        'teacher_config':'vcnn_mnist7',\n",
    "        'student_config':'ff_mnist4',\n",
    "        'distill_config':'pure_dstl5_4_crs_slw_3',\n",
    "        'distill_mode':'offline',\n",
    "        'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "\n",
    "teacher_model = MODELS[config['teacher_model']](hparams=tchr_hparams, cl_token=cl_token)\n",
    "\n",
    "   \n",
    "ckpt_dir = os.path.join(config['chkpt_dir'], task.name,\n",
    "                              '_'.join([teacher_model.model_name, config['teacher_config'],config['teacher_exp_name']]))\n",
    "tchr_ckpt = tf.train.Checkpoint(net=teacher_model)\n",
    "teacher_manager = tf.train.CheckpointManager(tchr_ckpt, ckpt_dir, max_to_keep=None)\n",
    "teacher_manager.latest_checkpoint\n",
    "\n",
    "student_model = MODELS[config['student_model']](hparams=std_hparams, cl_token=cl_token)\n",
    "ckpt_dir = os.path.join(config['chkpt_dir'], task.name,\n",
    "                          '_'.join([config['distill_mode'],config['distill_config'],\n",
    "                                    \"teacher\", teacher_model.model_name, \n",
    "                                    config['teacher_config'],\n",
    "                                    config['teacher_exp_name'],\n",
    "                                   \"student\",student_model.model_name,\n",
    "                                    str(config['student_config']),\n",
    "                                    config['student_exp_name']]))\n",
    "print(\"student_checkpoint:\", ckpt_dir)\n",
    "std_ckpt = tf.train.Checkpoint(net=student_model)\n",
    "student_manager = tf.train.CheckpointManager(std_ckpt, ckpt_dir, max_to_keep=None)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow_probability as tfp\n",
    "\n",
    "\n",
    "model_accuracy, predicted_class_probs, correct_class_probs, model_logits, model_trues= test_for_calibration(teacher_model,cmnist_trans, cmnist_trans.info.splits['test'].num_examples / 64, task, n_bins=20)\n",
    "\n",
    "print(len(model_accuracy))\n",
    "print(len(predicted_class_probs))\n",
    "plot_calibration(model_accuracy, predicted_class_probs, correct_class_probs, n_bins=20)\n",
    "plt.show()\n",
    "model_ece = tfp.stats.expected_calibration_error(\n",
    "    1000000,\n",
    "    logits=model_logits,\n",
    "    labels_true=model_trues,\n",
    ")\n",
    "print(model_ece.numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_for_calibration(model, data, batch_count, task, n_bins=10):\n",
    "    preds = []\n",
    "    correct_class_probs = []\n",
    "    predicted_class_probs = []\n",
    "    pred_logits = []\n",
    "    y_trues = []\n",
    "    for x, y in data:\n",
    "        logits = model(x)\n",
    "        pred_logits.extend(logits.numpy())\n",
    "        pred = tf.argmax(logits, axis=-1)\n",
    "        \n",
    "        prob = task.get_probs_fn()(logits, labels=y, temperature=1)\n",
    "        preds.extend(pred.numpy())\n",
    "        y_trues.extend(y.numpy())\n",
    "        batch_indexes = tf.cast(tf.range(len(y), dtype=tf.int32), dtype=tf.int32)\n",
    "        true_indexes = tf.concat([batch_indexes[:,None], y[:,None]], axis=1)\n",
    "        pred_indexes = tf.concat([batch_indexes[:,None], tf.cast(pred[:,None], tf.int32)], axis=1)\n",
    "\n",
    "        correct_class_probs.extend(tf.gather_nd(prob, true_indexes).numpy())\n",
    "        predicted_class_probs.extend(tf.gather_nd(prob, pred_indexes).numpy())\n",
    "\n",
    "        batch_count -= 1\n",
    "        if batch_count == 0:\n",
    "            break\n",
    "\n",
    "    model_accuracy = np.asarray(preds) == np.asarray(y_trues)\n",
    "\n",
    "    return model_accuracy, predicted_class_probs, correct_class_probs, pred_logits, y_trues\n",
    "\n",
    "def plot_calibration(model_accuracy, predicted_class_probs, correct_class_probs, n_bins=10):\n",
    "    p_confidence_bins = np.zeros(n_bins+1)\n",
    "    n_confidence_bins = np.zeros(n_bins+1)\n",
    "    total_confidence_bins = np.zeros(n_bins+1)\n",
    "    \n",
    "    denominator = 100.0 / n_bins\n",
    "    for i in np.arange(len(model_accuracy)):\n",
    "        if model_accuracy[i]:\n",
    "            p_confidence_bins[int(predicted_class_probs[i]*100 / denominator)] += 1.0\n",
    "        else:\n",
    "            n_confidence_bins[int(predicted_class_probs[i]*100 / denominator)] -= 1.0\n",
    "        total_confidence_bins[int(predicted_class_probs[i]*100 / denominator)] += 1\n",
    "\n",
    "    #sns.stripplot(model_accuracy,predicted_class_probs, color='blue', alpha=0.5, jitter=True)\n",
    "    #sns.stripplot(model_accuracy,correct_class_probs, color='green', alpha=0.2, jitter=True)\n",
    "    #sns.swarmplot(model_accuracy,predicted_class_probs, color='blue', alpha=0.5)\n",
    "    #plt.show()\n",
    "   \n",
    "    sns.barplot(x=np.arange(0,n_bins+1)*denominator, \n",
    "                y=np.arange(0,n_bins+1)/n_bins, \n",
    "                color='green', alpha=0.2, edgecolor='black')\n",
    "    ax = sns.barplot(x=np.arange(0,n_bins+1)*denominator, \n",
    "                    y=p_confidence_bins/total_confidence_bins, \n",
    "                    color='red', alpha=0.5, edgecolor='black')\n",
    "    \n",
    "    x_ticks = np.arange(0,n_bins,2)\n",
    "    x_tick_labels = x_ticks / np.float32(n_bins)\n",
    "    ax.set_xticks(x_ticks)\n",
    "    ax.set_xticklabels(x_tick_labels, fontsize=10)\n",
    "    \n",
    "def expected_calibration_error(teacher_accuracy, teacher_predicted_class_probs):\n",
    "    raise NotImplemented"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
