{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5cb9c88a-f947-4089-9d95-9bdc31932ce2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "from copy import copy\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "sys.path.append('../')\n",
    "from model_comparison import *\n",
    "import ggseg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a94fe48",
   "metadata": {},
   "outputs": [],
   "source": [
    "from plotnine import *\n",
    "import plotnine.options as plotnine_opts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0a1783b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_parquet('../final-results/raw_results.parquet.gzip').query('val_lower_ci > 0')\n",
    "\n",
    "# we only need one score column\n",
    "data['test_score_original'] = data['test_score']\n",
    "data['test_score'] = data['mean_bootstrap']\n",
    "data['test_lower_ci'] = data['lower_ci']\n",
    "data['test_upper_ci'] = data['upper_ci']\n",
    "data['val_score'] = data['val_mean_bootstrap']\n",
    "\n",
    "data['model_modality'] = data['mul_uni']\n",
    "data['data_alignment'] = data['alignment']\n",
    "\n",
    "#insert a new second column in data called model_id that combines model + train_type\n",
    "data.insert(0, 'model_id', data['model'] + '-' + data['train_type'])\n",
    "\n",
    "#insert a new second column in data called electrode_id that combines electrode and subject\n",
    "data.insert(1, 'electrode_id', data['subject'] + '-' + data['electrode'])\n",
    "\n",
    "#keep only the columns we'll need in the analysis\n",
    "data = data[['model_id','model','train_type','model_modality',\n",
    "             'data_alignment','electrode_id','times',\n",
    "             'val_score','val_lower_ci','val_upper_ci',\n",
    "             'test_score', 'test_lower_ci', 'test_upper_ci', 'test_score_original']]\n",
    "\n",
    "#drop all duplicate columns from data\n",
    "data = data.drop_duplicates()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99823772-12dc-482d-b537-8044f540806b",
   "metadata": {},
   "outputs": [],
   "source": [
    "def relabel_modality(row):\n",
    "    modality = row['model_type']\n",
    "    if modality == 'multimodal' and row['train_type'] == 'randomized':\n",
    "        if not any(x in row['model'] for x in ['lxmert', 'visual_bert']):\n",
    "            if 'language' in row['model']:\n",
    "                modality = 'unimodal_language'\n",
    "            if 'vision' in row['model']:\n",
    "                modality = 'unimodal_vision'\n",
    "            \n",
    "    return modality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2aeceaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.groupby(['data_alignment'])['electrode_id'].nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c916060d",
   "metadata": {},
   "outputs": [],
   "source": [
    "data.groupby(['data_alignment'])['model_id'].nunique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c93a75e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "group_summary_stats(data, ['data_alignment'], 'times')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b02ad3a-0bc5-48ca-b49a-f8d9ab05bef2",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_data = data[['model_id','model','train_type','model_modality']]\n",
    "model_data = model_data.drop_duplicates().reset_index(drop=True)\n",
    "#model_data['model_type'] = model_data.apply(lambda row: relabel_modality(row), axis = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be2aa717",
   "metadata": {},
   "outputs": [],
   "source": [
    "electrode_id = data['electrode_id'].unique()[5]\n",
    "nest1 = data[data['electrode_id'] == electrode_id]\n",
    "nest2 = nest1[nest1['data_alignment'] == 'language']\n",
    "nest3 = nest2[nest2['model_id'] == 'albef-trained']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "743719f4",
   "metadata": {},
   "outputs": [],
   "source": [
    "group_vars = ['model_id', 'electrode_id', 'data_alignment']\n",
    "summary = {'val':  group_summary_stats(data, group_vars, 'val_score'),\n",
    "           'test': group_summary_stats(data, group_vars, 'test_score_original')}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e891206",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary['val']['val_score_mean'].min(), summary['val']['val_score_mean'].max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dd73863",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary['test'].groupby(['model_id','data_alignment'])['electrode_id'].nunique().reset_index();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c87efb6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "#loop through each model_id and each electrode_id, bootstrap the mean score across times and save the results in a new dataframe associated with model ranking\n",
    "model_ranklists = make_ranklists(data)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "7a994f31-5274-419b-8ca5-a39978fab3c8",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Comparisons"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "531e551e-0f3c-4d2a-8c8c-4d00dc9934fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "default_winners, comparison_data  = model_comparisons(data, model_ranklists)\n",
    "comparison_data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18a1fbc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "default_winners.groupby(['comparison'])['electrode_id'].nunique().reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "784c3808-99b6-4c48-b8de-77bbb14736e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from statsmodels.stats.multitest import multipletests\n",
    "\n",
    "analysis_data = comparison_data.copy()\n",
    "\n",
    "comparison_list = analysis_data.comparison.unique()\n",
    "\n",
    "analysis_dflist = []\n",
    "for comparison in comparison_list:\n",
    "    cur_data = analysis_data[analysis_data['comparison'] == comparison].copy()\n",
    "    \n",
    "    cur_data['p_value_adj'] = (multipletests(cur_data['p_value'].values,\n",
    "                                             method = 'fdr_bh')[1])\n",
    "    \n",
    "    electrodes_remaining = cur_data[cur_data['p_value_adj'] < 0.05].shape[0]\n",
    "    print(f'electrodes passing comparison {comparison}: {electrodes_remaining}')\n",
    "\n",
    "    cur_data['significant'] = cur_data['p_value_adj'] < 0.05\n",
    "    \n",
    "    analysis_dflist.append(cur_data)\n",
    "    \n",
    "analysis_data = pd.concat(analysis_dflist)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e44de5e8-b5a3-456e-a1ae-1b48045e2904",
   "metadata": {},
   "outputs": [],
   "source": [
    "analysis_data.to_parquet('../final-results/comparisons.parquet.gzip', compression = 'gzip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4e4db5b",
   "metadata": {},
   "outputs": [],
   "source": [
    "analysis_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0bed2206",
   "metadata": {},
   "outputs": [],
   "source": [
    "analysis_data_parsed = (analysis_data.query('p_value_adj < 0.05')\n",
    "                        .query('model1_nobs >= 10')\n",
    "                        .query('model2_nobs >= 10')).copy()\n",
    "\n",
    "analysis_data_parsed\n",
    "\n",
    "comparison_list = ['rank1-beats-rank2', 'trained-beats-random','random-beats-trained', \n",
    "                   'multi-beats-unimodal', 'slip-beats-simclr']\n",
    "\n",
    "#for comparison in analysis_data_parsed['comparison'].unique():\n",
    "print('---Bootstrapped Comparisons---')\n",
    "for comparison in comparison_list:\n",
    "    cur_data = analysis_data_parsed[analysis_data_parsed['comparison'] == comparison]\n",
    "    print(comparison, ':', cur_data['electrode_id'].nunique())\n",
    "    \n",
    "for comparison in comparison_list:\n",
    "    cur_data = analysis_data_parsed[analysis_data_parsed['comparison'] == comparison]\n",
    "    if comparison == 'multi-beats-unimodal':\n",
    "        cur_data = cur_data.groupby('alignment')['electrode_id'].nunique().reset_index().query('electrode_id > 1')\n",
    "        comparison = 'multi-beats-unimodal (both alignments)'\n",
    "        print(comparison, ':', cur_data['electrode_id'].nunique())\n",
    "        \n",
    "    \n",
    "    if comparison == 'slip-beats-simclr':\n",
    "        cur_data = cur_data.groupby('alignment')['electrode_id'].nunique().reset_index().query('electrode_id > 1')\n",
    "        comparison = 'slip-beats-simclr (both alignments)'\n",
    "        print(comparison, ':', cur_data['electrode_id'].nunique())\n",
    "        \n",
    "    \n",
    "from os import linesep\n",
    "print(linesep)\n",
    "    \n",
    "print('---Default Winners---')\n",
    "for comparison in default_winners.comparison.unique():\n",
    "    cur_data = default_winners[default_winners['comparison'] == comparison]\n",
    "    \n",
    "    if comparison == 'multimodal-beats-unimodal':\n",
    "        comparison = 'multi-beats-unimodal'\n",
    "        \n",
    "    if comparison == 'slipclr-beats-slip':\n",
    "        comparison = 'slip-beats-simclr'\n",
    "    \n",
    "    print(comparison, cur_data['electrode_id'].nunique())\n",
    "    \n",
    "for comparison in default_winners.comparison.unique():\n",
    "    \n",
    "    cur_data = default_winners[default_winners['comparison'] == comparison]\n",
    "    if comparison == 'multimodal-beats-unimodal':\n",
    "        cur_data = cur_data.groupby('alignment')['electrode_id'].nunique().reset_index().query('electrode_id > 1')\n",
    "        comparison = 'multi-beats-unimodal (both alignments)'\n",
    "        print(comparison, ':', cur_data['electrode_id'].nunique())\n",
    "        \n",
    "    if comparison == 'slip-vision-beats-simclr':\n",
    "        cur_data = cur_data.groupby('alignment')['electrode_id'].nunique().reset_index().query('electrode_id > 1')\n",
    "        comparison = 'slip-beats-simclr (both alignments)'\n",
    "        print(comparison, ':', cur_data['electrode_id'].nunique())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "454e3372",
   "metadata": {},
   "outputs": [],
   "source": [
    "analysis_data_parsed.to_parquet('../final-results/parsed_comparison_data.parquet.gzip')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fbab5b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary['test'].query('count >= 10')#.to_csv('scoring_data.csv', index = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "98467059",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary['test'].query('count >= 0')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "590f19fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "(group_summary_stats(summary['test'].query('count >= 0'), ['model_id', 'data_alignment'], 'test_score_original_mean', use_value_name = False)\n",
    " .round(4).sort_values('mean', ascending=False))#.to_csv('trained_beats_random.csv', index = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1e0bedee",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_tests =summary['test'].query('count >= 0').merge(model_data, on = 'model_id', how = 'left')\n",
    "train_tests[train_tests['model_id'].str.contains('simcse')]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0424adfa",
   "metadata": {},
   "outputs": [],
   "source": [
    "(group_summary_stats(train_tests[train_tests['model_id'].str.contains('simcse')], ['model_id', 'data_alignment'], \n",
    "                     'test_score_original_mean', use_value_name = False)\n",
    " .round(4).sort_values(['data_alignment','mean'], ascending=False))#.to_csv('trained_beats_random.csv', index = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ee1a3fb4",
   "metadata": {},
   "outputs": [],
   "source": [
    "(group_summary_stats(train_tests, ['train_type', 'data_alignment'], 'test_score_original_mean', use_value_name = False)\n",
    " .round(4).sort_values('mean', ascending=False))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.4"
  },
  "vscode": {
   "interpreter": {
    "hash": "7863deda28c2bbbeadddba67b96758f51af59736d922b2a4440bab3e43d09602"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
