{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fa164e7-70e4-478d-b996-ff00d141f4d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import os\n",
    "import mne\n",
    "mne.set_log_level('ERROR')\n",
    "\n",
    "from warnings import filterwarnings\n",
    "filterwarnings('ignore')\n",
    "\n",
    "\n",
    "from IPython.utils import io\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "from braindecode.datautil.windowers import create_fixed_length_windows\n",
    "from braindecode.datautil.serialization import  load_concat_dataset\n",
    "\n",
    "from braindecode.datasets import BaseConcatDataset\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf8bc02b-44b9-4b06-8d39-8feb427a376f",
   "metadata": {},
   "source": [
    "## Data Loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a0ee9a6-ef32-42f4-910a-ada40df36654",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time \n",
    "from braindecode.datasets.tuh import TUHAbnormal\n",
    "data_path = '/data/datasets/TUH/EEG/tuh_eeg_abnormal/v2.0.0/edf/'\n",
    "dataset = TUHAbnormal(\n",
    "    path=data_path,\n",
    "    recording_ids=None,  # loads the n chronologically first recordings\n",
    "    target_name=target_name,  # age, gender, pathology\n",
    "    preload=False,\n",
    "    add_physician_reports=False,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c973441-97af-4110-ab11-de5a3e9767eb",
   "metadata": {},
   "source": [
    "## Data Preprocessing and saving\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88c0f511-18a5-42c3-9166-e49ce16742e7",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "from braindecode.preprocessing import preprocess, Preprocessor, scale as multiply\n",
    "import numpy as np\n",
    "from copy import deepcopy\n",
    "\n",
    "\n",
    "whole_train_set = dataset.split('train')['True']\n",
    "whole_eval_set = dataset.split('train')['False']\n",
    "\n",
    "short_ch_names = sorted([\n",
    "                'A1', 'A2', 'C3', 'C4', 'Cz', 'F3', 'F4', 'F7', 'F8',\n",
    "                'Fp1', 'Fp2', 'Fz', 'O1', 'O2', 'P3', 'P4', 'Pz', 'T3',\n",
    "                 'T4', 'T5', 'T6'\n",
    "            ])\n",
    "ar_ch_names = sorted([\n",
    "    'EEG A1-REF', 'EEG A2-REF',\n",
    "    'EEG FP1-REF', 'EEG FP2-REF', 'EEG F3-REF', 'EEG F4-REF', 'EEG C3-REF',\n",
    "    'EEG C4-REF', 'EEG P3-REF', 'EEG P4-REF', 'EEG O1-REF', 'EEG O2-REF',\n",
    "    'EEG F7-REF', 'EEG F8-REF', 'EEG T3-REF', 'EEG T4-REF', 'EEG T5-REF',\n",
    "    'EEG T6-REF', 'EEG FZ-REF', 'EEG CZ-REF', 'EEG PZ-REF'])\n",
    "le_ch_names = sorted([\n",
    "    'EEG A1-LE', 'EEG A2-LE',\n",
    "    'EEG FP1-LE', 'EEG FP2-LE', 'EEG F3-LE', 'EEG F4-LE', 'EEG C3-LE',\n",
    "    'EEG C4-LE', 'EEG P3-LE', 'EEG P4-LE', 'EEG O1-LE', 'EEG O2-LE',\n",
    "    'EEG F7-LE', 'EEG F8-LE', 'EEG T3-LE', 'EEG T4-LE', 'EEG T5-LE',\n",
    "    'EEG T6-LE', 'EEG FZ-LE', 'EEG CZ-LE', 'EEG PZ-LE'])\n",
    "assert len(short_ch_names) == len(ar_ch_names) == len(le_ch_names)\n",
    "ar_ch_mapping = {ch_name: short_ch_name for ch_name, short_ch_name in zip(\n",
    "    ar_ch_names, short_ch_names)}\n",
    "le_ch_mapping = {ch_name: short_ch_name for ch_name, short_ch_name in zip(\n",
    "    le_ch_names, short_ch_names)}\n",
    "ch_mapping = {'ar': ar_ch_mapping, 'le': le_ch_mapping}\n",
    "\n",
    "\n",
    "\n",
    "def custom_rename_channels(raw, mapping):\n",
    "    # rename channels which are dependent on referencing:\n",
    "    # le: EEG 01-LE, ar: EEG 01-REF\n",
    "    # mne fails if the mapping contains channels as keys that are not present\n",
    "    # in the raw\n",
    "    reference = raw.ch_names[0].split('-')[-1].lower()\n",
    "    assert reference in ['le', 'ref'], 'unexpected referencing'\n",
    "    reference = 'le' if reference == 'le' else 'ar'\n",
    "    raw.rename_channels(mapping[reference])\n",
    "\n",
    "\n",
    "def custom_crop(raw, tmin=0.0, tmax=None, include_tmax=True):\n",
    "    # crop recordings to tmin – tmax. can be incomplete if recording\n",
    "    # has lower duration than tmax\n",
    "    # by default mne fails if tmax is bigger than duration\n",
    "    tmax = min((raw.n_times - 1) / raw.info['sfreq'], tmax)\n",
    "    raw.crop(tmin=tmin, tmax=tmax, include_tmax=include_tmax)\n",
    "\n",
    "\n",
    "n_max_minutes=21\n",
    "tmin = 1 * 60\n",
    "tmax = n_max_minutes * 60\n",
    "sfreq = 100\n",
    "\n",
    "preprocessors = [\n",
    "    Preprocessor(custom_crop, tmin=tmin, tmax=tmax, include_tmax=False,\n",
    "                 apply_on_array=False),\n",
    "\n",
    "    Preprocessor(custom_rename_channels, mapping=ch_mapping,\n",
    "                 apply_on_array=False),\n",
    "    Preprocessor('pick_channels', ch_names=short_ch_names, ordered=True),\n",
    " \n",
    "    Preprocessor(multiply, factor=1e6, apply_on_array=True),\n",
    "    Preprocessor(np.clip, a_min=-800, a_max=800, apply_on_array=True),\n",
    "    \n",
    "    Preprocessor('set_eeg_reference', ref_channels='average', ch_type='eeg'),\n",
    "\n",
    "    Preprocessor('resample', sfreq=sfreq),\n",
    "    Preprocessor('set_meas_date', meas_date=None)\n",
    "    \n",
    "]\n",
    "# Preprocess the data\n",
    "preprocess(whole_train_set, preprocessors)\n",
    "\n",
    "\n",
    "# OR Preprocess and save dataset\n",
    "preprocess(\n",
    "            concat_ds=whole_train_set,\n",
    "            preprocessors=preprocessors,\n",
    "            n_jobs=4, \n",
    "            save_dir='/home/data/preprocessed_TUAB/final_train/', \n",
    "        )\n",
    "\n",
    "\n",
    "preprocess(\n",
    "            concat_ds=whole_eval_set,\n",
    "            preprocessors=preprocessors,\n",
    "            n_jobs=4, \n",
    "            save_dir='/home/data/preprocessed_TUAB/final_eval/', \n",
    "        )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d144d0a-496f-4681-974b-545bfda9ba85",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef3ad19c-cec1-4d7f-a0a4-e44a2ed0bc97",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
