{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1fad92d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy.random as npr\n",
    "import random\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "from tensorflow.keras import layers\n",
    "from tensorflow.keras import backend as K\n",
    "from keras.optimizers import Adam\n",
    "from keras_nlp.layers import PositionEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "df3b6089",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 428\n",
    "\n",
    "np.random.seed(seed)\n",
    "tf.random.set_seed(seed)\n",
    "random.seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e6302f08",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_masked_input_and_labels(encoded_texts, n_cat):\n",
    "    # For each sentence, mask each word one-by-one\n",
    "\n",
    "    encoded_texts_masked = []\n",
    "    y_labels = []\n",
    "\n",
    "    for encoded_text in encoded_texts:\n",
    "        for i in range(len(encoded_text)):\n",
    "            encoded_text_masked = np.copy(encoded_text)\n",
    "            y_label = encoded_text_masked[i]\n",
    "            encoded_texts_masked.append(np.delete(encoded_text_masked, i))\n",
    "            y_labels.append(np.array([y_label]))\n",
    "\n",
    "    return np.array(encoded_texts_masked), np.array(y_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7a9a18d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### K = number of countries = number of capitals = number of currencies\n",
    "#### M = number of words only used by each topic\n",
    "#### S = number of words used by both topics\n",
    "#### L = sentence length\n",
    "#### q1, q2 = probability of having 1 or 2 pairs\n",
    "#### embed_dim = dimension of embeddings\n",
    "#### n_sentences = number of training sentences\n",
    "\n",
    "def train_model(K, M, S, L, q1, q2, embed_dim, n_sentences):\n",
    "    \n",
    "    countries = ['country_' + str(i) for i in range(K)]\n",
    "    capitals = ['capital_' + str(i) for i in range(K)]\n",
    "    currencies = ['currency_' + str(i) for i in range(K)]\n",
    "    random_capitals = ['random_capital_' + str(i) for i in range(M)]\n",
    "    random_currencies = ['random_currency_' + str(i) for i in range(M)]\n",
    "    randoms = ['random_' + str(i) for i in range(S)]\n",
    "\n",
    "    vocabs = countries + capitals + currencies + random_capitals + random_currencies + randoms\n",
    "    vocab_map = {}\n",
    "\n",
    "    for i in range(len(vocabs)):\n",
    "        vocab_map[vocabs[i]] = i\n",
    "        \n",
    "    sentences = []\n",
    "    sentences_number = []\n",
    "    \n",
    "    q0 = 1 - q1 - q2\n",
    "\n",
    "    for i in range(n_sentences):\n",
    "\n",
    "        sentence = []\n",
    "        \n",
    "        temp = npr.uniform()\n",
    "        temp2 = npr.uniform()\n",
    "        \n",
    "        if temp2 <= q0:\n",
    "            n_pairs = 0\n",
    "        elif temp2 <= q0 + q1:\n",
    "            n_pairs = 1\n",
    "        else:\n",
    "            n_pairs = 2\n",
    "        \n",
    "        if temp <= 0.5: ### country - capital\n",
    "        \n",
    "            pairs = np.random.choice(np.arange(K), n_pairs, replace = False)\n",
    "#             pairs2 = np.random.choice(np.arange(K), n_pairs, replace = False)\n",
    "            for pair in pairs:\n",
    "                sentence.append(countries[pair])\n",
    "                sentence.append(capitals[pair])\n",
    "#             for pair in pairs2:\n",
    "#                 temp3 = npr.uniform()\n",
    "#                 if temp3 <= 0.8:\n",
    "#                     sentence.append(random_capitals[pair])\n",
    "#                 else:\n",
    "#                     sentence.append(random_currencies[pair])\n",
    "\n",
    "            randoms_dup = 4 * random_capitals + 2 * randoms + 1 * random_currencies\n",
    "            sentence += list(np.random.choice(randoms_dup, L - 2 * n_pairs, replace = False))  \n",
    "                 \n",
    "        else: ### country - currency\n",
    "            \n",
    "            pairs = np.random.choice(np.arange(K), n_pairs, replace = False)\n",
    "            pairs2 = np.random.choice(np.arange(K), n_pairs, replace = False)\n",
    "            for pair in pairs:\n",
    "                sentence.append(countries[pair])\n",
    "                sentence.append(currencies[pair])        \n",
    "#             for pair in pairs2:\n",
    "#                 temp3 = npr.uniform()\n",
    "#                 if temp3 <= 0.8:\n",
    "#                     sentence.append(random_currencies[pair])\n",
    "#                 else:\n",
    "#                     sentence.append(random_capitals[pair])\n",
    "            \n",
    "            randoms_dup = 1 * random_capitals + 2 * randoms + 4 * random_currencies\n",
    "            sentence += list(np.random.choice(randoms_dup, L - 2 * n_pairs, replace = False))  \n",
    "            \n",
    "#         sentence += list(np.random.choice(randoms, L - 3 * n_pairs, replace = False))  \n",
    "\n",
    "        sentence_number = [vocab_map[i] for i in sentence]\n",
    "        sentences.append(sentence)\n",
    "        sentences_number.append(sentence_number)\n",
    "        \n",
    "    x_train = np.array(sentences_number)\n",
    "    n_cat = len(vocab_map)\n",
    "    x_masked_train, y_masked_labels_train = get_masked_input_and_labels(x_train, n_cat)\n",
    "    \n",
    "    callback = keras.callbacks.EarlyStopping(monitor = 'val_loss', patience = 5, restore_best_weights = True)\n",
    "    inputs = layers.Input((x_masked_train.shape[1],), dtype=tf.int64)\n",
    "    word_embeddings = layers.Embedding(n_cat, embed_dim, name=\"word_embedding\")(inputs)\n",
    "    encoder_output = layers.GlobalAveragePooling1D()(word_embeddings)\n",
    "    mlm_output = layers.Dense(n_cat, name=\"mlm_cls\", activation=\"softmax\", use_bias=False)(encoder_output)\n",
    "    mlm_model = keras.Model(inputs = inputs, outputs = mlm_output)\n",
    "    adam = Adam()\n",
    "    mlm_model.compile(loss='sparse_categorical_crossentropy', optimizer=adam)\n",
    "\n",
    "    history = mlm_model.fit(x_masked_train, y_masked_labels_train,\n",
    "                        validation_split = 0.5, callbacks = [callback], \n",
    "                        epochs=500, batch_size=128, verbose=0)\n",
    "    \n",
    "    return sentences, vocab_map, mlm_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "85be6230",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples):\n",
    "    \n",
    "    sentences, vocab_map, current_model = train_model(K, M, S, L, q1, q2, embed_dim, n_sentences)\n",
    "\n",
    "    acc_capitals = []\n",
    "    prob_capitals = []\n",
    "\n",
    "    for _ in range(n_samples):\n",
    "        sentence = []\n",
    "        random_capitals = np.random.choice(np.arange(K), int(L/2), replace = False)\n",
    "        for random_capital in random_capitals:\n",
    "            sentence.append('country_' + str(random_capital))\n",
    "            sentence.append('capital_' + str(random_capital))\n",
    "        sentence = sentence[:-1]\n",
    "        sentence_number = [vocab_map[i] for i in sentence]\n",
    "        temp = keras.backend.function(inputs = current_model.layers[0].input, outputs = current_model.layers[-1].output) \\\n",
    "            (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "        actual = vocab_map['capital_' + str(random_capitals[-1])]\n",
    "        acc_capitals.append(1 if np.argsort(-1 * temp)[0][0] == actual else 0)\n",
    "        prob_capitals.append(temp[0][vocab_map['capital_' + str(random_capitals[-1])]])\n",
    "        \n",
    "    acc_currencies = []\n",
    "    prob_currencies = []\n",
    "\n",
    "    for _ in range(n_samples):\n",
    "        sentence = []\n",
    "        random_currencies = np.random.choice(np.arange(K), int(L/2), replace = False)\n",
    "        for random_currency in random_currencies:\n",
    "            sentence.append('country_' + str(random_currency))\n",
    "            sentence.append('currency_' + str(random_currency))\n",
    "        sentence = sentence[:-1]\n",
    "        sentence_number = [vocab_map[i] for i in sentence]\n",
    "        temp = keras.backend.function(inputs = current_model.layers[0].input, outputs = current_model.layers[-1].output) \\\n",
    "            (np.array(sentence_number).reshape(1,len(sentence_number)))\n",
    "        actual = vocab_map['currency_' + str(random_currencies[-1])]\n",
    "        acc_currencies.append(1 if np.argsort(-1 * temp)[0][0] == actual else 0)\n",
    "        prob_currencies.append(temp[0][vocab_map['currency_' + str(random_currencies[-1])]])\n",
    "        \n",
    "\n",
    "    return sentences, current_model, vocab_map, (np.mean(acc_capitals), np.mean(prob_capitals)), \\\n",
    "                (np.mean(acc_currencies), np.mean(prob_currencies))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d19dd717",
   "metadata": {},
   "outputs": [],
   "source": [
    "K = 10 # number of countries\n",
    "L = 8 # sentence length\n",
    "M = 20 # number of words used by each topic\n",
    "S = 20 # number of words used by both topics\n",
    "embed_dim = 10 # CBOW embedding dimension\n",
    "n_sentences = 50000 # number of sentences in the training set\n",
    "n_samples = 1000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "0e89e3d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0.0, 8.3214595e-17)\n",
      "(0.0, 1.4009464e-16)\n",
      "(0.0, 1.462113e-14)\n",
      "(0.0, 8.3089683e-14)\n",
      "(0.0, 1.1292068e-14)\n",
      "(0.0, 3.2792918e-14)\n",
      "(0.0, 1.9938461e-14)\n",
      "(0.0, 1.457133e-18)\n",
      "(0.0, 8.692915e-18)\n",
      "(0.0, 1.3593242e-17)\n",
      "(0.0, 1.6518049e-12)\n",
      "(0.0, 2.6100677e-13)\n",
      "(0.0, 1.5436694e-13)\n",
      "(0.0, 1.02521006e-10)\n",
      "(0.0, 9.3441714e-14)\n",
      "(0.0, 1.4293558e-13)\n",
      "(0.0, 3.9450636e-17)\n",
      "(0.0, 1.250635e-15)\n",
      "(0.0, 5.0616144e-15)\n",
      "(0.0, 2.0085748e-16)\n",
      "(0.0, 1.950658175454245e-13)\n",
      "(0.0, 1.0304243715442546e-11)\n"
     ]
    }
   ],
   "source": [
    "q0 = 0 # probability of having 0 pairs\n",
    "q1 = 1 # probability of having 1 pair\n",
    "q2 = 0 # probability of having 2 pairs\n",
    "\n",
    "accs_c = 0\n",
    "probs_c = 0\n",
    "accs_d = 0\n",
    "probs_d = 0\n",
    "\n",
    "for _ in range(10):\n",
    "    sentences, mlm_model, vocab_map, acc_c, acc_d \\\n",
    "        = get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples)\n",
    "    \n",
    "    print(acc_c)\n",
    "    print(acc_d)\n",
    "    \n",
    "    accs_c += acc_c[0]/10\n",
    "    probs_c += acc_c[1]/10\n",
    "    accs_d += acc_d[0]/10\n",
    "    probs_d += acc_d[1]/10\n",
    "    \n",
    "print((accs_c, probs_c))\n",
    "print((accs_d, probs_d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5910567f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0.0, 3.0484634e-07)\n",
      "(0.0, 2.4450063e-07)\n",
      "(0.0, 1.8634066e-05)\n",
      "(0.0, 2.783521e-05)\n",
      "(0.0, 1.7929342e-13)\n",
      "(0.0, 7.7827485e-13)\n",
      "(0.0, 3.0078863e-07)\n",
      "(0.0, 4.5519147e-07)\n",
      "(0.0, 3.52213e-06)\n",
      "(0.0, 1.1548773e-06)\n",
      "(0.0, 4.26297e-06)\n",
      "(0.0, 2.9786926e-07)\n",
      "(0.0, 2.4727813e-06)\n",
      "(0.0, 6.087891e-06)\n",
      "(0.0, 0.0004737468)\n",
      "(0.0, 0.00031728655)\n",
      "(0.0, 3.5114796e-09)\n",
      "(0.0, 2.798786e-07)\n",
      "(0.0, 2.5706981e-07)\n",
      "(0.0, 6.995911e-08)\n",
      "(0.0, 5.035049605769317e-05)\n",
      "(0.0, 3.537119335315171e-05)\n"
     ]
    }
   ],
   "source": [
    "q0 = 0 # probability of having 0 pairs\n",
    "q1 = 0 # probability of having 1 pair\n",
    "q2 = 1 # probability of having 2 pairs\n",
    "\n",
    "accs_c = 0\n",
    "probs_c = 0\n",
    "accs_d = 0\n",
    "probs_d = 0\n",
    "\n",
    "for _ in range(10):\n",
    "    sentences, mlm_model, vocab_map, acc_c, acc_d \\\n",
    "        = get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples)\n",
    "    \n",
    "    print(acc_c)\n",
    "    print(acc_d)\n",
    "    \n",
    "    accs_c += acc_c[0]/10\n",
    "    probs_c += acc_c[1]/10\n",
    "    accs_d += acc_d[0]/10\n",
    "    probs_d += acc_d[1]/10\n",
    "    \n",
    "print((accs_c, probs_c))\n",
    "print((accs_d, probs_d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "f3480dd3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(0.72, 0.502199)\n",
      "(0.694, 0.46673694)\n",
      "(0.825, 0.70494384)\n",
      "(0.736, 0.60295236)\n",
      "(0.573, 0.4654318)\n",
      "(0.477, 0.3770471)\n",
      "(0.786, 0.6514751)\n",
      "(0.723, 0.63462037)\n",
      "(0.81, 0.68875927)\n",
      "(0.727, 0.5743369)\n",
      "(0.778, 0.60464907)\n",
      "(0.741, 0.5742794)\n",
      "(0.689, 0.5197191)\n",
      "(0.776, 0.68141174)\n",
      "(0.512, 0.38633996)\n",
      "(0.599, 0.46602094)\n",
      "(0.786, 0.5728896)\n",
      "(0.864, 0.71857214)\n",
      "(0.373, 0.27059686)\n",
      "(0.507, 0.3280563)\n",
      "(0.6852, 0.5367003619670868)\n",
      "(0.6844, 0.5424034208059312)\n"
     ]
    }
   ],
   "source": [
    "q0 = 1/2 # probability of having 0 pairs\n",
    "q1 = 1/2 # probability of having 1 pair\n",
    "q2 = 0 # probability of having 2 pairs\n",
    "\n",
    "accs_c = 0\n",
    "probs_c = 0\n",
    "accs_d = 0\n",
    "probs_d = 0\n",
    "\n",
    "for _ in range(10):\n",
    "    sentences, mlm_model, vocab_map, acc_c, acc_d \\\n",
    "        = get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples)\n",
    "    \n",
    "    print(acc_c)\n",
    "    print(acc_d)\n",
    "    \n",
    "    accs_c += acc_c[0]/10\n",
    "    probs_c += acc_c[1]/10\n",
    "    accs_d += acc_d[0]/10\n",
    "    probs_d += acc_d[1]/10\n",
    "    \n",
    "print((accs_c, probs_c))\n",
    "print((accs_d, probs_d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "46ca5547",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-05-01 13:02:17.755421: W tensorflow/tsl/platform/default/dso_loader.cc:66] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory\n",
      "2024-05-01 13:02:17.755453: W tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: UNKNOWN ERROR (303)\n",
      "2024-05-01 13:02:17.755470: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (gl3471.arc-ts.umich.edu): /proc/driver/nvidia/version does not exist\n",
      "2024-05-01 13:02:17.755730: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1.0, 0.99926966)\n",
      "(1.0, 0.9982506)\n",
      "(1.0, 0.9990663)\n",
      "(1.0, 0.99929607)\n",
      "(1.0, 0.99965537)\n",
      "(1.0, 0.9978787)\n",
      "(1.0, 0.9984031)\n",
      "(1.0, 0.99948514)\n",
      "(1.0, 0.99854124)\n",
      "(1.0, 0.9973472)\n",
      "(1.0, 0.99922097)\n",
      "(1.0, 0.9976323)\n",
      "(1.0, 0.9997227)\n",
      "(1.0, 0.9988722)\n",
      "(1.0, 0.9995129)\n",
      "(1.0, 0.998846)\n",
      "(1.0, 0.9996017)\n",
      "(1.0, 0.9989794)\n",
      "(1.0, 0.99902964)\n",
      "(1.0, 0.9989065)\n",
      "(0.9999999999999999, 0.9992023587226869)\n",
      "(0.9999999999999999, 0.9985494077205659)\n"
     ]
    }
   ],
   "source": [
    "q0 = 1/2 # probability of having 0 pairs\n",
    "q1 = 0 # probability of having 1 pair\n",
    "q2 = 1/2 # probability of having 2 pairs\n",
    "\n",
    "accs_c = 0\n",
    "probs_c = 0\n",
    "accs_d = 0\n",
    "probs_d = 0\n",
    "\n",
    "for _ in range(10):\n",
    "    sentences, mlm_model, vocab_map, acc_c, acc_d \\\n",
    "        = get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples)\n",
    "    \n",
    "    print(acc_c)\n",
    "    print(acc_d)\n",
    "    \n",
    "    accs_c += acc_c[0]/10\n",
    "    probs_c += acc_c[1]/10\n",
    "    accs_d += acc_d[0]/10\n",
    "    probs_d += acc_d[1]/10\n",
    "    \n",
    "print((accs_c, probs_c))\n",
    "print((accs_d, probs_d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cd93057c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1.0, 0.98765427)\n",
      "(1.0, 0.9946565)\n",
      "(1.0, 0.99299246)\n",
      "(1.0, 0.9943016)\n",
      "(1.0, 0.98487014)\n",
      "(1.0, 0.9926157)\n",
      "(1.0, 0.9971343)\n",
      "(1.0, 0.99226826)\n",
      "(1.0, 0.99350095)\n",
      "(1.0, 0.9850159)\n",
      "(1.0, 0.99728906)\n",
      "(1.0, 0.9936302)\n",
      "(1.0, 0.99172324)\n",
      "(1.0, 0.9945902)\n",
      "(1.0, 0.991106)\n",
      "(1.0, 0.99587023)\n",
      "(1.0, 0.98808455)\n",
      "(1.0, 0.99075085)\n",
      "(1.0, 0.9936555)\n",
      "(1.0, 0.9970682)\n",
      "(0.9999999999999999, 0.9918010473251343)\n",
      "(0.9999999999999999, 0.9930767714977264)\n"
     ]
    }
   ],
   "source": [
    "q0 = 0 # probability of having 0 pairs\n",
    "q1 = 1/2 # probability of having 1 pair\n",
    "q2 = 1/2 # probability of having 2 pairs\n",
    "\n",
    "accs_c = 0\n",
    "probs_c = 0\n",
    "accs_d = 0\n",
    "probs_d = 0\n",
    "\n",
    "for _ in range(10):\n",
    "    sentences, mlm_model, vocab_map, acc_c, acc_d \\\n",
    "        = get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples)\n",
    "    \n",
    "    print(acc_c)\n",
    "    print(acc_d)\n",
    "    \n",
    "    accs_c += acc_c[0]/10\n",
    "    probs_c += acc_c[1]/10\n",
    "    accs_d += acc_d[0]/10\n",
    "    probs_d += acc_d[1]/10\n",
    "    \n",
    "print((accs_c, probs_c))\n",
    "print((accs_d, probs_d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "bb9a28cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1.0, 0.9966531)\n",
      "(1.0, 0.9970186)\n",
      "(1.0, 0.99799836)\n",
      "(1.0, 0.9940305)\n",
      "(1.0, 0.99430984)\n",
      "(1.0, 0.9984759)\n",
      "(1.0, 0.9956707)\n",
      "(1.0, 0.9973306)\n",
      "(1.0, 0.9982028)\n",
      "(1.0, 0.99786687)\n",
      "(1.0, 0.99744016)\n",
      "(1.0, 0.9959179)\n",
      "(1.0, 0.99782306)\n",
      "(1.0, 0.9973954)\n",
      "(1.0, 0.9984711)\n",
      "(1.0, 0.998881)\n",
      "(1.0, 0.99778235)\n",
      "(1.0, 0.997685)\n",
      "(1.0, 0.9978583)\n",
      "(1.0, 0.9979493)\n",
      "(0.9999999999999999, 0.9972209692001344)\n",
      "(0.9999999999999999, 0.9972551047801971)\n"
     ]
    }
   ],
   "source": [
    "q0 = 1/3 # probability of having 0 pairs\n",
    "q1 = 1/3 # probability of having 1 pair\n",
    "q2 = 1/3 # probability of having 2 pairs\n",
    "\n",
    "accs_c = 0\n",
    "probs_c = 0\n",
    "accs_d = 0\n",
    "probs_d = 0\n",
    "\n",
    "for _ in range(10):\n",
    "    sentences, mlm_model, vocab_map, acc_c, acc_d \\\n",
    "        = get_acc_prob(K, M, S, L, q1, q2, embed_dim, n_sentences, n_samples)\n",
    "    \n",
    "    print(acc_c)\n",
    "    print(acc_d)\n",
    "    \n",
    "    accs_c += acc_c[0]/10\n",
    "    probs_c += acc_c[1]/10\n",
    "    accs_d += acc_d[0]/10\n",
    "    probs_d += acc_d[1]/10\n",
    "    \n",
    "print((accs_c, probs_c))\n",
    "print((accs_d, probs_d))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d686a12b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "eb4ce539",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef1a75b9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4416760",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
