{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from pathlib import Path\n",
    "from pprint import pprint\n",
    "import numpy as np\n",
    "\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams['font.family'] = 'Times New Roman'\n",
    "plt.rcParams['mathtext.fontset'] = 'stix'\n",
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rc('text.latex', preamble=r'\\usepackage{amsmath} \\usepackage{bm}')\n",
    "%config InlineBackend.figure_formats = ['svg']\n",
    "PROJECT_ROOT = Path('..').resolve()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_centered_ylim(ax, y_data, margin=10):\n",
    "    avg = sum(y_data) / len(y_data)\n",
    "    # ymin = avg - avg\n",
    "    # ymax = avg + avg\n",
    "    ax.set_ylim(0, max(y_data))\n",
    "    return avg\n",
    "label_fontsize = 34\n",
    "tick_fontsize = 30\n",
    "cbar_fontsize = 30\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "model_name = 'average_word_embeddings_glove.840B.300d' \n",
    "# model_name = 'models/GoogleNews-vectors-negative300'\n",
    "topk = 500\n",
    "result_path = PROJECT_ROOT / f'results/{model_name}/norm_experiments/results.top{topk}.json'\n",
    "with open(result_path, 'r') as f:\n",
    "    result = json.load(f)\n",
    "\n",
    "prob = np.array(result['prob'])\n",
    "assert np.allclose(prob.sum() ,1)\n",
    "logp = -np.log(prob)\n",
    "\n",
    "norm = np.array(result['norm'])\n",
    "uniform_centered_norm = np.array(result['uniform_centered_norm'])\n",
    "uniform_whitened_norm = np.array(result['uniform_whitened_norm'])\n",
    "zipfian_centered_norm = np.array(result['zipfian_centered_norm'])\n",
    "zipfian_whitened_norm = np.array(result['zipfian_whitened_norm'])\n",
    "\n",
    "data = {\n",
    "    'Uniform Whitened': uniform_whitened_norm,\n",
    "    'Uniform Centered': uniform_centered_norm,\n",
    "    'Pre-trained GloVe': norm,\n",
    "    'Zipfian Centered': zipfian_centered_norm,\n",
    "    'Zipfian Whitened': zipfian_whitened_norm,\n",
    "}\n",
    "\n",
    "# Create subplots\n",
    "fig, axes = plt.subplots(1, 5, figsize=(25, 5), sharex=True)\n",
    "alpha = 0.5\n",
    "for i, (name, y_data) in enumerate(data.items()):\n",
    "    x_data = logp\n",
    "    axes[i].scatter(x_data, y_data,alpha=0.8,edgecolor='blue', facecolor='none')\n",
    "    axes[i].set_title(name, fontsize=label_fontsize)\n",
    "    axes[i].tick_params(axis='both', which='major', labelsize=tick_fontsize)\n",
    "    _ = set_centered_ylim(axes[i], y_data)\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "# set title on top of the figure\n",
    "print(f'Norm vs -log(p) for top {topk} words in {model_name}')\n",
    "# Show plot\n",
    "fig.text(0.5, -0.05, r'$-\\log{p(w)}$', ha='center', va='center',fontsize=label_fontsize+5)\n",
    "fig.text(-0.005, 0.5, r\"${\\lVert \\bm{w} \\rVert}_2$\", ha='center', va='center', rotation='vertical',fontsize=label_fontsize+5)\n",
    "plt.savefig('figs/norm_glove.pdf',bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def set_centered_ylim(ax, y_data, margin=10):\n",
    "    avg = sum(y_data) / len(y_data)\n",
    "    # ymin = avg - avg\n",
    "    # ymax = avg + avg\n",
    "    ax.set_ylim(0, max(y_data))\n",
    "    return avg\n",
    "label_fontsize = 34\n",
    "tick_fontsize = 30\n",
    "cbar_fontsize = 30\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "model_name = 'GoogleNews-vectors-negative300'\n",
    "topk = 500\n",
    "result_path = PROJECT_ROOT / f'results/{model_name}/norm_experiments/results.top{topk}.json'\n",
    "with open(result_path, 'r') as f:\n",
    "    result = json.load(f)\n",
    "\n",
    "prob = np.array(result['prob'])\n",
    "assert np.allclose(prob.sum() ,1)\n",
    "logp = -np.log(prob)\n",
    "\n",
    "norm = np.array(result['norm'])\n",
    "uniform_centered_norm = np.array(result['uniform_centered_norm'])\n",
    "uniform_whitened_norm = np.array(result['uniform_whitened_norm'])\n",
    "zipfian_centered_norm = np.array(result['zipfian_centered_norm'])\n",
    "zipfian_whitened_norm = np.array(result['zipfian_whitened_norm'])\n",
    "\n",
    "\n",
    "data = {\n",
    "    'Uniform Whitened': uniform_whitened_norm,\n",
    "    'Uniform Centered': uniform_centered_norm,\n",
    "    'Pre-trained Word2Vec': norm,\n",
    "    'Zipfian Centered': zipfian_centered_norm,\n",
    "    'Zipfian Whitened': zipfian_whitened_norm,\n",
    "}\n",
    "\n",
    "# Create subplots\n",
    "fig, axes = plt.subplots(1, 5, figsize=(25, 5), sharex=True)\n",
    "alpha = 0.5\n",
    "for i, (name, y_data) in enumerate(data.items()):\n",
    "    x_data = logp\n",
    "    axes[i].scatter(x_data, y_data,alpha=0.8,edgecolor='blue', facecolor='none')\n",
    "    axes[i].set_title(name, fontsize=label_fontsize)\n",
    "    axes[i].tick_params(axis='both', which='major', labelsize=tick_fontsize)\n",
    "    _ = set_centered_ylim(axes[i], y_data)\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "# set title on top of the figure\n",
    "print(f'Norm vs -log(p) for top {topk} words in {model_name}')\n",
    "# Show plot\n",
    "fig.text(0.5, -0.05, r'$-\\log{p(w)}$', ha='center', va='center',fontsize=label_fontsize+5)\n",
    "fig.text(-0.005, 0.5, r\"$\\lVert \\bm{w} \\rVert$_2\", ha='center', va='center', rotation='vertical',fontsize=label_fontsize+5)\n",
    "plt.savefig('figs/norm_word2vec.pdf',bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
