{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "d7c3bb81",
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from glob import glob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "2fe8ca23",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "all_test_results.json  \u001b[0m\u001b[01;34msub2\u001b[0m/  \u001b[01;34msub4\u001b[0m/  \u001b[01;34msub7\u001b[0m/  summary.json\r\n",
      "\u001b[01;34msub1\u001b[0m/                \u001b[01;34msub3\u001b[0m/  \u001b[01;34msub6\u001b[0m/  \u001b[01;34msub10\u001b[0m/\r\n"
     ]
    }
   ],
   "source": [
    "ls /storage/czw/self_supervised_seeg/full_brain_test_outs/superlet_large_pretrained/onset_finetuning/all_test_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "94513727",
   "metadata": {},
   "outputs": [],
   "source": [
    "# !export MPLCONFIGDIR=/storage/czw/.config/matplotlib"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6034d49d",
   "metadata": {},
   "source": [
    "## collect all three json stft files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "dd731269",
   "metadata": {},
   "outputs": [],
   "source": [
    "# baseline_stft_path = \"/storage/czw/self_supervised_seeg/stft_all_test_results/*\"\n",
    "# all_stft_results = {}\n",
    "# for path in glob(baseline_stft_path):    \n",
    "#     with open(path, \"r\") as f:\n",
    "#         results = json.load(f)\n",
    "#         for k,v in results.items():\n",
    "#             all_stft_results[k] = v\n",
    "# all_results_path = \"/storage/czw/self_supervised_seeg/full_brain_test_outs/superlet_large_pretrained/onset_finetuning/all_test_results/all_test_results.json\"\n",
    "all_linear_results = {}\n",
    "for path in glob(\"/storage/czw/self_supervised_seeg/outputs/2022-08-31/01-02-47/all_test_results/*/all_test_results.json\"): #linear\n",
    "    with open(path, \"r\") as f:\n",
    "        results = json.load(f)\n",
    "        for k,v in results.items():\n",
    "            all_linear_results[k] = v\n",
    "            \n",
    "all_superlet_results = {}\n",
    "for path in glob(\"/storage/czw/self_supervised_seeg/full_brain_test_outs/superlet_large_pretrained/onset_finetuning/all_test_results/*/subj_test_results.json\"):\n",
    "    with open(path, \"r\") as f:\n",
    "        results = json.load(f)\n",
    "        for k,v in results.items():\n",
    "            all_superlet_results[k] = v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0db85817",
   "metadata": {},
   "outputs": [],
   "source": [
    "# all_results_path = \"/storage/czw/self_supervised_seeg/full_brain_test_outs/superlet_large_pretrained/onset_finetuning/all_test_results/all_test_results.json\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "57817d9c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "all_test_results.json\r\n"
     ]
    }
   ],
   "source": [
    "ls /storage/czw/self_supervised_seeg/outputs/2022-08-31/01-02-47/all_test_results/sub4/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "a333c5ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# glob(\"/storage/czw/self_supervised_seeg/full_brain_test_outs/superlet_large_pretrained/onset_finetuning/all_test_results/*/subj_test_results.json\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ea31da86",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['sub10', 'sub3', 'sub6', 'sub1', 'sub7', 'sub2', 'sub4'])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_superlet_results.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "49725a6a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# baseline_linear_path = \"/storage/czw/self_supervised_seeg/outputs/2022-08-15/15-59-13/all_test_results.json\"\n",
    "# with open(baseline_linear_path, \"r\") as f:\n",
    "#     baseline_linear = json.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e4c341f5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': 0.5240416079759598,\n",
       " 'roc_auc': 0.84168657429527,\n",
       " 'f1': 0.7999999999999999}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_superlet_results['sub3']['T1cIe11']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "1a9deac2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(all_linear_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "53aa1d20",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# baseline_linear['sub3']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "5c64b46a",
   "metadata": {},
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: 'brain_plotting/left_hem_clean.png'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_3885/2669274011.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     26\u001b[0m \u001b[0mcorrelations_file\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'lag_correlation.json'\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     27\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 28\u001b[0;31m \u001b[0mleft_hem_img\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbase_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mleft_hem_file_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     29\u001b[0m \u001b[0mright_hem_img\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbase_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mright_hem_file_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     30\u001b[0m \u001b[0mcoords_df\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread_csv\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbase_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcoords_file_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/storage/czw/anaconda3/envs/sss/lib/python3.7/site-packages/matplotlib/pyplot.py\u001b[0m in \u001b[0;36mimread\u001b[0;34m(fname, format)\u001b[0m\n\u001b[1;32m   2158\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0m_copy_docstring_and_deprecators\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimread\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2159\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mimread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2160\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimage\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2161\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2162\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/storage/czw/anaconda3/envs/sss/lib/python3.7/site-packages/matplotlib/image.py\u001b[0m in \u001b[0;36mimread\u001b[0;34m(fname, format)\u001b[0m\n\u001b[1;32m   1558\u001b[0m                     \u001b[0mresponse\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mBytesIO\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresponse\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1559\u001b[0m                 \u001b[0;32mreturn\u001b[0m \u001b[0mimread\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mresponse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mformat\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mext\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1560\u001b[0;31m     \u001b[0;32mwith\u001b[0m \u001b[0mimg_open\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfname\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mimage\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1561\u001b[0m         return (_pil_png_to_float_array(image)\n\u001b[1;32m   1562\u001b[0m                 \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mimage\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mPIL\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPngImagePlugin\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mPngImageFile\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/storage/czw/anaconda3/envs/sss/lib/python3.7/site-packages/PIL/ImageFile.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, fp, filename)\u001b[0m\n\u001b[1;32m    102\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0misPath\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    103\u001b[0m             \u001b[0;31m# filename\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 104\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"rb\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    105\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfilename\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    106\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_exclusive_fp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: 'brain_plotting/left_hem_clean.png'"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "import h5py\n",
    "import math\n",
    "import matplotlib\n",
    "# matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab!\n",
    "import matplotlib.pyplot as plt\n",
    "import copy\n",
    "import numpy as np\n",
    "import os\n",
    "import json\n",
    "import scipy.stats\n",
    "import time\n",
    "from types import SimpleNamespace\n",
    "import random\n",
    "import pandas as pd\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "matlab_xlim = (-108.0278, 108.0278)\n",
    "matlab_ylim = (-72.9774, 72.9774)\n",
    "\n",
    "save_dir = 'brain_plotting/'\n",
    "base_path = 'brain_plotting/'\n",
    "left_hem_file_name = 'left_hem_clean.png'\n",
    "right_hem_file_name = 'right_hem_clean.png'\n",
    "coords_file_name = 'elec_coords_full.csv'\n",
    "correlations_file = 'lag_correlation.json'\n",
    "\n",
    "left_hem_img = plt.imread(os.path.join(base_path, left_hem_file_name))\n",
    "right_hem_img = plt.imread(os.path.join(base_path, right_hem_file_name))\n",
    "coords_df = pd.read_csv(os.path.join(base_path, coords_file_name))\n",
    "split_elec_id = coords_df['ID'].str.split('-')\n",
    "coords_df['Subject'] = [t[0] for t in split_elec_id]\n",
    "coords_df['Electrode'] = [t[1] for t in split_elec_id]\n",
    "\n",
    "# # Scale Matlab electrode locations to Python format\n",
    "def scale(x, s, d):\n",
    "    return -(x - d) * s\n",
    "\n",
    "x_scale = left_hem_img.shape[1] / (matlab_xlim[1] - matlab_xlim[0])\n",
    "y_scale_l = left_hem_img.shape[0] / (matlab_ylim[1] - matlab_ylim[0])\n",
    "\n",
    "y_scale_r = right_hem_img.shape[0] / (matlab_ylim[1] - matlab_ylim[0])\n",
    "\n",
    "scaled_coords_df = coords_df.copy()\n",
    "\n",
    "# scale left hemisphere coordinates\n",
    "scaled_coords_df.loc[scaled_coords_df['Hemisphere'] == 1, 'X'] = coords_df.loc[coords_df['Hemisphere'] == 1, 'X'].apply(lambda x: scale(x, x_scale, matlab_xlim[1]))\n",
    "scaled_coords_df.loc[scaled_coords_df['Hemisphere'] == 1, 'Y'] = coords_df.loc[coords_df['Hemisphere'] == 1, 'Y'].apply(lambda x: scale(x, y_scale_l, matlab_ylim[1]))\n",
    "\n",
    "# scale right hemisphere coordinates\n",
    "scaled_coords_df.loc[scaled_coords_df['Hemisphere'] == 0, 'X'] = coords_df.loc[coords_df['Hemisphere'] == 0, 'X'].apply(lambda x: -scale(x, y_scale_r, matlab_xlim[0]))\n",
    "scaled_coords_df.loc[scaled_coords_df['Hemisphere'] == 0, 'Y'] = coords_df.loc[coords_df['Hemisphere'] == 0, 'Y'].apply(lambda x: scale(x, y_scale_r, matlab_ylim[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b6e253e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_hemisphere_axis(ax=None, hemisphere=\"left\", subject_electrodes=None, cbar=True):\n",
    "    ax.set_aspect('equal')\n",
    "\n",
    "    if hemisphere==\"left\":\n",
    "        ax.imshow(left_hem_img)\n",
    "    elif hemisphere==\"right\":\n",
    "        ax.imshow(right_hem_img)\n",
    "    \n",
    "    ax.axis('off')\n",
    "    assert hemisphere in [\"left\", \"right\"]\n",
    "    hem_index = 1 if hemisphere==\"left\" else 0\n",
    "\n",
    "    selected = scaled_coords_df[(scaled_coords_df['Hemisphere'] == hem_index)]\n",
    "    plot_title = f'{hemisphere} hemisphere'\n",
    "    xs, ys, colors = [], [], []\n",
    "    for subject in subject_electrodes:\n",
    "        for electrode in subject_electrodes[subject]:\n",
    "            selected_elec = selected[(selected.Subject==subject) & (selected.Electrode==electrode)]\n",
    "            if len(selected_elec) > 0:\n",
    "                x = selected_elec.iloc[0][\"X\"]\n",
    "                y = selected_elec.iloc[0][\"Y\"]\n",
    "                xs.append(x)\n",
    "                ys.append(y)\n",
    "                colors.append(subject_electrodes[subject][electrode]['roc_auc'])\n",
    "    \n",
    "    sorted_color_idxs = [x[0] for x in sorted(enumerate(colors), key=lambda x: x[1])]\n",
    "    colors = np.array(colors)[sorted_color_idxs]\n",
    "    xs = np.array(xs)[sorted_color_idxs]\n",
    "    ys = np.array(ys)[sorted_color_idxs]\n",
    "    print(max(colors))\n",
    "    sc = ax.scatter(xs, ys,\n",
    "                    s=500,\n",
    "                    c=colors,\n",
    "                    edgecolor=\"black\",\n",
    "                    vmin=0.45, vmax=0.85)\n",
    "    if cbar:\n",
    "        divider = make_axes_locatable(ax)\n",
    "        cax = divider.append_axes(\"right\", size=\"5%\", pad=0.05)\n",
    "        cax.tick_params(labelsize=25) \n",
    "        plt.colorbar(sc, cax=cax)  \n",
    "\n",
    "#     ax.legend(fontsize=30)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49f09546",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_hemisphere(electrodes, half=\"left\"):\n",
    "    fig, ax = plt.subplots(1, 1, figsize=(15,15))\n",
    "    plot_hemisphere_axis(ax=ax, hemisphere=half, subject_electrodes=electrodes)\n",
    "    fig.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a8c5e2d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_hemisphere(all_superlet_results, \"right\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e70fd46e",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_hemisphere(all_linear_results, \"right\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "638da4aa",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_hemisphere(all_superlet_results, \"left\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "75634ddb",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_hemisphere(all_linear_results, \"left\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d84b6f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 4)\n",
    "fig.set_figheight(20)\n",
    "fig.set_figwidth(40)\n",
    "plot_hemisphere_axis(ax=axs[0], subject_electrodes=all_linear_results, hemisphere=\"right\", cbar=False)\n",
    "plot_hemisphere_axis(ax=axs[1], subject_electrodes=all_linear_results, hemisphere=\"left\", cbar=False)\n",
    "plot_hemisphere_axis(ax=axs[2], subject_electrodes=all_superlet_results, hemisphere=\"right\", cbar=False)\n",
    "plot_hemisphere_axis(ax=axs[3], subject_electrodes=all_superlet_results, hemisphere=\"left\", cbar=True)\n",
    "plt.subplots_adjust(wspace=0.05, hspace=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "30532a05",
   "metadata": {},
   "outputs": [],
   "source": [
    "#TODO, get superior temporal lobe percentage\n",
    "all_linear_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f560fdc8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import region_utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bec52dc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "all_region_dfs = {}\n",
    "clean_name = lambda x: x.replace(\"#\",\"\").replace(\"*\",\"\")\n",
    "for s in all_linear_results:\n",
    "    all_region_dfs[s] = region_utils.get_regions_file(s).set_index(\"Electrode\")\n",
    "    all_region_dfs[s].index = [clean_name(x) for x in all_region_dfs[s].index]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc51c442",
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mean_roc(results, region=None):\n",
    "    all_rocs = []\n",
    "    for sub, elecs in results.items():\n",
    "        for e, res in elecs.items():\n",
    "#             print(sub)\n",
    "#             print(all_region_dfs[sub].index.tolist())\n",
    "            region_name = all_region_dfs[sub].loc[e].DesikanKilliany\n",
    "#             print(all_region_dfs[sub])\n",
    "            region_name = region_name.split(\"-\")[-1]\n",
    "#             print(region_name)\n",
    "            if region is None or region_name==region:\n",
    "                all_rocs.append(res[\"roc_auc\"])\n",
    "    return np.mean(all_rocs), np.std(all_rocs)\n",
    "\n",
    "def fmt_results(mean_std):\n",
    "    mean, std = mean_std\n",
    "    return (f'{mean:.2f} \\pm {std:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8874d8f0",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"all_linear_results\", fmt_results(get_mean_roc(all_linear_results)))\n",
    "print(\"all_superlet_results\", fmt_results(get_mean_roc(all_superlet_results)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baa44bd0",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3448f11b",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"all_linear_results\", fmt_results(get_mean_roc(all_linear_results, region=\"superiortemporal\")))\n",
    "print(\"all_superlet_results\", fmt_results(get_mean_roc(all_superlet_results, region=\"superiortemporal\")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ec8f226",
   "metadata": {},
   "outputs": [],
   "source": [
    "sum([len(es) for s,es in all_superlet_results.items()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69e73222",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
