{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bd9877ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "2cbc8e70",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ef4cfa0a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    # Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "cf184d99",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7e0cff4b",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### K = number of countries = number of capitals\n",
    "#### S = number of noise words\n",
    "#### L = sentence length\n",
    "#### embed_dim = dimension of embeddings\n",
    "#### n_sentences = number of training sentences\n",
    "\n",
    "def train_model(K, S, L, embed_dim, n_sentences):\n",
    "    \n",
    "    countries = ['country_' + str(i) for i in range(K)]\n",
    "    capitals = ['capital_' + str(i) for i in range(K)]\n",
    "    randoms = ['random_' + str(i) for i in range(S)]\n",
    "\n",
    "    vocabs = countries + capitals + randoms\n",
    "    vocab_map = {}\n",
    "\n",
    "    for i in range(len(vocabs)):\n",
    "        vocab_map[vocabs[i]] = i\n",
    "        \n",
    "    sentences = []\n",
    "    sentences_number = []\n",
    "\n",
    "    for i in range(n_sentences):\n",
    "\n",
    "        sentence = []\n",
    "        \n",
    "        pair = (np.random.choice(np.arange(K), 1, replace = False))[0]\n",
    "        sentence.append(countries[pair])\n",
    "        sentence += list(np.random.choice(randoms, L - 2, replace = False))\n",
    "        sentence.append(capitals[pair])        \n",
    "        \n",
    "        sentence_number = [vocab_map[i] for i in sentence]\n",
    "        sentences.append(sentence)\n",
    "        sentences_number.append(sentence_number)\n",
    "        \n",
    "    x_train = np.array(sentences_number)\n",
    "    n_cat = len(vocab_map)\n",
    "    x_masked_train = x_train\n",
    "    y_masked_labels_train = x_train[:,1:]\n",
    "    \n",
    "    callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "    inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "    word_embeddings = layers.Embedding(n_cat, embed_dim, name=\"word_embedding\")(inputs)\n",
    "    sinusoidal_embeddings = get_sinusoidal_embeddings(len(x_masked_train[0]), embed_dim)\n",
    "    encoder_output = word_embeddings + sinusoidal_embeddings\n",
    "    \n",
    "    num_head = 2\n",
    "    for i in range(1):\n",
    "        encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "    encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "    mlm_output = layers.Dense(n_cat, name=\"mlm_cls\", activation=\"softmax\", use_bias=False)(encoder_output)\n",
    "    mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "    adam = Adam()\n",
    "    mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "\n",
    "    history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=500, batch_size=128, verbose=0)\n",
    "    \n",
    "    return sentences, vocab_map, mlm_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "4ab5986f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_acc_prob(K, S, L, embed_dim, n_sentences, n_samples):\n",
    "    \n",
    "    sentences, vocab_map, current_model = train_model(K, S, L, embed_dim, n_sentences)\n",
    "\n",
    "    acc_countries = []\n",
    "    prob_countries = []\n",
    "\n",
    "    for _ in range(n_samples):\n",
    "        sentence = []\n",
    "        random_countries = np.random.choice(np.arange(K), int(L/2), replace = False)\n",
    "        for random_country in random_countries:\n",
    "            sentence.append('country_' + str(random_country))\n",
    "            sentence.append('capital_' + str(random_country))\n",
    "        sentence_number = [vocab_map[i] for i in sentence]\n",
    "        temp = keras.backend.function(inputs = current_model.layers[0].input, outputs = current_model.layers[-1].output) \\\n",
    "            (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "        temp = temp[:,-1,:]\n",
    "        actual = vocab_map['capital_' + str(random_countries[-1])]\n",
    "        acc_countries.append(1 if np.argsort(-1 * temp)[0][0] == actual else 0)\n",
    "        prob_countries.append(temp[0][vocab_map['capital_' + str(random_countries[-1])]])\n",
    "        \n",
    "\n",
    "    return sentences, current_model, vocab_map, (np.mean(acc_countries), np.mean(prob_countries))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "061c08f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "K = 20 # number of countries\n",
    "L = 6 # sentence length\n",
    "S = 20 # number of words used by both topics\n",
    "embed_dim = 10 # CBOW embedding dimension\n",
    "n_sentences = 50000 # number of sentences in the training set\n",
    "n_samples = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "07132b13",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0.007, 0.007982072)\n",
      "(0.0, 0.00028077618)\n",
      "(0.005, 0.004856886)\n",
      "(0.0, 3.6539284e-05)\n",
      "(0.001, 0.0014217901)\n",
      "(0.0, 0.00019993928)\n",
      "(0.0, 8.864164e-08)\n",
      "(0.0, 2.9649276e-05)\n",
      "(0.0, 0.0003011636)\n",
      "(0.006, 0.0062957527)\n",
      "(0.0019000000000000002, 0.0021404657391499884)\n"
     ]
    }
   ],
   "source": [
    "accs = 0\n",
    "probs = 0\n",
    "\n",
    "for _ in range(10):\n",
    "    sentences, mlm_model, vocab_map, acc_c \\\n",
    "        = get_acc_prob(K, S, L, embed_dim, n_sentences, n_samples)\n",
    "    \n",
    "    print(acc_c)\n",
    "    \n",
    "    accs += acc_c[0]/10\n",
    "    probs += acc_c[1]/10\n",
    "    \n",
    "print((accs, probs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bf48f7cb",
   "metadata": {},
   "outputs": [],
   "source": [
    "K = 20 # number of countries\n",
    "L = 6 # sentence length\n",
    "S = 20 # number of words used by both topics\n",
    "embed_dim = 100 # CBOW embedding dimension\n",
    "n_sentences = 50000 # number of sentences in the training set\n",
    "n_samples = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b212a41b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0.004, 0.005297841)\n",
      "(0.007, 0.0070338305)\n",
      "(0.0, 4.129099e-05)\n",
      "(0.003, 0.0030311327)\n",
      "(0.0, 0.00017255956)\n",
      "(0.001, 0.0018705723)\n",
      "(0.004, 0.0043198317)\n",
      "(0.0, 8.0891405e-05)\n",
      "(0.0, 5.039761e-05)\n",
      "(0.01, 0.009748968)\n",
      "(0.0029000000000000002, 0.0031647316220187348)\n"
     ]
    }
   ],
   "source": [
    "accs = 0\n",
    "probs = 0\n",
    "\n",
    "for _ in range(10):\n",
    "    sentences, mlm_model, vocab_map, acc_c \\\n",
    "        = get_acc_prob(K, S, L, embed_dim, n_sentences, n_samples)\n",
    "    \n",
    "    print(acc_c)\n",
    "    \n",
    "    accs += acc_c[0]/10\n",
    "    probs += acc_c[1]/10\n",
    "    \n",
    "print((accs, probs))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
