{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "c550ad95",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import mne\n",
    "import pickle\n",
    "import seaborn as sns\n",
    "import pandas as pd\n",
    "from scipy import stats\n",
    "from statannotations.Annotator import Annotator\n",
    "\n",
    "plt.rcParams.update({'font.size': 14})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d4d5d71a-45f1-45ca-8994-dc1c2aaf9a00",
   "metadata": {},
   "source": [
    "# Accuracies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "e10190c3-8798-413e-b391-836beec736e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('excel_data.txt', sep=\"\\t\", header=0)\n",
    "df = df.stack().reset_index()\n",
    "df = df.rename(columns={'level_0': 'subject', 'level_1': 'model', 0: 'Validation accuracy'})\n",
    "\n",
    "model_type = []\n",
    "for i in range(len(df)):\n",
    "    if 'emb' in df['model'][i]:\n",
    "        model_type.append('group\\nembedding')\n",
    "    elif 'group' in df['model'][i]:\n",
    "        model_type.append('group')\n",
    "    else:\n",
    "        model_type.append('subject')\n",
    "df['model type'] = model_type"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "33954f25-e2e6-4bba-ad00-dff1b218cb9a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>subject</th>\n",
       "      <th>model</th>\n",
       "      <th>Validation accuracy</th>\n",
       "      <th>model type</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>0</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.490662</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>12</th>\n",
       "      <td>1</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.318117</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>19</th>\n",
       "      <td>2</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.112696</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>26</th>\n",
       "      <td>3</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.522857</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>33</th>\n",
       "      <td>4</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.705128</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>40</th>\n",
       "      <td>5</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.471510</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>47</th>\n",
       "      <td>6</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.637269</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>54</th>\n",
       "      <td>7</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.212554</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>61</th>\n",
       "      <td>8</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.464950</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>68</th>\n",
       "      <td>9</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.454416</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75</th>\n",
       "      <td>10</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.290000</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>82</th>\n",
       "      <td>11</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.159772</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>89</th>\n",
       "      <td>12</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.061254</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>96</th>\n",
       "      <td>13</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.428165</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>103</th>\n",
       "      <td>14</td>\n",
       "      <td>nonlin-group-emb</td>\n",
       "      <td>0.388252</td>\n",
       "      <td>group\\nembedding</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "     subject             model  Validation accuracy        model type\n",
       "5          0  nonlin-group-emb             0.490662  group\\nembedding\n",
       "12         1  nonlin-group-emb             0.318117  group\\nembedding\n",
       "19         2  nonlin-group-emb             0.112696  group\\nembedding\n",
       "26         3  nonlin-group-emb             0.522857  group\\nembedding\n",
       "33         4  nonlin-group-emb             0.705128  group\\nembedding\n",
       "40         5  nonlin-group-emb             0.471510  group\\nembedding\n",
       "47         6  nonlin-group-emb             0.637269  group\\nembedding\n",
       "54         7  nonlin-group-emb             0.212554  group\\nembedding\n",
       "61         8  nonlin-group-emb             0.464950  group\\nembedding\n",
       "68         9  nonlin-group-emb             0.454416  group\\nembedding\n",
       "75        10  nonlin-group-emb             0.290000  group\\nembedding\n",
       "82        11  nonlin-group-emb             0.159772  group\\nembedding\n",
       "89        12  nonlin-group-emb             0.061254  group\\nembedding\n",
       "96        13  nonlin-group-emb             0.428165  group\\nembedding\n",
       "103       14  nonlin-group-emb             0.388252  group\\nembedding"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[df['model']=='nonlin-group-emb']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "2b776683-b77c-48e9-b95f-2e9637b46b85",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1cb78907a2544e0486a3f98cf710990e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "p-value annotation legend:\n",
      "      ns: p <= 1.00e+00\n",
      "       *: 1.00e-02 < p <= 5.00e-02\n",
      "      **: 1.00e-03 < p <= 1.00e-02\n",
      "     ***: 1.00e-04 < p <= 1.00e-03\n",
      "    ****: p <= 1.00e-04\n",
      "\n",
      "lin-subject vs. nonlin-subject: t-test paired samples, P_val:5.736e-04 t=4.428e+00\n",
      "nonlin-group-emb vs. nonlin-group-emb finetuned: t-test paired samples, P_val:1.127e-06 t=-8.135e+00\n",
      "nonlin-group vs. nonlin-group-emb: t-test paired samples, P_val:1.911e-06 t=-7.773e+00\n",
      "lin-subject vs. nonlin-group-emb: t-test paired samples, P_val:1.285e-02 t=2.850e+00\n",
      "lin-subject vs. nonlin-group-emb finetuned: t-test paired samples, P_val:1.066e-03 t=-4.108e+00\n"
     ]
    }
   ],
   "source": [
    "%matplotlib widget\n",
    "\n",
    "    \n",
    "# Putting the parameters in a dictionary avoids code duplication\n",
    "# since we use the same for `sns.boxplot` and `Annotator` calls\n",
    "plot_params = {\n",
    "    'kind':    'violin',\n",
    "    'aspect':  2,\n",
    "    'cut':     0,\n",
    "    'ci':      None,\n",
    "    'scale':   'area',\n",
    "    'hue':     'model type',\n",
    "    'dodge':   False,\n",
    "    'data':    df,\n",
    "    'x':       'model',\n",
    "    'y':       'Validation accuracy'\n",
    "}\n",
    "\n",
    "g = sns.catplot(**plot_params)\n",
    "\n",
    "ax = g.axes[0][0]\n",
    "ax.axhline(0.008, ls='-', color='black', label='chance')\n",
    "plt.ylim(0, 1)\n",
    "plt.xlabel('')\n",
    "plt.text(6.77,0.001,'chance')\n",
    "plt.xticks(plt.xticks()[0], ['linear\\nsubject',\n",
    "                             'nonlinear\\nsubject',\n",
    "                             'linear\\ngroup',\n",
    "                             'nonlinear\\ngroup',\n",
    "                             'linear\\ngroup-emb',\n",
    "                             'nonlinear\\ngroup-emb',\n",
    "                             'nonlinear\\ngroup-emb\\nfinetuned'])\n",
    "\n",
    "ymin = 0.02\n",
    "ymax = 0.7\n",
    "alpha = 0.5\n",
    "dash = '--'\n",
    "color = 'red'\n",
    "ax.axvline(1.5, ymin, ymax, ls=dash, color=color, alpha=alpha)\n",
    "ax.axvline(3.5, ymin, ymax, ls=dash, color=color, alpha=alpha)\n",
    "plt.text(0.22,0.87,'subject models')\n",
    "plt.text(2,0.8,'group models')\n",
    "plt.text(3.7,0.8,'group models\\nwith embedding')\n",
    "\n",
    "# which pairs to computer stats on\n",
    "pairs = [('lin-subject', 'nonlin-subject'),\n",
    "         ('lin-subject', 'nonlin-group-emb'),\n",
    "         ('lin-subject', 'nonlin-group-emb finetuned'),\n",
    "         ('nonlin-group-emb', 'nonlin-group-emb finetuned'),\n",
    "         ('nonlin-group-emb', 'nonlin-group')]\n",
    "\n",
    "# Add statistics annotations\n",
    "annotator = Annotator(ax, pairs, data=df, x='model', y='Validation accuracy')\n",
    "annotator.configure(test='t-test_paired', verbose=True, line_offset_to_group=10).apply_and_annotate()\n",
    "\n",
    "plt.savefig('group_acc.pdf', format='pdf')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb00d04d-8db6-44c9-906f-97a97df0ded2",
   "metadata": {},
   "source": [
    "# Generalization to new subject"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "db62b231-196a-44e2-bb8e-21015ab791f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.join('..', 'results', 'cichy_epoched', 'indiv_wavenetlinear_MNN', 'val_loss_general.npy')\n",
    "accs = np.load(path)\n",
    "train1 = [0.591525424, 0.303672316, 0.121468925, 0.680790966, 0.885593221, 0.662429377, 0.730225995, 0.159604517, 0.579096052, 0.627118642, 0.223163842, 0.151129942, 0.06497175, 0.483050848, 0.412429377]\n",
    "accs = np.concatenate((accs, np.array(train1).reshape(-1, 1)), axis=1)\n",
    "chance = [0.00847] * 15\n",
    "accs = np.concatenate((np.array(chance).reshape(-1, 1), accs), axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "4cbfe8b8-8f24-43ef-895e-2c370a29f955",
   "metadata": {},
   "outputs": [],
   "source": [
    "accs_df = pd.DataFrame(accs)\n",
    "accs_df = accs_df.stack().reset_index()\n",
    "accs_df = accs_df.rename(columns={'level_0': 'subject', 'level_1': 'Training ratio', 0: 'Validation accuracy'})\n",
    "accs_df['level'] = ['subject'] * len(accs_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 159,
   "id": "ab8f3d7e-cd04-4e2d-9b1c-a4ce0101a34f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>level_0</th>\n",
       "      <th>level_1</th>\n",
       "      <th>0</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0.008470</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0.008475</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>0.013559</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>0.050847</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>4</td>\n",
       "      <td>0.116949</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>160</th>\n",
       "      <td>14</td>\n",
       "      <td>6</td>\n",
       "      <td>0.162429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>161</th>\n",
       "      <td>14</td>\n",
       "      <td>7</td>\n",
       "      <td>0.213277</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>162</th>\n",
       "      <td>14</td>\n",
       "      <td>8</td>\n",
       "      <td>0.283898</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>163</th>\n",
       "      <td>14</td>\n",
       "      <td>9</td>\n",
       "      <td>0.331921</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>164</th>\n",
       "      <td>14</td>\n",
       "      <td>10</td>\n",
       "      <td>0.412429</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>165 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     level_0  level_1         0\n",
       "0          0        0  0.008470\n",
       "1          0        1  0.008475\n",
       "2          0        2  0.013559\n",
       "3          0        3  0.050847\n",
       "4          0        4  0.116949\n",
       "..       ...      ...       ...\n",
       "160       14        6  0.162429\n",
       "161       14        7  0.213277\n",
       "162       14        8  0.283898\n",
       "163       14        9  0.331921\n",
       "164       14       10  0.412429\n",
       "\n",
       "[165 rows x 3 columns]"
      ]
     },
     "execution_count": 159,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accs_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "id": "56ea72e2-8eca-48bc-8ea7-ea2b299a58d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_df(accsg, level):\n",
    "    # need to get actual subjects\n",
    "    order = [10, 7, 3, 11, 8, 4, 12, 9, 5, 13, 1, 14, 2, 6, 0]\n",
    "    accsg_df = pd.DataFrame(accsg[order, :])\n",
    "    accsg_df = accsg_df.stack().reset_index()\n",
    "    accsg_df = accsg_df.rename(columns={'level_0': 'subject', 'level_1': 'Training ratio', 0: 'Validation accuracy'})\n",
    "    accsg_df['level'] = [level] * len(accsg_df)\n",
    "    \n",
    "    return accsg_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "5034fb1c-fa77-4ab0-98e7-95cb970cf767",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = os.path.join('..', 'results', 'cichy_epoched', 'all_wavenet_semb_general', 'val_loss_general.npy')\n",
    "accsg = np.load(path)\n",
    "group_emb = create_df(accsg, 'group-emb')\n",
    "\n",
    "path = os.path.join('..', 'results', 'cichy_epoched', 'all_wavenet_general', 'val_loss_general.npy')\n",
    "accsg = np.load(path)\n",
    "group = create_df(accsg, 'group')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "id": "29bb794f-4502-4c50-87a3-441d605770d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.concat((accs_df, group_emb, group), ignore_index=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "1b4302fc-010d-44be-9d02-e99b570cf259",
   "metadata": {},
   "outputs": [],
   "source": [
    "df['Training ratio'] = df['Training ratio'].astype(float)/10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "451f9349-9bdc-40ca-b51f-27b2b6bf0c98",
   "metadata": {},
   "outputs": [],
   "source": [
    "p_values = []\n",
    "for i in range(11):\n",
    "    test1 = df['Validation accuracy'][(df['Training ratio'] == i) & (df['level'] == 'group all')]\n",
    "    test2 = df['Validation accuracy'][(df['Training ratio'] == i) & (df['level'] == 'group-emb all')]\n",
    "    \n",
    "    p_values.append(stats.ttest_rel(test1, test2)[1] * 11)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "c5505436-bc0c-43f2-8e53-5be4778d7eb3",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.0060336492775827375,\n",
       " 1.351510411629857,\n",
       " 0.03169837884573075,\n",
       " 4.3785466680206016,\n",
       " 2.918205427079515,\n",
       " 0.10703676926544689,\n",
       " 0.020678270815296832,\n",
       " 0.055122900986178884,\n",
       " 0.11919219652472021,\n",
       " 0.23726568570982257,\n",
       " nan]"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p_values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "id": "dadf7c89-3c7f-4813-9c50-982203f32554",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "06731a9359544eb395c22254c38386ba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.legend.Legend at 0x7fa9cc264df0>"
      ]
     },
     "execution_count": 83,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%matplotlib widget\n",
    "g = sns.relplot(\n",
    "    data=df, kind=\"line\", hue='level',\n",
    "    x=\"Training ratio\", y=\"Validation accuracy\", n_boot=1000, aspect=1.2, ci=95\n",
    ")\n",
    "ax = g.axes[0][0]\n",
    "plt.axhline(0.6, 0.02, 0.7, color='black')\n",
    "plt.text(0,0.61,'group>subject (p<0.05)')\n",
    "plt.ylim(0, 0.6)\n",
    "plt.xlabel('Training set ratio')\n",
    "\n",
    "\n",
    "\n",
    "ax.axhline(0.008, ls='-', color='black', label='chance')\n",
    "ax.legend(loc='upper left')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "id": "2eb96a58-ba72-44f5-a34d-4c771d25c855",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.savefig('generalization.pdf', format='pdf')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
