{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys\n",
    "sys.path.append('../')\n",
    "sys.path.append('../../')\n",
    "import numpy as np\n",
    "import h5py\n",
    "import pickle\n",
    "from data_utils import getSeizureTimes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# File markers for EEG clips"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "RESAMPLE_DIR = '/media/nvme_data/siyitang/TUH_eeg_seq_v1.5.2/resampled_signal/'\n",
    "#RESAMPLE_DIR = '/home/siyitang/data/TUH_v1.5.2/TUH_eeg_seq_v1.5.2/resampled_signal'\n",
    "CLIP_LEN = 12\n",
    "TIME_STEP_SIZE = 1\n",
    "STRIDE = CLIP_LEN\n",
    "FREQUENCY = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "5179\n928\n1192\n"
     ]
    }
   ],
   "source": [
    "FILES_TO_CONSIDER = {}\n",
    "for split in ['train', 'dev', 'test']:\n",
    "    file_to_consider_txt = split+'Set_seizureDetect_files.txt'\n",
    "    with open(file_to_consider_txt, 'r') as f:\n",
    "        fstr = f.readlines()\n",
    "    FILES_TO_CONSIDER[split] = [fstr[i].strip('\\n').split(',')[0].split('/')[-1] for i in range(len(fstr))]\n",
    "print(len(FILES_TO_CONSIDER['train']))\n",
    "print(len(FILES_TO_CONSIDER['dev']))\n",
    "print(len(FILES_TO_CONSIDER['test']))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "['00008295_s009_t007.edf',\n",
       " '00010591_s001_t000.edf',\n",
       " '00010489_s008_t007.edf',\n",
       " '00012262_s007_t001.edf',\n",
       " '00004294_s001_t000.edf']"
      ]
     },
     "metadata": {},
     "execution_count": 4
    }
   ],
   "source": [
    "FILES_TO_CONSIDER['train'][:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "RAW_DATA_DIR = \"/media/nvme_data/TUH/v1.5.2/edf/\"\n",
    "#RAW_DATA_DIR = \"/data/crypt/eegdbs/temple/tuh_eeg_seizure/v1.5.2/edf/\"\n",
    "\n",
    "edf_files = []\n",
    "for path, subdirs, files in os.walk(RAW_DATA_DIR):\n",
    "    for name in files:\n",
    "        if \".edf\" in name:\n",
    "            edf_files.append(os.path.join(path, name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "VARIABLE_LENGTH = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(123)\n",
    "\n",
    "resampled_files = os.listdir(RESAMPLE_DIR)\n",
    "for split in ['train', 'dev', 'test']:\n",
    "    physical_clip_len = int(FREQUENCY*CLIP_LEN)\n",
    "    \n",
    "    if VARIABLE_LENGTH:\n",
    "        filemarker = os.path.join(\"variable_length\", \n",
    "                split+\"_cliplen\"+str(CLIP_LEN)+\"_stride\"+str(STRIDE)+\"_timestep\"+str(TIME_STEP_SIZE)+\".txt\")\n",
    "    else:\n",
    "        filemarker = split+\"_cliplen\"+str(CLIP_LEN)+\"_stride\"+str(STRIDE)+\".txt\"\n",
    "    \n",
    "    write_str = []\n",
    "    for h5_fn in resampled_files:\n",
    "        edf_fn = h5_fn.split('.h5')[0]+'.edf'\n",
    "        if edf_fn not in FILES_TO_CONSIDER[split]:\n",
    "            continue\n",
    "        edf_fn_full = [file for file in edf_files if edf_fn in file]\n",
    "        if len(edf_fn_full) != 1:\n",
    "            print(\"{} found {} times!\".format(edf_fn, len(edf_fn_full)))\n",
    "            print(edf_fn_full)\n",
    "        edf_fn_full = edf_fn_full[0]\n",
    "        seizure_times = getSeizureTimes(edf_fn_full.split('.edf')[0])\n",
    "        \n",
    "        h5_fn_full = os.path.join(RESAMPLE_DIR, h5_fn)\n",
    "        with h5py.File(h5_fn_full, 'r') as hf:\n",
    "            resampled_sig = hf[\"resampled_signal\"][()]\n",
    "        \n",
    "        if VARIABLE_LENGTH:\n",
    "            num_clips = (resampled_sig.shape[-1] - CLIP_LEN * FREQUENCY) // (STRIDE * FREQUENCY) + 2\n",
    "        else:\n",
    "            num_clips = (resampled_sig.shape[-1] - CLIP_LEN * FREQUENCY) // (STRIDE * FREQUENCY) + 1\n",
    "        \n",
    "        for i in range(num_clips):\n",
    "            start_window = i * FREQUENCY * STRIDE\n",
    "            end_window = np.minimum(start_window + FREQUENCY * CLIP_LEN, resampled_sig.shape[-1])\n",
    "            \n",
    "            # only include last short clip if it's longer than 60s time step size\n",
    "            if VARIABLE_LENGTH:\n",
    "                if (i == num_clips-1) and (end_window - start_window) < (TIME_STEP_SIZE * FREQUENCY):\n",
    "                    break\n",
    "                \n",
    "            is_seizure = 0\n",
    "            for t in seizure_times:\n",
    "                start_t = int(t[0] * FREQUENCY)\n",
    "                end_t = int(t[1] * FREQUENCY)\n",
    "                if not ((end_window < start_t) or (start_window > end_t)):\n",
    "                    is_seizure = 1\n",
    "                    break\n",
    "            write_str.append(edf_fn + ',' + str(i) + ',' + str(is_seizure) + '\\n')\n",
    "    \n",
    "    np.random.shuffle(write_str)\n",
    "    with open(filemarker, 'w') as f:\n",
    "        for curr_str in write_str:\n",
    "            f.writelines(curr_str)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get seizure/non-seizure balanced train set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "if VARIABLE_LENGTH:\n",
    "    train_filemarker = os.path.join(\"variable_length\", \n",
    "                split+\"_cliplen\"+str(CLIP_LEN)+\"_stride\"+str(STRIDE)+\"_timestep\"+str(TIME_STEP_SIZE)+\".txt\")\n",
    "else:\n",
    "    train_filemarker = os.path.join(\"train_cliplen\"+str(CLIP_LEN)+\"_stride\"+str(STRIDE)+\".txt\")\n",
    "        \n",
    "with open(train_filemarker, 'r') as f:\n",
    "    train_str = f.readlines()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "13646\n183000\n"
     ]
    }
   ],
   "source": [
    "sz_tuples = []\n",
    "nonsz_tuples = []\n",
    "for curr_str in train_str:\n",
    "    file, clip_idx, sz_label = curr_str.strip('\\n').split(',')\n",
    "    sz_label = int(sz_label)\n",
    "    if sz_label == 1:\n",
    "        sz_tuples.append((file, clip_idx, sz_label))\n",
    "    else:\n",
    "        nonsz_tuples.append((file, clip_idx, sz_label))\n",
    "print(len(sz_tuples))\n",
    "print(len(nonsz_tuples))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Keep all the seizure files and undersample non-seizure files..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "13646"
      ]
     },
     "metadata": {},
     "execution_count": 10
    }
   ],
   "source": [
    "np.random.seed(123)\n",
    "\n",
    "np.random.shuffle(nonsz_tuples)\n",
    "nonsz_tuple_small = nonsz_tuples[:len(sz_tuples)]\n",
    "\n",
    "len(nonsz_tuple_small)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "balanced_files = sz_tuples + nonsz_tuple_small\n",
    "np.random.shuffle(balanced_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "output_type": "execute_result",
     "data": {
      "text/plain": [
       "[('00007584_s004_t000.edf', '70', 1),\n",
       " ('00002991_s004_t006.edf', '2', 0),\n",
       " ('00007032_s008_t000.edf', '146', 1),\n",
       " ('00012940_s001_t003.edf', '173', 1),\n",
       " ('00010455_s003_t001.edf', '45', 1)]"
      ]
     },
     "metadata": {},
     "execution_count": 12
    }
   ],
   "source": [
    "balanced_files[:5]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "if VARIABLE_LENGTH:\n",
    "    balanced_train_filemarker = os.path.join(\n",
    "        \"variable_length\", \"train_cliplen\"+str(CLIP_LEN)+\"_stride\"+str(STRIDE)+\"_timestep\"+str(TIME_STEP_SIZE)+\"_balanced.txt\")\n",
    "else:\n",
    "    balanced_train_filemarker = \"train_cliplen\"+str(CLIP_LEN)+\"_stride\"+str(STRIDE)+\"_balanced.txt\"\n",
    "\n",
    "with open(balanced_train_filemarker, \"w\") as f:\n",
    "    for tup in balanced_files:\n",
    "        f.writelines(tup[0] + ',' + str(tup[1]) + ',' + str(tup[2]) + '\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Get `pos_weight` to weigh the loss function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 417,
   "metadata": {},
   "outputs": [],
   "source": [
    "from data.dataloader import load_dataset\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 320,
   "metadata": {},
   "outputs": [],
   "source": [
    "## on gemini\n",
    "RAW_DATA_DIR = \"/media/nvme_data/TUH/v1.5.2/\"\n",
    "PREPROC_DIR = \"/media/nvme_data/siyitang/TUH_eeg_seq_v1.5.2/resampled_signal\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "USE_FFT = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataloaders = load_dataset(input_dir=PREPROC_DIR, \n",
    "                           raw_data_dir=RAW_DATA_DIR, \n",
    "                           train_batch_size=64, \n",
    "                           test_batch_size=64,\n",
    "                           clip_len=CLIP_LEN, \n",
    "                           time_step_size=TIME_STEP_SIZE, \n",
    "                           stride=STRIDE,\n",
    "                           standardize=False, \n",
    "                           num_workers=8, \n",
    "                           augmentation=True,\n",
    "                           use_fft=USE_FFT,\n",
    "                          balance_train=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "y_train = []\n",
    "x_train = []\n",
    "file_name_train = []\n",
    "for x, y, _, _, _, file_name in dataloaders['train']:\n",
    "    y_train.append(y)\n",
    "    x_train.append(x)\n",
    "    file_name_train.extend(file_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1152, 60, 19, 100)"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train = torch.cat(x_train, dim=0)\n",
    "x_train = x_train.data.cpu().numpy()\n",
    "x_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1152, 60)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_train = torch.cat(y_train, dim=0)\n",
    "y_train = y_train.data.cpu().numpy()\n",
    "y_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1152,)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y_single = np.sum(y_train, axis=-1)\n",
    "y_single.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7837"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pos_clip_idxs = (y_single != 0)\n",
    "pos_timesteps = np.sum(y_train[pos_clip_idxs,:] == 1)\n",
    "pos_timesteps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "26723"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pos_clip_neg_timesteps = np.sum(y_train[pos_clip_idxs,:] == 0)\n",
    "pos_clip_neg_timesteps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "34560"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "neg_clip_idxs = (y_single == 0)\n",
    "neg_clip_neg_timesteps = np.sum(y_train[neg_clip_idxs,:] == 0)\n",
    "neg_clip_neg_timesteps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total positive time steps: 7837\n",
      "Total negative time steps: 61283\n"
     ]
    }
   ],
   "source": [
    "print(\"Total positive time steps:\", pos_timesteps)\n",
    "print(\"Total negative time steps:\", neg_clip_neg_timesteps+pos_clip_neg_timesteps)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7.8197014163583"
      ]
     },
     "execution_count": 29,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pos_weight = (neg_clip_neg_timesteps+pos_clip_neg_timesteps) / pos_timesteps\n",
    "pos_weight"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute mean & std of train set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1152, 60, 19, 100)"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_train.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mean: 3.882, Std: 1.566\n"
     ]
    }
   ],
   "source": [
    "mean = np.mean(x_train)\n",
    "std = np.std(x_train)\n",
    "print(\"Mean: {:.3f}, Std: {:.3f}\".format(mean, std))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "if USE_FFT:\n",
    "    with open(\"./mean_cliplen\"+str(CLIP_LEN)+\"_stride\"+str(STRIDE)+\"_fft.pkl\", \"wb\") as pf:\n",
    "        pickle.dump(mean, pf)\n",
    "    with open(\"./std_cliplen\"+str(CLIP_LEN)+\"_stride\"+str(STRIDE)+\"_fft.pkl\", \"wb\") as pf:\n",
    "        pickle.dump(std, pf)\n",
    "else:\n",
    "    with open(\"./mean_cliplen\"+str(CLIP_LEN)+\"_stride\"+str(STRIDE)+\".pkl\", \"wb\") as pf:\n",
    "        pickle.dump(mean, pf)\n",
    "    with open(\"./std_cliplen\"+str(CLIP_LEN)+\"_stride\"+str(STRIDE)+\".pkl\", \"wb\") as pf:\n",
    "        pickle.dump(std, pf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "name": "python3613jvsc74a57bd07e2a9fa3d96e0b0167b7c016fb778693fe54466e00460bfa005b3c271472f290",
   "display_name": "Python 3.6.13 64-bit ('lssl': conda)"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  },
  "metadata": {
   "interpreter": {
    "hash": "7e2a9fa3d96e0b0167b7c016fb778693fe54466e00460bfa005b3c271472f290"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}