{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "674d5c38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5a29d33b",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d96c3193",
   "metadata": {},
   "outputs": [],
   "source": [
    "def bert_module(query, key, value, embed_dim, num_head, i):\n",
    "    \n",
    "    # Multi headed self-attention\n",
    "    attention_output = layers.MultiHeadAttention(\n",
    "        num_heads=num_head,\n",
    "        key_dim=embed_dim // num_head,\n",
    "        name=\"encoder_{}/multiheadattention\".format(i)\n",
    "    )(query, key, value, use_causal_mask=True)\n",
    "    \n",
    "    # Add & Normalize\n",
    "    attention_output = layers.Add()([query, attention_output])  # Skip Connection\n",
    "    attention_output = layers.LayerNormalization(epsilon=1e-6)(attention_output)\n",
    "    \n",
    "    # Feedforward network\n",
    "    ff_net = keras.models.Sequential([\n",
    "        layers.Dense(2 * embed_dim, activation='relu', name=\"encoder_{}/ffn_dense_1\".format(i)),\n",
    "        layers.Dense(embed_dim, name=\"encoder_{}/ffn_dense_2\".format(i)),\n",
    "    ])\n",
    "\n",
    "    # Apply Feedforward network\n",
    "    ffn_output = ff_net(attention_output)\n",
    "\n",
    "    # Add & Normalize\n",
    "    ffn_output = layers.Add()([attention_output, ffn_output])  # Skip Connection\n",
    "    ffn_output = layers.LayerNormalization(epsilon=1e-6)(ffn_output)\n",
    "    \n",
    "    return ffn_output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "fd7afdea",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_sinusoidal_embeddings(sequence_length, embedding_dim):\n",
    "    position_enc = np.array([\n",
    "        [pos / np.power(10000, 2. * i / embedding_dim) for i in range(embedding_dim)]\n",
    "        if pos != 0 else np.zeros(embedding_dim)\n",
    "        for pos in range(sequence_length)\n",
    "    ])\n",
    "    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2])  # dim 2i\n",
    "    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2])  # dim 2i+1\n",
    "    return tf.cast(position_enc, dtype=tf.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d45907b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def insert_element_randomly(my_list, element):\n",
    "    \n",
    "    if len(my_list) > 1:\n",
    "\n",
    "        index = random.randint(0, 2)\n",
    "        \n",
    "    else:\n",
    "        \n",
    "        index = 0\n",
    "\n",
    "    new_list = my_list[:(4 * index)] + element + my_list[(4 * index):]\n",
    "    \n",
    "    return new_list"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "95bab1fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 20 # vocab_size\n",
    "M = 20 # number of random words\n",
    "\n",
    "vocabs = ['word_' + str(i) for i in range(N)] + ['random_' + str(i) for i in range(M)]\n",
    "\n",
    "vocabs_word = ['word_' + str(i) for i in range(N)]\n",
    "\n",
    "vocab_map = {}\n",
    "for i in range(len(vocabs)):\n",
    "    vocab_map[vocabs[i]] = i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5098e2f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_accuracy_prob(embed_dim):\n",
    "    \n",
    "    pairs = []\n",
    "\n",
    "    for i in vocabs_word:\n",
    "        for j in vocabs_word:\n",
    "            for k in vocabs_word:\n",
    "                if i != j and i != k and j != k:\n",
    "                    pairs.append((i,j,k))\n",
    "\n",
    "    indicator = np.random.choice([0, 1], size=len(pairs), p=[0.5, 0.5])\n",
    "\n",
    "    pairs_train = [pairs[i] for i in range(len(indicator)) if indicator[i] == 1]\n",
    "    pairs_test = [pairs[i] for i in range(len(indicator)) if indicator[i] == 0]\n",
    "    \n",
    "    sentences_train = []\n",
    "    sentences_number_train = []\n",
    "    sentences_test_a = []\n",
    "    sentences_number_test_a = []\n",
    "    sentences_test_b = []\n",
    "    sentences_number_test_b = []\n",
    "\n",
    "    x_masked_train = []\n",
    "    y_masked_labels_train = []\n",
    "    x_masked_test_a = []\n",
    "    y_masked_labels_test_a = []\n",
    "    x_masked_test_b = []\n",
    "    y_masked_labels_test_b = []\n",
    "\n",
    "    for _ in range(25000):\n",
    "\n",
    "        random_words = random.sample(['random_' + str(i) for i in range(M)], 4)\n",
    "    \n",
    "        [(a,b,c), (d,e,f)] = random.sample(pairs_train, 2)\n",
    "\n",
    "        temp = [a, b, c, a, d, e, f, d]\n",
    "        temp = insert_element_randomly(temp, random_words)\n",
    "\n",
    "        sentences_train.append(temp)\n",
    "        sentences_number_train.append([vocab_map[i] for i in temp])\n",
    "        x_masked_train.append([vocab_map[i] for i in temp])\n",
    "        y_masked_labels_train.append([vocab_map[i] for i in temp][1:])\n",
    "        \n",
    "        random_words = random.sample(['random_' + str(i) for i in range(M)], 4)\n",
    "\n",
    "        [(a,b,c), (d,e,f)] = random.sample(pairs_train, 2)\n",
    "\n",
    "        temp = [a, b, c, b, d, e, f, e]\n",
    "        temp = insert_element_randomly(temp, random_words)\n",
    "\n",
    "        sentences_train.append(temp)\n",
    "        sentences_number_train.append([vocab_map[i] for i in temp])\n",
    "        x_masked_train.append([vocab_map[i] for i in temp])\n",
    "        y_masked_labels_train.append([vocab_map[i] for i in temp][1:])\n",
    "\n",
    "\n",
    "\n",
    "    for _ in range(25000):\n",
    "\n",
    "        [(a,b,c), (d,e,f), (g,h,i)] = random.sample(pairs_test, 3)\n",
    "    \n",
    "        temp = [a, b, c, a, d, e, f, d, g, h, i, g]\n",
    "\n",
    "        sentences_test_a.append(temp)\n",
    "        sentences_number_test_a.append([vocab_map[i] for i in temp])\n",
    "        x_masked_test_a.append([vocab_map[i] for i in temp])\n",
    "        y_masked_labels_test_a.append([vocab_map[i] for i in temp][1:])\n",
    "\n",
    "        [(a,b,c), (d,e,f), (g,h,i)] = random.sample(pairs_test, 3)\n",
    "    \n",
    "        temp = [a, b, c, b, d, e, f, e, g, h, i, h]\n",
    "\n",
    "        sentences_test_b.append(temp)\n",
    "        sentences_number_test_b.append([vocab_map[i] for i in temp])\n",
    "        x_masked_test_b.append([vocab_map[i] for i in temp])\n",
    "        y_masked_labels_test_b.append([vocab_map[i] for i in temp][1:])\n",
    "\n",
    "    x_masked_train = np.array(x_masked_train)\n",
    "    y_masked_labels_train = np.array(y_masked_labels_train)\n",
    "    x_masked_test_a = np.array(x_masked_test_a)\n",
    "    y_masked_labels_test_a = np.array(y_masked_labels_test_a)\n",
    "    x_masked_test_b = np.array(x_masked_test_b)\n",
    "    y_masked_labels_test_b = np.array(y_masked_labels_test_b)\n",
    "\n",
    "    perm = np.random.permutation(len(x_masked_train))\n",
    "    x_masked_train = x_masked_train[perm]\n",
    "    y_masked_labels_train = y_masked_labels_train[perm]\n",
    "    \n",
    "    num_head = 2\n",
    "\n",
    "    callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "    inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "    word_embeddings = layers.Embedding(N + M, embed_dim, name=\"word_embedding\")(inputs)\n",
    "    sinusoidal_embeddings = get_sinusoidal_embeddings(len(x_masked_train[0]), embed_dim)\n",
    "    encoder_output = word_embeddings + sinusoidal_embeddings\n",
    "    for i in range(1):\n",
    "        encoder_output = bert_module(encoder_output, encoder_output, encoder_output, embed_dim, num_head, i)\n",
    "\n",
    "    encoder_output = keras.layers.Lambda(lambda x: x[:,:-1,:], name='slice')(encoder_output)\n",
    "    mlm_output = layers.Dense(N + M, name=\"mlm_cls\", activation=\"softmax\")(encoder_output)\n",
    "    mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "    adam = Adam()\n",
    "    mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "    history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                            validation_split = 0.5, callbacks = [callback], \n",
    "                            epochs=2000, batch_size=5000, \n",
    "                            verbose=0)\n",
    "    \n",
    "    acc_a = []\n",
    "    prob_a = []\n",
    "    x_test_subset_a = x_masked_test_a[np.random.choice(x_masked_test_a.shape[0], size=1000, replace=False)]\n",
    "\n",
    "    for sentence_number in x_test_subset_a:\n",
    "        temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "            (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "        temp = temp[:,-1,:]\n",
    "        acc_a.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "        prob_a.append(temp[0][sentence_number[-1]])\n",
    "        \n",
    "    acc_b = []\n",
    "    prob_b = []\n",
    "    x_test_subset_b = x_masked_test_b[np.random.choice(x_masked_test_b.shape[0], size=1000, replace=False)]\n",
    "\n",
    "    for sentence_number in x_test_subset_b:\n",
    "        temp = keras.backend.function(inputs = mlm_model.layers[0].input, outputs = mlm_model.layers[-1].output) \\\n",
    "            (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "        temp = temp[:,-1,:]\n",
    "        acc_b.append(1 if temp.argmax() == sentence_number[-1] else 0)\n",
    "        prob_b.append(temp[0][sentence_number[-1]])\n",
    "        \n",
    "    return ((np.mean(acc_a), np.mean(prob_a)), (np.mean(acc_b), np.mean(prob_b)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6c73bf2d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0.0, 0.0104504265)\n",
      "(0.0, 0.010562966)\n",
      "(0.007, 0.024185756)\n",
      "(0.004, 0.023544364)\n",
      "(0.039, 0.050357062)\n",
      "(0.041, 0.049098082)\n",
      "(0.051, 0.044412825)\n",
      "(0.047, 0.04345424)\n",
      "(0.071, 0.05135692)\n",
      "(0.079, 0.053065434)\n",
      "(0.046, 0.049459204)\n",
      "(0.06, 0.051130176)\n",
      "(0.006, 0.013703612)\n",
      "(0.012, 0.01466186)\n",
      "(0.0, 0.0009088933)\n",
      "(0.0, 0.00092142023)\n",
      "(0.0, 0.0010364049)\n",
      "(0.0, 0.0010705065)\n",
      "(0.051, 0.049393263)\n",
      "(0.064, 0.049239486)\n",
      "(0.0271, 0.02952643660828471)\n",
      "(0.0307, 0.029674853483447804)\n"
     ]
    }
   ],
   "source": [
    "accs_a = 0\n",
    "probs_a = 0\n",
    "accs_b = 0\n",
    "probs_b = 0\n",
    "\n",
    "for _ in range(10):\n",
    "    \n",
    "    ((acc_a, prob_a), (acc_b, prob_b)) = get_accuracy_prob(10)\n",
    "    \n",
    "    print((acc_a, prob_a))\n",
    "    print((acc_b, prob_b))\n",
    "    \n",
    "    accs_a += acc_a/10\n",
    "    probs_a += prob_a/10\n",
    "    accs_b += acc_b/10\n",
    "    probs_b += prob_b/10\n",
    "    \n",
    "print((accs_a, probs_a))\n",
    "print((accs_b, probs_b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "245346e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0.0, 6.859806e-05)\n",
      "(0.0, 6.848894e-05)\n",
      "(0.092, 0.054427147)\n",
      "(0.118, 0.05483171)\n",
      "(0.073, 0.052245557)\n",
      "(0.073, 0.05311937)\n",
      "(0.093, 0.052629348)\n",
      "(0.096, 0.05261145)\n",
      "(0.0, 0.0011938497)\n",
      "(0.0, 0.0011900436)\n",
      "(0.056, 0.04822184)\n",
      "(0.068, 0.048953418)\n",
      "(0.134, 0.1018675)\n",
      "(0.47, 0.22347645)\n",
      "(0.081, 0.052851494)\n",
      "(0.089, 0.05308263)\n",
      "(0.051, 0.050534338)\n",
      "(0.049, 0.050597157)\n",
      "(0.0, 0.0001349365)\n",
      "(0.0, 0.00013389719)\n",
      "(0.058, 0.0414174606979941)\n",
      "(0.09630000000000001, 0.053806461959902664)\n"
     ]
    }
   ],
   "source": [
    "accs_a = 0\n",
    "probs_a = 0\n",
    "accs_b = 0\n",
    "probs_b = 0\n",
    "\n",
    "for _ in range(10):\n",
    "    \n",
    "    ((acc_a, prob_a), (acc_b, prob_b)) = get_accuracy_prob(100)\n",
    "    \n",
    "    print((acc_a, prob_a))\n",
    "    print((acc_b, prob_b))\n",
    "    \n",
    "    accs_a += acc_a/10\n",
    "    probs_a += prob_a/10\n",
    "    accs_b += acc_b/10\n",
    "    probs_b += prob_b/10\n",
    "    \n",
    "print((accs_a, probs_a))\n",
    "print((accs_b, probs_b))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b46f86e0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b100611",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
