{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [Double-Hard Debias: Tailoring Word Embeddings for Gender Bias Mitigation](https://arxiv.org/abs/2005.00965)\n",
    "\n",
    "For more detailed explanations, please refer to the paper."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gensim\n",
    "import codecs\n",
    "import numpy as np\n",
    "from numpy import linalg as LA\n",
    "import scipy\n",
    "import codecs, os, json\n",
    "import operator\n",
    "import pickle\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "def normalize(wv):\n",
    "    \n",
    "    # normalize vectors\n",
    "    norms = np.apply_along_axis(LA.norm, 1, wv)\n",
    "    wv = wv / norms[:, np.newaxis]\n",
    "    return wv\n",
    "\n",
    "def load_w2v(file_path):\n",
    "    model =gensim.models.KeyedVectors.load_word2vec_format(file_path, binary=True)\n",
    "    vocab = sorted([w for w in model.vocab], key=lambda w: model.vocab[w].index)\n",
    "    w2i = {w: i for i, w in enumerate(vocab)}\n",
    "    wv = [model[w] for w in vocab]\n",
    "    wv = np.array(wv)\n",
    "    print(len(vocab), wv.shape, len(w2i))\n",
    "    \n",
    "    return wv, w2i, vocab"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3000000 (3000000, 300) 3000000\n"
     ]
    }
   ],
   "source": [
    "file_path_hd = './data/GoogleNews-vectors-negative300-hard-debiased.bin'\n",
    "\n",
    "hd_wv, hd_w2i, hd_vocab = load_w2v(file_path_hd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def load_glove_p(path):\n",
    "    \n",
    "#     debiased_embeds = pickle.load(open(path, 'rb'))\n",
    "#     wv = []\n",
    "#     vocab = []\n",
    "#     for w in debiased_embeds:\n",
    "#         print(w)\n",
    "#         print(debiased_embeds[w])\n",
    "#         wv.append(np.array(debiased_embeds[w]))\n",
    "#         vocab.append(str(w))\n",
    "        \n",
    "#     w2i = {w: i for i, w in enumerate(vocab)}\n",
    "#     wv = np.array(wv).astype(float)\n",
    "#     print(len(vocab), wv.shape, len(w2i))\n",
    "        \n",
    "#     return wv, w2i, vocab\n",
    "    \n",
    "# dhd_wv, dhd_w2i, dhd_vocab = load_glove_p(\"./data/dhd_word2vec_reproduce.p\")\n",
    "\n",
    "\n",
    "# def load_glove_p(path):\n",
    "#     # Modified from original because we did not pickle a map but rather wv_f directly\n",
    "#     wv = pickle.load(open(path, 'rb'))\n",
    "#     print(wv.shape)\n",
    "#     return wv\n",
    "# dhd_wv = load_glove_p(\"./data/dhd_word2vec_reproduce.p\")\n",
    "# dhd_w2i = w2i\n",
    "# dhd_vocab = vocab"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Analogy & Concept Categorization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from eval import evaluate_ana"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "7022\n",
      "ACCURACY TOP1-MSR: 73.61% (5169/7022)\n",
      "capital-common-countries.txt:\n",
      "ACCURACY TOP1: 24.47% (93/380)\n",
      "capital-world.txt:\n",
      "ACCURACY TOP1: 15.46% (100/647)\n",
      "currency.txt:\n",
      "ACCURACY TOP1: 12.15% (61/502)\n",
      "city-in-state.txt:\n",
      "ACCURACY TOP1: 13.58% (221/1627)\n",
      "family.txt:\n",
      "ACCURACY TOP1: 74.11% (375/506)\n",
      "gram1-adjective-to-adverb.txt:\n",
      "ACCURACY TOP1: 28.63% (284/992)\n",
      "gram2-opposite.txt:\n",
      "ACCURACY TOP1: 42.24% (343/812)\n",
      "gram3-comparative.txt:\n",
      "ACCURACY TOP1: 90.92% (1211/1332)\n",
      "gram4-superlative.txt:\n",
      "ACCURACY TOP1: 87.08% (977/1122)\n",
      "gram5-present-participle.txt:\n",
      "ACCURACY TOP1: 77.46% (818/1056)\n",
      "gram6-nationality-adjective.txt:\n",
      "ACCURACY TOP1: 21.82% (211/967)\n",
      "gram7-past-tense.txt:\n",
      "ACCURACY TOP1: 65.90% (1028/1560)\n",
      "gram8-plural.txt:\n",
      "ACCURACY TOP1: 90.77% (1209/1332)\n",
      "gram9-plural-verbs.txt:\n",
      "ACCURACY TOP1: 67.93% (591/870)\n",
      "Questions seen/total: 70.12% (13705/19544)\n",
      "Semantic accuracy: 23.21%  (850/3662)\n",
      "Syntactic accuracy: 66.43%  (6672/10043)\n",
      "Total accuracy: 54.89%  (7522/13705)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/yun.hy/Double-Hard-Debias/benchmarks/web/evaluate.py:143: FutureWarning: arrays to stack must be passed as a \"sequence\" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.\n",
      "  prot_left = np.mean(np.vstack(w.get(word, mean_vector) for word in prototypes[:, 0]), axis=0)\n",
      "/home/yun.hy/Double-Hard-Debias/benchmarks/web/evaluate.py:144: FutureWarning: arrays to stack must be passed as a \"sequence\" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.\n",
      "  prot_right = np.mean(np.vstack(w.get(word, mean_vector) for word in prototypes[:, 1]), axis=0)\n",
      "/home/yun.hy/Double-Hard-Debias/benchmarks/web/evaluate.py:147: FutureWarning: arrays to stack must be passed as a \"sequence\" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.\n",
      "  question_left, question_right = np.vstack(w.get(word, mean_vector) for word in questions[:, 0]), \\\n",
      "/home/yun.hy/Double-Hard-Debias/benchmarks/web/evaluate.py:148: FutureWarning: arrays to stack must be passed as a \"sequence\" type such as list or tuple. Support for non-sequence iterables such as generators is deprecated as of NumPy 1.16 and will raise an error in the future.\n",
      "  np.vstack(w.get(word, mean_vector) for word in questions[:, 1])\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Analogy prediction accuracy on SemEval2012 0.20753512818420655\n"
     ]
    }
   ],
   "source": [
    "evaluate_ana(hd_wv, hd_w2i, hd_vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
