{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generating Double-Hard Debias Word2Vec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load original embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import codecs, os, json, operator, pickle, gensim\n",
    "from random import shuffle\n",
    "import numpy as np\n",
    "from numpy import linalg as LA\n",
    "import scipy\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3000000 (3000000, 300) 3000000\n"
     ]
    }
   ],
   "source": [
    "def load_w2v(file_path):\n",
    "    model =gensim.models.KeyedVectors.load_word2vec_format(file_path, binary=True)\n",
    "    vocab = sorted([w for w in model.vocab], key=lambda w: model.vocab[w].index)\n",
    "    w2i = {w: i for i, w in enumerate(vocab)}\n",
    "    wv = [model[w] for w in vocab]\n",
    "    wv = np.array(wv)\n",
    "    print(len(vocab), wv.shape, len(w2i))\n",
    "    \n",
    "    return wv, w2i, vocab\n",
    "\n",
    "file_path_wv = './data/GoogleNews-vectors-negative300.bin'\n",
    "wv, w2i, vocab = load_w2v(file_path_wv)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Restrict Vocabulary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 50000/50000 [00:00<00:00, 679549.85it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "size of vocabulary: 26142\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "from tqdm import tqdm\n",
    "from utils import limit_vocab\n",
    "\n",
    "\n",
    "gender_specific = []\n",
    "with open('./data/male_word_file.txt') as f:\n",
    "    for l in f:\n",
    "        gender_specific.append(l.strip())\n",
    "with open('./data/female_word_file.txt') as f:\n",
    "    for l in f:\n",
    "        gender_specific.append(l.strip())\n",
    "\n",
    "with codecs.open('./data/gender_specific_full.json') as f:\n",
    "    gender_specific.extend(json.load(f))\n",
    "\n",
    "definitional_pairs = [['she', 'he'], ['herself', 'himself'], ['her', 'his'], ['daughter', 'son'], \n",
    "                      ['girl', 'boy'], ['mother', 'father'], ['woman', 'man'], ['mary', 'john'], \n",
    "                      ['gal', 'guy'], ['female', 'male']]\n",
    "definitional_words = []\n",
    "for pair in definitional_pairs:\n",
    "    for word in pair:\n",
    "        definitional_words.append(word)\n",
    "\n",
    "exclude_words = gender_specific\n",
    "vocab_limit, wv_limit, w2i_limit = limit_vocab(wv, w2i, vocab, exclude = exclude_words)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computer Original Bias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "he_embed = wv[w2i['he'], :]\n",
    "she_embed = wv[w2i['she'], :]\n",
    "\n",
    "def simi(a, b):\n",
    "    return 1-scipy.spatial.distance.cosine(a, b)\n",
    "\n",
    "def compute_bias_by_projection(wv, w2i, vocab):\n",
    "    d = {}\n",
    "    for w in vocab:\n",
    "        u = wv[w2i[w], :]\n",
    "        d[w] = simi(u, he_embed) - simi(u, she_embed)\n",
    "    return d\n",
    "\n",
    "gender_bias_bef = compute_bias_by_projection(wv_limit, w2i_limit, vocab_limit)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Debias"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "\n",
    "# get main PCA components\n",
    "def my_pca(wv):\n",
    "    wv_mean = np.mean(np.array(wv), axis=0)\n",
    "    wv_hat = np.zeros(wv.shape).astype(float)\n",
    "\n",
    "    for i in range(len(wv)):\n",
    "        wv_hat[i, :] = wv[i, :] - wv_mean\n",
    "\n",
    "    main_pca = PCA()\n",
    "    main_pca.fit(wv_hat)\n",
    "    \n",
    "    return main_pca\n",
    "\n",
    "main_pca = my_pca(wv)\n",
    "wv_mean = np.mean(np.array(wv), axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def hard_debias(wv, w2i, w2i_partial, vocab_partial, component_ids):\n",
    "    \n",
    "    D = []\n",
    "\n",
    "    for i in component_ids:\n",
    "        D.append(main_pca.components_[i])\n",
    "    \n",
    "    # get rid of frequency features\n",
    "    wv_f = np.zeros((len(vocab_partial), wv.shape[1])).astype(float)\n",
    "    \n",
    "    for i, w in enumerate(vocab_partial):\n",
    "        u = wv[w2i[w], :]\n",
    "        sub = np.zeros(u.shape).astype(float)\n",
    "        for d in D:\n",
    "            sub += np.dot(np.dot(np.transpose(d), u), d)\n",
    "        wv_f[w2i_partial[w], :] = wv[w2i[w], :] - sub - wv_mean\n",
    "        \n",
    "    # debias\n",
    "    gender_directions = list()\n",
    "    for gender_word_list in [definitional_pairs]:\n",
    "        gender_directions.append(doPCA(gender_word_list, wv_f, w2i_partial).components_[0])\n",
    "    \n",
    "    wv_debiased = np.zeros((len(vocab_partial), len(wv_f[0, :]))).astype(float)\n",
    "    for i, w in enumerate(vocab_partial):\n",
    "        u = wv_f[w2i_partial[w], :]\n",
    "        for gender_direction in gender_directions:\n",
    "            u = drop(u, gender_direction)\n",
    "            wv_debiased[w2i_partial[w], :] = u\n",
    "    \n",
    "    return wv_debiased"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.cluster import KMeans\n",
    "def cluster_and_visualize(words, X, random_state, y_true, num=2):\n",
    "    \n",
    "    kmeans = KMeans(n_clusters=num, random_state=random_state).fit(X)\n",
    "    y_pred = kmeans.predict(X)\n",
    "    correct = [1 if item1 == item2 else 0 for (item1,item2) in zip(y_true, y_pred) ]\n",
    "    preci = max(sum(correct)/float(len(correct)), 1 - sum(correct)/float(len(correct)))\n",
    "    print('precision', preci)\n",
    "    \n",
    "    return kmeans, y_pred, X, preci"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import extract_vectors\n",
    "from utils import train_and_predict\n",
    "from utils import doPCA, drop\n",
    "\n",
    "size=1000\n",
    "sorted_g = sorted(gender_bias_bef.items(), key=operator.itemgetter(1))\n",
    "female = [item[0] for item in sorted_g[:size]]\n",
    "male = [item[0] for item in sorted_g[-size:]]\n",
    "y_true = [1]*size + [0]*size\n",
    "\n",
    "c_vocab = list(set(male + female + [word for word in definitional_words if word in w2i]))\n",
    "c_w2i = dict()\n",
    "for idx, w in enumerate(c_vocab):\n",
    "    c_w2i[w] = idx\n",
    "    \n",
    "precisions = []\n",
    "    \n",
    "for component_id in range(20):\n",
    "    \n",
    "    print('component id: ', component_id)\n",
    "    \n",
    "    wv_debiased = hard_debias(wv, w2i, w2i_partial = c_w2i, vocab_partial = c_vocab, component_ids = [component_id])\n",
    "        \n",
    "    _, _, _, preci = cluster_and_visualize(male + female, \n",
    "                                           extract_vectors(male + female, wv_debiased, c_w2i), 1, y_true)\n",
    "    precisions.append(preci)\n",
    "\n",
    "# word2vec - 8th principal component significantly affects the debiasing performance\n",
    "# save the output to a file\n",
    "component_id=7 \n",
    "wv_debiased = hard_debias(wv, w2i, w2i_partial = w2i, vocab_partial = vocab, component_ids = [component_id])\n",
    "\n",
    "filename=\"data/dhd_word2vec_reproduce.p\"\n",
    "with open(filename, 'ab') as fp:\n",
    "    pickle.dump(wv_debiased,fp, protocol=pickle.HIGHEST_PROTOCOL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Create some mock data\n",
    "t = np.arange(1, 21)\n",
    "data1 = precisions\n",
    "\n",
    "fig, ax1 = plt.subplots(figsize=(6,2.8))\n",
    "\n",
    "color = 'red'\n",
    "ax1.set_xlabel('Project out the D-th directions', fontsize=17)\n",
    "ax1.set_ylabel('accuracy', fontsize=17)\n",
    "ax1.scatter(t, data1, color=color, label='GloVe', marker = 'x', s=60)\n",
    "plt.xticks([2,4,6,8,10, 12, 14, 16 ,18, 20], fontsize=15)\n",
    "ax1.tick_params(axis='y', labelsize=14)\n",
    "ax1.set_ylim(0.65, 0.84)\n",
    "ax1.legend(loc='lower right', frameon=True, fontsize='large')\n",
    "ax1.grid(axis='y')\n",
    "\n",
    "fig.tight_layout()  # otherwise the right y-label is slightly clipped\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
