{
 "cells": [
  {
   "cell_type": "code",
   "source": [
    "import os\n",
    "import json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.linear_model import RidgeCV, LogisticRegressionCV\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from tqdm.notebook import tqdm\n",
    "from rca import make_binary_scoring, make_multiclass_scoring, process_categorical, best_logistic_solver, checker, k_fold_cross_val"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Loading Data"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Loading dictionary of dtype to embed\n",
    "with open('../../data/dtype_to_embed.json', 'r') as f:\n",
    "    dtype_to_embed = json.load(f)\n",
    "    \n",
    "brain_behav_names = dtype_to_embed['brain'] + dtype_to_embed['behavior']\n",
    "\n",
    "# Iterating through pulled_embeds and finding union of all brain and behavior vocabs\n",
    "embeds_path = '../../data/embeds/'\n",
    "brain_behav_union = set()\n",
    "for name in tqdm(brain_behav_names):\n",
    "    vocab = set(pd.read_csv(embeds_path + name + '.csv', index_col=0).index)\n",
    "    brain_behav_union = brain_behav_union.union(vocab)\n",
    "\n",
    "len(brain_behav_union)  "
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "markdown",
   "source": "We drop feature_overlap because it contains many NaNs and compo_attribs because it doesn't have a large enough vocabulary and is also a identical to a 65 of the norms in the psychNorms dataset."
  },
  {
   "cell_type": "code",
   "source": [
    "# Pulling and standardising embeddings\n",
    "embeds = {}\n",
    "embeds_path = '../../data/embeds/'\n",
    "for f_name in tqdm(os.listdir(embeds_path)):\n",
    "    if f_name not in ['feature_overlap.csv', 'compo_attribs.csv']:  # dropping since contains many NaNs\n",
    "        \n",
    "        embed = pd.read_csv(embeds_path + f_name, index_col=0)\n",
    "        embed_name = f_name.split('.')[0]\n",
    "        \n",
    "        # Subsetting to brain and behavior vocab\n",
    "        embed = embed.loc[embed.index.intersection(brain_behav_union)]\n",
    "        \n",
    "        # Standardising\n",
    "        embeds[embed_name] = (embed - embed.mean()) / embed.std()\n",
    "\n",
    "{name: embed.shape for name, embed in embeds.items()}"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "norms = pd.read_csv('../../data/psychNorms/psychNorms.zip', index_col=0, compression='zip', low_memory=False)\n",
    "norm_metadata = pd.read_csv('../../data/psychNorms/psychNorms_metadata.csv', index_col='norm')\n",
    "norm_metadata['associated_embed'] = norm_metadata['associated_embed'].astype(str)\n",
    "norms"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "with open('../../data/embed_to_dtype.json', 'r') as f:\n",
    "    embed_to_type = json.load(f)\n",
    "embed_to_type\n",
    "\n",
    "# Log transforming selected norms\n",
    "norms_to_log = pd.read_csv('../../data/norms_to_log.csv')['norm']\n",
    "norms[norms_to_log] = norms[norms_to_log].apply(np.log1p)\n",
    "norms_to_log"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Cross Validation"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# Ridge\n",
    "min_ord, max_ord = -5, 5\n",
    "alphas = np.logspace(\n",
    "    min_ord, max_ord, max_ord - min_ord + 1\n",
    ")\n",
    "ridge = RidgeCV(alphas=alphas)\n",
    "\n",
    "# Logistic hyperparameters\n",
    "Cs = 1 / alphas\n",
    "inner_cv = 5\n",
    "penalty = 'l2'\n",
    "\n",
    "# Scorers\n",
    "binary_scoring = make_binary_scoring()\n",
    "multiclass_scoring = make_multiclass_scoring()\n",
    "continuous_scoring = {'r2': 'r2', 'neg_mse': 'neg_mean_squared_error'}\n",
    "\n",
    "# outer_cv setting \n",
    "outer_cv, n_jobs = 5, 10"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "# RCA\n",
    "rca = []\n",
    "for embed_name in tqdm(embeds.keys()):\n",
    "    embed = embeds[embed_name]\n",
    "    \n",
    "    to_print = []\n",
    "    for norm_name in tqdm(norms.columns, desc=embed_name):\n",
    "        \n",
    "        # Aligning data\n",
    "        y = norms[norm_name].dropna()\n",
    "        X, y = embed.align(y, axis=0, join='inner', copy=True) \n",
    "        \n",
    "        # Checking norm dtype \n",
    "        norm_dtype = norm_metadata.loc[norm_name, 'type']\n",
    "        \n",
    "        # Solvers, scoring, estimators\n",
    "        if norm_dtype in ['binary', 'multiclass']:\n",
    "            X, y = process_categorical(outer_cv, inner_cv, X, y)\n",
    "            \n",
    "            # may have switched form multi to bin after processing\n",
    "            norm_dtype = 'binary' if len(y.unique()) == 2 else 'multiclass'\n",
    "            \n",
    "            # Cross validation settings for logistic regression\n",
    "            solver = best_logistic_solver(X, norm_dtype)\n",
    "            \n",
    "            # Defining logistic regression \n",
    "            estimator = LogisticRegressionCV(\n",
    "                Cs=Cs, penalty=penalty, cv=StratifiedKFold(inner_cv),\n",
    "                solver=solver, n_jobs=8\n",
    "            )\n",
    "            scoring = binary_scoring if norm_dtype == 'binary' else multiclass_scoring\n",
    "        else: # continuous\n",
    "            estimator, scoring = ridge, continuous_scoring\n",
    "  \n",
    "        # Cross validation\n",
    "        associated_embed = norm_metadata.loc[norm_name, 'associated_embed']\n",
    "        check = checker(embed_name, y, norm_dtype, associated_embed, outer_cv)\n",
    "        if check == 'pass':\n",
    "            scores = k_fold_cross_val(estimator, X, y, outer_cv, scoring, n_jobs) # stratification is automatically used for classification\n",
    "            r2s, mses = scores['test_r2'], - scores['test_neg_mse']\n",
    "            r2_mean, r2_sd = r2s.mean(), r2s.std()\n",
    "            mse_mean, mse_sd = mses.mean(), mses.std()\n",
    "        else:\n",
    "            r2_mean, r2_sd = np.nan, np.nan\n",
    "            mse_mean, mse_sd = np.nan, np.nan\n",
    "            \n",
    "        # Saving\n",
    "        train_n = int(((outer_cv - 1) / outer_cv) * len(X))\n",
    "        test_n = len(X) - train_n\n",
    "        p = X.shape[1]\n",
    "        embed_type = embed_to_type[embed_name]\n",
    "        rca.append([\n",
    "            embed_name, embed_type, norm_name, train_n, test_n, p, \n",
    "            r2_mean, r2_sd, mse_mean, mse_sd, check\n",
    "        ])\n",
    "        \n",
    "        to_print.append([norm_name, train_n, r2_mean, r2_sd, check])\n",
    "\n",
    "    to_print = pd.DataFrame(to_print, columns=['norm' , 'train_n', 'r2_mean', 'r2_sd', 'check'])\n",
    "    print(to_print.sort_values('r2_mean', ascending=False).head(10))\n",
    "\n",
    "rca = pd.DataFrame(\n",
    "    rca, columns=[\n",
    "        'embed', 'embed_type', 'norm', 'train_n', 'test_n', 'p', \n",
    "        'r2_mean', 'r2_sd', 'mse_mean', 'mse_sd', 'check'\n",
    "    ]\n",
    ")\n",
    "\n",
    "rca.to_csv('../../data/results/rca.csv', index=False)\n",
    "rca"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
