{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "49b4f88e-f7c3-4a84-aab9-3420d682263c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import scipy, requests, codecs, os, re, nltk, itertools, csv\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "import tensorflow as tf\n",
    "from scipy.stats import spearmanr\n",
    "import functools as ft\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.preprocessing import normalize\n",
    "import gdown"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "da273c03-392b-473d-8e0a-23d52cddfbe5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import gensim\n",
    "import codecs\n",
    "import json\n",
    "from gensim.models.keyedvectors import Word2VecKeyedVectors\n",
    "from gensim.models import KeyedVectors\n",
    "import numpy as np\n",
    "import random\n",
    "import sklearn\n",
    "from sklearn import model_selection\n",
    "from sklearn import cluster\n",
    "from sklearn import metrics\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.svm import LinearSVC, SVC\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "from sklearn.feature_extraction import DictVectorizer\n",
    "import numpy as np\n",
    "from docopt import docopt\n",
    "import torch\n",
    "from transformers import *\n",
    "import pickle\n",
    "from tqdm import tqdm\n",
    "\n",
    "import scipy\n",
    "from scipy import linalg\n",
    "from scipy import sparse\n",
    "from scipy.stats.stats import pearsonr\n",
    "import tqdm\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import SGDClassifier, SGDRegressor, Perceptron, LogisticRegression\n",
    "\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "import pickle\n",
    "from collections import defaultdict, Counter\n",
    "from typing import List, Dict\n",
    "\n",
    "import torch\n",
    "from torch import utils\n",
    "\n",
    "\n",
    "import pandas as pd\n",
    "\n",
    "import time\n",
    "from mlxtend.math import vectorspace_orthonormalization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "42711c0a-9ed5-4243-b7ca-cdb77ecc2a69",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from keras.models import Sequential\n",
    "from keras.layers import Dense\n",
    "from keras.wrappers.scikit_learn import KerasClassifier\n",
    "from keras.utils import np_utils\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.model_selection import KFold\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.pipeline import Pipeline\n",
    "import keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "128cd7ff-0e14-45c5-8c68-14473c70365d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import Tensor\n",
    "from transformers import BertTokenizer, BertModel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "086245d4-ce67-4f11-8fe6-927e4413171a",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loading configuration file config.json from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/config.json\n",
      "Model config BertConfig {\n",
      "  \"_name_or_path\": \"bert-base-uncased\",\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.22.2\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "loading file vocab.txt from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/vocab.txt\n",
      "loading file tokenizer.json from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/tokenizer.json\n",
      "loading file added_tokens.json from cache at None\n",
      "loading file special_tokens_map.json from cache at None\n",
      "loading file tokenizer_config.json from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/tokenizer_config.json\n",
      "loading configuration file config.json from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/config.json\n",
      "Model config BertConfig {\n",
      "  \"_name_or_path\": \"bert-base-uncased\",\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.22.2\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "loading configuration file config.json from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/config.json\n",
      "Model config BertConfig {\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"output_hidden_states\": true,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.22.2\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "loading weights file pytorch_model.bin from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/pytorch_model.bin\n",
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "All the weights of BertModel were initialized from the model checkpoint at bert-base-uncased.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use BertModel for predictions without further training.\n"
     ]
    }
   ],
   "source": [
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "model = BertModel.from_pretrained('bert-base-uncased',\n",
    "                                  output_hidden_states = True, # Whether the model returns all hidden-states.\n",
    "                                  ).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cade316-3db6-45fe-a1a6-93b87e05dae4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7238ea9e-0ab7-4bba-8a04-6feb94994cf5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_text(model, data):\n",
    "    \"\"\"\n",
    "    encode the text\n",
    "    :param model: encoding model\n",
    "    :param data: data\n",
    "    :return: two numpy matrices of the data:\n",
    "                first: average of all tokens in each sentence\n",
    "                second: cls token of each sentence\n",
    "    \"\"\"\n",
    "    all_data_avg = []\n",
    "    batch = []\n",
    "    count = 0\n",
    "    for w in data:\n",
    "        tokens = tokenizer.encode(w, add_special_tokens=False)\n",
    "        batch.append(tokens)\n",
    "        input_ids = torch.tensor(batch).to(device)\n",
    "        with torch.no_grad():\n",
    "            last_hidden_states = model(input_ids)[0]\n",
    "            all_data_avg.append((last_hidden_states.squeeze(0).mean(dim=0).cpu()).numpy())\n",
    "        batch = []\n",
    "        count = count + 1\n",
    "        if count%1000 == 0:\n",
    "            print(count)\n",
    "    return np.array(all_data_avg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a5aa6150-048a-492e-ba24-37b586488d6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = pd.read_csv('Train_emb.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "4e7cfc3f-7f3d-4280-b6ff-76a9f7f92766",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = X_train.values[:,1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "164d676f-aa4f-4646-ac8e-266634f3d991",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(322636, 768)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "21117378-30f8-45a5-90ec-763057e69764",
   "metadata": {},
   "outputs": [],
   "source": [
    "def directions(w1,w2,model):\n",
    "    \n",
    "    tokens1 = tokenizer.encode(w1, add_special_tokens=False)\n",
    "    input_ids1 = torch.tensor([tokens1]).to(device)\n",
    "    with torch.no_grad():\n",
    "        last_hidden_states1 = model(input_ids1)[0]\n",
    "        embs1 = (last_hidden_states1.squeeze(0).mean(dim=0).cpu()).numpy()\n",
    "    \n",
    "    tokens2 = tokenizer.encode(w2, add_special_tokens=False)\n",
    "    input_ids2 = torch.tensor([tokens2]).to(device)\n",
    "    with torch.no_grad():\n",
    "        last_hidden_states2 = model(input_ids2)[0]\n",
    "        embs2 = (last_hidden_states2.squeeze(0).mean(dim=0).cpu()).numpy()\n",
    "        \n",
    "    return embs1 - embs2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "9366f127-da32-452d-9489-53f3f8a4fe2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def emb(w1,model):\n",
    "    \n",
    "    tokens1 = tokenizer.encode(w1, add_special_tokens=False)\n",
    "    input_ids1 = torch.tensor([tokens1]).to(device)\n",
    "    with torch.no_grad():\n",
    "        last_hidden_states1 = model(input_ids1)[0]\n",
    "        embs1 = (last_hidden_states1.squeeze(0).mean(dim=0).cpu()).numpy()\n",
    "\n",
    "    return embs1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "cf8241ad-9e6b-4eff-916d-fe79cf7e0158",
   "metadata": {},
   "outputs": [],
   "source": [
    "d1 = emb('Christianity',model)\n",
    "d2 = emb('Jewish',model)\n",
    "d3 = emb('Islam',model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "adb58056-87fd-4bc9-ba6a-29d084dc4607",
   "metadata": {},
   "outputs": [],
   "source": [
    "similarity = []\n",
    "labels = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "493438cc-b5c1-4741-a264-24914aba2e98",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.spatial.distance import cosine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "2d9f9811-60a1-40f4-ac58-0425ab027c03",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "10000\n",
      "20000\n",
      "30000\n",
      "40000\n",
      "50000\n",
      "60000\n",
      "70000\n",
      "80000\n",
      "90000\n",
      "100000\n",
      "110000\n",
      "120000\n",
      "130000\n",
      "140000\n",
      "150000\n",
      "160000\n",
      "170000\n",
      "180000\n",
      "190000\n",
      "200000\n",
      "210000\n",
      "220000\n",
      "230000\n",
      "240000\n",
      "250000\n",
      "260000\n",
      "270000\n",
      "280000\n",
      "290000\n",
      "300000\n",
      "310000\n",
      "320000\n"
     ]
    }
   ],
   "source": [
    "for i in range(X_train.shape[0]):\n",
    "    sim = [1-cosine(X_train[i,:], d1), 1-cosine(X_train[i,:], d2), 1-cosine(X_train[i,:], d3)]\n",
    "    similarity.append(sim)\n",
    "    labels.append(np.where(sim==np.max(sim))[0][0])\n",
    "    if i%10000 == 0:\n",
    "        print(i)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "70fe8ad0-333d-426e-b91b-3f9e4436a5c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df1 = pd.DataFrame(columns = ['Index','Christianity','Jewish','Islam','label'])\n",
    "df1['Index'] = range(X_train.shape[0])\n",
    "df1['Christianity'] = np.array(similarity)[:,0]\n",
    "df1['Jewish'] = np.array(similarity)[:,1]\n",
    "df1['Islam'] = np.array(similarity)[:,2]\n",
    "df1['label'] = labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "6764da25-5ff2-4555-97af-b1562441d11c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Index</th>\n",
       "      <th>Christianity</th>\n",
       "      <th>Jewish</th>\n",
       "      <th>Islam</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0.149957</td>\n",
       "      <td>0.230184</td>\n",
       "      <td>0.175605</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>0.446760</td>\n",
       "      <td>0.328652</td>\n",
       "      <td>0.472716</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>0.357326</td>\n",
       "      <td>0.260162</td>\n",
       "      <td>0.359524</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>0.208376</td>\n",
       "      <td>0.307968</td>\n",
       "      <td>0.247927</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>0.222522</td>\n",
       "      <td>0.246137</td>\n",
       "      <td>0.233121</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>322631</th>\n",
       "      <td>322631</td>\n",
       "      <td>0.242472</td>\n",
       "      <td>0.166228</td>\n",
       "      <td>0.213486</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>322632</th>\n",
       "      <td>322632</td>\n",
       "      <td>0.204167</td>\n",
       "      <td>0.159160</td>\n",
       "      <td>0.196199</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>322633</th>\n",
       "      <td>322633</td>\n",
       "      <td>0.092093</td>\n",
       "      <td>0.167098</td>\n",
       "      <td>0.139208</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>322634</th>\n",
       "      <td>322634</td>\n",
       "      <td>0.487145</td>\n",
       "      <td>0.397636</td>\n",
       "      <td>0.446845</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>322635</th>\n",
       "      <td>322635</td>\n",
       "      <td>0.374428</td>\n",
       "      <td>0.190989</td>\n",
       "      <td>0.295327</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>322636 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "         Index  Christianity    Jewish     Islam  label\n",
       "0            0      0.149957  0.230184  0.175605      1\n",
       "1            1      0.446760  0.328652  0.472716      2\n",
       "2            2      0.357326  0.260162  0.359524      2\n",
       "3            3      0.208376  0.307968  0.247927      1\n",
       "4            4      0.222522  0.246137  0.233121      1\n",
       "...        ...           ...       ...       ...    ...\n",
       "322631  322631      0.242472  0.166228  0.213486      0\n",
       "322632  322632      0.204167  0.159160  0.196199      0\n",
       "322633  322633      0.092093  0.167098  0.139208      1\n",
       "322634  322634      0.487145  0.397636  0.446845      0\n",
       "322635  322635      0.374428  0.190989  0.295327      0\n",
       "\n",
       "[322636 rows x 5 columns]"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "fa9724f4-d317-4260-aef1-259270a789b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_ = []\n",
    "Y_ = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "772b287b-ebf9-44a9-b859-a2606b7a4a5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "df1_0 = df1.loc[df1['label']==0,:]\n",
    "df1_1 = df1.loc[df1['label']==1,:]\n",
    "df1_2 = df1.loc[df1['label']==2,:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "d8cbf832-c2d6-4204-a1b9-2e68a0608af2",
   "metadata": {},
   "outputs": [],
   "source": [
    "df1_0 = df1_0.sort_values(by=\"Christianity\", ascending=False)\n",
    "df1_1 = df1_1.sort_values(by=\"Jewish\", ascending=False)\n",
    "df1_2 = df1_2.sort_values(by=\"Islam\", ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "54a5e672-f2b8-475e-939f-526705fde20a",
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(30000):\n",
    "    X_.append(X_train[df1_0.index[i],:])\n",
    "    Y_.append(0)\n",
    "    X_.append(X_train[df1_1.index[i],:])\n",
    "    Y_.append(1)\n",
    "    X_.append(X_train[df1_2.index[i],:])\n",
    "    Y_.append(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "5150651f-00ec-4572-abf2-93fd2396886e",
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = np.array(X_)\n",
    "Y_train = np.array(Y_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "ffb7633b-2798-4b8f-9a5a-7ffa44a0e47a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(90000, 768)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "X_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "8982c3ec-9d2b-49ad-9aae-724f5a2be09d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Using backend ThreadingBackend with 90 concurrent workers.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1, change: 1.00000000\n",
      "Epoch 2, change: 0.16825266\n",
      "Epoch 3, change: 0.09611755\n",
      "Epoch 4, change: 0.06424984\n",
      "Epoch 5, change: 0.06088954\n",
      "Epoch 6, change: 0.04265728\n",
      "max_iter reached after 12 seconds\n",
      "time: 12.914657354354858\n",
      "Epoch 7, change: 0.03678076\n",
      "0.9954333333333333\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[Parallel(n_jobs=90)]: Done   1 out of   1 | elapsed:   12.8s finished\n"
     ]
    }
   ],
   "source": [
    "clf = LogisticRegression(warm_start = True, penalty = 'l2',\n",
    "                         solver = \"saga\", multi_class = 'multinomial', fit_intercept = False,\n",
    "                         verbose = 5, n_jobs = 90, random_state = 1, max_iter = 7)\n",
    "\n",
    "        \n",
    "start = time.time()\n",
    "clf.fit(X_train, Y_train)\n",
    "print(\"time: {}\".format(time.time() - start))\n",
    "print(clf.score(X_train, Y_train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "844216ff-b753-4647-bdb8-c73b58cc2bb6",
   "metadata": {},
   "outputs": [],
   "source": [
    "Y = clf.predict_proba(X_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80136146-6300-40bb-b3da-7fed32364066",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "de0d9c03-a715-49cc-a6b2-078317dbf4b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def slice_freq(Y, interval):\n",
    "    y_f = pd.Series(Y)\n",
    "    a = pd.cut(y_f,interval)\n",
    "    b=a.value_counts()\n",
    "    frequency = b.sort_index().values\n",
    "    return frequency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "f75ab3e8-5a22-4bd4-9d52-39e61f93dc63",
   "metadata": {},
   "outputs": [],
   "source": [
    "def SDR_multi(X_train, Y, K, h):\n",
    "    X = X_train\n",
    "    n = X.shape[0]\n",
    "    d_x = X.shape[1]\n",
    "    d_y = Y.shape[1]\n",
    "\n",
    "    interval = np.linspace(0, 1, h+1,endpoint=True)\n",
    "    Gamma_PMS = np.zeros([d_x,d_x])\n",
    "    for i in range(d_y):  \n",
    "        freq = slice_freq(Y[:,i], interval)\n",
    "        Gamma = np.zeros([d_x,d_x])\n",
    "        for j in range(h):\n",
    "            ph = freq[j]/n\n",
    "            index = np.where((Y[:,i]>=interval[j])&(Y[:,i]<=interval[j+1]))\n",
    "            mh = (1/(n*ph))*np.sum(X[index,:],0)\n",
    "            Gamma = Gamma + ph*np.dot(mh.T,mh)\n",
    "        la_g, v_g = np.linalg.eig(Gamma)\n",
    "        la_g = la_g.real\n",
    "        Gamma_PMS = Gamma_PMS + la_g[0]*Gamma\n",
    "    la_PMS, v_PMS = np.linalg.eig(Gamma_PMS)\n",
    "    v_PMS = v_PMS.real\n",
    "    beta1 = vectorspace_orthonormalization(v_PMS[:,:K])\n",
    "    #beta1 = v_PMS[:,:K]\n",
    "    return np.dot(beta1,beta1.T), np.eye(d_x)-np.dot(beta1,beta1.T) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "b1a1ba8d-b4f8-4052-be46-4397886ca57c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def iteration_estimation_v2(X_train, Y, Iter, q, h):\n",
    "    N_sample = len(X_train)\n",
    "    d = X_train.shape[1]\n",
    "    beta1 = np.eye(d)\n",
    "    \n",
    "    X_batch = X_train\n",
    "    n = X_batch.shape[0]\n",
    "    X = X_batch-np.mean(X_batch,0).reshape(1,d)\n",
    "    Cx = np.dot(X_batch.T,X_batch)/n\n",
    "    la, v = np.linalg.eig(Cx)\n",
    "    la = la.real\n",
    "    v = v.real\n",
    "    Cx12 = np.dot(np.dot(v,np.diag(la**(0.5))),v.T)\n",
    "    X_train = np.dot(X_batch,np.linalg.pinv(Cx12,rcond=1e-8))\n",
    "    beta1, beta2 = SDR_multi(X_train, Y, q, h)\n",
    "    print(np.diag(beta1)[:5])\n",
    "    if Iter == 0:\n",
    "        return beta1, beta2\n",
    "    else:\n",
    "        for j in range(Iter):\n",
    "            idx = np.random.rand(X_train.shape[0]) < 1\n",
    "            X_train = np.dot(X_train,beta1)\n",
    "            beta1, beta2 = SDR_multi(X_train[idx], Y[idx], q, h)\n",
    "            print(np.diag(beta1)[:5])\n",
    "    return beta1, beta2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "3a32db18-1096-4694-a161-404eb1eaa318",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0.79002667 0.76965737 0.77869602 0.77822338 0.78575766]\n",
      "[0.79002667 0.76965737 0.77869602 0.77822338 0.78575766]\n",
      "[0.79002667 0.76965737 0.77869602 0.77822338 0.78575766]\n"
     ]
    }
   ],
   "source": [
    "beta1, beta2 = iteration_estimation_v2(X_train, Y, 2, 600, 100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5be2d8e4-4eb6-49b0-aeff-a85939940873",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "8863083e-a7b2-4139-ad33-ddc4745effc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "c1ca1532-5cdd-4325-8cca-248c3a64a45a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "f4aaa8b3-9d9a-46dc-a374-9386eda6c367",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "import random\n",
    "import re\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import weat\n",
    "\n",
    "\n",
    "class SEATRunner:\n",
    "    \"\"\"Runs SEAT tests for a given HuggingFace transformers model.\n",
    "    Implementation taken from: https://github.com/W4ngatang/sent-bias.\n",
    "    \"\"\"\n",
    "\n",
    "    # Extension for files containing SEAT tests.\n",
    "    TEST_EXT = \".jsonl\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        model,\n",
    "        tokenizer,\n",
    "        tests,\n",
    "        data_dir,\n",
    "        experiment_id,\n",
    "        n_samples=100000,\n",
    "        parametric=False,\n",
    "        seed=0,\n",
    "    ):\n",
    "        \"\"\"Initializes a SEAT test runner.\n",
    "        Args:\n",
    "            model: HuggingFace model (e.g., BertModel) to evaluate.\n",
    "            tokenizer: HuggingFace tokenizer (e.g., BertTokenizer) used for pre-processing.\n",
    "            tests (`str`): Comma separated list of SEAT tests to run. SEAT test files should\n",
    "                be in `data_dir` and have corresponding names with extension \".jsonl\".\n",
    "            data_dir (`str`): Path to directory containing the SEAT tests.\n",
    "            experiment_id (`str`): Experiment identifier. Used for logging.\n",
    "            n_samples (`int`): Number of permutation test samples used when estimating p-values\n",
    "                (exact test is used if there are fewer than this many permutations).\n",
    "            parametric (`bool`): Use parametric test (normal assumption) to compute p-values.\n",
    "            seed (`int`): Random seed.\n",
    "        \"\"\"\n",
    "        self._model = model\n",
    "        self._tokenizer = tokenizer\n",
    "        self._tests = tests\n",
    "        self._data_dir = data_dir\n",
    "        self._experiment_id = experiment_id\n",
    "        self._n_samples = n_samples\n",
    "        self._parametric = parametric\n",
    "        self._seed = seed\n",
    "\n",
    "    def __call__(self):\n",
    "        \"\"\"Runs specified SEAT tests.\n",
    "        Returns:\n",
    "            `list` of `dict`s containing the SEAT test results.\n",
    "        \"\"\"\n",
    "        random.seed(self._seed)\n",
    "        np.random.seed(self._seed)\n",
    "\n",
    "        all_tests = sorted(\n",
    "            [\n",
    "                entry[: -len(self.TEST_EXT)]\n",
    "                for entry in os.listdir(self._data_dir)\n",
    "                if not entry.startswith(\".\") and entry.endswith(self.TEST_EXT)\n",
    "            ],\n",
    "            key=_test_sort_key,\n",
    "        )\n",
    "\n",
    "        # Use the specified tests, otherwise, run all SEAT tests.\n",
    "        tests = self._tests or all_tests\n",
    "\n",
    "        results = []\n",
    "        for test in tests:\n",
    "            print(f\"Running test {test}\")\n",
    "\n",
    "            # Load the test data.\n",
    "            encs = _load_json(os.path.join(self._data_dir, f\"{test}{self.TEST_EXT}\"))\n",
    "\n",
    "            print(\"Computing sentence encodings\")\n",
    "            encs_targ1 = _encode(\n",
    "                self._model, self._tokenizer, encs[\"targ1\"][\"examples\"]\n",
    "            )\n",
    "            encs_targ2 = _encode(\n",
    "                self._model, self._tokenizer, encs[\"targ2\"][\"examples\"]\n",
    "            )\n",
    "            encs_attr1 = _encode(\n",
    "                self._model, self._tokenizer, encs[\"attr1\"][\"examples\"]\n",
    "            )\n",
    "            encs_attr2 = _encode(\n",
    "                self._model, self._tokenizer, encs[\"attr2\"][\"examples\"]\n",
    "            )\n",
    "\n",
    "            encs[\"targ1\"][\"encs\"] = encs_targ1\n",
    "            encs[\"targ2\"][\"encs\"] = encs_targ2\n",
    "            encs[\"attr1\"][\"encs\"] = encs_attr1\n",
    "            encs[\"attr2\"][\"encs\"] = encs_attr2\n",
    "\n",
    "            print(\"\\tDone!\")\n",
    "\n",
    "            # Run the test on the encodings.\n",
    "            esize, pval = weat.run_test(\n",
    "                encs, n_samples=self._n_samples, parametric=self._parametric\n",
    "            )\n",
    "\n",
    "            results.append(\n",
    "                {\n",
    "                    \"experiment_id\": self._experiment_id,\n",
    "                    \"test\": test,\n",
    "                    \"p_value\": pval,\n",
    "                    \"effect_size\": esize,\n",
    "                }\n",
    "            )\n",
    "\n",
    "        return results\n",
    "\n",
    "\n",
    "def _test_sort_key(test):\n",
    "    \"\"\"Return tuple to be used as a sort key for the specified test name.\n",
    "    Break test name into pieces consisting of the integers in the name\n",
    "    and the strings in between them.\n",
    "    \"\"\"\n",
    "    key = ()\n",
    "    prev_end = 0\n",
    "    for match in re.finditer(r\"\\d+\", test):\n",
    "        key = key + (test[prev_end : match.start()], int(match.group(0)))\n",
    "        prev_end = match.end()\n",
    "    key = key + (test[prev_end:],)\n",
    "\n",
    "    return key\n",
    "\n",
    "\n",
    "def _split_comma_and_check(arg_str, allowed_set, item_type):\n",
    "    \"\"\"Given a comma-separated string of items, split on commas and check if\n",
    "    all items are in allowed_set -- item_type is just for the assert message.\n",
    "    \"\"\"\n",
    "    items = arg_str.split(\",\")\n",
    "    for item in items:\n",
    "        if item not in allowed_set:\n",
    "            raise ValueError(f\"Unknown {item_type}: {item}!\")\n",
    "    return items\n",
    "\n",
    "\n",
    "def _load_json(sent_file):\n",
    "    \"\"\"Load from json. We expect a certain format later, so do some post processing.\"\"\"\n",
    "    print(f\"Loading {sent_file}...\")\n",
    "    all_data = json.load(open(sent_file, \"r\"))\n",
    "    data = {}\n",
    "    for k, v in all_data.items():\n",
    "        examples = v[\"examples\"]\n",
    "        data[k] = examples\n",
    "        v[\"examples\"] = examples\n",
    "\n",
    "    return all_data\n",
    "\n",
    "\n",
    "def _encode(model, tokenizer, texts):\n",
    "    encs = {}\n",
    "    for text in texts:\n",
    "        # Encode each example.\n",
    "        inputs = tokenizer(text, return_tensors=\"pt\")\n",
    "        outputs = model(**inputs)\n",
    "\n",
    "        # Average over the last layer of hidden representations.\n",
    "        enc = outputs[\"last_hidden_state\"]\n",
    "        enc = enc.mean(dim=1)\n",
    "        \n",
    "        \n",
    "\n",
    "        # Following May et al., normalize the representation.\n",
    "        encs[text] = enc.detach().view(-1).numpy()\n",
    "        encs[text] = proj.dot(encs[text])\n",
    "        encs[text] /= np.linalg.norm(encs[text])\n",
    "        #encs[text] = proj.dot(encs[text]) \n",
    "        \n",
    "\n",
    "    return encs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "d4274117-1907-49a9-9beb-10541b84c558",
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "paths = glob.glob('./Other_seat_test/Religon/*.jsonl')\n",
    "paths.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "cc7cc2d1-5e96-48ac-9c9f-3a0a0ba376db",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = 'cpu'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "d09d22e9-f5ae-4d3c-ae6e-a9e014847b8a",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "loading configuration file config.json from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/config.json\n",
      "Model config BertConfig {\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.1,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"output_hidden_states\": true,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.22.2\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30522\n",
      "}\n",
      "\n",
      "loading weights file pytorch_model.bin from cache at /home/leiding/.cache/huggingface/hub/models--bert-base-uncased/snapshots/0a6aa9128b6194f4f3c4db429b6cb4891cdb421b/pytorch_model.bin\n",
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "All the weights of BertModel were initialized from the model checkpoint at bert-base-uncased.\n",
      "If your task is similar to the task the model of the checkpoint was trained on, you can already use BertModel for predictions without further training.\n"
     ]
    }
   ],
   "source": [
    "model = BertModel.from_pretrained('bert-base-uncased',\n",
    "                                  output_hidden_states = True, # Whether the model returns all hidden-states.\n",
    "                                  ).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "b80f266a-8e7a-4af8-9834-1a129b9ef411",
   "metadata": {},
   "outputs": [],
   "source": [
    "proj = beta2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "c6dc4be1-838a-4666-a279-d2bc44b9ddd8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running test ./Other_seat_test/Religon/sent-religion1.jsonl\n",
      "Loading ./Other_seat_test/Religon/sent-religion1.jsonl...\n",
      "Computing sentence encodings\n",
      "\tDone!\n",
      "Computing cosine similarities...\n",
      "Null hypothesis: no difference between ChristianityTerms and IslamTerms in association to attributes Good and Bad\n",
      "Computing pval...\n",
      "Using non-parametric test\n",
      "Drawing 99999 samples (and biasing by 1)\n",
      "pval: 0.005\n",
      "computing effect size...\n",
      "esize: 0.392\n",
      "Running test ./Other_seat_test/Religon/sent-religion1b.jsonl\n",
      "Loading ./Other_seat_test/Religon/sent-religion1b.jsonl...\n",
      "Computing sentence encodings\n",
      "\tDone!\n",
      "Computing cosine similarities...\n",
      "Null hypothesis: no difference between ChristianityTerms and JewishTerms in association to attributes Good and Bad\n",
      "Computing pval...\n",
      "Using non-parametric test\n",
      "Drawing 99999 samples (and biasing by 1)\n",
      "pval: 0.665\n",
      "computing effect size...\n",
      "esize: -0.067\n",
      "Running test ./Other_seat_test/Religon/sent-religion2.jsonl\n",
      "Loading ./Other_seat_test/Religon/sent-religion2.jsonl...\n",
      "Computing sentence encodings\n",
      "\tDone!\n",
      "Computing cosine similarities...\n",
      "Null hypothesis: no difference between ChristianityTerms and IslamTerms in association to attributes Pleasant and Unpleasant\n",
      "Computing pval...\n",
      "Using non-parametric test\n",
      "Drawing 99999 samples (and biasing by 1)\n",
      "pval: 0.001\n",
      "computing effect size...\n",
      "esize: 0.493\n",
      "Running test ./Other_seat_test/Religon/sent-religion2b.jsonl\n",
      "Loading ./Other_seat_test/Religon/sent-religion2b.jsonl...\n",
      "Computing sentence encodings\n",
      "\tDone!\n",
      "Computing cosine similarities...\n",
      "Null hypothesis: no difference between ChristianityTerms and JewishTerms in association to attributes Pleasant and Unpleasant\n",
      "Computing pval...\n",
      "Using non-parametric test\n",
      "Drawing 99999 samples (and biasing by 1)\n",
      "pval: 0.275\n",
      "computing effect size...\n",
      "esize: 0.093\n"
     ]
    }
   ],
   "source": [
    "results = []\n",
    "for test in paths:\n",
    "    print(f\"Running test {test}\")\n",
    "\n",
    "    # Load the test data.\n",
    "    encs = _load_json(test)\n",
    "\n",
    "    print(\"Computing sentence encodings\")\n",
    "    encs_targ1 = _encode(\n",
    "        model, tokenizer, encs[\"targ1\"][\"examples\"]\n",
    "    )\n",
    "    encs_targ2 = _encode(\n",
    "        model, tokenizer, encs[\"targ2\"][\"examples\"]\n",
    "    )\n",
    "    encs_attr1 = _encode(\n",
    "        model, tokenizer, encs[\"attr1\"][\"examples\"]\n",
    "    )\n",
    "    encs_attr2 = _encode(\n",
    "        model, tokenizer, encs[\"attr2\"][\"examples\"]\n",
    "    )\n",
    "\n",
    "    encs[\"targ1\"][\"encs\"] = encs_targ1\n",
    "    encs[\"targ2\"][\"encs\"] = encs_targ2\n",
    "    encs[\"attr1\"][\"encs\"] = encs_attr1\n",
    "    encs[\"attr2\"][\"encs\"] = encs_attr2\n",
    "\n",
    "    print(\"\\tDone!\")\n",
    "\n",
    "    # Run the test on the encodings.\n",
    "    esize, pval = weat.run_test(\n",
    "        encs, n_samples=100000, parametric=False\n",
    "    )\n",
    "\n",
    "    results.append(\n",
    "        {\n",
    "            \"test\": test,\n",
    "            \"p_value\": pval,\n",
    "            \"effect_size\": esize,\n",
    "        }\n",
    "    )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "8d0497aa-c7ff-40bc-8146-216481913d5b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'test': './Other_seat_test/Religon/sent-religion1.jsonl',\n",
       "  'p_value': 0.00514,\n",
       "  'effect_size': 0.39227893614806275},\n",
       " {'test': './Other_seat_test/Religon/sent-religion1b.jsonl',\n",
       "  'p_value': 0.66489,\n",
       "  'effect_size': -0.0666094687991813},\n",
       " {'test': './Other_seat_test/Religon/sent-religion2.jsonl',\n",
       "  'p_value': 0.00057,\n",
       "  'effect_size': 0.492657303196326},\n",
       " {'test': './Other_seat_test/Religon/sent-religion2b.jsonl',\n",
       "  'p_value': 0.27538,\n",
       "  'effect_size': 0.09265610167169516}]"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "190ef36a-7ed3-42a5-b3bc-58e67a609d63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.2610504524538163\n"
     ]
    }
   ],
   "source": [
    "effect_size = []\n",
    "ss = 0\n",
    "for value in results:\n",
    "    v = value['effect_size']\n",
    "    effect_size.append(v)\n",
    "    ss += abs(v)\n",
    "print(ss/len(effect_size))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "95222d22-8678-4b6e-9343-d4b0af8f2f3e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.39227893614806275,\n",
       " -0.0666094687991813,\n",
       " 0.492657303196326,\n",
       " 0.09265610167169516]"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "effect_size"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "371027ed-c2d2-464c-9e15-6465b12db750",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4ccde25c-4f17-43b7-b4f7-4237176456c5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16b43576-cf1f-44f2-a03c-89663639e948",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
