{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from util import constants\n",
    "from util.config_util import get_model_params, get_task_params, get_train_params\n",
    "from tf2_models.trainer import Trainer\n",
    "from absl import app\n",
    "from absl import flags\n",
    "import numpy as np\n",
    "from util.models import MODELS\n",
    "from util.tasks import TASKS\n",
    "from notebook_utils import *\n",
    "\n",
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import seaborn as sns; sns.set()\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_name = 'word_sv_agreement_vp'\n",
    "task = TASKS[task_name](get_task_params(), data_dir='../data')\n",
    "cl_token = task.databuilder.sentence_encoder().encode(constants.bos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_models = []\n",
    "student_labels = []\n",
    "\n",
    "import logging\n",
    "tf.get_logger().setLevel(logging.ERROR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config={'student_exp_name':'gc_f_std5004',\n",
    "    'teacher_exp_name':'gc_o_tchr5021',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_lstm',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_lstm_v4',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2lstm_1')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5001',\n",
    "    'teacher_exp_name':'gc_o_tchr5011',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_lstm',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_lstm_v4',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2lstm_2')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5002',\n",
    "    'teacher_exp_name':'gc_o_tchr5020',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_lstm',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_lstm_v4',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "\n",
    "\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2lstm_3')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5003',\n",
    "    'teacher_exp_name':'gc_o_tchr5030',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_lstm',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_lstm_v4',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2lstm_4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for lbl, mdl in zip(student_labels, student_models):\n",
    "    train = mdl.evaluate(task.train_dataset, steps=task.n_train_batches)\n",
    "    valid = mdl.evaluate(task.valid_dataset, steps=task.n_valid_batches)\n",
    "    test = mdl.evaluate(task.test_dataset, steps=task.n_test_batches)\n",
    "\n",
    "    print(lbl)\n",
    "    \n",
    "    print(\"train:\", train)\n",
    "    print(\"valid:\", valid)\n",
    "    print(\"test:\", test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_models = []\n",
    "student_labels = []\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5030',\n",
    "    'teacher_exp_name':'gc_o_tchr5030',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_gpt2_shared',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_ugpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2ugpt_1')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5031',\n",
    "    'teacher_exp_name':'gc_o_tchr5031',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_gpt2_shared',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_ugpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2ugpt_2')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5032',\n",
    "    'teacher_exp_name':'gc_o_tchr5011',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_gpt2_shared',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_ugpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2ugpt_3')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5033',\n",
    "    'teacher_exp_name':'gc_o_tchr5021',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_gpt2_shared',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_ugpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2ugpt_4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for lbl, mdl in zip(student_labels, student_models):\n",
    "    train = mdl.evaluate(task.train_dataset, steps=task.n_train_batches)\n",
    "    valid = mdl.evaluate(task.valid_dataset, steps=task.n_valid_batches)\n",
    "    test = mdl.evaluate(task.test_dataset, steps=task.n_test_batches)\n",
    "\n",
    "    \n",
    "    print(lbl)\n",
    "    print(\"train:\", train)\n",
    "    print(\"valid:\", valid)\n",
    "    print(\"test:\", test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_models = []\n",
    "student_labels = []\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5010',\n",
    "    'teacher_exp_name':'gc_o_tchr5010',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_bert',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2bert_1')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5011',\n",
    "    'teacher_exp_name':'gc_o_tchr5011',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_bert',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2bert_2')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5012',\n",
    "    'teacher_exp_name':'gc_o_tchr5020',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_bert',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2bert_3')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5013',\n",
    "    'teacher_exp_name':'gc_o_tchr5021',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_bert',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2bert_4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for lbl, mdl in zip(student_labels, student_models):\n",
    "    train = mdl.evaluate(task.train_dataset, steps=task.n_train_batches)\n",
    "    valid = mdl.evaluate(task.valid_dataset, steps=task.n_valid_batches)\n",
    "    test = mdl.evaluate(task.test_dataset, steps=task.n_test_batches)\n",
    "\n",
    "    print(lbl)\n",
    "    print(\"train:\", train)\n",
    "    print(\"valid:\", valid)\n",
    "    print(\"test:\", test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_models = []\n",
    "student_labels = []\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5020',\n",
    "    'teacher_exp_name':'gc_o_tchr5020',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_gpt2',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2gpt_1')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5021',\n",
    "    'teacher_exp_name':'gc_o_tchr5021',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_gpt2',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2gpt_149')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5022',\n",
    "    'teacher_exp_name':'gc_o_tchr5010',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_gpt2',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2gpt_3')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std5023',\n",
    "    'teacher_exp_name':'gc_o_tchr5011',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_model':'cl_gpt2',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp5',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('lstm2gpt_151')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for lbl, mdl in zip(student_labels, student_models):\n",
    "    train = mdl.evaluate(task.train_dataset, steps=task.n_train_batches)\n",
    "    valid = mdl.evaluate(task.valid_dataset, steps=task.n_valid_batches)\n",
    "    test = mdl.evaluate(task.test_dataset, steps=task.n_test_batches)\n",
    "\n",
    "    print(lbl)\n",
    "    print(\"train:\", train)\n",
    "    print(\"valid:\", valid)\n",
    "    print(\"test:\", test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow_probability as tfp\n",
    "\n",
    "def test_for_calibration(model, task, n_bins=10):\n",
    "    preds = []\n",
    "    correct_class_probs = []\n",
    "    predicted_class_probs = []\n",
    "    pred_logits = []\n",
    "    y_trues = []\n",
    "    batch_count = task.n_valid_batches\n",
    "    for x, y in task.valid_dataset:\n",
    "        y = tf.cast(y, tf.int32)\n",
    "        logits = model(x)\n",
    "        pred_logits.extend(logits.numpy())\n",
    "        pred = tf.argmax(logits, axis=-1)\n",
    "        prob = task.get_probs_fn()(logits, labels=y, temperature=1)\n",
    "        preds.extend(pred.numpy())\n",
    "        y_trues.extend(y.numpy())\n",
    "        batch_indexes = tf.cast(tf.range(len(y), dtype=tf.int32), dtype=tf.int32)\n",
    "        true_indexes = tf.concat([batch_indexes[:,None], y[:,None]], axis=1)\n",
    "        pred_indexes = tf.concat([batch_indexes[:,None], tf.cast(pred[:,None], tf.int32)], axis=1)\n",
    "\n",
    "        correct_class_probs.extend(tf.gather_nd(prob, true_indexes).numpy())\n",
    "        predicted_class_probs.extend(tf.gather_nd(prob, pred_indexes).numpy())\n",
    "\n",
    "        batch_count -= 1\n",
    "        if batch_count == 0:\n",
    "            break\n",
    "\n",
    "    model_accuracy = np.asarray(preds) == np.asarray(y_trues)\n",
    "\n",
    "    return model_accuracy, predicted_class_probs, correct_class_probs, pred_logits, y_trues\n",
    "\n",
    "\n",
    "for label,model in zip(student_labels, student_models):\n",
    "    print(label)\n",
    "    print(model.model_name)\n",
    "\n",
    "    model_accuracy, predicted_class_probs, correct_class_probs, model_logits, model_trues= test_for_calibration(model, task, n_bins=20)\n",
    "\n",
    "    #print(len(model_accuracy))\n",
    "    #print(len(predicted_class_probs))\n",
    "    ##plot_calibration(model_accuracy, predicted_class_probs, correct_class_probs, n_bins=20)\n",
    "    #  plt.show()\n",
    "    model_ece = tfp.stats.expected_calibration_error(\n",
    "        1000000,\n",
    "        logits=model_logits,\n",
    "        labels_true=model_trues,\n",
    "    )\n",
    "    print(model_ece.numpy())\n",
    "    \n",
    "    print(\"#######################\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "\n",
    "def evaluate_vp_cl(model, verb_infl, noun_infl, task, split='test', batch_size=1000, cls=False):\n",
    "    distance_hits = Counter()\n",
    "    distance_total = Counter()\n",
    "    diff_hits = Counter()\n",
    "    diff_total = Counter()\n",
    "\n",
    "    test_data = task.databuilder.as_dataset(split=split, batch_size=batch_size)\n",
    "    e = 0\n",
    "    for examples in test_data:\n",
    "        e += 1\n",
    "        print(e, end=\"\\r\")\n",
    "        sentences = examples['sentence']\n",
    "        #bos = tf.cast(task.databuilder.sentence_encoder().encode(constants.bos) * tf.ones((sentences.shape[0],1)), dtype=tf.int64)\n",
    "        eos = tf.cast(task.databuilder.sentence_encoder().encode(constants.eos) *tf.ones((sentences.shape[0],1)), dtype=tf.int64)\n",
    "\n",
    "        sentences = tf.concat([sentences, eos], axis=-1)\n",
    "\n",
    "        verb_position = examples['verb_position']+int(cls)  #+1 because of adding bos.\n",
    "        # The verb it self is also masked\n",
    "        mask = tf.cast(tf.sequence_mask(verb_position,maxlen=tf.shape(sentences)[1]), dtype=tf.int64)\n",
    "        max_length = tf.reduce_max(verb_position + 1)\n",
    "\n",
    "        last_index_mask = tf.gather(tf.eye(tf.shape(sentences)[1], dtype=tf.int64),verb_position)\n",
    "        last_index_mask = last_index_mask * eos[0]\n",
    "\n",
    "        inputs = (sentences * mask + last_index_mask)[:,:max_length]\n",
    "\n",
    "#         print(sentences[0])\n",
    "#         print(task.databuilder.sentence_encoder().decode(inputs[0]))\n",
    "#         break\n",
    "        s_shape = tf.shape(inputs)\n",
    "        batch_size, length = s_shape[0], s_shape[1]\n",
    "        verb_classes = examples['verb_class']\n",
    "        actual_verbs = examples['verb']\n",
    "        #inflected_verbs = [verb_infl[v.decode(\"utf-8\")] for v in actual_verbs.numpy()]\n",
    "\n",
    "        distances = examples['distance'].numpy()\n",
    "        nz = examples['n_intervening'].numpy()\n",
    "        n_diffs = examples['n_diff_intervening'].numpy()\n",
    "\n",
    "        actual_verb_indexes = [task.databuilder.sentence_encoder().encode(v)[0] for v in actual_verbs.numpy()]\n",
    "        #inflected_verb_indexes = [task.databuilder.sentence_encoder().encode(v)[0] for v in inflected_verbs]\n",
    "\n",
    "\n",
    "        predictions = model(inputs, training=False)\n",
    "        predictions = np.argmax(predictions, axis=-1)\n",
    "        corrects = predictions == verb_classes\n",
    "\n",
    "        for i, c in enumerate(corrects):\n",
    "            if actual_verb_indexes[i] == 10035 or actual_verb_indexes[i] == 2:\n",
    "                continue\n",
    "            if nz[i] > 4 or distances[i] > 16:\n",
    "                continue\n",
    "\n",
    "            distance_total[distances[i]] += 1\n",
    "            distance_hits[distances[i]] += int(c)\n",
    "            if nz[i] == n_diffs[i]:\n",
    "                n = nz[i]\n",
    "                diff_total[n] += 1\n",
    "                diff_hits[n] += int(c)\n",
    "    \n",
    "    return distance_hits, distance_total, diff_hits, diff_total\n",
    "\n",
    "\n",
    "infl_eng = inflect.engine()\n",
    "verb_infl, noun_infl = gen_inflect_from_vocab(infl_eng, 'wiki.vocab')\n",
    "\n",
    "for label,model in zip(student_labels, student_models):\n",
    "    print(label)\n",
    "    print(model.model_name)\n",
    "    print('##################################')\n",
    "    distance_hits, distance_total, diff_hits, diff_total = evaluate_vp_cl(model, verb_infl, noun_infl, task)\n",
    "    compute_and_print_acc_stats(distance_hits, distance_total, diff_hits, diff_total)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
