{
 "cells": [
  {
   "cell_type": "code",
   "id": "initial_id",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import json\n",
    "from scipy.stats import ttest_1samp, wilcoxon"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "# Processing data"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "697eef5d8213ea01"
  },
  {
   "cell_type": "code",
   "source": [
    "# Loading data\n",
    "rca = pd.read_csv('../../data/results/rca_ensemb.csv').dropna()\n",
    "meta = pd.read_csv('../../data/psychNorms/psychNorms_metadata.csv', index_col=0)\n",
    "\n",
    "# Adding norm_cat to rca\n",
    "rca['norm_cat'] = (\n",
    "    rca['norm'].apply(lambda norm: meta.loc[norm]['category'])\n",
    "    .replace({'_': ' '}, regex=True)\n",
    ")\n",
    "\n",
    "\n",
    "with open('../../data/embed_to_dtype.json', 'r') as f:\n",
    "    embed_to_type = json.load(f)\n",
    "    \n",
    "def embed_to_group(embed_name):\n",
    "    if '&' in embed_name:\n",
    "        name_1, name_2 = embed_name.split('&')\n",
    "        return embed_to_type[name_1] + '&' + embed_to_type[name_2]\n",
    "    else:\n",
    "        return embed_to_type[embed_name]\n",
    "\n",
    "rca['embed_group'] = rca['embed'].apply(embed_to_group)\n",
    "rca"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "17d1a99a8144953a",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "rca_mean = (\n",
    "    rca[['embed_group', 'norm', 'fold', 'r2']]\n",
    "    .groupby(['embed_group', 'norm', 'fold'], as_index=False).mean(numeric_only=True)\n",
    "    .groupby(['embed_group', 'norm'], as_index=False).mean(numeric_only=True)\n",
    "    .rename(columns={'r2': 'r2_mean'})\n",
    "    .drop(columns='fold')\n",
    ")\n",
    "rca_mean"
   ],
   "id": "5ac2fc243f50a54b",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "rca_mean['norm_cat'] = (\n",
    "    rca_mean['norm'].apply(lambda norm: meta.loc[norm]['category'])\n",
    ")\n",
    "rca_mean"
   ],
   "id": "7834bd1c419d6412",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "**Grand avgs**",
   "id": "49ff815417935f12"
  },
  {
   "cell_type": "code",
   "source": [
    "rca_grand_avg = (\n",
    "    rca_mean[['embed_group', 'norm_cat', 'r2_mean']]\n",
    "    .groupby(['embed_group', 'norm_cat'], as_index=False).median(numeric_only=True)\n",
    "    .rename(columns={'r2_mean': 'r2_grand_avg'})\n",
    ")\n",
    "rca_grand_avg"
   ],
   "metadata": {
    "collapsed": false
   },
   "id": "7330265b524418dc",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Top embed_group\n",
    "sorted_embeds = (\n",
    "    rca_grand_avg.groupby('embed_group')['r2_grand_avg']\n",
    "    .mean()\n",
    "    .sort_values(ascending=True)\n",
    "    .reset_index()\n",
    ")\n",
    "sorted_embeds"
   ],
   "id": "7c5953138ce18a8c",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "**grand avg diffs**",
   "id": "2bb5d812890ad61e"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Pivoting to make it easier to compute pairwise differences \n",
    "rca_wide = ( \n",
    "    rca[['embed', 'norm', 'norm_cat', 'fold', 'r2']]\n",
    "    .pivot(index=['norm', 'fold', 'norm_cat'], columns='embed', values='r2')\n",
    "    .reset_index()\n",
    ")\n",
    "rca_wide"
   ],
   "id": "daaaafc1c0f3df64",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "text_1, text_2 = 'CBOW_GoogleNews', 'fastText_CommonCrawl' # Text\n",
    "text_text = text_1 + '&' + text_2 # Text & Text \n",
    "text_behav_1 = text_1 + '&' + 'PPMI_SVD_SWOW' # Text and Behavior 1\n",
    "text_behav_2 = text_2 + '&' + 'PPMI_SVD_SWOW' # Text and Behavior 2\n",
    "\n",
    "\n",
    "# Text & Behavior - Text & Text \n",
    "rca_wide[f'{text_behav_1} vs {text_text}'] = rca_wide[text_behav_1] - rca_wide[text_text]\n",
    "rca_wide[f'{text_behav_2} vs {text_text}'] = rca_wide[text_behav_2] - rca_wide[text_text]\n",
    "tb_vs_tt = (\n",
    "    rca_wide[['norm', 'fold', 'norm_cat', f'{text_behav_1} vs {text_text}', f'{text_behav_2} vs {text_text}']]\n",
    "    .melt(id_vars=['norm', 'norm_cat', 'fold'])\n",
    "    .rename(columns={'embed': 'comparing'})\n",
    ")\n",
    "tb_vs_tt"
   ],
   "id": "a4a7d8c9906c9d39",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Computing tb_vs_tt_mean analagously to how we compute diffs for wilcoxon test below\n",
    "tb_vs_tt_mean = (\n",
    "    tb_vs_tt.groupby(['norm', 'fold'], as_index=False)\n",
    "    .mean(numeric_only=True)\n",
    "    .groupby('norm', as_index=False).mean(numeric_only=True)\n",
    "    .rename(columns={'value': 'r2_diff'})\n",
    "    .drop(columns='fold')\n",
    ")\n",
    "\n",
    "tb_vs_tt_mean['norm_cat'] = tb_vs_tt_mean['norm'].apply(lambda norm: meta.loc[norm]['category'])\n",
    "tb_vs_tt_mean"
   ],
   "id": "d0ea52c6ea20389",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "tb_vs_tt_grand_avg = (\n",
    "    tb_vs_tt_mean.groupby('norm_cat', as_index=False)\n",
    "    .median(numeric_only=True)\n",
    "    .rename(columns={'r2_diff': 'r2_diff_grand_avg'})\n",
    ")\n",
    "tb_vs_tt_grand_avg"
   ],
   "id": "5a1304d94f3e6fc5",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Plotting ",
   "id": "6da5ec2bfbdf5b3f"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Pivot rca_grand_avg for plotting\n",
    "heat_df_1 = (\n",
    "    rca_grand_avg\n",
    "    .pivot(index='embed_group', columns='norm_cat', values='r2_grand_avg')\n",
    "    .loc[['text', 'behavior', 'text&text', 'text&behavior']]\n",
    ")\n",
    "\n",
    "# Ordering norm_cats by text&behavior performance\n",
    "norm_cat_order = heat_df_1.loc['text&behavior'].sort_values(ascending=True).index\n",
    "heat_df_1 = heat_df_1[norm_cat_order]\n",
    "heat_df_1.index = heat_df_1.index.str.replace('&', ' & ').str.title()\n",
    "heat_df_1.columns = heat_df_1.columns.str.replace('_', ' ')\n",
    "heat_df_1"
   ],
   "id": "65734727863bac33",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "rca_winners_bool = heat_df_1.apply(lambda col: col == col.max(), axis=0)\n",
    "rca_winners_bool"
   ],
   "id": "d12b568165929563",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Pivot diffs_grand_avg for plotting\n",
    "tb_vs_tt_grand_avg['comparing'] = 'text&behavior - text&text'\n",
    "heat_df_2 = (\n",
    "    tb_vs_tt_grand_avg\n",
    "    .pivot(index='comparing', columns='norm_cat', values='r2_diff_grand_avg')\n",
    ")\n",
    "heat_df_2 = heat_df_2[norm_cat_order]\n",
    "heat_df_2.index = heat_df_2.index.str.replace('&', ' & ').str.title()\n",
    "heat_df_2.columns = heat_df_2.columns.str.replace('_', ' ')\n",
    "heat_df_2"
   ],
   "id": "49907ed08549979f",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Function to create a lighter version of a colormap\n",
    "def lighten_cmap(cmap_name, factor=0.3):\n",
    "    cmap = plt.cm.get_cmap(cmap_name, 256)  # Get the original colormap\n",
    "    colors = cmap(np.linspace(0, 1, 256))\n",
    "\n",
    "    # Blend each color with white\n",
    "    white = np.array([1, 1, 1, 1])  # RGBA for white\n",
    "    new_colors = (1 - factor) * colors + factor * white\n",
    "\n",
    "    return LinearSegmentedColormap.from_list(f'light_{cmap_name}', new_colors)\n",
    "\n",
    "# Function to visualize a colormap\n",
    "def plot_colormap(cmap):\n",
    "    gradient = np.linspace(0, 1, 256)\n",
    "    gradient = np.vstack((gradient, gradient))\n",
    "\n",
    "    plt.imshow(gradient, aspect='auto', cmap=cmap)\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "\n",
    "# Usage example:\n",
    "# Generate a lighter viridis colormap\n",
    "lighter_viridis = lighten_cmap('viridis', factor=0.6)\n",
    "\n",
    "# Visualize it\n",
    "plot_colormap(lighter_viridis)"
   ],
   "id": "5ca37af303bc4e57",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def annotate(heat_df, ax, winners_bool=None):\n",
    "    for x, norm_cat in enumerate(heat_df.columns):\n",
    "        for y, embed in enumerate(heat_df.index):\n",
    "            annot = heat_df.loc[embed, norm_cat]\n",
    "            \n",
    "            # Scientific notation\n",
    "            if abs(annot) > 1e3:\n",
    "                annot = f'{annot:.1e}'\n",
    "            elif np.isnan(annot):\n",
    "                annot = ''\n",
    "            else:\n",
    "                annot = f'{annot:.2f}'\n",
    "            \n",
    "            # Fontsize and fontweight\n",
    "            fontsize, fontweight = 13, 'normal'\n",
    "            if winners_bool is not None:\n",
    "                if winners_bool.loc[embed, norm_cat]:\n",
    "                    fontsize, fontweight = 16, 'bold' \n",
    "            \n",
    "            ax.text(\n",
    "                x + .5, y + .5, annot, fontsize=fontsize, fontweight=fontweight,\n",
    "                ha='center', va='center', color='black'\n",
    "            )\n",
    "\n",
    "heat_dfs = [heat_df_1, heat_df_2]\n",
    "fig, axs = plt.subplots(2, figsize=(18, 6), height_ratios=[len(df) for df in heat_dfs])\n",
    "\n",
    "# Plotting grand avg\n",
    "vmax = heat_df_1.max().max()\n",
    "sns.heatmap(\n",
    "    heat_df_1, vmin=0, cmap=lighter_viridis, \n",
    "    vmax=vmax, annot=False, fmt='', cbar=False,\n",
    "    ax=axs[0]\n",
    ")\n",
    "\n",
    "# Plotting text & behavior - text & text\n",
    "vmax = heat_df_2.max().max()\n",
    "sns.heatmap(\n",
    "    heat_df_2, cmap=lighter_viridis,\n",
    "    vmin=0, vmax=vmax, annot=False, fmt='', cbar=False,\n",
    "    ax=axs[1]\n",
    ")\n",
    "\n",
    "for ax in axs:\n",
    "    ax.set(xlabel='', ylabel='')\n",
    "    ax.set_yticklabels(ax.get_yticklabels(), fontsize=13)\n",
    "    ax.set_xticklabels(ax.get_xticklabels(), fontsize=13)\n",
    "    \n",
    "    # rotates y-tick labels to horizontal\n",
    "    plt.setp(ax.get_yticklabels(), rotation=0)\n",
    "\n",
    "# Remove x-tick labels for all but last plot\n",
    "axs[0].set_xticklabels([])\n",
    "x_tick_labels = heat_df_2.columns.str.title().str.replace('Of', 'of', regex=True)\n",
    "axs[1].set_xticklabels(x_tick_labels, rotation=90, ha='right')\n",
    "\n",
    "# Annotates cells\n",
    "annotate(heat_df_1, axs[0], rca_winners_bool)\n",
    "annotate(heat_df_2, axs[1])\n",
    "\n",
    "# Sets axis titles\n",
    "axs[0].set_title('Average Test $R^2$', fontsize=20)\n",
    "  \n",
    "fig.tight_layout()\n",
    "plt.savefig('../../figures/rca_ensemb.png', dpi=300, bbox_inches='tight')"
   ],
   "id": "9873f87bd8e208ad",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Paper Stats",
   "id": "7a8c42948abf85c0"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# checking that ensembling always improves performance\n",
    "print(f\"Text & Text - Text: {(heat_df_1.loc['Text & Text'] - heat_df_1.loc['Text'] < 0).any()}\")\n",
    "print(f\"Text & Behavior - Behavior: {(heat_df_1.loc['Text & Behavior'] - heat_df_1.loc['Behavior'] < 0).any()}\")"
   ],
   "id": "152e08ce313b3ba7",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "print(f\"# where Text & Behavior > Text & Text: {(heat_df_2.loc['Text & Behavior - Text & Text'] > 0).sum()}\")",
   "id": "36f7258174af7793",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "sorted_avg_diffs = heat_df_2.loc['Text & Behavior - Text & Text'].sort_values()\n",
    "sorted_avg_diffs"
   ],
   "id": "358fd0f568fb6287",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def get_diffs(norm_cat):\n",
    "    # Averaging diffs across folds\n",
    "    diffs = (\n",
    "        tb_vs_tt.query(f'norm_cat == \"{norm_cat}\"')\n",
    "        .groupby(['norm', 'fold'])\n",
    "        .mean(numeric_only=True)['value']\n",
    "    )\n",
    "    \n",
    "    return diffs\n",
    "\n",
    "\n",
    "def format_p(p):\n",
    "    thresholds = [(0.001, 'p < .001'), (0.01, 'p < .01'), (0.05, p)]\n",
    "    \n",
    "    for threshold, label in thresholds:\n",
    "        if p < threshold:\n",
    "            return label\n",
    "\n",
    "\n",
    "def wilcoxon_test(diffs):\n",
    "    \"\"\"Does the same as above but with wilcoxon instead of t\"\"\"\n",
    "    w, p = wilcoxon(diffs)\n",
    "    return {\n",
    "        'median': round(diffs.median(), 2), \n",
    "        'n': len(diffs), 'w': w, 'p': format_p(round(p, 2))\n",
    "    }\n",
    "    \n",
    "tests_to_run = sorted_avg_diffs[sorted_avg_diffs.abs() > .03].index\n",
    "\n",
    "for norm_cat in tests_to_run:\n",
    "    diffs = get_diffs(norm_cat)\n",
    "    print(f'{norm_cat}: {wilcoxon_test(diffs)}')"
   ],
   "id": "783dcc4ff0ebcfbd",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
