{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "49b4f88e-f7c3-4a84-aab9-3420d682263c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy, requests, codecs, os, re, nltk, itertools, csv\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "import tensorflow as tf\n",
    "from scipy.stats import spearmanr\n",
    "import functools as ft\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.preprocessing import normalize\n",
    "import gdown"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "da273c03-392b-473d-8e0a-23d52cddfbe5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import gensim\n",
    "import codecs\n",
    "import json\n",
    "from gensim.models.keyedvectors import Word2VecKeyedVectors\n",
    "from gensim.models import KeyedVectors\n",
    "import numpy as np\n",
    "import random\n",
    "import sklearn\n",
    "from sklearn import model_selection\n",
    "from sklearn import cluster\n",
    "from sklearn import metrics\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.svm import LinearSVC, SVC\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from sklearn.feature_extraction import DictVectorizer\n",
    "import numpy as np\n",
    "from docopt import docopt\n",
    "import torch\n",
    "from transformers import *\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "\n",
    "import scipy\n",
    "from scipy import linalg\n",
    "from scipy import sparse\n",
    "from scipy.stats.stats import pearsonr\n",
    "import tqdm\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import SGDClassifier, SGDRegressor, Perceptron, LogisticRegression\n",
    "\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import pickle\n",
    "from collections import defaultdict, Counter\n",
    "from typing import List, Dict\n",
    "\n",
    "import torch\n",
    "from torch import utils\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import time\n",
    "from mlxtend.math import vectorspace_orthonormalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "42711c0a-9ed5-4243-b7ca-cdb77ecc2a69",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "from keras.wrappers.scikit_learn import KerasClassifier\n",
    "from keras.utils import np_utils\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.pipeline import Pipeline\n",
    "import keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "128cd7ff-0e14-45c5-8c68-14473c70365d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import Tensor\n",
    "from transformers import BertTokenizer, BertModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "086245d4-ce67-4f11-8fe6-927e4413171a",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loading configuration file config.json from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/config.json\n",
      "Model config BertConfig {\n",
      "  \"_name_or_path\": \"bert-base-uncased\",\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.22.2\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "loading file vocab.txt from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/vocab.txt\n",
      "loading file tokenizer.json from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/tokenizer.json\n",
      "loading file added_tokens.json from cache at None\n",
      "loading file special_tokens_map.json from cache at None\n",
      "loading file tokenizer_config.json from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/tokenizer_config.json\n",
      "loading configuration file config.json from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/config.json\n",
      "Model config BertConfig {\n",
      "  \"_name_or_path\": \"bert-base-uncased\",\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.22.2\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "loading configuration file config.json from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/config.json\n",
      "Model config BertConfig {\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"output_hidden_states\": true,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.22.2\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "loading weights file pytorch_model.bin from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/pytorch_model.bin\n",
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "All the weights of BertModel were initialized from the model checkpoint at bert-base-uncased.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use BertModel for predictions without further training.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = BertModel.from_pretrained('bert-base-uncased',\n",
    "                                  output_hidden_states = True, # Whether the model returns all hidden-states.\n",
    "                                  ).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cade316-3db6-45fe-a1a6-93b87e05dae4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7238ea9e-0ab7-4bba-8a04-6feb94994cf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_text(model, data):\n",
    "    \"\"\"\n",
    "    encode the text\n",
    "    :param model: encoding model\n",
    "    :param data: data\n",
    "    :return: two numpy matrices of the data:\n",
    "                first: average of all tokens in each sentence\n",
    "                second: cls token of each sentence\n",
    "    \"\"\"\n",
    "    all_data_avg = []\n",
    "    batch = []\n",
    "    count = 0\n",
    "    for w in data:\n",
    "        tokens = tokenizer.encode(w, add_special_tokens=False)\n",
    "        batch.append(tokens)\n",
    "        input_ids = torch.tensor(batch).to(device)\n",
    "        with torch.no_grad():\n",
    "            last_hidden_states = model(input_ids)[0]\n",
    "            all_data_avg.append((last_hidden_states.squeeze(0).mean(dim=0).cpu()).numpy())\n",
    "        batch = []\n",
    "        count = count + 1\n",
    "        if count%1000 == 0:\n",
    "            print(count)\n",
    "    return np.array(all_data_avg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a5aa6150-048a-492e-ba24-37b586488d6f",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "1670b191-78f0-406c-a4c3-e7975f85d128",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# avg_data = encode_text(model, all_words)\n",
    "# df = pd.DataFrame(avg_data)\n",
    "# df.to_csv('Train_emb.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4e7cfc3f-7f3d-4280-b6ff-76a9f7f92766",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = pd.read_csv('Train_emb.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7dbab3d1-737c-44e3-a619-cb871ba18b25",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = X_train.values[:,1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "21117378-30f8-45a5-90ec-763057e69764",
   "metadata": {},
   "outputs": [],
   "source": [
    "def directions(w1,w2,model):\n",
    "    \n",
    "    tokens1 = tokenizer.encode(w1, add_special_tokens=False)\n",
    "    input_ids1 = torch.tensor([tokens1]).to(device)\n",
    "    with torch.no_grad():\n",
    "        last_hidden_states1 = model(input_ids1)[0]\n",
    "        embs1 = (last_hidden_states1.squeeze(0).mean(dim=0).cpu()).numpy()\n",
    "    \n",
    "    tokens2 = tokenizer.encode(w2, add_special_tokens=False)\n",
    "    input_ids2 = torch.tensor([tokens2]).to(device)\n",
    "    with torch.no_grad():\n",
    "        last_hidden_states2 = model(input_ids2)[0]\n",
    "        embs2 = (last_hidden_states2.squeeze(0).mean(dim=0).cpu()).numpy()\n",
    "        \n",
    "    return embs1 - embs2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "cf8241ad-9e6b-4eff-916d-fe79cf7e0158",
   "metadata": {},
   "outputs": [],
   "source": [
    "d1 = directions('he','she',model)\n",
    "d2 = directions('man','woman',model)\n",
    "d3 = directions('father','mother',model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "adb58056-87fd-4bc9-ba6a-29d084dc4607",
   "metadata": {},
   "outputs": [],
   "source": [
    "similarity = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "493438cc-b5c1-4741-a264-24914aba2e98",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.distance import cosine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "2d9f9811-60a1-40f4-ac58-0425ab027c03",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "10000\n",
      "20000\n",
      "30000\n",
      "40000\n",
      "50000\n",
      "60000\n",
      "70000\n",
      "80000\n",
      "90000\n",
      "100000\n",
      "110000\n",
      "120000\n",
      "130000\n",
      "140000\n",
      "150000\n",
      "160000\n",
      "170000\n",
      "180000\n",
      "190000\n",
      "200000\n",
      "210000\n",
      "220000\n",
      "230000\n",
      "240000\n",
      "250000\n",
      "260000\n",
      "270000\n",
      "280000\n",
      "290000\n",
      "300000\n",
      "310000\n",
      "320000\n"
     ]
    }
   ],
   "source": [
    "for i in range(X_train.shape[0]):\n",
    "    similarity.append(1-cosine(X_train[i,:], d1))\n",
    "    if i%10000 == 0:\n",
    "        print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "ece1527f-3e56-449e-b002-e8db74ccc6f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df1 = pd.DataFrame(columns = ['word','sim'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "b80d3f5a-9fd0-42fc-bc83-029697f87f81",
   "metadata": {},
   "outputs": [],
   "source": [
    "df1['sim'] = similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "52c1063d-f42c-4ec5-961c-fa11b87c5167",
   "metadata": {},
   "outputs": [],
   "source": [
    "df1 = df1.sort_values(by=\"sim\" , ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "8aba8453-d06e-4057-8909-74494941c9b3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>word</th>\n",
       "      <th>sim</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>221327</th>\n",
       "      <td>NaN</td>\n",
       "      <td>0.397432</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>144341</th>\n",
       "      <td>NaN</td>\n",
       "      <td>0.397432</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>247522</th>\n",
       "      <td>NaN</td>\n",
       "      <td>0.397432</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>13</th>\n",
       "      <td>NaN</td>\n",
       "      <td>0.397432</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3749</th>\n",
       "      <td>NaN</td>\n",
       "      <td>0.358897</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25593</th>\n",
       "      <td>NaN</td>\n",
       "      <td>-0.408568</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>17343</th>\n",
       "      <td>NaN</td>\n",
       "      <td>-0.409648</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>381</th>\n",
       "      <td>NaN</td>\n",
       "      <td>-0.412838</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2356</th>\n",
       "      <td>NaN</td>\n",
       "      <td>-0.431266</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>18006</th>\n",
       "      <td>NaN</td>\n",
       "      <td>-0.447041</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>322636 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       word       sim\n",
       "221327  NaN  0.397432\n",
       "144341  NaN  0.397432\n",
       "247522  NaN  0.397432\n",
       "13      NaN  0.397432\n",
       "3749    NaN  0.358897\n",
       "...     ...       ...\n",
       "25593   NaN -0.408568\n",
       "17343   NaN -0.409648\n",
       "381     NaN -0.412838\n",
       "2356    NaN -0.431266\n",
       "18006   NaN -0.447041\n",
       "\n",
       "[322636 rows x 2 columns]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "23815ab5-1a25-4d8f-b1cb-09c0616190fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_ = []\n",
    "Y_ = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "6a561446-adeb-4408-803c-30a8a214d376",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(75000):\n",
    "    X_.append(X_train[df1.index[i],:])\n",
    "    Y_.append(1)\n",
    "    X_.append(X_train[df1.index[-i-1],:])\n",
    "    Y_.append(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "e0572246-2ba8-4577-8763-ff769788b6ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = np.array(X_)\n",
    "Y_train = np.array(Y_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "8982c3ec-9d2b-49ad-9aae-724f5a2be09d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.05946696\n",
      "Epoch 3, change: 0.02455873\n",
      "Epoch 4, change: 0.02104287\n",
      "Epoch 5, change: 0.00849297\n",
      "Epoch 6, change: 0.00799657\n",
      "max_iter reached after 11 seconds\n",
      "time: 11.363704681396484\n",
      "1.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   11.2s finished\n"
     ]
    }
   ],
   "source": [
    "clf = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                         verbose = 5, n_jobs = 90, random_state = 1, max_iter = 7)\n",
    "\n",
    "        \n",
    "start = time.time()\n",
    "clf.fit(X_train, Y_train)\n",
    "print(\"time: {}\".format(time.time() - start))\n",
    "print(clf.score(X_train, Y_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "844216ff-b753-4647-bdb8-c73b58cc2bb6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7, change: 0.00879040\n"
     ]
    }
   ],
   "source": [
    "Y = clf.predict_proba(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80136146-6300-40bb-b3da-7fed32364066",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "de0d9c03-a715-49cc-a6b2-078317dbf4b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def slice_freq(Y, interval):\n",
    "    y_f = pd.Series(Y)\n",
    "    a = pd.cut(y_f,interval)\n",
    "    b=a.value_counts()\n",
    "    frequency = b.sort_index().values\n",
    "    return frequency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "f75ab3e8-5a22-4bd4-9d52-39e61f93dc63",
   "metadata": {},
   "outputs": [],
   "source": [
    "def SDR_multi(X_train, Y, K, h):\n",
    "    X = X_train\n",
    "    n = X.shape[0]\n",
    "    d_x = X.shape[1]\n",
    "    d_y = Y.shape[1]\n",
    "\n",
    "    interval = np.linspace(0, 1, h+1,endpoint=True)\n",
    "    Gamma_PMS = np.zeros([d_x,d_x])\n",
    "    for i in range(d_y):  \n",
    "        freq = slice_freq(Y[:,i], interval)\n",
    "        Gamma = np.zeros([d_x,d_x])\n",
    "        for j in range(h):\n",
    "            ph = freq[j]/n\n",
    "            index = np.where((Y[:,i]>=interval[j])&(Y[:,i]<=interval[j+1]))\n",
    "            mh = (1/(n*ph))*np.sum(X[index,:],0)\n",
    "            Gamma = Gamma + ph*np.dot(mh.T,mh)\n",
    "        la_g, v_g = np.linalg.eig(Gamma)\n",
    "        la_g = la_g.real\n",
    "        Gamma_PMS = Gamma_PMS + la_g[0]*Gamma\n",
    "    la_PMS, v_PMS = np.linalg.eig(Gamma_PMS)\n",
    "    v_PMS = v_PMS.real\n",
    "    beta1 = vectorspace_orthonormalization(v_PMS[:,:K])\n",
    "    #beta1 = v_PMS[:,:K]\n",
    "    return np.dot(beta1,beta1.T), np.eye(d_x)-np.dot(beta1,beta1.T) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "b1a1ba8d-b4f8-4052-be46-4397886ca57c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def iteration_estimation_v2(X_train, Y, Iter, q, h):\n",
    "    N_sample = len(X_train)\n",
    "    d = X_train.shape[1]\n",
    "    beta1 = np.eye(d)\n",
    "    \n",
    "    X_batch = X_train\n",
    "    n = X_batch.shape[0]\n",
    "    X = X_batch-np.mean(X_batch,0).reshape(1,d)\n",
    "    Cx = np.dot(X_batch.T,X_batch)/n\n",
    "    la, v = np.linalg.eig(Cx)\n",
    "    la = la.real\n",
    "    v = v.real\n",
    "    Cx12 = np.dot(np.dot(v,np.diag(la**(0.5))),v.T)\n",
    "    X_train = np.dot(X_batch,np.linalg.pinv(Cx12,rcond=1e-8))\n",
    "    beta1, beta2 = SDR_multi(X_train, Y, q, h)\n",
    "    print(np.diag(beta1)[:5])\n",
    "    if Iter == 0:\n",
    "        return beta1, beta2\n",
    "    else:\n",
    "        for j in range(Iter):\n",
    "            idx = np.random.rand(X_train.shape[0]) < 1\n",
    "            X_train = np.dot(X_train,beta1)\n",
    "            beta1, beta2 = SDR_multi(X_train[idx], Y[idx], q, h)\n",
    "            print(np.diag(beta1)[:5])\n",
    "    return beta1, beta2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "3a32db18-1096-4694-a161-404eb1eaa318",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.93660651 0.90265901 0.91767121 0.92405236 0.92876875]\n",
      "[0.92972325 0.83710308 0.84907501 0.8414133  0.88773655]\n",
      "[0.83203457 0.77739867 0.77821186 0.77204261 0.80388392]\n",
      "[0.83864099 0.7854415  0.77893939 0.81878814 0.80729757]\n"
     ]
    }
   ],
   "source": [
    "beta1, beta2 = iteration_estimation_v2(X_train, Y, 3, 700, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5be2d8e4-4eb6-49b0-aeff-a85939940873",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "8863083e-a7b2-4139-ad33-ddc4745effc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "c1ca1532-5cdd-4325-8cca-248c3a64a45a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "f4aaa8b3-9d9a-46dc-a374-9386eda6c367",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import random\n",
    "import re\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import weat\n",
    "\n",
    "\n",
    "class SEATRunner:\n",
    "    \"\"\"Runs SEAT tests for a given HuggingFace transformers model.\n",
    "    Implementation taken from: https://github.com/W4ngatang/sent-bias.\n",
    "    \"\"\"\n",
    "\n",
    "    # Extension for files containing SEAT tests.\n",
    "    TEST_EXT = \".jsonl\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        model,\n",
    "        tokenizer,\n",
    "        tests,\n",
    "        data_dir,\n",
    "        experiment_id,\n",
    "        n_samples=100000,\n",
    "        parametric=False,\n",
    "        seed=0,\n",
    "    ):\n",
    "        \"\"\"Initializes a SEAT test runner.\n",
    "        Args:\n",
    "            model: HuggingFace model (e.g., BertModel) to evaluate.\n",
    "            tokenizer: HuggingFace tokenizer (e.g., BertTokenizer) used for pre-processing.\n",
    "            tests (`str`): Comma separated list of SEAT tests to run. SEAT test files should\n",
    "                be in `data_dir` and have corresponding names with extension \".jsonl\".\n",
    "            data_dir (`str`): Path to directory containing the SEAT tests.\n",
    "            experiment_id (`str`): Experiment identifier. Used for logging.\n",
    "            n_samples (`int`): Number of permutation test samples used when estimating p-values\n",
    "                (exact test is used if there are fewer than this many permutations).\n",
    "            parametric (`bool`): Use parametric test (normal assumption) to compute p-values.\n",
    "            seed (`int`): Random seed.\n",
    "        \"\"\"\n",
    "        self._model = model\n",
    "        self._tokenizer = tokenizer\n",
    "        self._tests = tests\n",
    "        self._data_dir = data_dir\n",
    "        self._experiment_id = experiment_id\n",
    "        self._n_samples = n_samples\n",
    "        self._parametric = parametric\n",
    "        self._seed = seed\n",
    "\n",
    "    def __call__(self):\n",
    "        \"\"\"Runs specified SEAT tests.\n",
    "        Returns:\n",
    "            `list` of `dict`s containing the SEAT test results.\n",
    "        \"\"\"\n",
    "        random.seed(self._seed)\n",
    "        np.random.seed(self._seed)\n",
    "\n",
    "        all_tests = sorted(\n",
    "            [\n",
    "                entry[: -len(self.TEST_EXT)]\n",
    "                for entry in os.listdir(self._data_dir)\n",
    "                if not entry.startswith(\".\") and entry.endswith(self.TEST_EXT)\n",
    "            ],\n",
    "            key=_test_sort_key,\n",
    "        )\n",
    "\n",
    "        # Use the specified tests, otherwise, run all SEAT tests.\n",
    "        tests = self._tests or all_tests\n",
    "\n",
    "        results = []\n",
    "        for test in tests:\n",
    "            print(f\"Running test {test}\")\n",
    "\n",
    "            # Load the test data.\n",
    "            encs = _load_json(os.path.join(self._data_dir, f\"{test}{self.TEST_EXT}\"))\n",
    "\n",
    "            print(\"Computing sentence encodings\")\n",
    "            encs_targ1 = _encode(\n",
    "                self._model, self._tokenizer, encs[\"targ1\"][\"examples\"]\n",
    "            )\n",
    "            encs_targ2 = _encode(\n",
    "                self._model, self._tokenizer, encs[\"targ2\"][\"examples\"]\n",
    "            )\n",
    "            encs_attr1 = _encode(\n",
    "                self._model, self._tokenizer, encs[\"attr1\"][\"examples\"]\n",
    "            )\n",
    "            encs_attr2 = _encode(\n",
    "                self._model, self._tokenizer, encs[\"attr2\"][\"examples\"]\n",
    "            )\n",
    "\n",
    "            encs[\"targ1\"][\"encs\"] = encs_targ1\n",
    "            encs[\"targ2\"][\"encs\"] = encs_targ2\n",
    "            encs[\"attr1\"][\"encs\"] = encs_attr1\n",
    "            encs[\"attr2\"][\"encs\"] = encs_attr2\n",
    "\n",
    "            print(\"\\tDone!\")\n",
    "\n",
    "            # Run the test on the encodings.\n",
    "            esize, pval = weat.run_test(\n",
    "                encs, n_samples=self._n_samples, parametric=self._parametric\n",
    "            )\n",
    "\n",
    "            results.append(\n",
    "                {\n",
    "                    \"experiment_id\": self._experiment_id,\n",
    "                    \"test\": test,\n",
    "                    \"p_value\": pval,\n",
    "                    \"effect_size\": esize,\n",
    "                }\n",
    "            )\n",
    "\n",
    "        return results\n",
    "\n",
    "\n",
    "def _test_sort_key(test):\n",
    "    \"\"\"Return tuple to be used as a sort key for the specified test name.\n",
    "    Break test name into pieces consisting of the integers in the name\n",
    "    and the strings in between them.\n",
    "    \"\"\"\n",
    "    key = ()\n",
    "    prev_end = 0\n",
    "    for match in re.finditer(r\"\\d+\", test):\n",
    "        key = key + (test[prev_end : match.start()], int(match.group(0)))\n",
    "        prev_end = match.end()\n",
    "    key = key + (test[prev_end:],)\n",
    "\n",
    "    return key\n",
    "\n",
    "\n",
    "def _split_comma_and_check(arg_str, allowed_set, item_type):\n",
    "    \"\"\"Given a comma-separated string of items, split on commas and check if\n",
    "    all items are in allowed_set -- item_type is just for the assert message.\n",
    "    \"\"\"\n",
    "    items = arg_str.split(\",\")\n",
    "    for item in items:\n",
    "        if item not in allowed_set:\n",
    "            raise ValueError(f\"Unknown {item_type}: {item}!\")\n",
    "    return items\n",
    "\n",
    "\n",
    "def _load_json(sent_file):\n",
    "    \"\"\"Load from json. We expect a certain format later, so do some post processing.\"\"\"\n",
    "    print(f\"Loading {sent_file}...\")\n",
    "    all_data = json.load(open(sent_file, \"r\"))\n",
    "    data = {}\n",
    "    for k, v in all_data.items():\n",
    "        examples = v[\"examples\"]\n",
    "        data[k] = examples\n",
    "        v[\"examples\"] = examples\n",
    "\n",
    "    return all_data\n",
    "\n",
    "\n",
    "def _encode(model, tokenizer, texts):\n",
    "    encs = {}\n",
    "    for text in texts:\n",
    "        # Encode each example.\n",
    "        inputs = tokenizer(text, return_tensors=\"pt\")\n",
    "        outputs = model(**inputs)\n",
    "\n",
    "        # Average over the last layer of hidden representations.\n",
    "        enc = outputs[\"last_hidden_state\"]\n",
    "        enc = enc.mean(dim=1)\n",
    "        \n",
    "        \n",
    "\n",
    "        # Following May et al., normalize the representation.\n",
    "        encs[text] = enc.detach().view(-1).numpy()\n",
    "        encs[text] = proj.dot(encs[text])\n",
    "        encs[text] /= np.linalg.norm(encs[text])\n",
    "        #encs[text] = proj.dot(encs[text]) \n",
    "        \n",
    "\n",
    "    return encs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "d4274117-1907-49a9-9beb-10541b84c558",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "paths = glob.glob('../SEAT-test/*.jsonl')\n",
    "paths.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "cc7cc2d1-5e96-48ac-9c9f-3a0a0ba376db",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "d09d22e9-f5ae-4d3c-ae6e-a9e014847b8a",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loading configuration file config.json from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/config.json\n",
      "Model config BertConfig {\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"output_hidden_states\": true,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.22.2\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "loading weights file pytorch_model.bin from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/pytorch_model.bin\n",
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "All the weights of BertModel were initialized from the model checkpoint at bert-base-uncased.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use BertModel for predictions without further training.\n"
     ]
    }
   ],
   "source": [
    "model = BertModel.from_pretrained('bert-base-uncased',\n",
    "                                  output_hidden_states = True, # Whether the model returns all hidden-states.\n",
    "                                  ).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "b80f266a-8e7a-4af8-9834-1a129b9ef411",
   "metadata": {},
   "outputs": [],
   "source": [
    "proj = beta2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "c6dc4be1-838a-4666-a279-d2bc44b9ddd8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running test ../SEAT-test/sent-weat6.jsonl\n",
      "Loading ../SEAT-test/sent-weat6.jsonl...\n",
      "Computing sentence encodings\n",
      "\tDone!\n",
      "Computing cosine similarities...\n",
      "Null hypothesis: no difference between MaleNames and FemaleNames in association to attributes Career and Family\n",
      "Computing pval...\n",
      "Using non-parametric test\n",
      "Drawing 99999 samples (and biasing by 1)\n",
      "pval: 0.562\n",
      "computing effect size...\n",
      "esize: -0.028\n",
      "Running test ../SEAT-test/sent-weat6b.jsonl\n",
      "Loading ../SEAT-test/sent-weat6b.jsonl...\n",
      "Computing sentence encodings\n",
      "\tDone!\n",
      "Computing cosine similarities...\n",
      "Null hypothesis: no difference between MaleTerms and FemaleTerms in association to attributes Career and Family\n",
      "Computing pval...\n",
      "Using non-parametric test\n",
      "Drawing 99999 samples (and biasing by 1)\n",
      "pval: 0.966\n",
      "computing effect size...\n",
      "esize: -0.286\n",
      "Running test ../SEAT-test/sent-weat7.jsonl\n",
      "Loading ../SEAT-test/sent-weat7.jsonl...\n",
      "Computing sentence encodings\n",
      "\tDone!\n",
      "Computing cosine similarities...\n",
      "Null hypothesis: no difference between Math and Arts in association to attributes MaleTerms and FemaleTerms\n",
      "Computing pval...\n",
      "Using non-parametric test\n",
      "Drawing 99999 samples (and biasing by 1)\n",
      "pval: 0.993\n",
      "computing effect size...\n",
      "esize: -0.403\n",
      "Running test ../SEAT-test/sent-weat7b.jsonl\n",
      "Loading ../SEAT-test/sent-weat7b.jsonl...\n",
      "Computing sentence encodings\n",
      "\tDone!\n",
      "Computing cosine similarities...\n",
      "Null hypothesis: no difference between Math and Arts in association to attributes MaleNames and FemaleNames\n",
      "Computing pval...\n",
      "Using non-parametric test\n",
      "Drawing 99999 samples (and biasing by 1)\n",
      "pval: 0.937\n",
      "computing effect size...\n",
      "esize: -0.255\n",
      "Running test ../SEAT-test/sent-weat8.jsonl\n",
      "Loading ../SEAT-test/sent-weat8.jsonl...\n",
      "Computing sentence encodings\n",
      "\tDone!\n",
      "Computing cosine similarities...\n",
      "Null hypothesis: no difference between Science and Arts in association to attributes MaleTerms and FemaleTerms\n",
      "Computing pval...\n",
      "Using non-parametric test\n",
      "Drawing 99999 samples (and biasing by 1)\n",
      "pval: 0.131\n",
      "computing effect size...\n",
      "esize: 0.213\n",
      "Running test ../SEAT-test/sent-weat8b.jsonl\n",
      "Loading ../SEAT-test/sent-weat8b.jsonl...\n",
      "Computing sentence encodings\n",
      "\tDone!\n",
      "Computing cosine similarities...\n",
      "Null hypothesis: no difference between Science and Arts in association to attributes MaleNames and FemaleNames\n",
      "Computing pval...\n",
      "Using non-parametric test\n",
      "Drawing 99999 samples (and biasing by 1)\n",
      "pval: 0.740\n",
      "computing effect size...\n",
      "esize: -0.124\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for test in paths:\n",
    "    print(f\"Running test {test}\")\n",
    "\n",
    "    # Load the test data.\n",
    "    encs = _load_json(test)\n",
    "\n",
    "    print(\"Computing sentence encodings\")\n",
    "    encs_targ1 = _encode(\n",
    "        model, tokenizer, encs[\"targ1\"][\"examples\"]\n",
    "    )\n",
    "    encs_targ2 = _encode(\n",
    "        model, tokenizer, encs[\"targ2\"][\"examples\"]\n",
    "    )\n",
    "    encs_attr1 = _encode(\n",
    "        model, tokenizer, encs[\"attr1\"][\"examples\"]\n",
    "    )\n",
    "    encs_attr2 = _encode(\n",
    "        model, tokenizer, encs[\"attr2\"][\"examples\"]\n",
    "    )\n",
    "\n",
    "    encs[\"targ1\"][\"encs\"] = encs_targ1\n",
    "    encs[\"targ2\"][\"encs\"] = encs_targ2\n",
    "    encs[\"attr1\"][\"encs\"] = encs_attr1\n",
    "    encs[\"attr2\"][\"encs\"] = encs_attr2\n",
    "\n",
    "    print(\"\\tDone!\")\n",
    "\n",
    "    # Run the test on the encodings.\n",
    "    esize, pval = weat.run_test(\n",
    "        encs, n_samples=100000, parametric=False\n",
    "    )\n",
    "\n",
    "    results.append(\n",
    "        {\n",
    "            \"test\": test,\n",
    "            \"p_value\": pval,\n",
    "            \"effect_size\": esize,\n",
    "        }\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "8d0497aa-c7ff-40bc-8146-216481913d5b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'test': '../SEAT-test/sent-weat6.jsonl',\n",
       "  'p_value': 0.56194,\n",
       "  'effect_size': -0.028227643608501583},\n",
       " {'test': '../SEAT-test/sent-weat6b.jsonl',\n",
       "  'p_value': 0.96552,\n",
       "  'effect_size': -0.2861490042382975},\n",
       " {'test': '../SEAT-test/sent-weat7.jsonl',\n",
       "  'p_value': 0.9926,\n",
       "  'effect_size': -0.40295335202088955},\n",
       " {'test': '../SEAT-test/sent-weat7b.jsonl',\n",
       "  'p_value': 0.93748,\n",
       "  'effect_size': -0.2546618717162532},\n",
       " {'test': '../SEAT-test/sent-weat8.jsonl',\n",
       "  'p_value': 0.13133,\n",
       "  'effect_size': 0.21291200432645224},\n",
       " {'test': '../SEAT-test/sent-weat8b.jsonl',\n",
       "  'p_value': 0.74041,\n",
       "  'effect_size': -0.12413830112746467}]"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "190ef36a-7ed3-42a5-b3bc-58e67a609d63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.21817369617297644\n"
     ]
    }
   ],
   "source": [
    "effect_size = []\n",
    "ss = 0\n",
    "for value in results:\n",
    "    v = value['effect_size']\n",
    "    effect_size.append(v)\n",
    "    ss += abs(v)\n",
    "print(ss/len(effect_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "95222d22-8678-4b6e-9343-d4b0af8f2f3e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[-0.028227643608501583,\n",
       " -0.2861490042382975,\n",
       " -0.40295335202088955,\n",
       " -0.2546618717162532,\n",
       " 0.21291200432645224,\n",
       " -0.12413830112746467]"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "effect_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "371027ed-c2d2-464c-9e15-6465b12db750",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ccde25c-4f17-43b7-b4f7-4237176456c5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16b43576-cf1f-44f2-a03c-89663639e948",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
