{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from util import constants\n",
    "from util.config_util import get_model_params, get_task_params, get_train_params\n",
    "from tf2_models.trainer import Trainer\n",
    "from absl import app\n",
    "from absl import flags\n",
    "import numpy as np\n",
    "from util.models import MODELS\n",
    "from util.tasks import TASKS\n",
    "\n",
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import seaborn as sns; sns.set()\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_exp_name='fd1'\n",
    "teacher_exp_name='0.0001_offlineteacher_v3'\n",
    "teacher_config='small_lstm_v4'\n",
    "task_name = 'word_sv_agreement_vp'\n",
    "student_model='cl_gpt2'\n",
    "teacher_model='cl_lstm'\n",
    "student_config='small_gpt_v9'\n",
    "distill_config='pure_distill_2'\n",
    "distill_mode='offline'\n",
    "\n",
    "chkpt_dir='../tf_ckpts'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = TASKS[task_name](get_task_params(), data_dir='../data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cl_token = task.databuilder.sentence_encoder().encode(constants.bos)\n",
    "teacher_model = MODELS[teacher_model](hparams=get_model_params(task, teacher_model, teacher_config), cl_token=cl_token)\n",
    "std_hparams=get_model_params(task, student_model, student_config)\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "student_model = MODELS[student_model](\n",
    "std_hparams, cl_token=cl_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_ckpt_dir = os.path.join(chkpt_dir, task.name,\n",
    "                              '_'.join([distill_mode,distill_config,\n",
    "                                        \"teacher\", teacher_model.model_name, \n",
    "                                        #teacher_config,\n",
    "                                        teacher_exp_name,\n",
    "                                       \"student\",student_model.model_name,\n",
    "                                        str(student_config),\n",
    "                                        student_exp_name]))\n",
    "print(\"student_checkpoint:\", student_ckpt_dir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_ckpt = tf.train.Checkpoint(net=student_model)\n",
    "student_manager = tf.train.CheckpointManager(student_ckpt, student_ckpt_dir, max_to_keep=None)\n",
    "\n",
    "student_ckpt.restore(student_manager.latest_checkpoint)\n",
    "if student_manager.latest_checkpoint:\n",
    "  print(\"Restored student from {}\".format(student_manager.latest_checkpoint))\n",
    "\n",
    "student_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "student_model.evaluate(task.test_dataset, steps=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher_ckpt_dir = os.path.join(chkpt_dir, task.name,\n",
    "                                  '_'.join([teacher_model.model_name, teacher_config,teacher_exp_name]))\n",
    "\n",
    "teacher_ckpt = tf.train.Checkpoint(net=teacher_model)\n",
    "teacher_manager = tf.train.CheckpointManager(teacher_ckpt, teacher_ckpt_dir, max_to_keep=None)\n",
    "\n",
    "teacher_ckpt.restore(teacher_manager.latest_checkpoint)\n",
    "if teacher_manager.latest_checkpoint:\n",
    "  print(\"Restored student from {}\".format(teacher_manager.latest_checkpoint))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "teacher_model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "teacher_model.evaluate(task.test_dataset, steps=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name='cl_gpt2'\n",
    "model_config='small_gpt_v9'\n",
    "learning_rate=0.0001\n",
    "exp_name='offlineteacher_v1'\n",
    "\n",
    "cl_token = task.databuilder.sentence_encoder().encode(constants.bos)\n",
    "hparams=get_model_params(task, model_name, model_config)\n",
    "hparams.output_attentions = True\n",
    "hparams.output_embeddings = True\n",
    "\n",
    "model = MODELS[model_name](hparams=hparams, cl_token=cl_token)\n",
    "\n",
    "\n",
    "ckpt_dir = os.path.join(chkpt_dir,task.name,\n",
    "                        model.model_name+\"_\"+str(model_config)+\"_\"+str(learning_rate)+\"_\"+exp_name)\n",
    "\n",
    "ckpt = tf.train.Checkpoint(net=model)\n",
    "manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=None)\n",
    "\n",
    "ckpt.restore(manager.latest_checkpoint)\n",
    "if manager.latest_checkpoint:\n",
    "  print(\"Restored student from {}\".format(manager.latest_checkpoint))\n",
    "\n",
    "model.compile(loss=task.get_loss_fn(), metrics=task.metrics())\n",
    "model.evaluate(task.test_dataset, steps=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_preds = []\n",
    "y_trues = []\n",
    "teacher_preds = []\n",
    "independent_preds = []\n",
    "inputs = []\n",
    "batch_count = task.n_valid_batches\n",
    "for x, y in tqdm(task.valid_dataset, total=batch_count):\n",
    "    std_pred = tf.argmax(student_model(x), axis=-1)\n",
    "    teach_pred = tf.argmax(teacher_model(x), axis=-1)\n",
    "    indep_pred = tf.argmax(model(x), axis=-1)\n",
    "    student_preds.extend(std_pred.numpy())\n",
    "    teacher_preds.extend(teach_pred.numpy())\n",
    "    y_trues.extend(y.numpy())\n",
    "    independent_preds.extend(indep_pred)\n",
    "    inputs.extend(x)\n",
    "    batch_count -= 1\n",
    "    if batch_count == 0:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_mistakes = np.asarray(student_preds) == np.asarray(y_trues)\n",
    "teacher_mistakes = np.asarray(teacher_preds) == np.asarray(y_trues)\n",
    "model_mistakes = np.asarray(independent_preds) == np.asarray(y_trues)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nonoverlapping_mistakes = np.where(student_mistakes != teacher_mistakes)[0]\n",
    "teach_wrong_student_right = np.where(student_mistakes > teacher_mistakes)[0]\n",
    "teach_right_student_wrong = np.where(student_mistakes < teacher_mistakes)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_right_student_wrong = np.where(student_mistakes < model_mistakes)[0]\n",
    "model_wrong_student_right = np.where(student_mistakes > model_mistakes)[0]\n",
    "model_right_teacher_wrong = np.where(teacher_mistakes < model_mistakes)[0]\n",
    "model_wrong_teacher_right = np.where(teacher_mistakes > model_mistakes)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(model_wrong_student_right)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
