{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Analyses of BERT learnt based on high-order frequency distributions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from vocab_mismatch_utils import *\n",
    "from data_formatter_utils import *\n",
    "from datasets import DatasetDict\n",
    "from datasets import Dataset\n",
    "from datasets import load_dataset\n",
    "import transformers\n",
    "import pandas as pd\n",
    "import operator\n",
    "from collections import OrderedDict\n",
    "from tqdm import tqdm, trange\n",
    "\n",
    "import collections\n",
    "import os\n",
    "import unicodedata\n",
    "from typing import List, Optional, Tuple\n",
    "\n",
    "from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace\n",
    "from transformers.utils import logging\n",
    "import torch\n",
    "logger = logging.get_logger(__name__)\n",
    "import numpy as np\n",
    "import copy\n",
    "from nltk.stem import WordNetLemmatizer\n",
    "lemmatizer = WordNetLemmatizer() \n",
    "from word_forms.word_forms import get_word_forms\n",
    "\n",
    "seed = 42\n",
    "# set seeds again at start\n",
    "torch.manual_seed(seed)\n",
    "np.random.seed(seed)\n",
    "random.seed(seed)\n",
    "\n",
    "from functools import partial\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "font = {'family' : 'Times New Roman',\n",
    "        'size'   : 15}\n",
    "plt.rc('font', **font)\n",
    "\n",
    "import math\n",
    "import seaborn as sb\n",
    "\n",
    "import transformers\n",
    "from transformers import (\n",
    "    AutoConfig,\n",
    "    AutoModelForSequenceClassification,\n",
    "    AutoTokenizer,\n",
    "    EvalPrediction,\n",
    "    HfArgumentParser,\n",
    "    PretrainedConfig,\n",
    "    Trainer,\n",
    "    TrainingArguments,\n",
    "    default_data_collator,\n",
    "    set_seed,\n",
    "    EarlyStoppingCallback\n",
    ")\n",
    "from transformers.trainer_utils import is_main_process, EvaluationStrategy\n",
    "from tabulate import tabulate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# task setups\n",
    "task_name = \"sst3\"\n",
    "num_labels = 3\n",
    "FILENAME_CONFIG = {\n",
    "    \"sst3\" : \"sst-tenary\"\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### First-order frequency and label correlations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# let us corrupt SST3 in the same way as before\n",
    "train_df = pd.read_csv(os.path.join(external_output_dirname, FILENAME_CONFIG[task_name], \n",
    "                                    f\"{FILENAME_CONFIG[task_name]}-train.tsv\"), \n",
    "                       delimiter=\"\\t\")\n",
    "eval_df = pd.read_csv(os.path.join(external_output_dirname, FILENAME_CONFIG[task_name], \n",
    "                                   f\"{FILENAME_CONFIG[task_name]}-dev.tsv\"), \n",
    "                      delimiter=\"\\t\")\n",
    "test_df = pd.read_csv(os.path.join(external_output_dirname, FILENAME_CONFIG[task_name], \n",
    "                                   f\"{FILENAME_CONFIG[task_name]}-test.tsv\"), \n",
    "                      delimiter=\"\\t\")\n",
    "\n",
    "train_df = Dataset.from_pandas(train_df)\n",
    "eval_df = Dataset.from_pandas(eval_df)\n",
    "test_df = Dataset.from_pandas(test_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "modified_basic_tokenizer = ModifiedBasicTokenizer()\n",
    "label_vocab_map = {}\n",
    "token_frequency_map = {} # overwrite this everytime for a new dataset\n",
    "for i, example in enumerate(train_df):\n",
    "    if i % 10000 == 0 and i != 0:\n",
    "        print(f\"processing #{i} example...\")\n",
    "    original_sentence = example['text']\n",
    "    label = example['label']\n",
    "    if len(original_sentence.strip()) != 0:\n",
    "        tokens = modified_basic_tokenizer.tokenize(original_sentence)\n",
    "        if label not in label_vocab_map.keys():\n",
    "            label_vocab_map[label] = tokens\n",
    "        else:\n",
    "            for t in tokens:\n",
    "                label_vocab_map[label].append(t)\n",
    "        for t in tokens:\n",
    "            if t in token_frequency_map.keys():\n",
    "                token_frequency_map[t] = token_frequency_map[t] + 1\n",
    "            else:\n",
    "                token_frequency_map[t] = 1\n",
    "task_token_frequency_map = sorted(token_frequency_map.items(), key=operator.itemgetter(1), reverse=True)\n",
    "task_token_frequency_map = OrderedDict(task_token_frequency_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "freq_set = set([])\n",
    "for k, v in task_token_frequency_map.items():\n",
    "    freq_set.add(v)\n",
    "freq_set = list(freq_set)\n",
    "freq_set.sort()\n",
    "freq_bucket = np.logspace(math.log(freq_set[0], 10), math.log(freq_set[-1], 10), 25, endpoint=True)\n",
    "freq_bucket = freq_bucket[:-1]\n",
    "freq_bucket = [math.ceil(n) for n in freq_bucket]\n",
    "# finally the bucket is a map between freq and bucket number\n",
    "def find_bucket_number(freq, freq_bucket):\n",
    "    for i in range(len(freq_bucket)):\n",
    "        if freq > freq_bucket[i]:\n",
    "            continue\n",
    "        else:\n",
    "            return i+1\n",
    "    return len(freq_bucket)\n",
    "\n",
    "freq_bucket_map = {}\n",
    "for freq in freq_set:\n",
    "    bucket_num = find_bucket_number(freq, freq_bucket)\n",
    "    freq_bucket_map[freq] = bucket_num\n",
    "    \n",
    "# only looking at words that are unique to each label, otherwise long-tail dist dominate!\n",
    "label_token_freq_bucket_map = {}\n",
    "for k, v in label_vocab_map.items():\n",
    "    freq_counts = []\n",
    "    for t in v:\n",
    "        freq_counts.append(freq_bucket_map[task_token_frequency_map[t]])\n",
    "    label_token_freq_bucket_map[k] = freq_counts\n",
    "    \n",
    "# have to take samples in order to remove the bias\n",
    "min_len = 99999999\n",
    "for k, v in label_token_freq_bucket_map.items():\n",
    "    if len(v) < min_len:\n",
    "        min_len = len(v)\n",
    "sampled_label_buckets = {}\n",
    "for k, v in label_token_freq_bucket_map.items():\n",
    "    sampled_label_buckets[k] = random.sample(v, k=min_len)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "quantitive results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PMI of frequency bucket for each label\n",
    "\n",
    "# p_label\n",
    "p_label_sum = sum(list(collections.Counter(train_df[\"label\"]).values()))\n",
    "p_label = [v/p_label_sum for v in list(collections.Counter(train_df[\"label\"]).values())]\n",
    "label = list(collections.Counter(train_df[\"label\"]).keys())\n",
    "p_label = dict(zip(label, p_label))\n",
    "\n",
    "# p_bucket\n",
    "p_bucket = {}\n",
    "p_bucket_sum = 0\n",
    "for i in range(0, num_labels):\n",
    "    for b in label_token_freq_bucket_map[i]:\n",
    "        if b in p_bucket.keys():\n",
    "            p_bucket[b] += 1\n",
    "        else:\n",
    "            p_bucket[b] = 1\n",
    "        p_bucket_sum += 1\n",
    "for k, v in p_bucket.items():\n",
    "    p_bucket[k] = v/p_bucket_sum\n",
    "    \n",
    "# p_label_bucket\n",
    "p_label_bucket = {}\n",
    "for i in range(0, num_labels):\n",
    "    p_label_bucket[i] = {}\n",
    "    for b in label_token_freq_bucket_map[i]:\n",
    "        if b in p_label_bucket[i].keys():\n",
    "            p_label_bucket[i][b] += 1\n",
    "        else:\n",
    "            p_label_bucket[i][b] = 1\n",
    "    for k, v in p_label_bucket[i].items():\n",
    "        p_label_bucket[i][k] = (v/p_bucket_sum)/(p_bucket[k]*p_label[i])\n",
    "        \n",
    "sorted_p_label_buckets = []\n",
    "for i in range(0, num_labels):\n",
    "    sorted_p_label_bucket = sorted(p_label_bucket[i].items(), key=operator.itemgetter(1),reverse=True)\n",
    "    sorted_p_label_buckets.append(sorted_p_label_bucket)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get how many examples containing different buckets\n",
    "label_bucket_example_count_map = {}\n",
    "for i, example in enumerate(train_df):\n",
    "    if i % 10000 == 0 and i != 0:\n",
    "        print(f\"processing #{i} example...\")\n",
    "    original_sentence = example['text']\n",
    "    label = example['label']\n",
    "    if len(original_sentence.strip()) != 0:\n",
    "        tokens = modified_basic_tokenizer.tokenize(original_sentence)\n",
    "        if label not in label_bucket_example_count_map.keys():\n",
    "            label_bucket_example_count_map[label] = {}\n",
    "        buckets = set([])\n",
    "        for t in tokens:\n",
    "            bucket = freq_bucket_map[task_token_frequency_map[t]]\n",
    "            buckets.add(bucket)\n",
    "        for b in buckets:\n",
    "            if b not in label_bucket_example_count_map[label].keys():\n",
    "                label_bucket_example_count_map[label][b] = 1\n",
    "            else:\n",
    "                label_bucket_example_count_map[label][b] += 1\n",
    "\n",
    "for k, v in label_bucket_example_count_map.items():\n",
    "    count_example = collections.Counter(train_df[\"label\"])[k]\n",
    "    for bucket, count in label_bucket_example_count_map[k].items():\n",
    "        label_bucket_example_count_map[k][bucket] = count/count_example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the stats table for different labels\n",
    "headers = []\n",
    "for i in range(0, num_labels):\n",
    "    headers += [f\"label_{i}_pmi\", f\"label_{i}_prob\"]\n",
    "top_k = 5\n",
    "lines = []\n",
    "for i in range(top_k):\n",
    "    line = []\n",
    "    for j in range(0, num_labels):\n",
    "        pmi_bucket = sorted_p_label_buckets[j][i][0]\n",
    "        prob = label_bucket_example_count_map[j][pmi_bucket]\n",
    "        line.append(f\"bucket[#{pmi_bucket}]\")\n",
    "        line.append(round(prob, 4))\n",
    "    lines.append(line)\n",
    "print(tabulate(lines, headers=headers))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "qualitative results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams[\"font.family\"] = \"Times New Roman\"\n",
    "plt.style.use(\"default\")\n",
    "font = {'family' : 'Times New Roman',\n",
    "        'size'   : 20}\n",
    "plt.rc('font', **font)\n",
    "\n",
    "count_to_display = 4\n",
    "groupby_names = [f\"bucket#{i+1}\" for i in range(0, count_to_display)]\n",
    "from itertools import groupby\n",
    "counts_all = []\n",
    "for k, v in sampled_label_buckets.items():\n",
    "    counter_k = collections.Counter(sampled_label_buckets[k])\n",
    "    counts = []\n",
    "    for i in range(1, count_to_display+1):\n",
    "        counts.append(counter_k[i])\n",
    "    counts_all.append(counts)\n",
    "groups = counts_all\n",
    "group_names = ['negative', 'positive', 'neutral']\n",
    "\n",
    "x = np.array([0, 2, 4, 6])  # the label\n",
    "width = 0.35  # the width of the bars\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(9,7))\n",
    "rects1 = ax.bar(x - width, groups[0], width, label=group_names[0], edgecolor='black', color=\"red\")\n",
    "rects2 = ax.bar(x, groups[1], width, label=group_names[1], edgecolor='black', color=\"yellow\")\n",
    "rects3 = ax.bar(x + width, groups[2], width, label=group_names[2], edgecolor='black', color=\"blue\")\n",
    "\n",
    "# Add some text for labels, title and custom x-axis tick labels, etc.\n",
    "ax.set_ylabel('frequency')\n",
    "ax.set_yscale('log')\n",
    "ax.set_xticks(x)\n",
    "ax.set_xticklabels(groupby_names)\n",
    "ax.legend(loc='lower center', bbox_to_anchor=(0.5, -0.2),\n",
    "      ncol=3, fancybox=True, shadow=True, fontsize=15)\n",
    "\n",
    "def autolabel(rects):\n",
    "    \"\"\"Attach a text label above each bar in *rects*, displaying its height.\"\"\"\n",
    "    for rect in rects:\n",
    "        height = rect.get_height()\n",
    "        ax.annotate('{:.0f}'.format(height),\n",
    "                    xy=(rect.get_x() + rect.get_width() / 2, height),\n",
    "                    xytext=(0, 3),  # 3 points vertical offset\n",
    "                    textcoords=\"offset points\",\n",
    "                    ha='center', va='bottom', fontsize=10)\n",
    "\n",
    "autolabel(rects1)\n",
    "autolabel(rects2)\n",
    "autolabel(rects3)\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(8,6))\n",
    "ax = fig.add_subplot(111)\n",
    "g1 = ax.hist(label_token_freq_bucket_map[0], bins=len(freq_bucket), facecolor='b', alpha = 0.5)\n",
    "g2 = ax.hist(label_token_freq_bucket_map[1], bins=len(freq_bucket), facecolor='g', alpha = 0.5)\n",
    "g3 = ax.hist(label_token_freq_bucket_map[2], bins=len(freq_bucket), facecolor='y', alpha = 0.5)\n",
    "plt.grid(True)\n",
    "plt.grid(color='black', linestyle='-.')\n",
    "import matplotlib.ticker as mtick\n",
    "ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.2e'))\n",
    "ax.set_yscale('log')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Second-order frequency and label correlations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_freq_freq_map = {}\n",
    "label_freq_freq_map[0] = []\n",
    "label_freq_freq_map[1] = []\n",
    "label_freq_freq_map[2] = []\n",
    "\n",
    "for i, example in enumerate(train_df):\n",
    "    if i % 10000 == 0 and i != 0:\n",
    "        print(f\"processing #{i} example...\")\n",
    "    original_sentence = example['text']\n",
    "    label = example['label']\n",
    "    if len(original_sentence.strip()) != 0:\n",
    "        tokens = modified_basic_tokenizer.tokenize(original_sentence)\n",
    "        # make the matrix symmetric. i guess we can also just look at one side.\n",
    "        for i in range(len(tokens)):\n",
    "            for j in range(len(tokens)):\n",
    "                t1 = tokens[i]\n",
    "                t2 = tokens[j]\n",
    "                freq_tuple = tuple([freq_bucket_map[token_frequency_map[t1]], freq_bucket_map[token_frequency_map[t2]]])\n",
    "                label_freq_freq_map[label].append(freq_tuple)\n",
    "\n",
    "# have to take samples in order to remove the bias\n",
    "min_len = 99999999\n",
    "for k, v in label_freq_freq_map.items():\n",
    "    if len(v) < min_len:\n",
    "        min_len = len(v)\n",
    "sampled_label_buckets = {}\n",
    "for k, v in label_freq_freq_map.items():\n",
    "    sampled_label_buckets[k] = random.sample(v, k=min_len)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "quantitative results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# PMI of 2nd frequency tuple for each label\n",
    "\n",
    "# p_label\n",
    "p_label_sum = sum(list(collections.Counter(train_df[\"label\"]).values()))\n",
    "p_label = [v/p_label_sum for v in list(collections.Counter(train_df[\"label\"]).values())]\n",
    "label = list(collections.Counter(train_df[\"label\"]).keys())\n",
    "p_label = dict(zip(label, p_label))\n",
    "\n",
    "# p_bucket\n",
    "p_bucket = {}\n",
    "p_bucket_sum = 0\n",
    "for i in range(0, num_labels):\n",
    "    for b in label_freq_freq_map[i]:\n",
    "        if b in p_bucket.keys():\n",
    "            p_bucket[b] += 1\n",
    "        else:\n",
    "            p_bucket[b] = 1\n",
    "        p_bucket_sum += 1\n",
    "for k, v in p_bucket.items():\n",
    "    p_bucket[k] = v/p_bucket_sum\n",
    "    \n",
    "# p_label_bucket\n",
    "p_label_bucket = {}\n",
    "for i in range(0, num_labels):\n",
    "    p_label_bucket[i] = {}\n",
    "    for b in label_freq_freq_map[i]:\n",
    "        if b in p_label_bucket[i].keys():\n",
    "            p_label_bucket[i][b] += 1\n",
    "        else:\n",
    "            p_label_bucket[i][b] = 1\n",
    "    for k, v in p_label_bucket[i].items():\n",
    "        p_label_bucket[i][k] = (v/p_bucket_sum)/(p_bucket[k]*p_label[i])\n",
    "        \n",
    "sorted_p_label_buckets = []\n",
    "for i in range(0, num_labels):\n",
    "    sorted_p_label_bucket = sorted(p_label_bucket[i].items(), key=operator.itemgetter(1),reverse=True)\n",
    "    sorted_p_label_buckets.append(sorted_p_label_bucket)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get how many examples containing different buckets\n",
    "label_bucket_example_count_map = {}\n",
    "for i, example in enumerate(train_df):\n",
    "    if i % 10000 == 0 and i != 0:\n",
    "        print(f\"processing #{i} example...\")\n",
    "    original_sentence = example['text']\n",
    "    label = example['label']\n",
    "    if len(original_sentence.strip()) != 0:\n",
    "        tokens = modified_basic_tokenizer.tokenize(original_sentence)\n",
    "        if label not in label_bucket_example_count_map.keys():\n",
    "            label_bucket_example_count_map[label] = {}\n",
    "        buckets = set([])\n",
    "        for i in range(len(tokens)):\n",
    "            for j in range(len(tokens)):\n",
    "                t1 = tokens[i]\n",
    "                t2 = tokens[j]\n",
    "                freq_tuple = tuple([freq_bucket_map[token_frequency_map[t1]], freq_bucket_map[token_frequency_map[t2]]])\n",
    "                buckets.add(freq_tuple)\n",
    "        for b in buckets:\n",
    "            if b not in label_bucket_example_count_map[label].keys():\n",
    "                label_bucket_example_count_map[label][b] = 1\n",
    "            else:\n",
    "                label_bucket_example_count_map[label][b] += 1\n",
    "\n",
    "for k, v in label_bucket_example_count_map.items():\n",
    "    count_example = collections.Counter(train_df[\"label\"])[k]\n",
    "    for bucket, count in label_bucket_example_count_map[k].items():\n",
    "        label_bucket_example_count_map[k][bucket] = count/count_example"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "headers = []\n",
    "for i in range(0, num_labels):\n",
    "    headers += [f\"label_{i}_pmi\", f\"label_{i}_prob\"]\n",
    "top_k = 10\n",
    "lines = []\n",
    "for i in range(top_k):\n",
    "    line = []\n",
    "    for j in range(0, num_labels):\n",
    "        pmi_bucket = sorted_p_label_buckets[j][i][0]\n",
    "        prob = label_bucket_example_count_map[j][pmi_bucket]\n",
    "        line.append(f\"bucket[(#,#){pmi_bucket}]\")\n",
    "        line.append(round(prob, 6))\n",
    "    if i % 2 == 0: # hacky way to skip the repeatitive pair of buckets\n",
    "        lines.append(line)\n",
    "print(tabulate(lines, headers=headers))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "quanlitative results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "label_freq_freq_bucket_map = {}\n",
    "for label, v in sampled_label_buckets.items():\n",
    "    label_freq_freq_bucket_map[label] = {}\n",
    "    for tu in sampled_label_buckets[label]:\n",
    "        if tu in label_freq_freq_bucket_map[label].keys():\n",
    "            label_freq_freq_bucket_map[label][tu] += 1\n",
    "        else:\n",
    "            label_freq_freq_bucket_map[label][tu] = 1\n",
    "# turing freq tuple into a heatmap\n",
    "label_freq_freq_2d_map = {}\n",
    "for label, _ in sampled_label_buckets.items():\n",
    "    label_freq_freq_2d_map[label] = torch.zeros(len(freq_bucket), len(freq_bucket))\n",
    "for label, f_f_m in label_freq_freq_bucket_map.items():\n",
    "    for k, v in f_f_m.items():\n",
    "        label_freq_freq_2d_map[label][k[0]-1, k[1]-1] = v\n",
    "\n",
    "label_freq_freq_2d_map_norm = {}\n",
    "for label, f_f_2d_m in label_freq_freq_2d_map.items():\n",
    "    f_f_2d_m_norm = torch.zeros_like(f_f_2d_m)\n",
    "    for i in range(f_f_2d_m_norm.shape[0]):\n",
    "        for j in range(f_f_2d_m_norm.shape[1]):\n",
    "            if f_f_2d_m[max(i,j),max(i,j)] != 0.0:\n",
    "                f_f_2d_m_norm[i,j] = f_f_2d_m[i,j] / f_f_2d_m[max(i,j),max(i,j)]\n",
    "            else:\n",
    "                f_f_2d_m_norm[i,j] = 0.0\n",
    "    label_freq_freq_2d_map_norm[label] = f_f_2d_m_norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(label_freq_freq_2d_map_norm[0].numpy())\n",
    "mask = np.zeros_like(label_freq_freq_2d_map_norm[0])\n",
    "mask[np.tril_indices_from(mask)] = True\n",
    "_ = sb.heatmap(df, cmap=\"Blues\", square=True, linewidth=0.1, cbar_kws={\"shrink\": .8}, \n",
    "               vmin=0.0)\n",
    "plt.ylim(0, 24)\n",
    "plt.xticks([])\n",
    "plt.yticks([])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(label_freq_freq_2d_map_norm[1].numpy())\n",
    "_ = sb.heatmap(df, cmap=\"Blues\", square=True, linewidth=0.1, cbar_kws={\"shrink\": .8},\n",
    "               vmin=0.0, vmax=0.8)\n",
    "plt.ylim(0, 24)\n",
    "plt.xticks([])\n",
    "plt.yticks([])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(label_freq_freq_2d_map_norm[2].numpy())\n",
    "_ = sb.heatmap(df, cmap=\"Blues\", square=True, linewidth=0.1, cbar_kws={\"shrink\": .8},\n",
    "               vmin=0.0)\n",
    "plt.ylim(0, 24)\n",
    "plt.xticks([])\n",
    "plt.yticks([])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame((label_freq_freq_2d_map_norm[0]-label_freq_freq_2d_map_norm[1]).numpy())\n",
    "_ = sb.heatmap(df, square=True, linewidth=0.1, cbar_kws={\"shrink\": .8})\n",
    "plt.ylim(0, 24)\n",
    "plt.xticks([])\n",
    "plt.yticks([])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame((label_freq_freq_2d_map_norm[0]-label_freq_freq_2d_map_norm[2]).numpy())\n",
    "_ = sb.heatmap(df, square=True, linewidth=0.1, cbar_kws={\"shrink\": .8})\n",
    "plt.ylim(0, 24)\n",
    "plt.xticks([])\n",
    "plt.yticks([])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame((label_freq_freq_2d_map_norm[1]-label_freq_freq_2d_map_norm[2]).numpy())\n",
    "_ = sb.heatmap(df, square=True, linewidth=0.1, cbar_kws={\"shrink\": .8})\n",
    "plt.ylim(0, 24)\n",
    "plt.xticks([])\n",
    "plt.yticks([])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Running BERT sentence embeddings and 2nd order frequency information"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from models.modeling_bert import CustomerizedBertForSequenceClassification\n",
    "NUM_LABELS = 3\n",
    "MAX_SEQ_LEN = 128\n",
    "CACHE_DIR = \"../tmp/\"\n",
    "MODEL_TYPE = \"bert-base-uncased\"\n",
    "MODEL_PATH = \"../saved-models/sst-tenary-finetuned-bert-base-uncased-3B/pytorch_model.bin\"\n",
    "NUM_LABEL_CONFIG = {\n",
    "    \"sst2\": 2,\n",
    "    \"sst3\": 3\n",
    "}\n",
    "config = AutoConfig.from_pretrained(\n",
    "    MODEL_TYPE,\n",
    "    num_labels=NUM_LABEL_CONFIG[task_name],\n",
    "    finetuning_task=task_name,\n",
    "    cache_dir=CACHE_DIR\n",
    ")\n",
    "tokenizer = AutoTokenizer.from_pretrained(\n",
    "    MODEL_TYPE,\n",
    "    use_fast=False,\n",
    "    cache_dir=CACHE_DIR\n",
    ")\n",
    "model = CustomerizedBertForSequenceClassification.from_pretrained(\n",
    "    MODEL_PATH,\n",
    "    from_tf=False,\n",
    "    config=config,\n",
    "    cache_dir=CACHE_DIR\n",
    ")\n",
    "SAMPLE_LIMIT=1000\n",
    "train_df = train_df.shuffle(seed=seed)\n",
    "train_df_subset = train_df.select(range(SAMPLE_LIMIT))\n",
    "TASK_CONFIG = {\n",
    "    \"sst3\": (\"text\", None)\n",
    "}\n",
    "sentence1_key, sentence2_key = TASK_CONFIG[task_name]\n",
    "padding = \"max_length\"\n",
    "label_to_id = None\n",
    "def preprocess_function(examples):\n",
    "    # Tokenize the texts\n",
    "    args = (\n",
    "        (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])\n",
    "    )\n",
    "    result = tokenizer(*args, padding=padding, max_length=MAX_SEQ_LEN, truncation=True)\n",
    "    # Map labels to IDs (not necessary for GLUE tasks)\n",
    "    if label_to_id is not None and \"label\" in examples:\n",
    "        result[\"label\"] = [label_to_id[l] for l in examples[\"label\"]]\n",
    "    return result\n",
    "train_df_subset = train_df_subset.map(preprocess_function, batched=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader\n",
    "subset_dataloader = DataLoader(train_df_subset, 1, shuffle=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "for i, batch_dataloader in enumerate(tqdm(subset_dataloader)):\n",
    "    input_ids = torch.cat(batch_dataloader['input_ids'], dim=0).unsqueeze(dim=0)\n",
    "    attention_mask = torch.cat(batch_dataloader['attention_mask'], dim=0).unsqueeze(dim=0)\n",
    "    hidden_states = model.forward_simple(input_ids=input_ids, attention_mask=attention_mask)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
