{
 "cells": [
  {
   "cell_type": "code",
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from sklearn.linear_model import RidgeCV, LogisticRegressionCV\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from tqdm.notebook import tqdm\n",
    "import json\n",
    "from rca import process_categorical, best_logistic_solver, k_fold_cross_val, make_binary_scoring, make_multiclass_scoring, checker"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Loading Data"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "source": [
    "rca = pd.read_csv('../../data/results/rca.csv').dropna()\n",
    "meta = pd.read_csv('../../data/psychNorms/psychNorms_metadata.csv', index_col=0)\n",
    "norms = pd.read_csv('../../data/psychNorms/psychNorms.zip', index_col=0, compression='zip', low_memory=False)\n",
    "\n",
    "# Adding norm_cat to rca\n",
    "rca['norm_cat'] = (\n",
    "    rca['norm'].apply(lambda norm: meta.loc[norm]['category'])\n",
    "    .replace({'_': ' '}, regex=True)\n",
    ")\n",
    "\n",
    "rca"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "embed_avgs = (\n",
    "    rca[['embed', 'norm_cat', 'r2_mean']]\n",
    "    .groupby(['embed', 'norm_cat']).median(numeric_only=True) # median is used to mitigate outliers within norm_cats\n",
    "    .groupby('embed').mean()\n",
    "    .rename(columns={'r2_mean': 'r2_avg'})\n",
    ")\n",
    "embed_avgs"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "# Adding embed types\n",
    "with open('../../data/embed_to_dtype.json', 'r') as f:\n",
    "    embed_to_type = json.load(f)\n",
    "embed_avgs['type'] = embed_avgs.index.map(embed_to_type)\n",
    "\n",
    "# Finding top 2 text \n",
    "text_name_1, text_name_2 = (\n",
    "    embed_avgs.query('type == \"text\"').sort_values('r2_avg', ascending=False).head(2).index.tolist()\n",
    ")\n",
    "text_name_1, text_name_2"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Finding top behavior\n",
    "behavior_name = (\n",
    "    embed_avgs.query('type == \"behavior\"').sort_values('r2_avg', ascending=False).head(1).index[0]\n",
    ")\n",
    "behavior_name"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "# Loading embeds\n",
    "text_1 = pd.read_csv(f'../../data/embeds/{text_name_1}.csv', index_col=0)\n",
    "text_2 = pd.read_csv(f'../../data/embeds/{text_name_2}.csv', index_col=0)\n",
    "behavior = pd.read_csv(f'../../data/embeds/{behavior_name}.csv', index_col=0)\n",
    "\n",
    "# Aligning vocabs\n",
    "intersect = sorted(list(set.intersection(set(text_1.index), set(text_2.index), set(behavior.index))))\n",
    "text_1, text_2, behavior = text_1.loc[intersect], text_2.loc[intersect], behavior.loc[intersect]\n",
    "\n",
    "# Standardizing\n",
    "standardize = lambda df: (df - df.mean()) / df.std()\n",
    "text_1, text_2, behavior = standardize(text_1), standardize(text_2), standardize(behavior)\n",
    "\n",
    "# Ensembling for comparison\n",
    "embeds = {\n",
    "    behavior_name: behavior,\n",
    "    text_name_1: text_1, \n",
    "    text_name_2: text_2,\n",
    "    text_name_1 + '&' + text_name_2: pd.concat([text_1, text_2], axis=1),\n",
    "    text_name_1 + '&' + behavior_name: pd.concat([text_1, behavior], axis=1),\n",
    "    text_name_2 + '&' + behavior_name: pd.concat([text_2, behavior], axis=1)\n",
    "}\n",
    "\n",
    "# Fixing column names\n",
    "for embed_name, embed in embeds.items():\n",
    "    embed.columns = list(range(embed.shape[1]))\n",
    "    embeds[embed_name] = embed\n",
    "\n",
    "{name: embed.shape for name, embed in embeds.items()}"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "metadata": {},
   "cell_type": "code",
   "source": [
    "# Changing associated_embed to more usable format\n",
    "meta['associated_embed'] = meta['associated_embed'].str.split(' ')\n",
    "\n",
    "# Log transforming selected norms\n",
    "norms_to_log = pd.read_csv('../../data/norms_to_log.csv')['norm']\n",
    "norms[norms_to_log] = norms[norms_to_log].apply(np.log1p)\n",
    "norms_to_log"
   ],
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Cross Validation"
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "source": [
    "# Ridge\n",
    "min_ord, max_ord = -5, 5\n",
    "alphas = np.logspace(\n",
    "    min_ord, max_ord, max_ord - min_ord + 1\n",
    ")\n",
    "ridge = RidgeCV(alphas=alphas)\n",
    "\n",
    "# Logistic hyperparameters\n",
    "Cs = 1 / alphas\n",
    "inner_cv = 5\n",
    "penalty = 'l2'\n",
    "\n",
    "# Scorers\n",
    "binary_scoring = make_binary_scoring()\n",
    "multiclass_scoring = make_multiclass_scoring()\n",
    "continuous_scoring = {'r2': 'r2', 'neg_mse': 'neg_mean_squared_error'}\n",
    "\n",
    "# outer_cv setting \n",
    "outer_cv, n_jobs = 5, 6\n",
    "\n",
    "solo_embed_names = [text_name_1, text_name_2, behavior_name] # For checking data leakage in checker"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  },
  {
   "cell_type": "code",
   "source": [
    "# RCA\n",
    "rca = []\n",
    "for norm_name in tqdm(norms.columns):\n",
    "    print(f'{norm_name}:')\n",
    "    y = norms[norm_name].dropna()\n",
    "    \n",
    "    to_print = []\n",
    "    for embed_name, embed in embeds.items():\n",
    "        \n",
    "        # Aligning embed with norm\n",
    "        X, y = embed.align(y, axis='index', join='inner', copy=True)\n",
    "        \n",
    "        # Checking norm dtype \n",
    "        norm_dtype = meta.loc[norm_name, 'type']\n",
    "        \n",
    "        # Solvers, scoring, estimators ir categorical or continuous\n",
    "        if norm_dtype in ['binary', 'multiclass']: # categorical\n",
    "            X, y = process_categorical(outer_cv, inner_cv, X, y)\n",
    "            \n",
    "            # may have switched form multi to bin after processing\n",
    "            norm_dtype = 'binary' if len(y.unique()) == 2 else 'multiclass'\n",
    "            \n",
    "            # Cross validation settings for logistic regression\n",
    "            solver = best_logistic_solver(y, norm_dtype)\n",
    "            \n",
    "            # Defining logistic regression \n",
    "            estimator = LogisticRegressionCV(\n",
    "                Cs=Cs, penalty=penalty, cv=StratifiedKFold(inner_cv), solver=solver\n",
    "            )\n",
    "            scoring = binary_scoring if norm_dtype == 'binary' else multiclass_scoring\n",
    "        else: # continuous\n",
    "            estimator, scoring = ridge, continuous_scoring\n",
    "            \n",
    "        # Cross validation\n",
    "        associated_embed = meta.loc[norm_name, 'associated_embed']\n",
    "        check = checker(solo_embed_names, y, norm_dtype, associated_embed, outer_cv)\n",
    "        if check == 'pass':\n",
    "            scores = k_fold_cross_val(estimator, X, y, outer_cv, scoring, n_jobs) # stratification is automatically used for classification\n",
    "            r2s, mses = scores['test_r2'], - scores['test_neg_mse']\n",
    "        else:\n",
    "            r2s, mses = pd.Series([np.nan] * outer_cv), pd.Series([np.nan] * outer_cv)\n",
    "            \n",
    "        # Saving\n",
    "        train_n = int(((outer_cv - 1) / outer_cv) * len(y))\n",
    "        for i, (r2, mse) in enumerate(zip(r2s, mses)):\n",
    "            rca.append([embed_name, norm_name, train_n, i + 1, r2, mse, check])\n",
    "            \n",
    "        # Printing\n",
    "        to_print.append([embed_name, r2s.mean(), r2s.std(), check])\n",
    "    to_print = pd.DataFrame(to_print, columns=['embed', 'r2_mean', 'r2_std', 'check'])\n",
    "    print(to_print.sort_values('r2_mean', ascending=False).head(10).reset_index(drop=True))\n",
    "    print('--------------------------------')\n",
    " \n",
    " \n",
    "rca = pd.DataFrame(\n",
    "    rca, columns=[\n",
    "        'embed', 'norm', 'train_n', 'fold', 'r2', 'mse', 'check']\n",
    ")\n",
    "rca.to_csv('../../data/results/rca_ensemb.csv', index=False)\n",
    "rca"
   ],
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "execution_count": null
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
