{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib.colors import LinearSegmentedColormap\n",
    "import json"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": "|## Loading data",
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "source": [
    "norms = pd.read_csv('../../data/psychNorms/psychNorms.zip', index_col=0, low_memory=False)\n",
    "rca_full = pd.read_csv('../../data/results/rca.csv')\n",
    "meta = pd.read_csv('../../data/psychNorms/psychNorms_metadata.csv', index_col='norm')\n",
    "\n",
    "print(f\"# Norms: {len(meta.index.unique())}\")\n",
    "print(f\"# Norm categories: {len(meta['category'].unique())}\")\n",
    "print(f\"# Embeds: {len(rca_full['embed'].unique())}\")\n",
    "\n",
    "# Adding norm category\n",
    "rca_full['norm_category'] = (\n",
    "    rca_full['norm'].apply(lambda norm: meta.loc[norm]['category'])\n",
    "    .replace({'_': ' '}, regex=True)\n",
    ")\n",
    "rca_full"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Evaluating how many embed-norm pairs didn't pass the checker"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "rca_full['check'].value_counts()",
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def groupby_pivot(df):\n",
    "    return(\n",
    "        df.groupby(['embed', 'norm_category'], as_index=False).count()\n",
    "        .pivot(index='embed', columns='norm_category', values='norm')\n",
    "    )\n",
    "\n",
    "rca = rca_full.dropna()\n",
    "rca_full_counts = groupby_pivot(rca_full[['embed', 'norm_category', 'norm']])\n",
    "rca_counts =  groupby_pivot(rca[['embed', 'norm_category', 'norm']])\n",
    "\n",
    "rca_full_counts, rca_counts = rca_full_counts.align(rca_counts, join='outer')\n",
    "perc_retained = ((rca_counts / rca_full_counts) * 100)\n",
    "perc_retained"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "fig, axs = plt.subplots(1, 3, figsize=(25, 8))\n",
    "\n",
    "sns.heatmap(rca_full_counts, ax=axs[0], cmap='viridis', annot=True, fmt='g', cbar=False)\n",
    "\n",
    "sns.heatmap(rca_counts, ax=axs[1], cmap='viridis', annot=True, fmt='g', cbar=False)\n",
    "\n",
    "sns.heatmap(-perc_retained, ax=axs[2], cmap='viridis', annot=perc_retained.round(0), fmt='g', cbar=False)\n",
    "\n",
    "axs[0].set_title('Full RCA')\n",
    "axs[1].set_title('RCA')\n",
    "axs[2].set_title('RCA / Full RCA (%)')\n",
    "\n",
    "# remove y tick labels for all but the first plot\n",
    "for ax in axs[1:]:\n",
    "    ax.set(ylabel='')\n",
    "    ax.set_yticklabels([])\n",
    "    \n",
    "fig.tight_layout()"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "# Heatmap"
  },
  {
   "cell_type": "code",
   "source": [
    "with open('../../data/embed_to_dtype.json', 'r') as f:\n",
    "    embed_to_type = json.load(f)\n",
    "    \n",
    "rca['embed_type'] = rca['embed'].map(embed_to_type)\n",
    "rca"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# rca average\n",
    "rca_avg = (\n",
    "    rca[['norm_category', 'embed', 'r2_mean']]\n",
    "    .groupby(['norm_category', 'embed'], as_index=False).median()\n",
    ")\n",
    "rca_avg"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "rca_avg_piv = rca_avg.pivot(index='embed', columns='norm_category', values='r2_mean')\n",
    "rca_avg_piv"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "winner_mask = rca_avg_piv.apply(lambda col: col == col.max(), axis=0)\n",
    "winner_mask"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Function to create a lighter version of a colormap\n",
    "def lighten_cmap(cmap_name, factor=0.3):\n",
    "    cmap = plt.cm.get_cmap(cmap_name, 256)  # Get the original colormap\n",
    "    colors = cmap(np.linspace(0, 1, 256))\n",
    "\n",
    "    # Blend each color with white\n",
    "    white = np.array([1, 1, 1, 1])  # RGBA for white\n",
    "    new_colors = (1 - factor) * colors + factor * white\n",
    "\n",
    "    return LinearSegmentedColormap.from_list(f'light_{cmap_name}', new_colors)\n",
    "\n",
    "# Function to visualize a colormap\n",
    "def plot_colormap(cmap):\n",
    "    gradient = np.linspace(0, 1, 256)\n",
    "    gradient = np.vstack((gradient, gradient))\n",
    "\n",
    "    plt.imshow(gradient, aspect='auto', cmap=cmap)\n",
    "    plt.axis('off')\n",
    "    plt.show()\n",
    "\n",
    "# Usage example:\n",
    "# Generate a lighter viridis colormap\n",
    "lighter_viridis = lighten_cmap('viridis', factor=0.6)\n",
    "\n",
    "# Visualize it\n",
    "plot_colormap(lighter_viridis)"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "def annotate(df, ax):\n",
    "    for x, norm_cat in enumerate(df.columns):\n",
    "        for y, embed in enumerate(df.index):\n",
    "            annot = df.loc[embed, norm_cat]\n",
    "            \n",
    "            # Scientific notation\n",
    "            if abs(annot) > 1e3:\n",
    "                annot = f'{annot:.1e}'\n",
    "            elif np.isnan(annot):\n",
    "                annot = ''\n",
    "            else:\n",
    "                annot = f'{annot:.2f}'\n",
    "            \n",
    "            # Fontsize and fontweight\n",
    "            if winner_mask.loc[embed, norm_cat]:\n",
    "                fontsize, fontweight = 15, 'bold'\n",
    "            else:\n",
    "                fontsize, fontweight = 10, 'normal'\n",
    "            \n",
    "            \n",
    "            ax.text(\n",
    "                x + .5, y + .6, annot, fontsize=fontsize, fontweight=fontweight,\n",
    "                ha='center', va='center', color='black'\n",
    "            )\n",
    "\n",
    "top_behav = (\n",
    "    rca_avg_piv[[embed_to_type[embed] == 'behavior' for embed in rca_avg_piv.index]] # Selects behavior embeds\n",
    "    .mean(axis=1).idxmax() # Selects the behavior embed with the highest average r2\n",
    ")\n",
    "\n",
    "# Sorts norms by the average r2 of the top behavior embed\n",
    "norm_ord = rca_avg_piv.loc[top_behav].sort_values(ascending=True).index\n",
    "\n",
    "# Builds heatmap dfs\n",
    "heat_dfs = {}\n",
    "embed_types = ['text', 'brain', 'behavior']\n",
    "for embed_type in embed_types:\n",
    "    heat_df = rca_avg_piv[[embed_to_type[embed] == embed_type for embed in rca_avg_piv.index]]\n",
    "\n",
    "    # Sorts index and columns\n",
    "    embed_order = heat_df.mean(axis=1).sort_values(ascending=False).index\n",
    "    heat_dfs[embed_type] = heat_df[norm_ord].loc[embed_order]\n",
    "\n",
    "fig, axs = plt.subplots(3, 1, figsize=(18, 10))\n",
    "\n",
    "vmax = rca_avg_piv.max().max()\n",
    "for i, embed_type in enumerate(['text', 'behavior', 'brain']):\n",
    "    heat_df = heat_dfs[embed_type]\n",
    "    \n",
    "    sns.heatmap(\n",
    "        heat_df, ax=axs[i], vmin=0, cmap=lighter_viridis, \n",
    "        vmax=vmax, annot=False, fmt='', cbar=False,\n",
    "        \n",
    "    )\n",
    "    \n",
    "    \n",
    "    axs[i].set(xlabel='', xticklabels=[])\n",
    "    \n",
    "    # sets ylabel on right-hand side and flips it\n",
    "    axs[i].set_ylabel(\n",
    "        embed_type.title(), fontsize=17, rotation=270,\n",
    "        labelpad=20, va='center', ha='center'\n",
    "    )\n",
    "    axs[i].yaxis.set_label_position('right')\n",
    "    \n",
    "    # Annotates cells\n",
    "    annotate(heat_df, axs[i])\n",
    "    \n",
    "    \n",
    "    # Ensure y-axis labels match the number of ticks\n",
    "    axs[i].set_yticks(pd.Series(range(len(heat_df.index))) + .5)\n",
    "    heat_df.index = heat_df.index.str.replace('SVD_sim_rel', 'SVD_similarity_relatedness')\n",
    "    heat_df.index = heat_df.index.str.replace('_', ' ', regex=True)\n",
    "    axs[i].set_yticklabels(heat_df.index, fontsize=12)\n",
    "\n",
    "# Adding xticklabels to last plot\n",
    "norm_ord = norm_ord.str.title().str.replace(' Of ', ' of ', regex=True)\n",
    "axs[-1].set_xticklabels(norm_ord, rotation=90, fontsize=13)\n",
    "\n",
    "# Sets figure title\n",
    "axs[0].set_title('Average Test ${R^2}$', fontsize=20)\n",
    "    \n",
    "fig.tight_layout()\n",
    "plt.savefig('../../figures/rca.png', dpi=300, bbox_inches='tight')"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Comparing performance of the best-performing embeds from each type"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "rca_avg['embed_type'] = rca_avg['embed'].map(embed_to_type)\n",
    "\n",
    "# Taking the 90th percentile of each embed type\n",
    "rca_grand_avg = rca_avg.groupby(['embed_type', 'norm_category'], as_index=False).quantile(.90, numeric_only=True)\n",
    "rca_grand_avg"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Top and bottom reported in the paper\n",
    "rca_grand_avg = rca_grand_avg.pivot(columns='norm_category', index='embed_type', values='r2_mean').T\n",
    "rca_grand_avg['behavior - text'] = rca_grand_avg['behavior'] - rca_grand_avg['text']\n",
    "rca_grand_avg = rca_grand_avg.sort_values(by='behavior - text', ascending=True).round(2)\n",
    "rca_grand_avg"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": "",
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
