{
 "cells": [
  {
   "cell_type": "code",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.manifold import MDS\n",
    "import itertools\n",
    "from adjustText import adjust_text\n",
    "import json"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "rsa = pd.read_csv('../../data/results/rsa.csv')\n",
    "rsa"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "with open('../../data/embed_to_dtype.json', 'r') as f:\n",
    "    embed_to_dtype = json.load(f)\n",
    "    \n",
    "rsa['dtype_i'] = rsa['name_i'].map(embed_to_dtype)\n",
    "rsa['dtype_j'] = rsa['name_j'].map(embed_to_dtype)\n",
    "\n",
    "def dtype_corr(dtype_i, dtype_j):\n",
    "    return rsa.query(\n",
    "        '(dtype_i == @dtype_i & dtype_j == @dtype_j) | (dtype_i == @dtype_j & dtype_j == @dtype_i)'\n",
    "    )['spearman'].mean().round(2)\n",
    "\n",
    "# Self-correlation\n",
    "text_text, brain_brain, behavior_behavior = dtype_corr('text', 'text'), dtype_corr('brain', 'brain'), dtype_corr('behavior', 'behavior')\n",
    "print(f'text-text mean correlation {text_text}')\n",
    "print(f'brain-brain mean correlation {brain_brain}')\n",
    "print(f'behavior-behavior mean correlation {behavior_behavior}')\n",
    "print('---------------')\n",
    "\n",
    "# Self-another\n",
    "text_brain = dtype_corr('text', 'brain')\n",
    "print(f'Text-brain mean correlation {text_brain}')\n",
    "text_behavior = dtype_corr('text', 'behavior')\n",
    "print(f'Text-behavior mean correlation {text_behavior}')\n",
    "brain_behavior = dtype_corr('brain', 'behavior')\n",
    "print(f'Brain-behavior mean correlation {brain_behavior}')"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "def to_heat_df(results, col):\n",
    "    # Heat df template\n",
    "    names = list(pd.concat([results['name_i'], results['name_j']]).unique()) # needed because not all models guaranteed on one column\n",
    "    heat_df = pd.DataFrame(index=names, columns=names)\n",
    "\n",
    "    # Filling with correlations\n",
    "    query = '(name_i == @name_i & name_j == @name_j) | (name_i == @name_j & name_j == @name_i)'\n",
    "    for name_i, name_j in list(itertools.combinations(names, 2)):\n",
    "        r, *_ = results.query(query)[col]\n",
    "        heat_df.loc[name_i, name_j] = r\n",
    "        heat_df.loc[name_j, name_i] = r\n",
    "        \n",
    "        order = text_names + behavior_names + brain_names\n",
    "\n",
    "    return heat_df.loc[order, order].astype(float)\n",
    "\n",
    "\n",
    "with open('../../data/dtype_to_embed.json', 'r') as f:\n",
    "    dtype_to_embed = json.load(f)\n",
    "\n",
    "text_names = dtype_to_embed['text']\n",
    "brain_names = dtype_to_embed['brain']\n",
    "behavior_names = dtype_to_embed['behavior']\n",
    "\n",
    "print({dtype: len(names) for dtype, names in dtype_to_embed.items()})\n",
    "\n",
    "spearmans = to_heat_df(rsa, 'spearman')\n",
    "spearmans"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "dissimilarity = 1 - spearmans\n",
    "np.fill_diagonal(dissimilarity.values, 0.0)\n",
    "\n",
    "# MDS\n",
    "mds = MDS(n_components=2, dissimilarity='precomputed', random_state=0)\n",
    "spearmans_2d = mds.fit_transform(dissimilarity)\n",
    "spearmans_2d = pd.DataFrame(spearmans_2d, index=spearmans.index)\n",
    "\n",
    "def data_type(mod_name):\n",
    "    if mod_name in brain_names:\n",
    "        return 'brain'\n",
    "    elif mod_name in behavior_names:\n",
    "        return 'behavior'\n",
    "    else:\n",
    "        return 'text'\n",
    "\n",
    "# Adding data type\n",
    "spearmans_2d['embed_type'] = [data_type(name) for name in spearmans_2d.index]\n",
    "spearmans_2d"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "## Plotting"
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# renaming 'compo_attribs' to 'experiential_attributes'\n",
    "rename = {'compo_attribs': 'experiential_attributes', 'SVD_sim_rel': 'SVD_similarity_relatedness'}\n",
    "\n",
    "for df in [spearmans, spearmans_2d]:\n",
    "    df.index = df.index.to_series().replace(rename).values\n",
    "\n",
    "\n",
    "fig, (ax_1, ax_2) = plt.subplots(1, 2, figsize=(20, 8), width_ratios=(0.8, 1))\n",
    "\n",
    "# Colors\n",
    "cmap = plt.get_cmap('viridis', 4)\n",
    "embed_type_to_color = {\n",
    "    'brain': cmap(1),\n",
    "    'behavior': cmap(0),\n",
    "    'text': cmap(2)\n",
    "}\n",
    "\n",
    "sns.scatterplot(\n",
    "    data=spearmans_2d, x=0, y=1, hue='embed_type',\n",
    "    sizes=(500, 500), legend=False, s=110,\n",
    "    marker='s', linewidth=0.1, edgecolor='black',\n",
    "    palette=embed_type_to_color, ax=ax_1\n",
    ")\n",
    "\n",
    "ax_1.set(xticklabels='', yticklabels='', xlabel='', ylabel='')\n",
    "\n",
    "texts = []\n",
    "for model in spearmans.index:\n",
    "    texts.append(\n",
    "        ax_1.text(spearmans_2d[0][model], spearmans_2d[1][model], model.replace('_', ' '), fontsize=13)\n",
    "    )\n",
    "\n",
    "# Adjust text labels to avoid overlap\n",
    "adjust_text(\n",
    "    texts, arrowprops=dict(arrowstyle='-', color='black', lw=.5), ax=ax_1\n",
    ")\n",
    "ax_1.axis('off') # Turn off the axis\n",
    "\n",
    "# Heatmap \n",
    "spearmans.index = spearmans.index.str.replace('_', ' ')\n",
    "spearmans.columns = spearmans.columns.str.replace('_', ' ')\n",
    "sns.heatmap(\n",
    "    spearmans, square=True, annot=True, cmap='viridis',\n",
    "    vmin=0, vmax=spearmans.max().max(), \n",
    "    fmt='.2f', annot_kws={\"fontsize\": 6}, cbar=False, ax=ax_2\n",
    ")\n",
    "\n",
    "# Adding bold panel labels\n",
    "ax_1.text(-0.1, 1.05, 'A', transform=ax_1.transAxes, fontsize=20, fontweight='bold', va='top')\n",
    "ax_2.text(-0.4, 1.05, 'B', transform=ax_2.transAxes, fontsize=20, fontweight='bold', va='top')\n",
    "\n",
    "fig.tight_layout()\n",
    "plt.savefig('../../figures/rsa.png', dpi=300, bbox_inches='tight')"
   ],
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
