{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "figure 1.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "wROGzhrj2qHN",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import tensorflow as tf\n",
        "tf.enable_v2_behavior()\n",
        "tf.compat.v1.enable_resource_variables()\n",
        "import numpy as np\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import layers as L\n",
        "from tensorflow.keras import Model\n",
        "import tensorflow_datasets as tfds\n",
        "from tensorflow_probability.python.distributions import categorical"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WblcgzGuOVI3",
        "colab_type": "text"
      },
      "source": [
        "Class `ContextEmbedder` outputs a predicted representation for a target word by forming a linear combination of context word feature vectors. In particular, for a context $h=w_1, \\dots, w_n$ with context embeddings $r_{w_i}$, the predicted representation is:\n",
        "$$ \\hat{q}(h) = \\sum_{i=1}^n c_i \\odot r_{w_i} $$\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Tqlyst_S293r",
        "colab_type": "code",
        "outputId": "d19ba7c2-a067-4fb3-b153-116ea2e31bb8",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "class ContextEmbedder(Model):\n",
        "  \"\"\"Reweights and sums a set of vectors in R^d representing context word\n",
        "  vectors.\n",
        "  @param:\n",
        "    `vocab_size` is how many words in the vocabulary\n",
        "    `window_size` is how many context words surround the word of interest. Note\n",
        "      that we assume the window is centered here, and the context includes the word.\n",
        "      Hence, a window size of 5 means [w_0, w_1, w_2, w_3, w_4, w_5] where w_3 is\n",
        "      the target word.\n",
        "    `embed_dim` is the size of the embeddings dimension\n",
        "  \"\"\"\n",
        "  def __init__(self, vocab_size, window_size, embed_dim, exclude_target=True,\n",
        "               #embeddings_intializer='uniform',\n",
        "               embeddings_intializer=keras.initializers.RandomUniform(minval=-0.001, maxval=0.001, seed=None),\n",
        "               ):\n",
        "\n",
        "    super(ContextEmbedder, self).__init__()\n",
        "    self._embed_dim = embed_dim\n",
        "    self._vocab_size = vocab_size\n",
        "    self._window_size = window_size\n",
        "    self._exclude_target = exclude_target\n",
        "    self._target_idx = self._window_size // 2\n",
        "    self._get_context_weights = L.Embedding(self._window_size, self._embed_dim,\n",
        "                                            embeddings_initializer=embeddings_intializer)\n",
        "    self._get_context_embeddings = L.Embedding(self._vocab_size, self._embed_dim,\n",
        "                                                embeddings_initializer=embeddings_intializer)\n",
        "    print(f\"INFO: exclude target is {self._exclude_target}\")\n",
        "    \n",
        "  def _exclude_target_idx(self, lst):\n",
        "    \"\"\"return list without the target index\n",
        "    \"\"\"\n",
        "    print(f\"lst: {lst.shape}\")\n",
        "    out = []\n",
        "    for row in lst:\n",
        "      out.append(row[:self._target_idx] + row[self._target_idx + 1:])\n",
        "    print(out)\n",
        "    print(tf.map_fn(exclude_helper, lst))\n",
        "\n",
        "  def embed(self, x):\n",
        "    \"\"\"Given (? x `self._window_size`) sequence of integers `x`, returns \n",
        "    a \n",
        "    (? x `self._window_size` x `self._embed_dim`) tensor of embeddings\n",
        "    \"\"\"\n",
        "    if self._exclude_target:\n",
        "      y = tf.concat([x[:,:self._target_idx], x[:,self._target_idx + 1:]], axis=-1)\n",
        "      return self._get_context_embeddings(y)\n",
        "    else:\n",
        "      return self._get_context_embeddings(x)\n",
        "    # return self._get_context_embeddings(self._exclude_target_idx(x))\n",
        "\n",
        "  def _make_index_tensor(self, x):\n",
        "    \"\"\"Returns a tensor of shape `x.shape` where the last axis is always \n",
        "    the integers list(range(len(x.shape[1])))\n",
        "    \"\"\"\n",
        "     # create set of index tensors to retrieve position weights\n",
        "    batch_shape, window_len, _ = x.shape\n",
        "    index_numbers = list(range(window_len))\n",
        "    to_tile = tf.reshape(tf.constant(index_numbers), shape=(1, -1))\n",
        "    context_positions = tf.tile(to_tile, multiples=(x.shape[0], 1))\n",
        "    # print(context_positions)\n",
        "    return context_positions\n",
        "\n",
        "  def _weight_and_sum(self, x):\n",
        "    context_positions = self._make_index_tensor(x)\n",
        "    position_weights = self._get_context_weights(context_positions)\n",
        "    # reweight embeddings vector dims according to `position_weights`, then sum embedding vectors\n",
        "    weighted_sum = tf.reduce_sum(tf.multiply(position_weights, x), axis=1)\n",
        "    return weighted_sum\n",
        "\n",
        "  def call(self, h):\n",
        "    \"\"\"Given (? x `self._context_window_size`) sequence of integers `x`, returns\n",
        "    (? x `self._embed_dim`) tensor of predicted embeddings\n",
        "    \"\"\"\n",
        "    context_word_embeddings = self.embed(h)  # (? x `self._window_size` x `self._embed_dim`)\n",
        "    predicted_embeddings = self._weight_and_sum(context_word_embeddings)  \n",
        "    return predicted_embeddings"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO: exclude target is True\n",
            "INFO: exclude target is False\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "K8hiIDuCOORq",
        "colab_type": "text"
      },
      "source": [
        "Class `LogBilinearScore` computes a score as the similarity between a predicted vector $\\hat{q}$ or `qhat` and the embedding $q$ or `q`. We include a term $b_w$ or `b` to capture the context-independent frequency of a word $w$. Given some $h$ and $\\hat{q}(h)$, calling an instance of `LogBilinearScore` returns\n",
        "$$s_{\\theta}(w,h) = \\hat{q}(h)^\\top q_w + b_w$$"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0TeriLqE3Fg6",
        "colab_type": "code",
        "outputId": "78b8fb51-4f8f-44e0-f106-d598ce69c657",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "class LogBilinearScore(Model):\n",
        "\n",
        "  def __init__(self, vocab_size,  embed_dim, embeddings_intializer='uniform'):\n",
        "    super(LogBilinearScore, self).__init__()\n",
        "    self.embed_dim = embed_dim\n",
        "    self.vocab_size = vocab_size\n",
        "    self.q = L.Embedding(self.vocab_size, self.embed_dim,\n",
        "                         embeddings_initializer=embeddings_intializer)\n",
        "    self.b = L.Embedding(self.vocab_size, 1,\n",
        "                          embeddings_initializer=embeddings_intializer)\n",
        "\n",
        "  def call(self, w, qhat):\n",
        "    \"\"\" Returns a (? x `self.window_size`) score tensor\n",
        "    w shape is (? x `self.window_size`)\n",
        "    qhat is (? x `self._window_size` x `self._embed_dim`) \n",
        "    \"\"\"\n",
        "    q = self.q(w)              # (? x `self.window_size` x `self.embed_dim`)\n",
        "    b = tf.squeeze(self.b(w))  # (? x `self.window_size`)\n",
        "    return tf.reduce_sum(tf.multiply(q, qhat), axis=-1, keepdims=False) + b\n"
      ],
      "execution_count": 0,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "INFO: exclude target is True\n",
            "INFO: exclude target is False\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6pGo5dBG3dFl",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import random\n",
        "class WindowRebatcher(object):\n",
        "  \"\"\"Take a batch of sequences, some padded with `pad_token`. Return a new batch\n",
        "  of sequences `window_size` over the original sequences\n",
        "  \"\"\"\n",
        "\n",
        "  def __init__(self, window_size, pad_token=-1):\n",
        "    self._ws = window_size\n",
        "    self._pt = pad_token\n",
        "  \n",
        "  def rebatch(self, batch):\n",
        "    rebatches = []\n",
        "    # print(batch)\n",
        "    for item in batch:\n",
        "      # print(item)\n",
        "      end_idx = np.where(item.numpy() == -1)[0]\n",
        "      if len(end_idx) > 0:\n",
        "        # print(end_idx[0])\n",
        "        end_idx = end_idx[0]\n",
        "      else:\n",
        "        # print(\"no -1\")\n",
        "        end_idx = len(item)\n",
        "        # print(len(item))\n",
        "      n_windows = end_idx - (self._ws - 1)\n",
        "      # print(n_windows)\n",
        "      idxs = [list(range(i, i + self._ws)) for i in range(n_windows)]\n",
        "      random.shuffle(idxs)\n",
        "      idxs = tf.constant(idxs[0:20],\n",
        "                         dtype=tf.int64)\n",
        "      idxs = tf.constant(idxs,\n",
        "                    dtype=tf.int64)\n",
        "      # print(f\"idxs:{idxs}\")\n",
        "      if np.sum(idxs) == 0:\n",
        "        # print(\"saw empty batch...\")\n",
        "        continue  # empty batch?\n",
        "      # print(tf.map_fn(lambda x: tf.gather(item, x), idxs))\n",
        "      rebatches.append(tf.map_fn(lambda x: tf.gather(item, x), idxs))\n",
        "      # print(f\"rebatches_shape:{rebatches[-1].shape}\")\n",
        "\n",
        "    return tf.concat(rebatches, axis=0)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-M-j4NldODpp",
        "colab_type": "text"
      },
      "source": [
        "Given a context $h$, an NPLM defines the distribution for the word to be predicted using the scoring function $s_{\\theta}(w,h)$ that  quantifies  the  compatibility  between  the  context  and  the  candidate  target word. Here $\\theta$ are model parameters, which include the word embeddings. The scores are converted to probabilities by exponentiating and normalizing:\n",
        "$$P_{\\theta}^h = \\frac{\\exp (s_{\\theta}(w,h))}{\\sum_{w'}\\exp(s_{\\theta}(w',h))}$$\n",
        "Unfortunately both evaluating $P^h_{\\theta}(w)$ and computing the corresponding likelihood gradient requires normalizing  over  the  entire  vocabulary,  which  means  that  maximum  likelihood  training  of  such models takes time linear in the vocabulary size, and thus is prohibitively expensive for all but the smallest vocabularies.\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LMMHxpGc3ykU",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "class NPLMEstimator(Model):\n",
        "\n",
        "  def __init__(self, vocab_size, embed_dim, window_size, mc_noise_samples, noise_pmf=None):\n",
        "    super(NPLMEstimator, self).__init__()\n",
        "    assert(window_size % 2 == 1), f\"window_size must be odd but saw {window_size}!\"\n",
        "    self.mc_noise_samples = mc_noise_samples\n",
        "    self.batcher = WindowRebatcher(window_size)\n",
        "    self.context_embedder = ContextEmbedder(vocab_size, window_size, embed_dim)\n",
        "    self.word_idx_to_score = LogBilinearScore(vocab_size, embed_dim)\n",
        "    self.target_idx = window_size // 2\n",
        "    if noise_pmf is None:\n",
        "      noise_pmf = tf.ones(shape=(vocab_size,), dtype=tf.float64) / vocab_size\n",
        "    self.noise_dist = categorical.Categorical(logits=tf.constant(tf.math.log(noise_pmf), \n",
        "                                                                 dtype=tf.float64))\n",
        "    self.logk = tf.log(tf.constant(mc_noise_samples, dtype=tf.float32))\n",
        "    \n",
        "  def sample_and_log_prob_noise(self, n):\n",
        "    samples = []\n",
        "    log_probs = []\n",
        "    for _ in range(n):\n",
        "      samples.append(self.noise_dist.sample(self.mc_noise_samples))\n",
        "      log_probs.append(self.noise_dist.log_prob(samples[-1]))\n",
        "    return tf.stack(samples), tf.stack(log_probs)\n",
        "    \n",
        "  def del_score(self, score, log_p_noise, simple=True):\n",
        "    if simple:\n",
        "      return score\n",
        "    else: \n",
        "    return score - tf.cast(log_p_noise + tf.ones_like(log_p_noise) * tf.cast(self.logk, dtype=tf.float64), dtype=tf.float32)\n",
        "\n",
        "  def call(self, rebatched):\n",
        "    \"\"\" x is (1, S) sequence, where m changes batch to batch\n",
        "    Let W =  S - (window_size - 1)\n",
        "    \"\"\"\n",
        "    # --> embed the context words\n",
        "    qhat = self.context_embedder(rebatched)                    # (W, embed_size)\n",
        "\n",
        "    # --> compute scores for target word in each windowed batch\n",
        "    scores = self.word_idx_to_score(rebatched[:, self.target_idx], qhat)     # (W, 1)\n",
        "    log_p_true_under_noise = self.noise_dist.log_prob(rebatched[:, self.target_idx])\n",
        "\n",
        "    # --> Monte Carlo estimate of loss under noise distribution\n",
        "    noise_samples, log_p_noise = self.sample_and_log_prob_noise(rebatched.shape[0])\n",
        "    noise_words = tf.reshape(noise_samples,           # (? * n_noise_samples, 1)\n",
        "                            shape=(tf.reduce_prod(noise_samples.shape),))\n",
        "    noise_log_probs = tf.reshape(log_p_noise,         # (? * n_noise_samples, 1)\n",
        "                            shape=(tf.reduce_prod(log_p_noise.shape),))\n",
        "    \n",
        "    qhat_tiled = tf.reshape(tf.keras.backend.repeat(qhat, n=self.mc_noise_samples),  # (? * n_noise_samples, window_size, embed_size)  \n",
        "                            shape=(-1, qhat.shape[-1]))\n",
        "    \n",
        "    noise_scores_repeated = self.word_idx_to_score(noise_words, qhat_tiled)  # (? * n_noise_samples, 1)\n",
        "    noise_scores_samples = tf.reshape(noise_scores_repeated,           \n",
        "                                      shape=(rebatched.shape[0], -1))\n",
        "    \n",
        "    noise_scores = tf.reduce_mean(noise_scores_samples, axis=-1)        # (?, 1)\n",
        "\n",
        "    true_del_score = self.del_score(scores, log_p_true_under_noise)\n",
        "    noise_del_score_repeated = self.del_score(noise_scores_repeated, \n",
        "                                                      noise_log_probs)\n",
        "\n",
        "    expected_true = tf.nn.sigmoid_cross_entropy_with_logits(labels=tf.ones_like(true_del_score), \n",
        "                                                            logits=true_del_score)\n",
        "\n",
        "    expected_false = tf.reduce_mean(tf.reshape(  \n",
        "        tf.nn.sigmoid_cross_entropy_with_logits(\n",
        "            labels=tf.zeros_like(noise_del_score_repeated), \n",
        "            logits=noise_del_score_repeated), shape=(rebatched.shape[0], -1)), \n",
        "            axis=-1)\n",
        "\n",
        "    return tf.reduce_mean(expected_true + expected_false)\n",
        "\n",
        "                          "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Fy17wyL83_eL",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "(train_data, test_data), info = tfds.load( \n",
        "  'lm1b/subwords8k',\n",
        "  split = (tfds.Split.TRAIN, tfds.Split.TEST), \n",
        "  with_info=True, as_supervised=True)\n",
        "\n",
        "padded_shapes = ([None],                             \n",
        "                # ()) \n",
        "                (None,))\n",
        "\n",
        "pad_int = -1     # integer flag that never appears in dataset\n",
        "batch_size = 10  # retrieve one sequence at a time\n",
        "\n",
        "padding_values = (tf.constant(pad_int, dtype=tf.int64),\n",
        "                  tf.constant(0, dtype=tf.int64)) \n",
        "\n",
        "train_batches = train_data.shuffle(1024).padded_batch(\n",
        "  batch_size,\n",
        "  padded_shapes,\n",
        "  padding_values=padding_values)\n",
        "\n",
        "test_batches = test_data.shuffle(1024).padded_batch(\n",
        "    batch_size,\n",
        "    padded_shapes,\n",
        "    padding_values=padding_values)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jD1L4nB64Bcr",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "window_size = 5           # @param\n",
        "mc_noise_samples = 25     # @param\n",
        "embed_dim = 2             # @param\n",
        "vocab_size = info.features['text'].encoder.vocab_size\n",
        "train_iter = 50000         # @param\n",
        "log_interval = 10000       # @param\n",
        "epochs = 1                   "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Mtdrd-VG4Jdo",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# !!! Only run this cell to retrain from scratch\n",
        "\n",
        "# %%time\n",
        "# save_interval = 5\n",
        "# save_path = \"./\n",
        "\n",
        "\n",
        "# @tf.function\n",
        "# def update(rebatched):\n",
        "#   with tf.GradientTape() as tape:\n",
        "#     loss = nce_scorer(rebatched)\n",
        "  \n",
        "#   gradients = tape.gradient(loss, nce_scorer.trainable_variables)\n",
        "#   optimizer.apply_gradients(zip(gradients, nce_scorer.trainable_variables))\n",
        "#   return loss\n",
        "\n",
        "# losses = []\n",
        "# for e in range(epochs):\n",
        "\n",
        "#   for idx, batch in enumerate(iter(train_batches)):\n",
        "#     seq, _ = batch  # (1, ?)\n",
        "#     rebatched = nce_scorer.batcher.rebatch(seq)                       # (W, window_size)\n",
        "#     loss = update(rebatched)\n",
        "#     print(f\"iter: {idx}\")\n",
        "#     if idx % save_interval == 0:\n",
        "#       print(\"saving>...\")\n",
        "#       nce_scorer.save_weights(save_path)\n",
        "#     # nce_scorer.save_weights\n",
        "#     losses.append(loss)\n",
        "  \n",
        "#     # with tf.GradientTape() as tape:\n",
        "#     #   loss = nce_scorer(seq)\n",
        "    \n",
        "#     # gradients = tape.gradient(loss, nce_scorer.trainable_variables)\n",
        "#     # optimizer.apply_gradients(zip(gradients, nce_scorer.trainable_variables))\n",
        "\n",
        "#     if idx % log_interval == 0:\n",
        "#       losses.append(loss)\n",
        "#       print(f\"Epoch: {e} | iter {idx} | Loss: {loss}\")\n",
        "#     if (idx + 1) % train_iter == 0:\n",
        "#       break\n",
        "    \n",
        "#   losses.append(loss)\n",
        "#   print(f\"Epoch End: {e} Loss: {loss}\")\n",
        "\n",
        "# HERE: save the model output weights to a directory of your choosing"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "d4apgCNm4fri",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from sklearn.cross_decomposition import CCA\n",
        "from numpy import linalg as LA\n",
        "\n",
        "model_ids = [0, 2, 3, 4]\n",
        "\n",
        "# Model checkpoints have been provided with this script\n",
        "\n",
        "# Choose the appropriate root directory for the data, e.g., Google Drive\n",
        "USER_ROOT_DIR = \"./\" \n",
        "\n",
        "model_paths = [[os.path.join(USER_ROOT_DIR, \"2_0.003000_%d/model_checkpoints/rep1\"%x),\n",
        "                os.path.join(USER_ROOT_DIR, \"2_0.003000_%d/model_checkpoints/rep2\"%x)] for x in model_ids]\n",
        "\n",
        "print(model_paths)\n",
        "\n",
        "estimators = [[NPLMEstimator(vocab_size, embed_dim, window_size, mc_noise_samples, noise_pmf=noise_pmf),\n",
        "              NPLMEstimator(vocab_size, embed_dim, window_size, mc_noise_samples, noise_pmf=noise_pmf)] for _ in range(len(model_paths))]\n",
        "\n",
        "for idx, est in enumerate(estimators):\n",
        "  est[0].load_weights(model_paths[idx][0])\n",
        "  est[1].load_weights(model_paths[idx][1])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JixvzI0d4oLg",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Tableau Color Blind 10\n",
        "tableau20blind = [(0, 107, 164), (255, 128, 14), (171, 171, 171), (89, 89, 89),\n",
        "             (95, 158, 209), (200, 82, 0), (137, 137, 137), (163, 200, 236),\n",
        "             (255, 188, 121), (207, 207, 207)]\n",
        "  \n",
        "for i in range(len(tableau20blind)):  \n",
        "    r, g, b = tableau20blind[i]  \n",
        "    tableau20blind[i] = (r / 255., g / 255., b / 255.)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lsnI1ELfMs76",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import jax.numpy as jnp\n",
        "import jax\n",
        "import jax.experimental.optimizers\n",
        "\n",
        "# learn the best possible linear transformation by regressing\n",
        "\n",
        "def align(act1, act2):\n",
        "  def align_loss(p):\n",
        "    w, b = p\n",
        "    pred = jnp.matmul(act1, w)# + b\n",
        "    return jnp.mean(jnp.square(act2 - pred))\n",
        "    \n",
        "  w = jnp.zeros((2,2))\n",
        "  b = jnp.zeros((2,))\n",
        "  p = (w,b)\n",
        "  #init_fn, update_fn, get_p = jax.experimental.optimizers.momentum(0.005, 0.09)\n",
        "  #init_fn, update_fn, get_p = jax.experimental.optimizers.momentum(0.001, 0.01)\n",
        "  #init_fn, update_fn, get_p = jax.experimental.optimizers.momentum(0.001, 0.9)\n",
        "  init_fn, update_fn, get_p = jax.experimental.optimizers.momentum(0.01, 0.9)\n",
        "  opt_state = init_fn(p)\n",
        "  v_grad_fn = jax.jit(jax.value_and_grad(align_loss))\n",
        "  \n",
        "  def update(i, opt_state):\n",
        "    p = get_p(opt_state)\n",
        "    v, grad = v_grad_fn(p)\n",
        "    if i % 50 == 0:\n",
        "      print(v)\n",
        "    return update_fn(i, grad, opt_state)\n",
        "  \n",
        "  for i in range(1000):\n",
        "    opt_state = update(i, opt_state)\n",
        "  w, b = get_p(opt_state)\n",
        "  return w, b"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "LlAqBvl1MvVW",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import matplotlib.pyplot as plt\n",
        "import seaborn as sns\n",
        "from matplotlib import collections  as mc\n",
        "from matplotlib import cm\n",
        "\n",
        "sns.set_style(\"whitegrid\")\n",
        "print(f\"est: {len(estimators)}\")\n",
        "\n",
        "fig, ax = plt.subplots(2, 3, figsize=(12, 7))\n",
        "X1 = None\n",
        "X2 = None\n",
        "\n",
        "texts = [tt.ints2str([i]) for i in range(N)]\n",
        "idxs = np.asarray([i for i in range(N)])\n",
        "\n",
        "# Set data offsets for a nice visualization if text is desired\n",
        "offset = 70    # picked to avoid subwords that aren't full words\n",
        "ss = 90        # picked to make plot easier to read. see also `offset` comment\n",
        "texts = texts[offset:ss]\n",
        "\n",
        "# set colours:\n",
        "plasma1 = cm.get_cmap('tab20c', len(texts)).colors\n",
        "plasma2 = cm.get_cmap('tab20b', len(texts)).colors\n",
        "plasma3 = cm.get_cmap('tab20', len(texts)).colors\n",
        "\n",
        "num_to_text = 0\n",
        "lines_fun = lambda V: [[(0,0),(V[i,0], V[i,1])] for i in range(len(V))]\n",
        "\n",
        "def set_lims(t_ax):\n",
        "  t_ax.set_xlim([-5,5])\n",
        "  t_ax.set_ylim([-3,3])\n",
        "\n",
        "def draw_arrow(vec, num, color=\"lightgrey\", alpha=1.0):\n",
        "  ax[idx][num].quiver(0, 0, vec[:,0], vec[:,1], alpha=alpha)#,  scale = 4, zorder=10, \n",
        "                    # angles='xy', color=color, linestyle='--')\n",
        "\n",
        "\n",
        "def draw_arrow_drop_shadow(vec, num):\n",
        "  ax[idx][num].quiver(0, 0, vec[:,0], vec[:,1])#, width=.020, scale=3.6,  \n",
        "                      #zorder=9, angles='c', color=\"white\")\n",
        "\n",
        "def number_text(curr_idx, num, A, idxs_to_label, zorder=100):\n",
        "  for i in idxs_to_label:\n",
        "    # print(f\"text, i: {texts[i]}, %d\"%i)\n",
        "    if texts[i] == \"han\":\n",
        "      texts[i] = \"hand\"  # labels are subwords and arbitrary\n",
        "    # print(f\"vec: {len(A[:,0])}\")\n",
        "    scale = 0\n",
        "    y_jitter = scale * np.random.normal()\n",
        "    # if texts[i] == \"known \" and (curr_idx == 0 or curr_idx == 1) and num == 2:\n",
        "    #   print(f\"saw {texts[i]}\")\n",
        "    #   y_jitter = -.2\n",
        "    # if texts[i] == \"increase \" and curr_idx == 1:\n",
        "    #   if num == 0:\n",
        "    #     y_jitter = -1.0\n",
        "    #   elif num == 1:\n",
        "    #     y_jitter = 2.3\n",
        "    #   print(f\"saw {texts[i]}\")\n",
        "    ax[curr_idx][num].text(A[i, 0], A[i, 1] + y_jitter, texts[i], zorder=zorder)\n",
        "\n",
        "def stdize(vec):  # seems to have no effect\n",
        "  print(vec.shape)\n",
        "  return (vec - np.mean(vec, axis=0, keepdims=True)) / np.std(vec, axis=0, keepdims=True)\n",
        "\n",
        "def draw_lines(lines, this_ax, c, cmap=None, alpha=1.0, zorder=10, autoscale=False):\n",
        "  lc = mc.LineCollection(lines, colors=cmap, linewidths=4, alpha=alpha, zorder=zorder,)\n",
        "  this_ax.add_collection(lc)\n",
        "  if autoscale:\n",
        "    this_ax.autoscale()\n",
        "\n",
        "dist = lambda p1, p2: np.sqrt( ((p1[0]-p2[0])**2)+((p1[1]-p2[1])**2) )\n",
        "\n",
        "for idx, est in enumerate([estimators[i] for i in [1, 2]]):\n",
        "#for idx, est in enumerate(estimators):\n",
        "  # use different colours\n",
        "  if idx == 0:\n",
        "    plasma = plasma1\n",
        "  elif idx == 1:\n",
        "    plasma = plasma2\n",
        "  else:\n",
        "    plasma = plasma3\n",
        "\n",
        "  print(f\"INDEX: {idx}\")\n",
        "  # get embeddings\n",
        "  embeddings1 = est[0].word_idx_to_score.q(idxs.reshape([-1, 1]))\n",
        "  embeddings2 = est[1].word_idx_to_score.q(idxs.reshape([-1, 1]))\n",
        "\n",
        "  scale = 1  # don't change the scale\n",
        "\n",
        "  X1_full = stdize(scale*embeddings1.numpy().squeeze())\n",
        "  X2_full = stdize(scale*embeddings2.numpy().squeeze())\n",
        "\n",
        "\n",
        "  X1 = X1_full[offset:ss]\n",
        "  X2 = X2_full[offset:ss]\n",
        "\n",
        "  # print(f\"len texts: {len(texts)}\")\n",
        "\n",
        "  c2 = \"red\"   # tableau20blind[8]\n",
        "  c1 = \"blue\"  # tableau20blind[4]\n",
        "  alph = 1.0\n",
        "  lt=\" \"\n",
        "\n",
        "  # compute lengths of all the arrows\n",
        "  dists_X1 = [dist(x[0],x[1]) for x in lines_fun(X1)]\n",
        "  dists_X2 = [dist(x[0],x[1]) for x in lines_fun(X2)]\n",
        "\n",
        "  # find the longest arrows on X1 and X2 (not necessarily the same)\n",
        "  idxs_to_label_X1 = np.argsort(dists_X1)[-num_to_text:]\n",
        "  idxs_to_label_X2 = np.argsort(dists_X2)[-num_to_text:]\n",
        "\n",
        "  # place vector lines on plot\n",
        "  ax[idx][0].plot(X2[:, 0], X2[:, 1], lt, alpha = alph, c=c1)\n",
        "  draw_lines(lines_fun(X2), ax[idx][0], c1, cmap=plasma, alpha=1.0)\n",
        "\n",
        "  ax[idx][1].plot(X1[:, 0], X1[:, 1], lt, alpha = alph, c=c2)\n",
        "  draw_lines(lines_fun(X1), ax[idx][1], c2, cmap=plasma, alpha=1.0)\n",
        "\n",
        "\n",
        "\n",
        "  w_trans, b_trans = align(np.copy(X1_full), np.copy(X2_full))\n",
        "  X1_c = np.copy(X1_full)\n",
        "  X2_c = np.copy(np.matmul(X2_full, w_trans))# + b_trans)\n",
        "  #X2_c = X2_full\n",
        "\n",
        "  X1_c = X1_c[offset:ss]\n",
        "  X2_c = X2_c[offset:ss]\n",
        "  \n",
        "  draw_lines(lines_fun(X1_c), ax[idx][2], c2, cmap=plasma, alpha=1.0)\n",
        "  draw_lines(lines_fun(X2_c), ax[idx][2], c1, cmap=plasma, alpha=1.0)\n",
        "\n",
        "    # place text on plot\n",
        "  if num_to_text > 0:\n",
        "    number_text(idx, 0, X2, idxs_to_label_X2)\n",
        "    number_text(idx, 1, X1, idxs_to_label_X2)\n",
        "    number_text(idx, 2, X2_c, idxs_to_label_X2)\n",
        "\n",
        "  sns.despine(left=True, right=True, bottom=True, top=True)\n",
        "\n",
        "  for t_ax in ax[idx]:\n",
        "    t_ax.set_xlim(-1, 1)\n",
        "    t_ax.set_ylim(-1, 1)\n",
        "\n",
        "  for t_ax in ax[idx]:\n",
        "    t_ax.set_xticks([])\n",
        "    t_ax.set_yticks([])\n",
        "fig.tight_layout()\n",
        "fig.savefig(\"fig1.pdf\", bbox_inches=\"tight\")\n",
        "%download_file test.pdf"
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}