{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from util import constants\n",
    "from util.config_util import get_model_params, get_task_params, get_train_params\n",
    "from tf2_models.trainer import Trainer\n",
    "from absl import app\n",
    "from absl import flags\n",
    "import numpy as np\n",
    "from util.models import MODELS\n",
    "from util.tasks import TASKS\n",
    "from notebook_utils import *\n",
    "%matplotlib inline\n",
    "import pandas as pd\n",
    "import seaborn as sns; sns.set()\n",
    "from collections import Counter\n",
    "\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dir = \"../logs\"\n",
    "chkpt_dir = \"../tf_ckpts\"\n",
    "\n",
    "task = TASKS['word_sv_agreement_vp'](task_params=get_task_params(),data_dir='../data')\n",
    "cl_token = task.databuilder.sentence_encoder().encode(constants.bos)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = {'model_name':'cl_gpt2',\n",
    "        'model_config':'small_gpt_v9',\n",
    "        'learning_rate':0.001,\n",
    "        'exp_name':'offlineteacher_v5',\n",
    "        'chkpt_dir': '../tf_ckpts'\n",
    "}\n",
    "hparams=get_model_params(task, config['model_name'], config['model_config'])\n",
    "hparams.output_attentions = True\n",
    "hparams.output_embeddings = True\n",
    "\n",
    "model, ckpt = get_model(config, task, hparams, cl_token)\n",
    "\n",
    "\n",
    "model.evaluate(task.test_dataset, steps=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#LSTM\n",
    "\n",
    "config={'student_exp_name':'lisa_fd135',\n",
    "    'teacher_exp_name':'0.001_offlineteacher_v5',\n",
    "    'teacher_config':'small_lstm_v4',\n",
    "    'task_name':'word_sv_agreement_vp',\n",
    "    'student_model':'cl_lstm',\n",
    "    'teacher_model':'cl_lstm',\n",
    "    'student_config':'small_lstm_v4',\n",
    "    'distill_config':'pure_dstl_4_crs_slw',\n",
    "    'distill_mode':'offline',\n",
    "    'chkpt_dir':'../tf_ckpts',\n",
    "       }\n",
    "\n",
    "std_hparams=get_model_params(task, config['student_model'], config['student_config'])\n",
    "std_hparams.output_attentions = True\n",
    "std_hparams.output_embeddings = True\n",
    "\n",
    "model, ckpt = get_student_model(config, task, std_hparams, cl_token)\n",
    "\n",
    "#model.evaluate(task.test_dataset, steps=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "infl_eng = inflect.engine()\n",
    "verb_infl, noun_infl = gen_inflect_from_vocab(infl_eng, 'wiki.vocab')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "distance_hits, distance_total, diff_hits, diff_total = evaluate_vp_cl(model, verb_infl, noun_infl, task)\n",
    "compute_and_print_acc_stats(distance_hits, distance_total, diff_hits, diff_total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task.databuilder.sentence_encoder().encode(\"unk\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
