{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c7f7a3a-b77a-4aba-9ad2-144ec8f4aa28",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "from pathlib import Path\n",
    "from pprint import pprint\n",
    "import numpy as np\n",
    "\n",
    "from scipy.stats import pearsonr,spearmanr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baf5f974",
   "metadata": {},
   "outputs": [],
   "source": [
    "PROJECT_ROOT = Path('..').resolve()\n",
    "\n",
    "def sts_result_parser(model_name,task_name,whitening_transformer_name,pooling_mode):\n",
    "    with open(PROJECT_ROOT / model_name / task_name / whitening_transformer_name /pooling_mode / f'{task_name}.json') as f:\n",
    "        result = json.load(f)\n",
    "    return result['test']['cos_sim']['pearson']\n",
    "\n",
    "def word_similarity_result_parser(model_name,task_name,whitening_transformer_name,pooling_mode):\n",
    "    with open(PROJECT_ROOT / model_name / task_name / whitening_transformer_name /pooling_mode / f'{task_name}.json') as f:\n",
    "        result = json.load(f)\n",
    "    return result['spearman_correlation']\n",
    "\n",
    "def isotropy_result_parser(model_name,whitening_transformer_name,pooling_mode):\n",
    "    with open(PROJECT_ROOT / model_name / 'isotropy_scores' / whitening_transformer_name /pooling_mode / 'isotropy_scores.json') as f:\n",
    "        result = json.load(f)\n",
    "    return result\n",
    "\n",
    "\n",
    "def generate_table(task_name):\n",
    "    model_name = 'results/average_word_embeddings_glove.840B.300d'\n",
    "    results_glove = {\n",
    "        'mean' : round(sts_result_parser(model_name,task_name,'normal','mean')*100,2),\n",
    "        'uniform_centered' : round(sts_result_parser(model_name,task_name,'uniform_whitening','centering_only')*100,2),\n",
    "        'uniform_whitened' : round(sts_result_parser(model_name,task_name,'uniform_whitening','whitening')*100,2),\n",
    "        'zipfian_centered' : round(sts_result_parser(model_name,task_name,'zipfian_whitening','centering_only')*100,2),\n",
    "        'zipfian_whitened' : round(sts_result_parser(model_name,task_name,'zipfian_whitening','whitening')*100,2),\n",
    "        'all_but_the_top' : round(sts_result_parser(model_name,task_name,'abtp','component_removal')*100,2),\n",
    "        'sif': round(sts_result_parser(model_name,task_name,'sif','sif_w_component_removal')*100,2),\n",
    "    }\n",
    "    pprint(results_glove)\n",
    "    model_name = 'results/GoogleNews-vectors-negative300'\n",
    "    results_word2vec = {\n",
    "        'mean' : round(sts_result_parser(model_name,task_name,'normal','mean')*100,2),\n",
    "        'uniform_centered' : round(sts_result_parser(model_name,task_name,'uniform_whitening','centering_only')*100,2),\n",
    "        'uniform_whitened' : round(sts_result_parser(model_name,task_name,'uniform_whitening','whitening')*100,2),\n",
    "        'zipfian_centered' : round(sts_result_parser(model_name,task_name,'zipfian_whitening','centering_only')*100,2),\n",
    "        'zipfian_whitened' : round(sts_result_parser(model_name,task_name,'zipfian_whitening','whitening')*100,2),\n",
    "        'all_but_the_top' : round(sts_result_parser(model_name,task_name,'abtp','component_removal')*100,2),\n",
    "        'sif': round(sts_result_parser(model_name,task_name,'sif','sif_w_component_removal')*100,2),\n",
    "    } \n",
    "    pprint(results_word2vec)\n",
    "    print(model_name)\n",
    "    print('#'*50)\n",
    "    template = rf\"\"\"\n",
    "\\documentclass{{article}}\n",
    "\\usepackage{{booktabs}}\n",
    "\\usepackage{{colortbl}}\n",
    "\\usepackage{{subfig}}\n",
    "\\usepackage{{caption}}\n",
    "\n",
    "\\begin{{document}}\n",
    "\n",
    "\\begin{{table}}[tb]\n",
    "\\begin{{center}}\n",
    "\\captionsetup{{justification=centering}}\n",
    "\\footnotesize\n",
    "{{%\n",
    "\\begin{{tabular}}{{lcc}}\n",
    "\\toprule\n",
    "GloVe & \\multicolumn{{2}}{{c}}{{{results_glove['mean']}}} \\\\\n",
    "\\midrule\n",
    "{{\n",
    "}}\n",
    "& \\cellcolor{{black!5}} Uniform\n",
    "& \\cellcolor{{blue!5}} Zipfian\n",
    "\\\\\n",
    "+ Centering\n",
    "& \\cellcolor{{black!5}} {results_glove['uniform_centered']}\n",
    "& \\cellcolor{{blue!5}} \\textbf{{{results_glove['zipfian_centered']}}}\n",
    "\\\\\n",
    "+ Whitening\n",
    "& \\cellcolor{{black!5}} {results_glove['uniform_whitened']}\n",
    "& \\cellcolor{{blue!5}} \\textbf{{{results_glove['zipfian_whitened']}}}\n",
    "\\\\\n",
    "\\midrule\n",
    "+ ABTT \\citep{{mu2018allbutthetop}} & \\multicolumn{{2}}{{c}}{{{results_glove['all_but_the_top']}}} \\\\\n",
    "+ SIF + CCR \\citep{{arora2017simplebut:sif}} & \\multicolumn{{2}}{{c}}{{{results_glove['sif']}}} \\\\\n",
    "\\bottomrule\n",
    "\\end{{tabular}}\n",
    "}}%\n",
    "\\hspace*{{0.5cm}}\n",
    "{{%\n",
    "\\begin{{tabular}}{{lcc}}\n",
    "\\toprule\n",
    "Word2Vec & \\multicolumn{{2}}{{c}}{{{results_word2vec['mean']}}} \\\\\n",
    "\\midrule\n",
    "{{\n",
    "}}\n",
    "& \\cellcolor{{black!5}} Uniform\n",
    "& \\cellcolor{{blue!5}} Zipfian\n",
    "\\\\\n",
    "+ Centering\n",
    "& \\cellcolor{{black!5}} {results_word2vec['uniform_centered']}\n",
    "& \\cellcolor{{blue!5}} \\textbf{{{results_word2vec['zipfian_centered']}}}\n",
    "\\\\\n",
    "+ Whitening\n",
    "& \\cellcolor{{black!5}} {results_word2vec['uniform_whitened']}\n",
    "& \\cellcolor{{blue!5}} \\textbf{{{results_word2vec['zipfian_whitened']}}}\n",
    "\\\\\n",
    "\\midrule\n",
    "+ ABTT \\citep{{mu2018allbutthetop}} & \\multicolumn{{2}}{{c}}{{{results_word2vec['all_but_the_top']}}} \\\\\n",
    "+ SIF + CCR \\citep{{arora2017simplebut:sif}} & \\multicolumn{{2}}{{c}}{{{results_word2vec['sif']}}} \\\\\n",
    "\\bottomrule\n",
    "\\end{{tabular}}\n",
    "}}\n",
    "\\end{{center}}\n",
    "\\caption{{\n",
    "    \\textbf{{Task}}: STS-B (sentence semantic task) \\citep{{cer2017semeval:stsb}}.\n",
    "    \\textbf{{Estimation of sentence simirality}}: Average the word embeddings to estimate sentence embeddings, then calculate the cosine similarity between the input two sentences.\n",
    "    \\textbf{{Evaluation:}} Pearson's $r\\times 100$ with estimated cosine simiralities and human-annotated scores.\n",
    "}}\n",
    "\\end{{table}}\n",
    "\n",
    "\\end{{document}}\n",
    "\"\"\"\n",
    "    print(template)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b049962",
   "metadata": {},
   "source": [
    "# STS-B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f3b1afde",
   "metadata": {},
   "outputs": [],
   "source": [
    "task_name = 'STSBenchmark'\n",
    "generate_table(task_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f90ca4c",
   "metadata": {},
   "source": [
    "### SICK-R"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6af76027",
   "metadata": {},
   "outputs": [],
   "source": [
    "task_name = 'SICK-R'\n",
    "generate_table(task_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "316a32ff",
   "metadata": {},
   "source": [
    "# Correlation between isotropy scores and task performane"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4604493a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# hard coding for now\n",
    "model_names = ['results/average_word_embeddings_glove.840B.300d','results/GoogleNews-vectors-negative300']\n",
    "TRANSFORM_CONFIG = {\n",
    "    \"normal\": {\n",
    "        \"whitening_transformer_class\": None,\n",
    "        \"pooling\": [\"mean\"],\n",
    "    },\n",
    "    \"uniform_whitening\": {\n",
    "        \"pooling\": [\"centering_only\", \"whitening\"],\n",
    "    },\n",
    "    \"zipfian_whitening\": {\n",
    "        \"pooling\": [\"centering_only\", \"whitening\"],\n",
    "    },\n",
    "    \"abtp\": {\n",
    "        \"pooling\": [\"component_removal\"],\n",
    "    },\n",
    "    \"sif\": {\n",
    "        \"pooling\": [\"sif_w_component_removal\"],\n",
    "    },\n",
    "}\n",
    "results = {\n",
    "    'downstream_scores': {\n",
    "        'STS-B' : [],\n",
    "        'SICK-R' : [], \n",
    "    },\n",
    "    'isotropy_scores': {\n",
    "        'sym1': [],\n",
    "        'sym2': [],\n",
    "        'sym1_uniform': [],\n",
    "        'sym2_uniform': [],\n",
    "        'cosine': [],\n",
    "        'iso_score': [],\n",
    "    }\n",
    "}\n",
    "for model_name in model_names:\n",
    "    for whitening_transformer_name in TRANSFORM_CONFIG.keys():\n",
    "        for pooling_mode in TRANSFORM_CONFIG[whitening_transformer_name]['pooling']:\n",
    "            results['downstream_scores']['STS-B'].append(sts_result_parser(model_name,'STSBenchmark',whitening_transformer_name,pooling_mode))\n",
    "            results['downstream_scores']['SICK-R'].append(sts_result_parser(model_name,'SICK-R',whitening_transformer_name,pooling_mode))\n",
    "            isotropy_scores = isotropy_result_parser(model_name,whitening_transformer_name,pooling_mode)\n",
    "            results['isotropy_scores']['sym1'].append(isotropy_scores['sym1'])\n",
    "            results['isotropy_scores']['sym2'].append(isotropy_scores['sym2'])\n",
    "            results['isotropy_scores']['sym1_uniform'].append(isotropy_scores['sym1_uniform'])\n",
    "            results['isotropy_scores']['sym2_uniform'].append(isotropy_scores['sym2_uniform'])\n",
    "            results['isotropy_scores']['cosine'].append(isotropy_scores['cosine'])\n",
    "            results['isotropy_scores']['iso_score'].append(isotropy_scores['iso_score'])\n",
    "\n",
    "corr_results = {}\n",
    "# Calculate correlation\n",
    "for downstream_task in results['downstream_scores'].keys():\n",
    "    for isotropy_score in results['isotropy_scores'].keys():\n",
    "        pearson_corr, _ = pearsonr(results['downstream_scores'][downstream_task],results['isotropy_scores'][isotropy_score])\n",
    "        corr_results[f'{downstream_task}_{isotropy_score}'] = f\"{round(pearson_corr*100,2):,.2f}\"\n",
    "\n",
    "\n",
    "template = rf\"\"\"\n",
    "\\begin{{table}}[tb]\n",
    "\\centering\n",
    "\\begin{{tabular}}{{lSSSSSSS}}\n",
    "\\toprule\n",
    "& {{\\multirow{{2}}{{*}}{{\\text{{Ave. Cos.\\cite{{ethayarajh2019-how-contextual}}}}}}}} & {{\\multirow{{2}}{{*}}{{\\text{{IsoScore\\cite{{Rudman2022-sb}}}}}}}} & \\multicolumn{{2}}{{c}}{{\\cellcolor{{black!5}}Uniform}} & \\multicolumn{{2}}{{c}}{{\\cellcolor{{blue!5}}Zipfian}} \\\\\n",
    "& & & {{\\cellcolor{{black!5}}}}\\text{{Centering}} & {{\\cellcolor{{black!5}}}}\\text{{Whitening}} & {{\\cellcolor{{blue!5}}}}\\text{{Centering}} & {{\\cellcolor{{blue!5}}}}\\text{{Whitening}} \\\\\n",
    "\\midrule\n",
    "\\text{{STS-B}} & {corr_results['STS-B_cosine']} & {corr_results['STS-B_iso_score']} & \\cellcolor{{black!5}}{corr_results['STS-B_sym1_uniform']} & \\cellcolor{{black!5}}{corr_results['STS-B_sym2_uniform']} & \\cellcolor{{blue!5}}{corr_results['STS-B_sym1']} & \\cellcolor{{blue!5}}{corr_results['STS-B_sym2']} \\\\\n",
    "\\text{{SICK-R}} & {corr_results['SICK-R_cosine']} & {corr_results['SICK-R_iso_score']} & \\cellcolor{{black!5}}{corr_results['SICK-R_sym1_uniform']} & \\cellcolor{{black!5}}{corr_results['SICK-R_sym2_uniform']} & \\cellcolor{{blue!5}}{corr_results['SICK-R_sym1']} & \\cellcolor{{blue!5}}{corr_results['SICK-R_sym2']} \\\\\n",
    "\\bottomrule\n",
    "\\end{{tabular}}\n",
    "\\caption{{Pearson's $r \\times 100$ between isotropy scores and downstream task scores on static word embeddings (GloVe + Word2Vec).}}\n",
    "\\label{{tb:experiments_correlation_symmetry_performance}}\n",
    "\\end{{table}}\n",
    "\"\"\"\n",
    "\n",
    "\n",
    "print(template)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50bb4068",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Assuming `results` is already defined somewhere in your code\n",
    "# Uncomment the following line if `results` is not defined\n",
    "# results = your_data_loading_function()\n",
    "import matplotlib.pyplot as plt\n",
    "import matplotlib.gridspec as gridspec\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "# backend SVG\n",
    "plt.rcParams['font.family'] = 'Times New Roman'\n",
    "plt.rcParams['mathtext.fontset'] = 'stix'\n",
    "plt.rcParams['text.usetex'] = True\n",
    "plt.rc('text.latex', preamble=r'\\usepackage{amsmath} \\usepackage{bm}')\n",
    "# weight\n",
    "# latex change the font weight\n",
    "\n",
    "\n",
    "%config InlineBackend.figure_formats = ['svg']\n",
    "\n",
    "# Create a custom colormap: white to red\n",
    "colors = [(0.8, 0.8, 0.8), (1, 1, 0), (1, 0, 0)] # white to red\n",
    "n_bins = 100  # Discretizes the interpolation into bins\n",
    "cmap_name = 'white_to_red'\n",
    "cm = LinearSegmentedColormap.from_list(cmap_name, colors, N=n_bins)\n",
    "\n",
    "sym1 = results['isotropy_scores']['sym1']\n",
    "sym2 = results['isotropy_scores']['sym2']\n",
    "sym1_uniform = results['isotropy_scores']['sym1_uniform']\n",
    "sym2_uniform = results['isotropy_scores']['sym2_uniform']\n",
    "task_scores = list(map(lambda x: round(x, 2) * 100, results['downstream_scores']['STS-B']))\n",
    "\n",
    "# Create scatter plot\n",
    "fig = plt.figure(figsize=(16, 4))\n",
    "gs = gridspec.GridSpec(1, 3, width_ratios=[1, 1, 0.05], wspace=0.75)\n",
    "\n",
    "\n",
    "ax1 = fig.add_subplot(gs[0])\n",
    "ax2 = fig.add_subplot(gs[1])\n",
    "cax = fig.add_subplot(gs[2])\n",
    "\n",
    "label_fontsize = 22\n",
    "tick_fontsize = 20\n",
    "cbar_fontsize = 22\n",
    "buffer = 0.02\n",
    "xlims = [min(min(sym1), min(sym1_uniform)) - buffer, 1.0 + buffer]\n",
    "ylims = [min(min(sym2), min(sym2_uniform)) - buffer, 1.0 + buffer-0.01]\n",
    "\n",
    "# Scatter plot for the left subplot\n",
    "scatter1 = ax1.scatter(sym1_uniform, sym2_uniform, c=task_scores, cmap=cm, alpha=0.8, s=100)\n",
    "ax1.set_xlabel(\"\"\"Deg. of symmetry\\n(1st moment, Unif.)\"\"\", fontsize=label_fontsize)\n",
    "ax1.set_ylabel(\"\"\"Deg. of symmetry\\n(2nd moment, Unif.)\"\"\", fontsize=label_fontsize)\n",
    "ax1.tick_params(axis='both', which='major', labelsize=tick_fontsize)\n",
    "ax1.set_xlim(xlims)\n",
    "ax1.set_ylim(ylims)\n",
    "\n",
    "# Scatter plot for the right subplot\n",
    "scatter2 = ax2.scatter(sym1, sym2, c=task_scores, cmap=cm, alpha=0.8, s=100)\n",
    "ax2.set_xlabel(\"\"\"Deg. of symmetry\\n(1st moment, Zipf.)\"\"\", fontsize=label_fontsize)\n",
    "ax2.set_ylabel(\"\"\"Deg. of symmetry\\n(2nd moment, Zipf.)\"\"\", fontsize=label_fontsize)\n",
    "ax2.tick_params(axis='both', which='major', labelsize=tick_fontsize)\n",
    "ax2.set_xlim(xlims)\n",
    "ax2.set_ylim(ylims)\n",
    "\n",
    "# Create a colorbar\n",
    "cbar = plt.colorbar(scatter2, cax=cax)\n",
    "cbar.set_label(r\"\"\"STS-B performance\n",
    "(Pearson's $r \\times 100$)\"\"\", fontsize=cbar_fontsize)\n",
    "cbar.ax.tick_params(labelsize=tick_fontsize)\n",
    "\n",
    "# Show plot\n",
    "plt.savefig('figs/correlation_3dplot.pdf',bbox_inches='tight')\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
