{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [Double-Hard Debias: Tailoring Word Embeddings for Gender Bias Mitigation](https://arxiv.org/abs/2005.00965)\n",
    "\n",
    "For more detailed explanations, please refer to the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load original embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import codecs, os, json, operator, pickle, gensim\n",
    "from random import shuffle\n",
    "import numpy as np\n",
    "from numpy import linalg as LA\n",
    "import scipy\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_w2v(file_path):\n",
    "    model =gensim.models.KeyedVectors.load_word2vec_format(file_path, binary=True)\n",
    "    vocab = sorted([w for w in model.vocab], key=lambda w: model.vocab[w].index)\n",
    "    w2i = {w: i for i, w in enumerate(vocab)}\n",
    "    wv = [model[w] for w in vocab]\n",
    "    wv = np.array(wv)\n",
    "    print(len(vocab), wv.shape, len(w2i))\n",
    "    \n",
    "    return wv, w2i, vocab\n",
    "\n",
    "file_path_wv = './data/GoogleNews-vectors-negative300.bin'\n",
    "wv, w2i, vocab = load_w2v(file_path_wv)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Remove Frequency Direction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.decomposition import PCA\n",
    "\n",
    "# get main PCA components\n",
    "def my_pca(wv):\n",
    "    wv_mean = np.mean(np.array(wv), axis=0)\n",
    "    wv_hat = np.zeros(wv.shape).astype(float)\n",
    "\n",
    "    for i in range(len(wv)):\n",
    "        wv_hat[i, :] = wv[i, :] - wv_mean\n",
    "\n",
    "    main_pca = PCA()\n",
    "    main_pca.fit(wv_hat)\n",
    "    \n",
    "    return main_pca\n",
    "\n",
    "main_pca = my_pca(wv)\n",
    "wv_mean = np.mean(np.array(wv), axis=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def remove_frequency(wv, w2i, w2i_partial, vocab_partial, component_ids):\n",
    "    \n",
    "    D = []\n",
    "\n",
    "    for i in component_ids:\n",
    "        D.append(main_pca.components_[i])\n",
    "    \n",
    "    # get rid of frequency features\n",
    "    wv_f = np.zeros((len(vocab_partial), wv.shape[1])).astype(float)\n",
    "    \n",
    "    for i, w in enumerate(vocab_partial):\n",
    "        u = wv[w2i[w], :]\n",
    "        sub = np.zeros(u.shape).astype(float)\n",
    "        for d in D:\n",
    "            sub += np.dot(np.dot(np.transpose(d), u), d)\n",
    "        wv_f[w2i_partial[w], :] = wv[w2i[w], :] - sub - wv_mean\n",
    "    \n",
    "    print(wv_f.shape)\n",
    "    return wv_f"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#Remove 8th component because authors found best performance \n",
    "#when the 8th component was removed.\n",
    "component_id=7 \n",
    "wv_f = remove_frequency(wv, w2i, w2i_partial = w2i, vocab_partial = vocab, component_ids = [component_id])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(\"./data/w2v_frequency_removed.txt\", \"w\") as outputFile:\n",
    "    for i in range(len(vocab)):\n",
    "        word = vocab[i]\n",
    "        embedding = word + \" \" + \" \".join([str(feature) for feature in wv_f[w2i[word],:]])\n",
    "        outputFile.write(embedding + \"\\n\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
