{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from util import constants\n",
    "from util.config_util import get_model_params, get_task_params, get_train_params\n",
    "from tf2_models.trainer import Trainer\n",
    "from absl import app\n",
    "from absl import flags\n",
    "import numpy as np\n",
    "from util.models import MODELS\n",
    "from util.tasks import TASKS\n",
    "from notebook_utils import *\n",
    "from distill.repsim_util import *\n",
    "\n",
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import seaborn as sns; sns.set()\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "chkpt_dir='../tf_ckpts'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_outputs(model, x):\n",
    "  outputs = model.detailed_call(x, training=tf.convert_to_tensor(True))\n",
    "  logits, reps = outputs[0], outputs[model.rep_index]\n",
    "  if model.rep_layer is not None and model.rep_layer is not -1:\n",
    "    reps = reps[model.rep_layer]\n",
    "\n",
    "  return reps, logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task1 = 'svhn'\n",
    "task1 = TASKS[task1](get_task_params(), data_dir='../data')\n",
    "cl_token = 0\n",
    "\n",
    "task = task1\n",
    "\n",
    "models = []\n",
    "labels = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_stdr1',\n",
    "    'teacher_exp_name':'gc_o_dtch1',\n",
    "    'teacher_config':'rsnt_svhn1',\n",
    "    'task_name':'svhn',\n",
    "    'student_model':'cl_vff',\n",
    "    'teacher_model':'resnet',\n",
    "    'student_config':'ff_svhn2',\n",
    "    'distill_config':'pure_dstl5_4_crs_slw_2',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "std_ff_r0, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "tchr_hparams.output_attentions = True\n",
    "tchr_hparams.output_embeddings = True\n",
    "tchr_hparams.output_hidden_states = True\n",
    "\n",
    "tchr_rsnt_r0, _ = get_teacher_model(config, task, tchr_hparams, cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_stdr2',\n",
    "    'teacher_exp_name':'gc_o_dtchrr2',\n",
    "    'teacher_config':'rsnt_svhn1',\n",
    "    'task_name':'svhn',\n",
    "    'student_model':'cl_vff',\n",
    "    'teacher_model':'resnet',\n",
    "    'student_config':'ff_svhn2',\n",
    "    'distill_config':'pure_dstl5_4_crs_slw_2',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "std_ff_r1, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "tchr_hparams.output_attentions = True\n",
    "tchr_hparams.output_embeddings = True\n",
    "tchr_hparams.output_hidden_states = True\n",
    "\n",
    "tchr_rsnt_r1, _ = get_teacher_model(config, task, tchr_hparams, cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_stdr3',\n",
    "    'teacher_exp_name':'gc_o_dtchrr3',\n",
    "    'teacher_config':'rsnt_svhn1',\n",
    "    'task_name':'svhn',\n",
    "    'student_model':'cl_vff',\n",
    "    'teacher_model':'resnet',\n",
    "    'student_config':'ff_svhn2',\n",
    "    'distill_config':'pure_dstl5_4_crs_slw_2',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "std_ff_r2, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "tchr_hparams.output_attentions = True\n",
    "tchr_hparams.output_embeddings = True\n",
    "tchr_hparams.output_hidden_states = True\n",
    "\n",
    "tchr_rsnt_r2, _ = get_teacher_model(config, task, tchr_hparams, cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_std45',\n",
    "    'teacher_exp_name':'gc_o_dtchr3',\n",
    "    'teacher_config':'rsnt_svhn1',\n",
    "    'task_name':'svhn',\n",
    "    'student_model':'cl_vff',\n",
    "    'teacher_model':'resnet',\n",
    "    'student_config':'ff_svhn2',\n",
    "    'distill_config':'pure_dstl5_4_crs_slw_2',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "std_ff1, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "tchr_hparams.output_attentions = True\n",
    "tchr_hparams.output_embeddings = True\n",
    "tchr_hparams.output_hidden_states = True\n",
    "\n",
    "tchr_rsnt1, _ = get_teacher_model(config, task, tchr_hparams, cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_std55',\n",
    "    'teacher_exp_name':'gc_o_dtchr3',\n",
    "    'teacher_config':'rsnt_svhn1',\n",
    "    'task_name':'svhn',\n",
    "    'student_model':'cl_vff',\n",
    "    'teacher_model':'resnet',\n",
    "    'student_config':'ff_svhn2',\n",
    "    'distill_config':'pure_dstl5_4_crs_slw_2',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "std_ff2, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "# tchr_hparams.output_attentions = True\n",
    "# tchr_hparams.output_embeddings = True\n",
    "# tchr_hparams.output_hidden_states = True\n",
    "\n",
    "# tchr_rsnt2, _ = get_teacher_model(config, task, tchr_hparams, cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_std56',\n",
    "    'teacher_exp_name':'gc_o_dtchr3',\n",
    "    'teacher_config':'rsnt_svhn1',\n",
    "    'task_name':'svhn',\n",
    "    'student_model':'cl_vff',\n",
    "    'teacher_model':'resnet',\n",
    "    'student_config':'ff_svhn2',\n",
    "    'distill_config':'pure_dstl5_4_crs_slw_3',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "std_ff3, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "# tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "# tchr_hparams.output_attentions = True\n",
    "# tchr_hparams.output_embeddings = True\n",
    "# tchr_hparams.output_hidden_states = True\n",
    "\n",
    "# tchr_rsnt2, _ = get_teacher_model(config, task, tchr_hparams, cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_std61',\n",
    "    'teacher_exp_name':'gc_o_dtchr61',\n",
    "    'teacher_config':'rsnt_svhn1',\n",
    "    'task_name':'svhn',\n",
    "    'student_model':'cl_vff',\n",
    "    'teacher_model':'resnet',\n",
    "    'student_config':'ff_svhn2',\n",
    "    'distill_config':'pure_dstl5_4_crs_slw_3',\n",
    "    'distill_mode':'online',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "std_ff4, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "tchr_hparams.output_attentions = True\n",
    "tchr_hparams.output_embeddings = True\n",
    "tchr_hparams.output_hidden_states = True\n",
    "\n",
    "tchr_rsnt4, _ = get_teacher_model(config, task, tchr_hparams, cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_std62',\n",
    "    'teacher_exp_name':'gc_o_dtchr62',\n",
    "    'teacher_config':'rsnt_svhn1',\n",
    "    'task_name':'svhn',\n",
    "    'student_model':'cl_vff',\n",
    "    'teacher_model':'resnet',\n",
    "    'student_config':'ff_svhn2',\n",
    "    'distill_config':'pure_dstl5_4_crs_slw_3',\n",
    "    'distill_mode':'online',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "std_ff5, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "tchr_hparams.output_attentions = True\n",
    "tchr_hparams.output_embeddings = True\n",
    "tchr_hparams.output_hidden_states = True\n",
    "\n",
    "tchr_rsnt5, _ = get_teacher_model(config, task, tchr_hparams, cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_std63',\n",
    "    'teacher_exp_name':'gc_o_dtchr63',\n",
    "    'teacher_config':'rsnt_svhn1',\n",
    "    'task_name':'svhn',\n",
    "    'student_model':'cl_vff',\n",
    "    'teacher_model':'resnet',\n",
    "    'student_config':'ff_svhn2',\n",
    "    'distill_config':'pure_dstl5_4_crs_slw_3',\n",
    "    'distill_mode':'online',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "std_ff6, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "tchr_hparams.output_attentions = True\n",
    "tchr_hparams.output_embeddings = True\n",
    "tchr_hparams.output_hidden_states = True\n",
    "\n",
    "tchr_rsnt6, _ = get_teacher_model(config, task, tchr_hparams, cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_std64',\n",
    "    'teacher_exp_name':'gc_o_dtchr64',\n",
    "    'teacher_config':'ff_svhn2',\n",
    "    'task_name':'svhn',\n",
    "    'student_model':'cl_vff',\n",
    "    'teacher_model':'cl_vff',\n",
    "    'student_config':'ff_svhn2',\n",
    "    'distill_config':'pure_dstl5_4_crs_slw_3',\n",
    "    'distill_mode':'online',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "std_ff7, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "tchr_hparams.output_attentions = True\n",
    "tchr_hparams.output_embeddings = True\n",
    "tchr_hparams.output_hidden_states = True\n",
    "\n",
    "tchr_ff7, _ = get_teacher_model(config, task, tchr_hparams, cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_std65',\n",
    "    'teacher_exp_name':'gc_o_dtchr65',\n",
    "    'teacher_config':'ff_svhn2',\n",
    "    'task_name':'svhn',\n",
    "    'student_model':'cl_vff',\n",
    "    'teacher_model':'cl_vff',\n",
    "    'student_config':'ff_svhn2',\n",
    "    'distill_config':'pure_dstl5_4_crs_slw_3',\n",
    "    'distill_mode':'online',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "std_ff8, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "tchr_hparams.output_attentions = True\n",
    "tchr_hparams.output_embeddings = True\n",
    "tchr_hparams.output_hidden_states = True\n",
    "\n",
    "tchr_ff8, _ = get_teacher_model(config, task, tchr_hparams, cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_std66',\n",
    "    'teacher_exp_name':'gc_o_dtchr66',\n",
    "    'teacher_config':'ff_svhn2',\n",
    "    'task_name':'svhn',\n",
    "    'student_model':'cl_vff',\n",
    "    'teacher_model':'cl_vff',\n",
    "    'student_config':'ff_svhn2',\n",
    "    'distill_config':'pure_dstl5_4_crs_slw_3',\n",
    "    'distill_mode':'online',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "std_ff9, _ = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "tchr_hparams.output_attentions = True\n",
    "tchr_hparams.output_embeddings = True\n",
    "tchr_hparams.output_hidden_states = True\n",
    "\n",
    "tchr_ff9, _ = get_teacher_model(config, task, tchr_hparams, cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "models = [std_ff_r0, std_ff_r1, std_ff_r2, std_ff9, std_ff8, std_ff7, std_ff4, std_ff5, std_ff6, std_ff1, std_ff2, std_ff3,\n",
    "          tchr_rsnt_r0, tchr_rsnt_r1, tchr_rsnt_r2, tchr_ff9, tchr_ff8, tchr_ff7, tchr_rsnt4, tchr_rsnt5, tchr_rsnt6, tchr_rsnt1]\n",
    "labels = ['std_f_r0', 'std_ff_r1', 'std_ff_r2', 'std_ff9', 'std_ff8', 'std_ff7', 'std_ff4', 'std_ff5', 'std_ff6', 'std_ff1', 'std_ff2', 'std_ff3', \n",
    "          'tchr_ff_r0', 'tchr_rsnt_r1', 'tchr_rsnt_r2', 'tchr_ff9', 'tchr_ff8', 'tchr_ff7', 'tchr_rsnt4', 'tchr_rsnt5', 'tchr_rsnt6', 'tchr_rsnt1']\n",
    "\n",
    "for x, y in task.valid_dataset:\n",
    "    for model in models:\n",
    "        model(x)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def  get_reploss_dic(models, labels, task, repindex=1):    \n",
    "    reploss_dic = {}\n",
    "    for l1 in labels:\n",
    "        reploss_dic[l1] = {}\n",
    "        for l2 in labels:\n",
    "            reploss_dic[l1][l2] = []\n",
    "\n",
    "    num_batches = 0\n",
    "    for x, y in task.valid_dataset:\n",
    "        reps = []\n",
    "        for m in models:\n",
    "            outputs = get_outputs(m, x)\n",
    "            reps.append(outputs[repindex])\n",
    "        for i in np.arange(len(labels)):\n",
    "            for j in np.arange(i, len(labels)):\n",
    "                reploss = rep_loss(reps1=reps[i], reps2=reps[j],\n",
    "                                     padding_symbol=None,\n",
    "                                     inputs=x)\n",
    "                reploss_dic[labels[i]][labels[j]].append(reploss)\n",
    "                if i != j:\n",
    "                    reploss_dic[labels[j]][labels[i]].append(reploss)\n",
    "        num_batches += 1\n",
    "\n",
    "        if num_batches > 20:\n",
    "            break\n",
    "            \n",
    "    dist_lists = {}\n",
    "    for l1 in reploss_dic.keys():\n",
    "        dist_lists[l1] = []\n",
    "        for i in  np.arange(len(reploss_dic[l1].keys())):\n",
    "            l2 = list(reploss_dic[l1].keys())[i]\n",
    "            dist_lists[l1].append(np.mean(reploss_dic[l1][l2]))\n",
    "            \n",
    "    return reploss_dic, dist_lists"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reploss_dic, dist_lists = get_reploss_dic(models, labels, task, repindex=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn import manifold\n",
    "\n",
    "def project2d(reploss_dic, dist_lists):\n",
    "    dists = []\n",
    "    cities = []\n",
    "    for i in np.arange(len(dist_lists.keys())):\n",
    "        k = list(dist_lists.keys())[i]\n",
    "        cities.append(k)\n",
    "        dists.append(list(map(float , dist_lists[k])))\n",
    "\n",
    "    adist = np.array(dists)\n",
    "    amax = np.amax(adist)\n",
    "    adist /= amax\n",
    "\n",
    "    mds = manifold.MDS(n_components=2, dissimilarity=\"precomputed\", random_state=6)\n",
    "    results = mds.fit(adist)\n",
    "\n",
    "    coords = results.embedding_\n",
    "\n",
    "    plt.subplots_adjust(bottom = 0.1)\n",
    "    plt.scatter(\n",
    "        coords[:, 0], coords[:, 1], marker = 'o'\n",
    "        )\n",
    "    for label, x, y in zip(cities, coords[:, 0], coords[:, 1]):\n",
    "        plt.annotate(\n",
    "            label,\n",
    "            xy = (x, y), xytext = (-20, 20),\n",
    "            textcoords = 'offset points', ha = 'right', va = 'bottom',\n",
    "            bbox = dict(boxstyle = 'round,pad=0.5', fc = 'yellow', alpha = 0.5),\n",
    "            arrowprops = dict(arrowstyle = '->', connectionstyle = 'arc3,rad=0'))\n",
    "\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "project2d(reploss_dic, dist_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "reploss_dic, dist_lists = get_reploss_dic(models, labels, task, repindex=0)\n",
    "project2d(reploss_dic, dist_lists)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
