{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Process eeg data  \n",
    "load eeg data from `.mat` files, process each sample with:\n",
    "1. temporal downsampling from 500Hz to 128Hz;\n",
    "2. 0-padding to 1280 time points, 128 channels;\n",
    "3. excluding invalid samples;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import scipy\n",
    "import h5py\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from glob import glob\n",
    "import rich.progress as rp\n",
    "from typing import Literal\n",
    "\n",
    "def mat2df_zuco(dataset_name: Literal['ZuCo1','ZuCo2'],\n",
    "                eeg_src_dir: os.PathLike, \n",
    "                task_dir_names: list[str],\n",
    "                task_keys: list[str],\n",
    "                subject_keys: list[str],\n",
    "                n_sentences: list[str], # see zuco paper\n",
    "                src_sample_rate = 500,\n",
    "                tgt_sample_rate = 128, \n",
    "                tgt_max_len = 1280, \n",
    "                tgt_width = 128,\n",
    "                ):\n",
    "    \n",
    "    n_subjects = len(subject_keys)\n",
    "    n_records_expected = [x * n_subjects for x in n_sentences] \n",
    "    assert tgt_sample_rate <= src_sample_rate\n",
    "    \n",
    "    with rp.Progress(rp.SpinnerColumn(),\n",
    "            rp.TextColumn(\"[progress.description]{task.description}\"),\n",
    "            rp.BarColumn(),\n",
    "            rp.TaskProgressColumn(),\n",
    "            \"•\", rp.TextColumn((\"Total: {task.total} Recorded: {task.fields[n_recorded]} \"\n",
    "                                \"Dropped: {task.fields[dropped_records]} ({task.fields[drop_rate]:.2f}%)\")),\n",
    "            \"•\", rp.TimeElapsedColumn()) as prep:\n",
    "        records = []\n",
    "        dropped_lens = []\n",
    "        # unmatched_sentences = []\n",
    "        for i, task_dir_name in enumerate(task_dir_names): # iterate over 3 tasks\n",
    "\n",
    "            mat_dir = eeg_src_dir + f'/{task_dir_name}/Matlab files'\n",
    "            mat_paths = sorted(glob(mat_dir + '/*.mat', recursive=False))\n",
    "            assert len(mat_paths) == n_subjects, f'{task_dir_name}:We have 12 subjects for each task!'\n",
    "            n_recorded = 0\n",
    "            dropped_records = 0\n",
    "            drop_rate = 0\n",
    "            task_key = task_keys[i]\n",
    "            dataset_key = dataset_name\n",
    "            task_proc = prep.add_task(f'Proc {task_key}...', \n",
    "                                      total= n_records_expected[i], \n",
    "                                      n_recorded = n_recorded,\n",
    "                                      dropped_records=dropped_records, \n",
    "                                      drop_rate=drop_rate)\n",
    "\n",
    "            for mat_path in mat_paths: # iterate over 12 subjects\n",
    "                subject_key = os.path.basename(mat_path).split('_')[0].replace('results','').strip()\n",
    "                assert subject_key in subject_keys\n",
    "                if dataset_name == 'ZuCo1':\n",
    "                    task_records = scipy.io.loadmat(mat_path, squeeze_me=True, \n",
    "                                                    struct_as_record=False)['sentenceData']\n",
    "                    n = len(task_records)\n",
    "                elif dataset_name == 'ZuCo2':\n",
    "                    mat = h5py.File(mat_path, 'r')\n",
    "                    n = len(mat['sentenceData']['rawData'])\n",
    "                assert n_sentences[i] == n, \\\n",
    "                    f'the actual num of sentences ({n}) does not match the expectation ({n_sentences[i]})'\n",
    "\n",
    "                for j in range(n): \n",
    "                    if dataset_name == 'ZuCo1':\n",
    "                        eeg_raw = task_records[j].rawData  # the raw sentence-level EEG time-series, \n",
    "                        text_raw = task_records[j].content\n",
    "                    elif dataset_name == 'ZuCo2':\n",
    "                        eeg_raw = mat[mat['sentenceData']['rawData'][j][0]][:].T.astype(np.float32)\n",
    "                        text_raw = ''.join(chr(int(k)) for k in mat[mat['sentenceData']['content'][j][0]][:].squeeze())\n",
    "                    \n",
    "                    # exclude nan/inf eeg samples\n",
    "                    if not np.all(np.isfinite(eeg_raw)):  \n",
    "                        dropped_records += 1\n",
    "                        continue\n",
    "                        \n",
    "                    assert eeg_raw[-1].any() == False \n",
    "                    # NOTE: the last channel is all empty!!!\n",
    "                    # why has this never been mentioned before? even in the original paper/repo.\n",
    "                    eeg104 = eeg_raw[:-1, :]  # (104, x)\n",
    "\n",
    "                    width, len_raw = eeg104.shape\n",
    "                    if len_raw < 0.5*src_sample_rate or len_raw > 10*src_sample_rate: # (0.5s, 12s) at 500Hz \n",
    "                        dropped_records += 1\n",
    "                        dropped_lens.append(len_raw)\n",
    "                        continue\n",
    "\n",
    "                    len_new = int(len_raw * tgt_sample_rate / src_sample_rate)\n",
    "                    eeg = scipy.signal.resample(eeg104, len_new, axis=1)  # dtype=float32\n",
    "                    eeg = np.pad(eeg, ((0, tgt_width - width), (0, tgt_max_len - len_new)), \n",
    "                                 'constant', constant_values=0)\n",
    "                    mask = np.zeros(tgt_max_len, dtype=np.int8) \n",
    "                    mask[:len_new] = 1  # 1 for `not masked`, 0 for `masked`\n",
    "                    records.append({\n",
    "                                    'eeg': eeg.T, \n",
    "                                    'mask': mask,\n",
    "                                    'text': text_raw,\n",
    "                                    'dataset': dataset_key,\n",
    "                                    'task': task_key,\n",
    "                                    'subject': subject_key\n",
    "                                    })\n",
    "                    n_recorded += 1\n",
    "                        \n",
    "                    drop_rate = (dropped_records / (n_recorded + dropped_records)) * 100\n",
    "                    prep.update(task_proc, advance=1, n_recorded = n_recorded, \n",
    "                                dropped_records=dropped_records, drop_rate=drop_rate)\n",
    "    print(f'Done! {len(records)} / {sum(n_records_expected)} are recorded!')      \n",
    "    print(f'{len(dropped_lens)} / {sum(n_records_expected)-len(records)} are dropped due to the length!') \n",
    "    # print(drop_lens)             \n",
    "    df = pd.DataFrame(records) \n",
    "    return df\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ZuCo1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ref https://www.nature.com/articles/sdata2018291/tables/4  \n",
    "ref https://osf.io/q3zws/wiki/home/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a673b86f27354957846c38e4c630ac95",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done! 11113 / 13284 are recorded!\n",
      "1357 / 2171 are dropped due to the length!\n"
     ]
    }
   ],
   "source": [
    "data_dir = './data/raw_data'\n",
    "df_zuco1 = mat2df_zuco(dataset_name='ZuCo1',\n",
    "                       eeg_src_dir = data_dir + '/ZuCo1',\n",
    "                       task_dir_names = ['task1- SR', 'task2 - NR', 'task3 - TSR'],\n",
    "                       task_keys = ['task1', 'task2', 'task3'],\n",
    "                       subject_keys = ['ZAB', 'ZDM', 'ZDN', 'ZGW', 'ZJM', 'ZJN', \\\n",
    "                                       'ZJS', 'ZKB', 'ZKH', 'ZKW', 'ZMG', 'ZPH'],\n",
    "                       n_sentences = [400, 300, 407])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ZuCo2  \n",
    "ref https://aclanthology.org/2020.lrec-1.18.pdf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "347af4a73b234bf4b08da544bb945d90",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Done! 11222 / 13302 are recorded!\n",
      "1281 / 2080 are dropped due to the length!\n"
     ]
    }
   ],
   "source": [
    "df_zuco2 = mat2df_zuco(dataset_name='ZuCo2',\n",
    "                       eeg_src_dir = data_dir + '/ZuCo2',\n",
    "                       task_dir_names = ['task1 - NR', 'task2 - TSR'],  # NOTE\n",
    "                       task_keys = ['task2', 'task3'],\n",
    "                       subject_keys = ['YAC', 'YAG', 'YAK', 'YDG', 'YDR', 'YFR', \\\n",
    "                                       'YFS', 'YHS', 'YIS', 'YLS', 'YMD', 'YMS', \\\n",
    "                                       'YRH', 'YRK', 'YRP', 'YSD', 'YSL', 'YTL'],\n",
    "                       n_sentences = [349, 390])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(22335, 6)\n",
      "Index(['eeg', 'mask', 'text', 'dataset', 'task', 'subject'], dtype='object')\n"
     ]
    }
   ],
   "source": [
    "df = pd.concat([df_zuco1, df_zuco2], ignore_index=True)\n",
    "print(df.shape)\n",
    "print(df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "pd.to_pickle(df, './data/tmp/zuco_eeg_128ch_1280len.df')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "glim-clone",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
