{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "sessionNames = ['t12.2022.04.28',  't12.2022.05.26',  't12.2022.06.21',  't12.2022.07.21',  't12.2022.08.13',\n",
    "'t12.2022.05.05',  't12.2022.06.02',  't12.2022.06.23',  't12.2022.07.27',  't12.2022.08.18',\n",
    "'t12.2022.05.17',  't12.2022.06.07',  't12.2022.06.28',  't12.2022.07.29',  't12.2022.08.23',\n",
    "'t12.2022.05.19',  't12.2022.06.14',  't12.2022.07.05',  't12.2022.08.02',  't12.2022.08.25',\n",
    "'t12.2022.05.24',  't12.2022.06.16',  't12.2022.07.14',  't12.2022.08.11']\n",
    "sessionNames.sort()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "ALL_CHANNELS = False\n",
    "\n",
    "BROCA_CHANNELS = True"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re \n",
    "from g2p_en import G2p\n",
    "import numpy as np\n",
    "\n",
    "g2p = G2p()\n",
    "PHONE_DEF = [\n",
    "    'AA', 'AE', 'AH', 'AO', 'AW',\n",
    "    'AY', 'B',  'CH', 'D', 'DH',\n",
    "    'EH', 'ER', 'EY', 'F', 'G',\n",
    "    'HH', 'IH', 'IY', 'JH', 'K',\n",
    "    'L', 'M', 'N', 'NG', 'OW',\n",
    "    'OY', 'P', 'R', 'S', 'SH',\n",
    "    'T', 'TH', 'UH', 'UW', 'V',\n",
    "    'W', 'Y', 'Z', 'ZH'\n",
    "]\n",
    "PHONE_DEF_SIL = PHONE_DEF + ['SIL']\n",
    "\n",
    "def phoneToId(p):\n",
    "    return PHONE_DEF_SIL.index(p)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "phoneToIdDict = {p:phoneToId(p) for p in PHONE_DEF_SIL}\n",
    "idToPhone = {v: k for k, v in phoneToIdDict.items()}\n",
    "\n",
    "def idsToPhonemes(seqClassIDs, idToPhone = idToPhone):\n",
    "    \"\"\"\n",
    "    Converts a sequence of phoneme IDs back to their phoneme representations.\n",
    "    \n",
    "    Args:\n",
    "        seqClassIDs (numpy array): The numerical sequence of phoneme IDs.\n",
    "        idToPhone (dict): A dictionary mapping phoneme IDs back to phonemes.\n",
    "        \n",
    "    Returns:\n",
    "        list: The corresponding phoneme sequence.\n",
    "    \"\"\"\n",
    "    phonemeSeq = [idToPhone[id - 1] for id in seqClassIDs if id > 0]  # -1 because IDs were stored with +1\n",
    "    return phonemeSeq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: 'AA',\n",
       " 1: 'AE',\n",
       " 2: 'AH',\n",
       " 3: 'AO',\n",
       " 4: 'AW',\n",
       " 5: 'AY',\n",
       " 6: 'B',\n",
       " 7: 'CH',\n",
       " 8: 'D',\n",
       " 9: 'DH',\n",
       " 10: 'EH',\n",
       " 11: 'ER',\n",
       " 12: 'EY',\n",
       " 13: 'F',\n",
       " 14: 'G',\n",
       " 15: 'HH',\n",
       " 16: 'IH',\n",
       " 17: 'IY',\n",
       " 18: 'JH',\n",
       " 19: 'K',\n",
       " 20: 'L',\n",
       " 21: 'M',\n",
       " 22: 'N',\n",
       " 23: 'NG',\n",
       " 24: 'OW',\n",
       " 25: 'OY',\n",
       " 26: 'P',\n",
       " 27: 'R',\n",
       " 28: 'S',\n",
       " 29: 'SH',\n",
       " 30: 'T',\n",
       " 31: 'TH',\n",
       " 32: 'UH',\n",
       " 33: 'UW',\n",
       " 34: 'V',\n",
       " 35: 'W',\n",
       " 36: 'Y',\n",
       " 37: 'Z',\n",
       " 38: 'ZH',\n",
       " 39: 'SIL'}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idToPhone"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy\n",
    "\n",
    "def loadFeaturesAndNormalize(sessionPath, all_channels=False):\n",
    "    \n",
    "    \n",
    "    dat = scipy.io.loadmat(sessionPath)\n",
    "\n",
    "    input_features = []\n",
    "    transcriptions = []\n",
    "    frame_lens = []\n",
    "    block_means = []\n",
    "    block_stds = []\n",
    "    n_trials = dat['sentenceText'].shape[0]\n",
    "\n",
    "    #collect area 6v tx1 and spikePow features\n",
    "    for i in range(n_trials):    \n",
    "        #get time series of TX and spike power for this trial\n",
    "        #first 128 columns = area 6v only\n",
    "\n",
    "        if all_channels:\n",
    "            features = np.concatenate([dat['tx1'][0,i], dat['spikePow'][0,i]], axis=1)\n",
    "\n",
    "        elif BROCA_CHANNELS:\n",
    "            features = np.concatenate([dat['tx1'][0,i][:,128:], dat['spikePow'][0,i][:,128:]], axis=1)\n",
    "        else:\n",
    "            features = np.concatenate([dat['tx1'][0,i][:,0:128], dat['spikePow'][0,i][:,0:128]], axis=1)\n",
    "\n",
    "        sentence_len = features.shape[0]\n",
    "        sentence = dat['sentenceText'][i].strip()\n",
    "\n",
    "        input_features.append(features)\n",
    "        transcriptions.append(sentence)\n",
    "        frame_lens.append(sentence_len)\n",
    "\n",
    "    #block-wise feature normalization\n",
    "    blockNums = np.squeeze(dat['blockIdx'])\n",
    "    blockList = np.unique(blockNums)\n",
    "    blocks = []\n",
    "    for b in range(len(blockList)):\n",
    "        sentIdx = np.argwhere(blockNums==blockList[b])\n",
    "        sentIdx = sentIdx[:,0].astype(np.int32)\n",
    "        blocks.append(sentIdx)\n",
    "\n",
    "    for b in range(len(blocks)):\n",
    "        feats = np.concatenate(input_features[blocks[b][0]:(blocks[b][-1]+1)], axis=0)\n",
    "        feats_mean = np.mean(feats, axis=0, keepdims=True)\n",
    "        feats_std = np.std(feats, axis=0, keepdims=True)\n",
    "        for i in blocks[b]:\n",
    "            input_features[i] = (input_features[i] - feats_mean) / (feats_std + 1e-8)\n",
    "\n",
    "    #convert to tfRecord file\n",
    "    session_data = {\n",
    "        'inputFeatures': input_features,\n",
    "        'transcriptions': transcriptions,\n",
    "        'frameLens': frame_lens\n",
    "    }\n",
    "\n",
    "    return session_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "def getDataset(fileName):\n",
    "    session_data = loadFeaturesAndNormalize(fileName, all_channels=ALL_CHANNELS)\n",
    "        \n",
    "    allDat = []\n",
    "    trueSentences = []\n",
    "    seqElements = []\n",
    "    \n",
    "    for x in range(len(session_data['inputFeatures'])):\n",
    "        allDat.append(session_data['inputFeatures'][x])\n",
    "        trueSentences.append(session_data['transcriptions'][x])\n",
    "        \n",
    "        thisTranscription = str(session_data['transcriptions'][x]).strip()\n",
    "        thisTranscription = re.sub(r'[^a-zA-Z\\- \\']', '', thisTranscription)\n",
    "        thisTranscription = thisTranscription.replace('--', '').lower()\n",
    "        addInterWordSymbol = True\n",
    "\n",
    "        phonemes = []\n",
    "        for p in g2p(thisTranscription):\n",
    "            if addInterWordSymbol and p==' ':\n",
    "                phonemes.append('SIL')\n",
    "            p = re.sub(r'[0-9]', '', p)  # Remove stress\n",
    "            if re.match(r'[A-Z]+', p):  # Only keep phonemes\n",
    "                phonemes.append(p)\n",
    "\n",
    "        #add one SIL symbol at the end so there's one at the end of each word\n",
    "        if addInterWordSymbol:\n",
    "            phonemes.append('SIL')\n",
    "\n",
    "        seqLen = len(phonemes)\n",
    "        maxSeqLen = 500\n",
    "        seqClassIDs = np.zeros([maxSeqLen]).astype(np.int32)\n",
    "        seqClassIDs[0:seqLen] = [phoneToId(p) + 1 for p in phonemes]\n",
    "        seqElements.append(seqClassIDs)\n",
    "\n",
    "    newDataset = {}\n",
    "    newDataset['sentenceDat'] = allDat\n",
    "    newDataset['transcriptions'] = trueSentences\n",
    "    newDataset['phonemes'] = seqElements\n",
    "    \n",
    "    timeSeriesLens = []\n",
    "    phoneLens = []\n",
    "    for x in range(len(newDataset['sentenceDat'])):\n",
    "        timeSeriesLens.append(newDataset['sentenceDat'][x].shape[0])\n",
    "        \n",
    "        zeroIdx = np.argwhere(newDataset['phonemes'][x]==0)\n",
    "        phoneLens.append(zeroIdx[0,0])\n",
    "    \n",
    "    newDataset['timeSeriesLens'] = np.array(timeSeriesLens)\n",
    "    newDataset['phoneLens'] = np.array(phoneLens)\n",
    "    newDataset['phonePerTime'] = newDataset['phoneLens'].astype(np.float32) / newDataset['timeSeriesLens'].astype(np.float32)\n",
    "    return newDataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n",
      "10\n",
      "11\n",
      "12\n",
      "13\n",
      "14\n",
      "15\n",
      "16\n",
      "17\n",
      "18\n",
      "19\n",
      "20\n",
      "21\n",
      "22\n",
      "23\n"
     ]
    }
   ],
   "source": [
    "trainDatasets = []\n",
    "testDatasets = []\n",
    "competitionDatasets = []\n",
    "\n",
    "dataDir = \"/data/datasets/speechBCI/competitionData\"\n",
    "\n",
    "for dayIdx in range(len(sessionNames)):\n",
    "    print(dayIdx)\n",
    "    trainDataset = getDataset(dataDir + '/train/' + sessionNames[dayIdx] + '.mat')\n",
    "    testDataset = getDataset(dataDir + '/test/' + sessionNames[dayIdx] + '.mat')\n",
    "\n",
    "    trainDatasets.append(trainDataset)\n",
    "    testDatasets.append(testDataset)\n",
    "\n",
    "    if os.path.exists(dataDir + '/competitionHoldOut/' + sessionNames[dayIdx] + '.mat'):\n",
    "        dataset = getDataset(dataDir + '/competitionHoldOut/' + sessionNames[dayIdx] + '.mat')\n",
    "        competitionDatasets.append(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 18, 19, 20]\n"
     ]
    }
   ],
   "source": [
    "competitionDays = []\n",
    "for dayIdx in range(len(sessionNames)):\n",
    "    if os.path.exists(dataDir + '/competitionHoldOut/' + sessionNames[dayIdx] + '.mat'):\n",
    "        competitionDays.append(dayIdx)\n",
    "print(competitionDays)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['sentenceDat', 'transcriptions', 'phonemes', 'timeSeriesLens', 'phoneLens', 'phonePerTime'])"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainDatasets[0].keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The spray will be used in first division matches next season.\n",
      "['DH', 'AH', 'SIL', 'S', 'P', 'R', 'EY', 'SIL', 'W', 'IH', 'L', 'SIL', 'B', 'IY', 'SIL', 'Y', 'UW', 'Z', 'D', 'SIL', 'IH', 'N', 'SIL', 'F', 'ER', 'S', 'T', 'SIL', 'D', 'IH', 'V', 'IH', 'ZH', 'AH', 'N', 'SIL', 'M', 'AE', 'CH', 'AH', 'Z', 'SIL', 'N', 'EH', 'K', 'S', 'T', 'SIL', 'S', 'IY', 'Z', 'AH', 'N', 'SIL']\n"
     ]
    }
   ],
   "source": [
    "idx = 2\n",
    "print(trainDatasets[0][\"transcriptions\"][idx])\n",
    "print(idsToPhonemes(trainDatasets[0][\"phonemes\"][idx]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "allDatasets = {}\n",
    "allDatasets['train'] = trainDatasets\n",
    "allDatasets['test'] = testDatasets\n",
    "allDatasets['competition'] = competitionDatasets\n",
    "\n",
    "if ALL_CHANNELS:\n",
    "    with open('datasets_all_channels', 'wb') as handle:\n",
    "        pickle.dump(allDatasets, handle)\n",
    "elif BROCA_CHANNELS:\n",
    "    with open('datasets_broca_channels', 'wb') as handle:\n",
    "        pickle.dump(allDatasets, handle)\n",
    "else:\n",
    "    with open('datasets_baseline', 'wb') as handle:\n",
    "        pickle.dump(allDatasets, handle)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
