{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import tensorflow as tf\n",
    "from tasks.tasks import SvAgreementLM, WordSvAgreementLM, WordSvAgreementVP\n",
    "from tf2_models.lm_transformer import LmGPT2, ClassifierGPT2\n",
    "from util.config_util import get_model_params, get_task_params, get_train_params\n",
    "from tf2_models.lm_lstm import LmLSTM, LmLSTMSharedEmb, ClassifierLSTM\n",
    "from tf2_models.trainer import Trainer\n",
    "from absl import app\n",
    "from absl import flags\n",
    "from util import constants\n",
    "from collections import Counter\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "from tf2_models.metrics import *\n",
    "\n",
    "MODELS = {\"lm_lstm\": LmLSTM,\n",
    "          \"lm_gpt2\": LmGPT2,\n",
    "          \"lm_lstm_shared_emb\": LmLSTMSharedEmb,\n",
    "          'cl_gpt2': ClassifierGPT2,\n",
    "          'cl_lstm': ClassifierLSTM}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_dir = \"../logs\"\n",
    "chkpt_dir = \"../tf_ckpts\"\n",
    "exp_name = \"nol2_batchsumloss\"\n",
    "\n",
    "task = WordSvAgreementLM(task_params=get_task_params(),data_dir='../data')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_config = 'lstm_drop30_v2'\n",
    "model_name = 'lm_lstm_shared_emb'\n",
    "train_config ='radam_fast'\n",
    "# Create the Model\n",
    "model_params = get_model_params(task,model_name, model_config)\n",
    "print(\"model_params: \", model_params.__dict__)\n",
    "\n",
    "model = MODELS[model_name](hparams=get_model_params(task,model_name, model_config))\n",
    "\n",
    "trainer_params = get_train_params(train_config)\n",
    "\n",
    "log_dir = os.path.join(log_dir,task.name, model.model_name+\"_\"+str(model_config)+\"_\"+str(trainer_params.learning_rate)+\"_\"+exp_name)\n",
    "ckpt_dir = os.path.join(chkpt_dir,task.name, model.model_name+\"_\"+str(model_config)+\"_\"+str(trainer_params.learning_rate)+\"_\"+exp_name)\n",
    "\n",
    "print(log_dir)\n",
    "\n",
    "trainer = Trainer(task=task,\n",
    "                model=model,\n",
    "                train_params=get_train_params('radam_fast'),\n",
    "                log_dir=log_dir,\n",
    "                ckpt_dir=ckpt_dir)\n",
    "\n",
    "trainer.restore()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.model.evaluate(task.test_dataset, steps=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for x,y in task.test_dataset:\n",
    "    print(len(x))\n",
    "    mask = tf.cast(y > 0, dtype=tf.float32)\n",
    "    logits = model(x)\n",
    "    print(logits.shape)\n",
    "    logits = tf.reshape(logits, (-1, logits.shape[-1]))\n",
    "    targets = tf.reshape(y, (-1, 1))\n",
    "    mask = tf.reshape(mask, (-1, 1))\n",
    "    correct = tf.cast(tf.argmax(model(x), axis=-1) == y, dtype=tf.float32)\n",
    "    print(logits.shape)\n",
    "    print(targets.shape)\n",
    "    print(model.loss)\n",
    "    print(model.loss(y_pred=logits, y_true=targets))\n",
    "    print(tf.reduce_sum(masked_sequence_loss(y_pred=logits, y_true=targets)))\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_inflect_from_vocab(vocab_file, freq_threshold=1000):\n",
    "    vbp = {}\n",
    "    vbz = {}\n",
    "    nn = {}\n",
    "    nns = {}\n",
    "    from_pos = {'NNS': nns, 'NN': nn, 'VBP': vbp, 'VBZ': vbz}\n",
    "\n",
    "    for line in open(vocab_file):\n",
    "        if line.startswith(' '):   # empty string token\n",
    "            continue\n",
    "        word, pos, count = line.strip().split()\n",
    "        count = int(count)\n",
    "        if len(word) > 1 and pos in from_pos and count >= freq_threshold:\n",
    "            from_pos[pos][word] = count\n",
    "\n",
    "    verb_infl = {'VBP': 'VBZ', 'VBZ': 'VBP'}\n",
    "    for word, count in vbz.items():\n",
    "        candidate = infl_eng.plural_verb(word)\n",
    "        if candidate in vbp:\n",
    "            verb_infl[candidate] = word\n",
    "            verb_infl[word] = candidate\n",
    "\n",
    "    noun_infl = {'NN': 'NNS', 'NNS': 'NN'}\n",
    "    for word, count in nn.items():\n",
    "        candidate = infl_eng.plural_noun(word)\n",
    "        if candidate in nns:\n",
    "            noun_infl[candidate] = word\n",
    "            noun_infl[word] = candidate\n",
    "\n",
    "    return verb_infl, noun_infl"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from util import inflect\n",
    "\n",
    "infl_eng = inflect.engine()\n",
    "\n",
    "dependency_fields = ['sentence', 'orig_sentence', 'pos_sentence',\n",
    "                     'subj', 'verb', 'subj_pos', 'has_rel', 'has_nsubj',\n",
    "                     'verb_pos', 'subj_index', 'verb_index', 'n_intervening',\n",
    "                     'last_intervening', 'n_diff_intervening', 'distance',\n",
    "                     'max_depth', 'all_nouns', 'nouns_up_to_verb']\n",
    "\n",
    "verb_infl, noun_infl = gen_inflect_from_vocab('wiki.vocab')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "distance_hits = Counter()\n",
    "distance_total = Counter()\n",
    "diff_hits = Counter()\n",
    "diff_total = Counter()\n",
    "\n",
    "test_data = task.databuilder.as_dataset(split='test', batch_size=1000)\n",
    "e = 0\n",
    "for example in tqdm(test_data):\n",
    "    e += 1\n",
    "    encoded_sentences = example['sentence']\n",
    "    s_shape = tf.shape(encoded_sentences)\n",
    "    batch_size, length = s_shape[0], s_shape[1]\n",
    "    bos = tf.ones((batch_size,1), dtype=tf.int64) * task.databuilder.sentence_encoder().encode(constants.bos)\n",
    "    eos = tf.ones((batch_size,1), dtype=tf.int64) * task.databuilder.sentence_encoder().encode(constants.eos)\n",
    "\n",
    "    encoded_sentences = tf.concat([bos, encoded_sentences, eos], axis=1)\n",
    "    \n",
    "    actual_verbs = example['verb']\n",
    "    inflected_verbs = [verb_infl[v.decode(\"utf-8\")] for v in actual_verbs.numpy()]\n",
    "    verb_indexes = example['verb_position'] - 1\n",
    "    distances = example['distance'].numpy()\n",
    "    nz = example['n_intervening'].numpy()\n",
    "    n_diffs = example['n_diff_intervening'].numpy()\n",
    "    \n",
    "    sentence =  task.databuilder.sentence_encoder().decode(encoded_sentences[0])\n",
    "    actual_verb_indexes = [task.databuilder.sentence_encoder().encode(v)[0] for v in actual_verbs.numpy()]\n",
    "    inflected_verb_indexes = [task.databuilder.sentence_encoder().encode(v)[0] for v in inflected_verbs]\n",
    "\n",
    "    \n",
    "    scores = model(encoded_sentences)\n",
    "    actual_batch_indexes = [ (i,verb_indexes[i], actual_verb_indexes[i]) for i in range(len(verb_indexes))]\n",
    "    actual_scores = tf.compat.v2.gather_nd(scores, actual_batch_indexes)\n",
    "\n",
    "    inflected_batch_indexes = [ (i,verb_indexes[i], inflected_verb_indexes[i]) for i in range(len(verb_indexes))]\n",
    "    infelected_scores = tf.compat.v2.gather_nd(scores, inflected_batch_indexes)\n",
    "    \n",
    "    corrects = actual_scores > infelected_scores\n",
    "    for i, c in enumerate(corrects):\n",
    "        if verb_indexes[i] == 10035:\n",
    "            continue\n",
    "        if nz[i] > 4 or distances[i] > 16:\n",
    "            continue\n",
    "            \n",
    "        distance_total[distances[i]] += 1\n",
    "        distance_hits[distances[i]] += int(c)\n",
    "        if nz[i] == n_diffs[i]:\n",
    "            n = nz[i]\n",
    "            diff_total[n] += 1\n",
    "            diff_hits[n] += int(c)\n",
    "            "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dis_acc = {}\n",
    "dis_acc = np.zeros(17)\n",
    "dif_acc = np.zeros(5)\n",
    "print('Accuracy by distance')\n",
    "for k in sorted(distance_hits.keys()):\n",
    "    v = distance_hits[k]\n",
    "    acc = v / distance_total[k]\n",
    "    dis_acc[k] = acc\n",
    "    print(\"%d | %.2f\" % (k, acc), distance_total[k])\n",
    "\n",
    "print('Accuracy by intervenings')\n",
    "for k in sorted(diff_hits.keys()):\n",
    "    v = diff_hits[k]\n",
    "    acc = v * 1./diff_total[k]\n",
    "    print(\"%d | %.2f\" % (k, acc), diff_total[k])\n",
    "    dif_acc[k] = acc\n",
    "\n",
    "stats = {'distance': dis_acc, 'intervenings': dif_acc}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task.databuilder.sentence_encoder().encode(\"unk\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.style.use('classic')\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "sns.set()\n",
    "\n",
    "d = dis_acc\n",
    "lists = sorted(d.items()) # sorted by key, return a list of tuples\n",
    "x, y = zip(*lists) # unpack a list of pairs into two tuples\n",
    "\n",
    "plt.plot(x, y)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
