{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fc16db8-9345-4f15-88f5-a448b8d98448",
   "metadata": {},
   "outputs": [],
   "source": [
    "import distutils.spawn\n",
    "import itertools\n",
    "import os\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "from matplotlib import rc\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import pickle\n",
    "import seaborn as sns\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from transformers import AutoTokenizer, BertModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df666025",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs(\"figures\", exist_ok=True)\n",
    "os.makedirs(\"pickles\", exist_ok=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73aa59ce-3b9b-4921-84d6-b0f74899ad2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "sns.set(font_scale=1.15)\n",
    "if distutils.spawn.find_executable('latex'):\n",
    "    rc('font', **{'family': 'serif', 'serif': ['Computer Modern']})\n",
    "    rc('text', usetex=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d9ad9e89-1216-4988-91be-43f3d4a6083e",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "model = BertModel.from_pretrained(\"bert-base-uncased\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d57bbe33-a530-4cd4-a88a-7fe803586513",
   "metadata": {},
   "outputs": [],
   "source": [
    "# The examples are obtained by doing a cartesian product over all words in the lists.\n",
    "# The label is obtained as the product of the numbers after the words.\n",
    "train_lists_a = {\n",
    "    'city': [('The city is', 1)],\n",
    "    'city_size': [('nice.', 1), ('dreadful.', 1), ('clean.', 1), ('dirty.', 1)],\n",
    "    'pronouns': [('my', 1), ('his', 1), ('her', 1), ('our', 1)],\n",
    "    'color': [('', 1), ('blue', 1), ('orange', 1), ('beige', 1), ('green', 1), ('red', 1)],\n",
    "    'animal': [('dog', 1), ('owl', 1), ('cow', 1)],\n",
    "    'verb': [('is', 1)],\n",
    "    '(adverb': [('', 1), ('very', 1), ('extremely', 1)],\n",
    "    'adjective': [('nice', 1), ('mean', -1), ('cute', 1), ('dreadful', -1), ('aggressive', -1), ('delightful', 1)],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ce2b07b-c0be-459a-99a7-813bbb6a1813",
   "metadata": {},
   "outputs": [],
   "source": [
    "# another order.\n",
    "train_lists_b = {\n",
    "    'city': [('The city is', 1)],\n",
    "    'city_size': [('nice.', 1), ('dreadful.', 1), ('clean.', 1), ('dirty.', 1)],\n",
    "    'pronouns': [('my', 1), ('his', 1), ('her', 1), ('our', 1)],\n",
    "    'color': [('', 1), ('blue', 1), ('orange', 1), ('beige', 1), ('green', 1), ('red', 1)],\n",
    "    'adjective': [('nice', 1), ('mean', -1), ('cute', 1), ('dreadful', -1), ('aggressive', -1), ('delightful', 1)],\n",
    "    'animal': [('dog', 1), ('owl', 1), ('cow', 1)],\n",
    "    'verb': [('is', 1)],\n",
    "    '(adverb': [('', 1), ('very', 1), ('extremely', 1)],\n",
    "    'size': [('tiny', 1), ('giant', 1)],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b6a0f37-2ef0-41c8-9d3c-add338829276",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_lists_a = {\n",
    "    'city': [('The city is', 1)],\n",
    "    'city_size': [('nice.', 1), ('dreadful.', 1), ('clean.', 1), ('dirty.', 1)],\n",
    "    'pronouns': [('their', 1), ('your', 1)],\n",
    "    'color': [('purple', 1), ('gray', 1)],\n",
    "    'adjective': [('nice', 1), ('mean', -1), ('cute', 1), ('dreadful', -1), ('aggressive', -1), ('delightful', 1)],\n",
    "    'animal': [('cat', 1), ('eagle', 1), ('goat', 1)],\n",
    "    'verb': [('is', 1)],\n",
    "    '(adverb': [('', 1), ('especially', 1)],\n",
    "    'size': [('big', 1), ('little', 1), ('mini', 1), ('enormous', 1)],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "760bbbfd-71ec-4eed-b4a5-08848d39bd6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# another order.\n",
    "test_lists_b = {\n",
    "    'city': [('The city is', 1)],\n",
    "    'city_size': [('nice.', 1), ('dreadful.', 1), ('clean.', 1), ('dirty.', 1)],\n",
    "    'pronouns': [('their', 1), ('your', 1)],\n",
    "    'color': [('purple', 1), ('gray', 1)],\n",
    "    'adjective': [('nice', 1), ('mean', -1), ('cute', 1), ('dreadful', -1), ('aggressive', -1), ('delightful', 1)],\n",
    "    'animal': [('cat', 1), ('eagle', 1), ('goat', 1)],\n",
    "    'verb': [('is', 1)],\n",
    "    '(adverb': [('', 1), ('especially', 1)],\n",
    "    'size': [('big', 1), ('little', 1), ('mini', 1), ('enormous', 1)],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5030afc8-90b0-40cd-a9b5-f6aa6105904c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# change the words that generate the labels.\n",
    "test_lists_c = {\n",
    "    'city': [('The city is', 1)],\n",
    "    'city_size': [('nice.', 1), ('dreadful.', 1), ('clean.', 1), ('dirty.', 1)],\n",
    "    'greetings': [('Good evening,', 1)],\n",
    "    'pronouns': [('their', 1), ('your', 1)],\n",
    "    'size': [('big', 1), ('little', 1), ('mini', 1), ('enormous', 1)],\n",
    "    'color': [('purple', 1), ('gray', 1)],\n",
    "    'animal': [('cat', 1), ('eagle', 1), ('goat', 1)],\n",
    "    'verb': [('is', 1)],\n",
    "    '(adverb': [('', 1), ('especially', 1)],\n",
    "    'adjective': [('lovely', 1), ('nasty', -1), ('charming', 1), ('foul', -1)],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5b8b0dd-9019-4d51-ac67-1be10967787b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# change the words that generate the labels + another order.\n",
    "test_lists_d = {\n",
    "    'city': [('The city is', 1)],\n",
    "    'city_size': [('nice.', 1), ('dreadful.', 1), ('clean.', 1), ('dirty.', 1)],\n",
    "    'greetings': [('Good evening,', 1)],\n",
    "    'pronouns': [('their', 1), ('your', 1)],\n",
    "    'adjective': [('lovely', 1), ('nasty', -1), ('charming', 1), ('foul', -1)],\n",
    "    'color': [('purple', 1), ('gray', 1)],\n",
    "    'animal': [('cat', 1), ('eagle', 1), ('goat', 1)],\n",
    "    'verb': [('is', 1)],\n",
    "    '(adverb': [('', 1), ('especially', 1)],\n",
    "    'size': [('big', 1), ('little', 1), ('mini', 1), ('enormous', 1)],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18095983-04a0-4560-b206-0dfee4ecbad5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# change the sentence.\n",
    "test_lists_e = {\n",
    "    'question': [('Hello, how are you?', 1)],\n",
    "    'greetings': [('Good evening,', 1)],\n",
    "    'pronouns': [('their', 1), ('your', 1)],\n",
    "    'color': [('purple', 1), ('gray', 1)],\n",
    "    'animal': [('cat', 1), ('eagle', 1), ('goat', 1)],\n",
    "    'verb': [('is', 1)],\n",
    "    '(adverb': [('', 1), ('especially', 1)],\n",
    "    'adjective': [('nice', 1), ('mean', -1), ('cute', 1), ('dreadful', -1), ('aggressive', -1), ('delightful', 1)],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5bd631d3-965e-4799-ab0b-f92478bfcb71",
   "metadata": {},
   "outputs": [],
   "source": [
    "# change the sentence and the words that generate the labels.\n",
    "test_lists_f = {\n",
    "    'question': [('Hello, how are you?', 1)],\n",
    "    'greetings': [('Good evening,', 1)],\n",
    "    'pronouns': [('their', 1), ('your', 1)],\n",
    "    'color': [('purple', 1), ('gray', 1)],\n",
    "    'animal': [('cat', 1), ('eagle', 1), ('goat', 1)],\n",
    "    'verb': [('is', 1)],\n",
    "    '(adverb': [('', 1), ('especially', 1)],\n",
    "    'adjective': [('lovely', 1), ('nasty', -1), ('charming', 1), ('foul', -1)],\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9058696f-ec0b-449a-9485-5d450c5f09e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create(raw_lists):\n",
    "    X = []\n",
    "    y = []\n",
    "    for sentence in itertools.product(*raw_lists.values()):\n",
    "        X.append(' '.join([word[0] for word in sentence]).replace('  ', ' ').strip())\n",
    "        y.append(np.prod([word[1] for word in sentence]))\n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25a0f315-8d71-4a37-b51c-76ff82b9ca35",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_X_a, train_y_a = create(train_lists_a)\n",
    "train_X_b, train_y_b = create(train_lists_b)\n",
    "train_X_total = train_X_a + train_X_b\n",
    "train_y_total = train_y_a + train_y_b"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5edadb23-ad76-4ed3-875e-d18d16fb6f56",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_X_a, test_y_a = create(test_lists_a)\n",
    "test_X_b, test_y_b = create(test_lists_b)\n",
    "test_X_c, test_y_c = create(test_lists_c)\n",
    "test_X_d, test_y_d = create(test_lists_d)\n",
    "test_X_e, test_y_e = create(test_lists_e)\n",
    "test_X_f, test_y_f = create(test_lists_f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e7e76082-0b2f-407e-b65a-617c7ea7e706",
   "metadata": {},
   "outputs": [],
   "source": [
    "def go_through_transformers(sentence_list):\n",
    "    result = []\n",
    "    for k, sentence in enumerate(sentence_list):\n",
    "        if k % 100 == 0:\n",
    "            print(k)\n",
    "        inputs = tokenizer(sentence, return_tensors=\"pt\")\n",
    "        result_k = []\n",
    "        outputs = model(**inputs, output_hidden_states=True)\n",
    "        for state in outputs.hidden_states:\n",
    "            # select only the first token of each sequence, i.e., [CLS] token\n",
    "            result_k.append(state[0][0].detach().numpy())\n",
    "        result.append(np.stack(result_k))\n",
    "    return np.stack(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb2466cc-ad7d-4146-89b0-3a917a1f5971",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "hidden_states_train = go_through_transformers(train_X_total)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53c81883-c5dd-4bcf-8d96-09bf799fe547",
   "metadata": {},
   "outputs": [],
   "source": [
    "hidden_states_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0a4d326-fe74-4d85-840a-178956f59f1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('pickles/hidden_states_train_X.pickle', 'wb') as f:\n",
    "    pickle.dump(hidden_states_train, f, pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f91eedbe-d62e-4a2e-8c06-664b2163e04f",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('pickles/hidden_states_train_X.pickle', 'rb') as f:\n",
    "    hidden_states_train = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30bfb110-9e16-4df2-8a55-211bbc25ab59",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "hidden_states_test_X_a = go_through_transformers(test_X_a)\n",
    "hidden_states_test_X_b = go_through_transformers(test_X_b)\n",
    "hidden_states_test_X_c = go_through_transformers(test_X_c)\n",
    "hidden_states_test_X_d = go_through_transformers(test_X_d)\n",
    "hidden_states_test_X_e = go_through_transformers(test_X_e)\n",
    "hidden_states_test_X_f = go_through_transformers(test_X_f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4aaa3f6c-99a2-418a-94e3-22140e459ada",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('pickles/hidden_states_test_X_a.pickle', 'wb') as f:\n",
    "    pickle.dump(hidden_states_test_X_a, f, pickle.HIGHEST_PROTOCOL)\n",
    "with open('pickles/hidden_states_test_X_b.pickle', 'wb') as f:\n",
    "    pickle.dump(hidden_states_test_X_b, f, pickle.HIGHEST_PROTOCOL)\n",
    "with open('pickles/hidden_states_test_X_c.pickle', 'wb') as f:\n",
    "    pickle.dump(hidden_states_test_X_c, f, pickle.HIGHEST_PROTOCOL)\n",
    "with open('pickles/hidden_states_test_X_d.pickle', 'wb') as f:\n",
    "    pickle.dump(hidden_states_test_X_d, f, pickle.HIGHEST_PROTOCOL)\n",
    "with open('pickles/hidden_states_test_X_e.pickle', 'wb') as f:\n",
    "    pickle.dump(hidden_states_test_X_e, f, pickle.HIGHEST_PROTOCOL)\n",
    "with open('pickles/hidden_states_test_X_f.pickle', 'wb') as f:\n",
    "    pickle.dump(hidden_states_test_X_f, f, pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fad2f9ba-fe23-46d2-9a21-a8c0a998b0b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('pickles/hidden_states_test_X_a.pickle', 'rb') as f:\n",
    "    hidden_states_test_X_a = pickle.load(f)\n",
    "with open('pickles/hidden_states_test_X_b.pickle', 'rb') as f:\n",
    "    hidden_states_test_X_b = pickle.load(f)\n",
    "with open('pickles/hidden_states_test_X_c.pickle', 'rb') as f:\n",
    "    hidden_states_test_X_c = pickle.load(f)\n",
    "with open('pickles/hidden_states_test_X_d.pickle', 'rb') as f:\n",
    "    hidden_states_test_X_d = pickle.load(f)\n",
    "with open('pickles/hidden_states_test_X_e.pickle', 'rb') as f:\n",
    "    hidden_states_test_X_e = pickle.load(f)\n",
    "with open('pickles/hidden_states_test_X_f.pickle', 'rb') as f:\n",
    "    hidden_states_test_X_f = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b89a9a8c-ab23-4ee6-be19-ba93490eae73",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "clfs = {}\n",
    "for k in range(13):\n",
    "    clfs[k] = LogisticRegression(solver='saga', max_iter=1_000).fit(hidden_states_train[:,k,:], train_y_total)\n",
    "    print(clfs[k].score(hidden_states_train[:,k,:], train_y_total))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75beffd4-d4c1-462a-b80c-d10f8b8fa41a",
   "metadata": {},
   "outputs": [],
   "source": [
    "results = {'Type': [], 'Accuracy': [], 'type_hue': []}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "73579bd9-9654-42a3-89e4-a085105bceda",
   "metadata": {},
   "outputs": [],
   "source": [
    "for _ in range(10):\n",
    "    res = list((clfs[0].predict(hidden_states_train[:,0,:]) == np.array(train_y_total)).astype(float))\n",
    "    results['Type'].extend(['Initial \\n embeddings']*len(res))\n",
    "    results['Accuracy'].extend(res)\n",
    "    results['type_hue'].extend(['Initial \\n embeddings']*len(res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0c4f57a-4630-4a2e-93e2-9a9b46a57e3d",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in range(1, 13):\n",
    "    res = list((clfs[k].predict(hidden_states_train[:,k,:]) == np.array(train_y_total)).astype(float))\n",
    "    results['Type'].extend(['Train set']*len(res))\n",
    "    results['Accuracy'].extend(res)\n",
    "    results['type_hue'].extend(['Train set']*len(res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "64dd0cca-9b4a-4bd9-a8fe-30613f75912e",
   "metadata": {},
   "outputs": [],
   "source": [
    "for k in range(1, 13):\n",
    "    res = list((clfs[k].predict(hidden_states_test_X_a[:,k,:]) == np.array(test_y_a)).astype(float))\n",
    "    results['Type'].extend(['Test set']*len(res))\n",
    "    results['Accuracy'].extend(res)\n",
    "    results['type_hue'].extend(['Test']*len(res))\n",
    "    res = list((clfs[k].predict(hidden_states_test_X_b[:,k,:]) == np.array(test_y_b)).astype(float))\n",
    "    results['Type'].extend(['Test set']*len(res))\n",
    "    results['Accuracy'].extend(res)\n",
    "    results['type_hue'].extend(['Test']*len(res))\n",
    "    \n",
    "    res = list((clfs[k].predict(hidden_states_test_X_c[:,k,:]) == np.array(test_y_c)).astype(float))\n",
    "    results['Type'].extend(['Test OOD \\n tokens']*len(res))\n",
    "    results['Accuracy'].extend(res)\n",
    "    results['type_hue'].extend(['Test']*len(res))\n",
    "    res = list((clfs[k].predict(hidden_states_test_X_d[:,k,:]) == np.array(test_y_d)).astype(float))\n",
    "    results['Type'].extend(['Test OOD \\n tokens']*len(res))\n",
    "    results['Accuracy'].extend(res)\n",
    "    results['type_hue'].extend(['Test']*len(res))\n",
    "    \n",
    "    res = list((clfs[k].predict(hidden_states_test_X_e[:,k,:]) == np.array(test_y_e)).astype(float))\n",
    "    results['Type'].extend(['Test OOD \\n structure']*len(res))\n",
    "    results['Accuracy'].extend(res)\n",
    "    results['type_hue'].extend(['Test']*len(res))\n",
    "    \n",
    "    res = list((clfs[k].predict(hidden_states_test_X_f[:,k,:]) == np.array(test_y_f)).astype(float))\n",
    "    results['Type'].extend(['Test OOD \\n structure \\n + tokens']*len(res))\n",
    "    results['Accuracy'].extend(res)\n",
    "    results['type_hue'].extend(['Test']*len(res))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d53b1598-7646-415f-be4a-c3a45d854ec8",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.DataFrame(results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0c6eb53f-d1d4-45a2-950b-374e2d1d5d71",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('pickles/result_transformer_exp.pickle', 'wb') as f:\n",
    "    pickle.dump(df, f, pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6dd34e7a-c1d3-414a-b26c-c300a0d65194",
   "metadata": {},
   "outputs": [],
   "source": [
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28e1da13-50d7-4163-80ad-12d3fe2951d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure()\n",
    "ax = sns.barplot(data=df, x='Type', y='Accuracy', hue='type_hue', legend=False)\n",
    "for container in ax.containers:\n",
    "    ax.bar_label(container, fmt='%.2f', fontsize=14)\n",
    "plt.xlabel('')\n",
    "plt.savefig('figures/exp_transformers.pdf', bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9353c4eb-8a99-4f1b-86ae-9b7129c07e35",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1b8dc36-c473-4fc7-a84e-ade1aadc9cd1",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
