{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from util import constants\n",
    "from util.config_util import get_model_params, get_task_params, get_train_params\n",
    "from tf2_models.trainer import Trainer\n",
    "from absl import app\n",
    "from absl import flags\n",
    "import numpy as np\n",
    "from util.models import MODELS\n",
    "from util.tasks import TASKS\n",
    "from notebook_utils import *\n",
    "\n",
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import seaborn as sns; sns.set()\n",
    "\n",
    "from tqdm import tqdm\n",
    "\n",
    "import logging\n",
    "tf.get_logger().setLevel(logging.ERROR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task_name = 'word_sv_agreement_vp'\n",
    "task = TASKS[task_name](get_task_params(), data_dir='../data')\n",
    "cl_token = task.databuilder.sentence_encoder().encode(constants.bos)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_models = []\n",
    "student_labels = []\n",
    "\n",
    "import logging\n",
    "tf.get_logger().setLevel(logging.ERROR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Bert to LSTM\n",
    "\n",
    "config={'student_exp_name':'gc_f_std9303',\n",
    "    'teacher_exp_name':'gc_o_tchr8323',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_lstm',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_lstm_v4',\n",
    "    'distill_config':'pure_dstl_4_exp_vp9',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2lstm_1')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std9304',\n",
    "    'teacher_exp_name':'gc_o_tchr8324',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_lstm',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_lstm_v4',\n",
    "    'distill_config':'pure_dstl_4_exp_vp9',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2lstm_2')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std9301',\n",
    "    'teacher_exp_name':'gc_o_tchr9301',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_lstm',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_lstm_v4',\n",
    "    'distill_config':'pure_dstl_4_exp_vp9',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "\n",
    "\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2lstm_3')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std9302',\n",
    "    'teacher_exp_name':'gc_o_tchr9302',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_lstm',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_lstm_v4',\n",
    "    'distill_config':'pure_dstl_4_exp_vp9',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2lstm_4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for lbl, mdl in zip(student_labels, student_models):\n",
    "    train = mdl.evaluate(task.train_dataset, steps=task.n_train_batches)\n",
    "    valid = mdl.evaluate(task.valid_dataset, steps=task.n_valid_batches)\n",
    "    test = mdl.evaluate(task.test_dataset, steps=task.n_test_batches)\n",
    "\n",
    "    print(lbl)\n",
    "    print(\"train:\", train)\n",
    "    print(\"valid:\", valid)\n",
    "    print(\"test:\", test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_models = []\n",
    "student_labels = []\n",
    "\n",
    "config={'student_exp_name':'gc_f_std8331',\n",
    "    'teacher_exp_name':'gc_o_tchr8321',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_gpt2_shared',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_ugpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp8',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2ugpt_1')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std8332',\n",
    "    'teacher_exp_name':'gc_o_tchr8322',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_gpt2_shared',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_ugpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp8',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2ugpt_2')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std8333',\n",
    "    'teacher_exp_name':'gc_o_tchr8323',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_gpt2_shared',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_ugpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp8',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2ugpt_3')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std8334',\n",
    "    'teacher_exp_name':'gc_o_tchr8324',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_gpt2_shared',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_ugpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp8',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2ugpt_4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for lbl, mdl in zip(student_labels, student_models):\n",
    "    train = mdl.evaluate(task.train_dataset, steps=task.n_train_batches)\n",
    "    valid = mdl.evaluate(task.valid_dataset, steps=task.n_valid_batches)\n",
    "    test = mdl.evaluate(task.test_dataset, steps=task.n_test_batches)\n",
    "\n",
    "    \n",
    "    print(lbl)\n",
    "    print(\"train:\", train)\n",
    "    print(\"valid:\", valid)\n",
    "    print(\"test:\", test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_models = []\n",
    "student_labels = []\n",
    "\n",
    "config={'student_exp_name':'gc_f_std8311',\n",
    "    'teacher_exp_name':'gc_o_tchr8311',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_bert',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp8',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2bert_1')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std8312',\n",
    "    'teacher_exp_name':'gc_o_tchr8322',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_bert',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp8',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2bert_2')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std8313',\n",
    "    'teacher_exp_name':'gc_o_tchr8323',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_bert',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp8',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2bert_3')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std8314',\n",
    "    'teacher_exp_name':'gc_o_tchr8324',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_bert',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp8',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2bert_4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for lbl, mdl in zip(student_labels, student_models):\n",
    "    train = mdl.evaluate(task.train_dataset, steps=task.n_train_batches)\n",
    "    valid = mdl.evaluate(task.valid_dataset, steps=task.n_valid_batches)\n",
    "    test = mdl.evaluate(task.test_dataset, steps=task.n_test_batches)\n",
    "\n",
    "    print(lbl)\n",
    "    print(\"train:\", train)\n",
    "    print(\"valid:\", valid)\n",
    "    print(\"test:\", test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "student_models = []\n",
    "student_labels = []\n",
    "\n",
    "config={'student_exp_name':'gc_f_std8321',\n",
    "    'teacher_exp_name':'gc_o_tchr8321',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_gpt2',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp8',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2gpt_1')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std8322',\n",
    "    'teacher_exp_name':'gc_o_tchr8322',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_gpt2',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp8',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2gpt_2')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std8323',\n",
    "    'teacher_exp_name':'gc_o_tchr8323',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_gpt2',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp8',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2gpt_3')\n",
    "\n",
    "config={'student_exp_name':'gc_f_std8324',\n",
    "    'teacher_exp_name':'gc_o_tchr8324',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'teacher_model':'cl_bert',\n",
    "    'student_model':'cl_gpt2',\n",
    "    'teacher_config':'small_gpt_v9',\n",
    "    'student_config':'small_gpt_v9',\n",
    "    'distill_config':'pure_dstl_4_exp_vp8',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "     }\n",
    "# config['distill_mode'] = 'online'\n",
    "# config['student_exp_name'] = config['student_exp_name'].replace('_f_', '_o_')\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "student_model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "tchr_hparams=get_model_params(task, config['teacher_model'], config['teacher_config'])\n",
    "teacher_model = get_teacher_model(config, task, tchr_hparams, cl_token)\n",
    "\n",
    "student_models.append(student_model)\n",
    "student_labels.append('bert2gpt_4')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for lbl, mdl in zip(student_labels, student_models):\n",
    "    train = mdl.evaluate(task.train_dataset, steps=task.n_train_batches)\n",
    "    valid = mdl.evaluate(task.valid_dataset, steps=task.n_valid_batches)\n",
    "    test = mdl.evaluate(task.test_dataset, steps=task.n_test_batches)\n",
    "\n",
    "    print(lbl)\n",
    "    print(\"train:\", train)\n",
    "    print(\"valid:\", valid)\n",
    "    print(\"test:\", test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow_probability as tfp\n",
    "\n",
    "def test_for_calibration(model, task, n_bins=10):\n",
    "    preds = []\n",
    "    correct_class_probs = []\n",
    "    predicted_class_probs = []\n",
    "    pred_logits = []\n",
    "    y_trues = []\n",
    "    batch_count = task.n_valid_batches\n",
    "    for x, y in task.valid_dataset:\n",
    "        y = tf.cast(y, tf.int32)\n",
    "        logits = model(x)\n",
    "        pred_logits.extend(logits.numpy())\n",
    "        pred = tf.argmax(logits, axis=-1)\n",
    "        prob = task.get_probs_fn()(logits, labels=y, temperature=1)\n",
    "        preds.extend(pred.numpy())\n",
    "        y_trues.extend(y.numpy())\n",
    "        batch_indexes = tf.cast(tf.range(len(y), dtype=tf.int32), dtype=tf.int32)\n",
    "        true_indexes = tf.concat([batch_indexes[:,None], y[:,None]], axis=1)\n",
    "        pred_indexes = tf.concat([batch_indexes[:,None], tf.cast(pred[:,None], tf.int32)], axis=1)\n",
    "\n",
    "        correct_class_probs.extend(tf.gather_nd(prob, true_indexes).numpy())\n",
    "        predicted_class_probs.extend(tf.gather_nd(prob, pred_indexes).numpy())\n",
    "\n",
    "        batch_count -= 1\n",
    "        if batch_count == 0:\n",
    "            break\n",
    "\n",
    "    model_accuracy = np.asarray(preds) == np.asarray(y_trues)\n",
    "\n",
    "    return model_accuracy, predicted_class_probs, correct_class_probs, pred_logits, y_trues\n",
    "\n",
    "\n",
    "for label,model in zip(student_labels, student_models):\n",
    "    print(label)\n",
    "    print(model.model_name)\n",
    "\n",
    "    model_accuracy, predicted_class_probs, correct_class_probs, model_logits, model_trues= test_for_calibration(model, task, n_bins=20)\n",
    "\n",
    "    #print(len(model_accuracy))\n",
    "    #print(len(predicted_class_probs))\n",
    "    ##plot_calibration(model_accuracy, predicted_class_probs, correct_class_probs, n_bins=20)\n",
    "    #  plt.show()\n",
    "    model_ece = tfp.stats.expected_calibration_error(\n",
    "        1000000,\n",
    "        logits=model_logits,\n",
    "        labels_true=model_trues,\n",
    "    )\n",
    "    print(model_ece.numpy())\n",
    "    \n",
    "    print(\"#######################\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
