{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "33da12ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import time\n",
    "from tqdm import tqdm\n",
    "import argparse\n",
    "\n",
    "from rnn_model import GRUDecoder\n",
    "from evaluate_model_helpers import *\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "6a2c3e1b",
   "metadata": {},
   "outputs": [],
   "source": [
    "PHONE_DEF = [\n",
    "    'AA', 'AE', 'AH', 'AO', 'AW',\n",
    "    'AY', 'B',  'CH', 'D', 'DH',\n",
    "    'EH', 'ER', 'EY', 'F', 'G',\n",
    "    'HH', 'IH', 'IY', 'JH', 'K',\n",
    "    'L', 'M', 'N', 'NG', 'OW',\n",
    "    'OY', 'P', 'R', 'S', 'SH',\n",
    "    'T', 'TH', 'UH', 'UW', 'V',\n",
    "    'W', 'Y', 'Z', 'ZH'\n",
    "]\n",
    "PHONE_DEF_SIL = PHONE_DEF + ['SIL']\n",
    "\n",
    "def phoneToId(p):\n",
    "    return PHONE_DEF_SIL.index(p)\n",
    "\n",
    "phoneToIdDict = {p:phoneToId(p) for p in PHONE_DEF_SIL}\n",
    "\n",
    "class AugmentedNeuralTextDataset(torch.utils.data.Dataset):\n",
    "    \"\"\"augmentations ideas:\n",
    "        1. truncate the signal\n",
    "        2. mixup two samples from the same day replacing a word\n",
    "        3. add a word from another sample in the same day at beginning or end\n",
    "\n",
    "        #for the future, check only valid sentences\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, data_dict, mix_prob = 0.5, subset_len = 3):\n",
    "        self.neural_features = data_dict['neural_features']\n",
    "        self.n_time_steps = data_dict['n_time_steps']\n",
    "        self.seq_class_ids = data_dict['seq_class_ids']\n",
    "        self.seq_len = data_dict['seq_len']\n",
    "        self.sentence_label = data_dict['sentence_label']\n",
    "        self.merged_segments = data_dict['merged_segments']\n",
    "        self.assignments = data_dict['assignments']\n",
    "        self.transcriptions = data_dict[\"sentence_label\"]\n",
    "        self.days = torch.tensor(data_dict[\"dayIdx\"], dtype=torch.long)\n",
    "        self.mix_prob = mix_prob\n",
    "        self.subset_len = subset_len\n",
    "\n",
    "        self.aug_methods = [\"truncate\", \"mixup\"]\n",
    "        self.aug_p = [0.50, 0.50]# probabilities for each augmentation method\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.neural_features)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "\n",
    "        seq_class_ids = self.seq_class_ids[idx]\n",
    "        #remove padding (class 0)\n",
    "        seq_class_ids = [p.item() for p in seq_class_ids if p != 0]\n",
    "\n",
    "        item =  {\n",
    "            'neural_features': torch.tensor(self.neural_features[idx], dtype=torch.float32),\n",
    "            'n_time_steps': self.n_time_steps[idx],\n",
    "            'seq_class_ids': torch.tensor(seq_class_ids, dtype=torch.long),\n",
    "            'seq_len': torch.tensor(self.seq_len[idx], dtype=torch.long),\n",
    "            'sentence_label': self.sentence_label[idx],\n",
    "            'merged_segments': self.merged_segments[idx],\n",
    "            'assignments': self.assignments[idx],\n",
    "            'transcriptions': self.transcriptions[idx],\n",
    "            'day': self.days[idx],\n",
    "            'augmentation': \"none\"\n",
    "        }\n",
    "\n",
    "        r = np.random.rand()\n",
    "        if r < self.mix_prob:\n",
    "            #do mixup \n",
    "\n",
    "            #1. look for another random index from the same day\n",
    "            # same_day_indices = [i for i, day in enumerate(self.days) if day == self.days[idx] and i != idx]\n",
    "            # if len(same_day_indices) == 0:\n",
    "            #     return item\n",
    "            # rand_idx = np.random.choice(same_day_indices)\n",
    "\n",
    "            # other_item = {\n",
    "            #     'neural_features': torch.tensor(self.neural_features[rand_idx], dtype=torch.float32),\n",
    "            #     'n_time_steps': self.n_time_steps[rand_idx],\n",
    "            #     'seq_class_ids': torch.tensor(self.seq_class_ids[rand_idx], dtype=torch.long),  \n",
    "            #     'seq_len': torch.tensor(self.seq_len[rand_idx], dtype=torch.long),\n",
    "            #     'sentence_label': self.sentence_label[rand_idx],\n",
    "            #     'merged_segments': self.merged_segments[rand_idx],\n",
    "            #     'assignments': self.assignments[rand_idx],\n",
    "            #     'transcriptions': self.transcriptions[rand_idx],\n",
    "            #     'day': self.days[rand_idx]\n",
    "            # }\n",
    "\n",
    "            aug_method = np.random.choice(self.aug_methods, p=self.aug_p)\n",
    "            if aug_method == \"truncate\":\n",
    "\n",
    "                other_item = item.copy()\n",
    "\n",
    "                assignments_full = other_item['assignments']\n",
    "                merged_full = other_item['merged_segments']\n",
    "                n_words = len(assignments_full)\n",
    "\n",
    "                if n_words <= self.subset_len:\n",
    "                    # keep your existing first/last drop logic here (unchanged)\n",
    "                    if n_words == 1:\n",
    "                        return other_item\n",
    "                    if np.random.rand() < 0.5:\n",
    "                        # ---- drop FIRST word (your existing code) ----\n",
    "                        merged_segms = other_item['merged_segments'][0]\n",
    "                        start_idx = merged_segms[-1]\n",
    "                        end_idx = len(other_item['neural_features'])\n",
    "\n",
    "                        other_item['neural_features'] = other_item['neural_features'][start_idx:end_idx]\n",
    "                        other_item['n_time_steps'] = end_idx - start_idx\n",
    "                        other_item['transcriptions'] = \" \".join([a[-1] for a in assignments_full[1:]])\n",
    "                        other_item['sentence_label'] = other_item['transcriptions']\n",
    "\n",
    "                        new_seq_class_ids = []\n",
    "                        for ass in assignments_full[1:]:\n",
    "                            new_seq_class_ids.extend([phoneToId(p) + 1 for p in ass[4]])\n",
    "                            new_seq_class_ids.append(phoneToId(\"SIL\") + 1)\n",
    "                        other_item[\"seq_class_ids\"] = torch.tensor(new_seq_class_ids, dtype=torch.long)\n",
    "                        other_item[\"seq_len\"] = len(new_seq_class_ids)\n",
    "\n",
    "                        # rebase merged segments & assignments to the new slice\n",
    "                        kept_assignments = assignments_full[1:]\n",
    "                        kept_merged = []\n",
    "                        for seg in merged_full[1:]:\n",
    "                            kept_merged.append([ix - start_idx for ix in seg])\n",
    "                        other_item['assignments'] = kept_assignments\n",
    "                        other_item['merged_segments'] = kept_merged\n",
    "\n",
    "                        other_item[\"augmentation\"] = \"drop_first\"\n",
    "\n",
    "                    else:\n",
    "                        # ---- drop LAST word (mirror path) ----\n",
    "                        last_seg = other_item['merged_segments'][-1]\n",
    "                        end_idx = last_seg[0]  # start of last word\n",
    "                        start_idx = 0\n",
    "\n",
    "                        other_item['neural_features'] = other_item['neural_features'][start_idx:end_idx+4] # include a bit of silence after last word\n",
    "                        other_item['n_time_steps'] = end_idx - start_idx\n",
    "                        other_item['transcriptions'] = \" \".join([a[-1] for a in assignments_full[:-1]])\n",
    "                        other_item['sentence_label'] = other_item['transcriptions']\n",
    "\n",
    "                        new_seq_class_ids = []\n",
    "                        for ass in assignments_full[:-1]:\n",
    "                            new_seq_class_ids.extend([phoneToId(p) + 1 for p in ass[4]])\n",
    "                            new_seq_class_ids.append(phoneToId(\"SIL\") + 1)\n",
    "                        other_item[\"seq_class_ids\"] = torch.tensor(new_seq_class_ids, dtype=torch.long)\n",
    "                        other_item[\"seq_len\"] = len(new_seq_class_ids)\n",
    "\n",
    "                        # merged segments/assignments (no rebase needed when starting at 0)\n",
    "                        other_item['assignments'] = assignments_full[:-1]\n",
    "                        other_item['merged_segments'] = merged_full[:-1]\n",
    "\n",
    "                        other_item[\"augmentation\"] = \"drop_last\"\n",
    "                    \n",
    "\n",
    "                else:\n",
    "                    win_len = self.subset_len\n",
    "                    start_word = np.random.randint(0, n_words - win_len + 1)\n",
    "                    end_word = start_word + win_len  # exclusive\n",
    "\n",
    "                    # slice words/segments\n",
    "                    kept_assignments = assignments_full[start_word:end_word]\n",
    "                    kept_merged_abs = merged_full[start_word:end_word]\n",
    "\n",
    "                    # absolute indices to cut the neural signal\n",
    "                    # start at the first kept word's start\n",
    "                    start_idx = int(kept_merged_abs[0][0])\n",
    "\n",
    "                    # end at the NEXT word's start if it exists, else end of the original signal\n",
    "                    if end_word < n_words:\n",
    "                        end_idx = int(merged_full[end_word][0])  # start of the next word (exclusive)\n",
    "                    else:\n",
    "                        # use original (pre-slice) length to avoid compounding\n",
    "                        end_idx = int(len(other_item['neural_features']))\n",
    "\n",
    "                    # safety\n",
    "                    end_idx = max(end_idx, start_idx + 1)\n",
    "\n",
    "                    # slice neural features\n",
    "                    other_item['neural_features'] = other_item['neural_features'][start_idx:end_idx]\n",
    "                    other_item['n_time_steps'] = end_idx - start_idx\n",
    "\n",
    "                    # update text\n",
    "                    other_item['transcriptions'] = \" \".join([a[-1] for a in kept_assignments])\n",
    "                    other_item['sentence_label'] = other_item['transcriptions']\n",
    "\n",
    "                    # rebuild seq_class_ids (insert SIL between words; no trailing SIL)\n",
    "                    new_seq_class_ids = []\n",
    "                    for i, ass in enumerate(kept_assignments):\n",
    "                        new_seq_class_ids.extend([phoneToId(p) + 1 for p in ass[4]])\n",
    "                        if i < len(kept_assignments) - 1:\n",
    "                            new_seq_class_ids.append(phoneToId(\"SIL\") + 1)\n",
    "\n",
    "                    other_item[\"seq_class_ids\"] = torch.tensor(new_seq_class_ids, dtype=torch.long)\n",
    "                    other_item[\"seq_len\"] = len(new_seq_class_ids)\n",
    "\n",
    "                    # rebase merged segments to the new window and CLIP inside [0, end_idx-start_idx)\n",
    "                    L = end_idx - start_idx\n",
    "                    kept_merged_rebased = []\n",
    "                    for seg in kept_merged_abs:\n",
    "                        rebased = [int(ix - start_idx) for ix in seg]\n",
    "                        # clip to be safe (handles inclusive vs exclusive ambiguity)\n",
    "                        rebased = [min(max(0, x), L - 1) for x in rebased]\n",
    "                        # ensure sorted and unique\n",
    "                        rebased = sorted(set(rebased))\n",
    "                        kept_merged_rebased.append(rebased)\n",
    "\n",
    "                    other_item['assignments'] = kept_assignments\n",
    "                    other_item['merged_segments'] = kept_merged_rebased\n",
    "                    other_item[\"augmentation\"] = \"subset\"\n",
    "        \n",
    "                return other_item\n",
    "        \n",
    "            elif aug_method == \"mixup\":\n",
    "                \n",
    "                # ---- MIXUP (variable-length splice) ----\n",
    "                # 1) find donor trial from same day\n",
    "                same_day_indices = [i for i, d in enumerate(self.days) if d == self.days[idx] and i != idx]\n",
    "                if not same_day_indices:\n",
    "                    return item, item  # no donor, bail out safely\n",
    "                rand_idx = int(np.random.choice(same_day_indices))\n",
    "\n",
    "                donor_feats = torch.tensor(self.neural_features[rand_idx], dtype=torch.float32)\n",
    "                donor_assign = self.assignments[rand_idx]\n",
    "                donor_merged = self.merged_segments[rand_idx]\n",
    "\n",
    "                # 2) pick random word indices\n",
    "                orig_n_words = len(item['assignments'])\n",
    "                donor_n_words = len(donor_assign)\n",
    "                if orig_n_words == 0 or donor_n_words == 0:\n",
    "                    return item, item\n",
    "\n",
    "                i_orig = np.random.randint(0, orig_n_words)\n",
    "                j_donor = np.random.randint(0, donor_n_words)\n",
    "\n",
    "                # 3) get original and donor segment boundaries (absolute frame indices)\n",
    "                orig_seg_abs = item['merged_segments'][i_orig]\n",
    "                donor_seg_abs = donor_merged[j_donor]\n",
    "\n",
    "                orig_start = int(orig_seg_abs[0])\n",
    "                orig_end   = int(orig_seg_abs[-1])\n",
    "                donor_start = int(donor_seg_abs[0])\n",
    "                donor_end   = int(donor_seg_abs[-1])\n",
    "\n",
    "                orig_len = max(0, orig_end - orig_start)\n",
    "                donor_len = max(0, donor_end - donor_start)\n",
    "                if orig_len == 0 or donor_len == 0:\n",
    "                    return item, item\n",
    "\n",
    "                # 4) deep(ish) copy current example to edit\n",
    "                other_item = {\n",
    "                    'neural_features': item['neural_features'].clone(),  # (T, C) tensor\n",
    "                    'n_time_steps': item['n_time_steps'],\n",
    "                    'seq_class_ids': None,\n",
    "                    'seq_len': None,\n",
    "                    'sentence_label': item['sentence_label'],\n",
    "                    'merged_segments': [list(seg) for seg in item['merged_segments']],\n",
    "                    'assignments': [list(a) for a in item['assignments']],  # make mutable\n",
    "                    'transcriptions': item['transcriptions'],\n",
    "                    'day': item['day'],\n",
    "                }\n",
    "\n",
    "                # 5) splice donor frames (variable length): new = before + donor + after\n",
    "                before = other_item['neural_features'][:orig_start]\n",
    "                donor_slice = donor_feats[donor_start:donor_end]         # (donor_len, C)\n",
    "                after  = other_item['neural_features'][orig_end:]\n",
    "\n",
    "                new_feats = torch.cat([before, donor_slice, after], dim=0)\n",
    "                other_item['neural_features'] = new_feats\n",
    "                other_item['n_time_steps'] = new_feats.shape[0]\n",
    "\n",
    "                # length delta: how much timeline shifts for words AFTER i_orig\n",
    "                delta = donor_len - orig_len\n",
    "\n",
    "                # 6) update merged_segments:\n",
    "                #    - replaced segment gets donor's internal indices, rebased to orig_start\n",
    "                #    - all following segments shift by +delta\n",
    "                #    - prior segments unchanged\n",
    "                new_merged = []\n",
    "                for k, seg in enumerate(item['merged_segments']):\n",
    "                    if k < i_orig:\n",
    "                        new_merged.append(list(seg))\n",
    "                    elif k == i_orig:\n",
    "                        # rebase donor indices to align at orig_start\n",
    "                        # donor indices relative -> ix' = orig_start + (ix - donor_start)\n",
    "                        rebased = [int(orig_start + (ix - donor_start)) for ix in donor_seg_abs]\n",
    "                        # ensure sorted and within new length\n",
    "                        rebased = sorted(rebased)\n",
    "                        new_merged.append(rebased)\n",
    "                    else:\n",
    "                        # shift everything by delta\n",
    "                        shifted = [int(ix + delta) for ix in seg]\n",
    "                        new_merged.append(shifted)\n",
    "                other_item['merged_segments'] = new_merged\n",
    "\n",
    "                # 7) update assignments content at i_orig using donor's word + phonemes\n",
    "                #    Keep the *structure* of the assignment tuple; replace only word & phoneme list.\n",
    "                a_list = other_item['assignments'][i_orig]\n",
    "                a_list[4] = list(donor_assign[j_donor][4])   # phoneme strings\n",
    "                a_list[-1] = donor_assign[j_donor][-1]       # word text\n",
    "                other_item['assignments'][i_orig] = a_list\n",
    "\n",
    "                # (Optional) If your assignment stores a word_idx at [0], ensure it still matches:\n",
    "                # other_item['assignments'][i_orig][0] = i_orig\n",
    "\n",
    "                # 8) update transcription text (swap the word string)\n",
    "                words = other_item['transcriptions'].split()\n",
    "                if len(words) == len(other_item['assignments']):\n",
    "                    words[i_orig] = donor_assign[j_donor][-1]\n",
    "                    new_text = \" \".join(words)\n",
    "                else:\n",
    "                    # fallback: reconstruct from assignments to be safe\n",
    "                    new_text = \" \".join([a[-1] for a in other_item['assignments']])\n",
    "                other_item['transcriptions'] = new_text\n",
    "                other_item['sentence_label'] = new_text\n",
    "\n",
    "                # 9) rebuild seq_class_ids from assignments with SIL between words (no trailing SIL)\n",
    "                new_seq_class_ids = []\n",
    "                for wi, ass in enumerate(other_item['assignments']):\n",
    "                    # ass[4] is list of phoneme strings\n",
    "                    new_seq_class_ids.extend([phoneToId(p) + 1 for p in ass[4]])\n",
    "                    if wi < len(other_item['assignments']) - 1:\n",
    "                        new_seq_class_ids.append(phoneToId(\"SIL\") + 1)\n",
    "\n",
    "                other_item[\"seq_class_ids\"] = torch.tensor(new_seq_class_ids, dtype=torch.long)\n",
    "                other_item[\"seq_len\"] = len(new_seq_class_ids)\n",
    "                other_item[\"augmentation\"] = \"mixup\"\n",
    "\n",
    "                return other_item\n",
    "\n",
    "        return item\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "class BaseNeuralTextDataset(torch.utils.data.Dataset):\n",
    "    \"\"\"for validation and test\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, data_dict, eval_type=\"val\"):\n",
    "        self.neural_features = data_dict['neural_features']\n",
    "        self.n_time_steps = data_dict['n_time_steps']\n",
    "        self.seq_class_ids = data_dict['seq_class_ids']\n",
    "        self.seq_len = data_dict['seq_len']\n",
    "        self.sentence_label = data_dict['sentence_label']\n",
    "        self.transcriptions = data_dict[\"sentence_label\"]\n",
    "        self.days = torch.tensor(data_dict[\"dayIdx\"], dtype=torch.long)\n",
    "        self.eval_type = eval_type\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.neural_features)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "\n",
    "        if self.eval_type == \"val\":\n",
    "            seq_class_ids = self.seq_class_ids[idx]\n",
    "            #remove padding (class 0)\n",
    "            seq_class_ids = [p.item() for p in seq_class_ids if p != 0]\n",
    "\n",
    "            item =  {\n",
    "                'neural_features': torch.tensor(self.neural_features[idx], dtype=torch.float32),\n",
    "                'n_time_steps': self.n_time_steps[idx],\n",
    "                'seq_class_ids': torch.tensor(seq_class_ids, dtype=torch.long),\n",
    "                'seq_len': torch.tensor(self.seq_len[idx], dtype=torch.long),\n",
    "                'sentence_label': self.sentence_label[idx],\n",
    "                'transcriptions': self.transcriptions[idx],\n",
    "                'day': self.days[idx],\n",
    "            }\n",
    "\n",
    "            return item\n",
    "        else: #test set\n",
    "            item =  {\n",
    "                'neural_features': torch.tensor(self.neural_features[idx], dtype=torch.float32),\n",
    "                'n_time_steps': self.n_time_steps[idx],\n",
    "                'day': self.days[idx],\n",
    "            }\n",
    "            return item\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3f2b0ad6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "\n",
    "def collate_fn_flexible(\n",
    "    batch,\n",
    "    *,\n",
    "    pad_val_feat: float = 0.0,   # padding value for neural frames\n",
    "    pad_val_id: int = 0,         # padding value for seq_class_ids (0 = your PAD)\n",
    "    make_masks: bool = True,     # return masks for padded dims\n",
    "):\n",
    "    \"\"\"\n",
    "    Flexible collate for:\n",
    "      - AugmentedNeuralTextDataset (train)\n",
    "      - BaseNeuralTextDataset (val/test)\n",
    "\n",
    "    Returns a dict with:\n",
    "      - neural_features:  (B, T_max, C)\n",
    "      - neural_lengths:   (B,)\n",
    "      - neural_mask:      (B, T_max) [optional]\n",
    "      - seq_class_ids:    (B, L_max) [if present]\n",
    "      - seq_lengths:      (B,)       [if present]\n",
    "      - seq_mask:         (B, L_max) [optional, if seq_class_ids present]\n",
    "      - day:              (B,)\n",
    "      - n_time_steps:     (B,)\n",
    "      - sentence_label:   list[str]  [if present]\n",
    "      - transcriptions:   list[str]  [if present]\n",
    "      - merged_segments:  list[list] [if present]\n",
    "      - assignments:      list[list/tuples] [if present]\n",
    "      - augmentation:     list[str]  [if present]\n",
    "    \"\"\"\n",
    "\n",
    "    # Helper to check optional keys\n",
    "    def has_key(k):\n",
    "        return (k in batch[0]) and (batch[0][k] is not None)\n",
    "\n",
    "    B = len(batch)\n",
    "\n",
    "    # ---------- Neural features (variable T, fixed C) ----------\n",
    "    feats = [b['neural_features'] for b in batch]  # each (T_i, C)\n",
    "    # ensure tensors and float dtype\n",
    "    feats = [f if isinstance(f, torch.Tensor) else torch.tensor(f, dtype=torch.float32) for f in feats]\n",
    "    feats = [f.float() for f in feats]\n",
    "\n",
    "    # lengths\n",
    "    neural_lengths = torch.tensor([f.shape[0] for f in feats], dtype=torch.long)\n",
    "\n",
    "    # pad along time (dim=0), batch_first=True -> (B, T_max, C)\n",
    "    neural_features = pad_sequence(feats, batch_first=True, padding_value=pad_val_feat)\n",
    "\n",
    "    out = {\n",
    "        'neural_features': neural_features,       # (B, T_max, C)\n",
    "        'neural_lengths': neural_lengths,         # (B,)\n",
    "    }\n",
    "\n",
    "    # n_time_steps (if provided) – keep for convenience\n",
    "    if has_key('n_time_steps'):\n",
    "        out['n_time_steps'] = torch.tensor([int(b['n_time_steps']) for b in batch], dtype=torch.long)\n",
    "    else:\n",
    "        out['n_time_steps'] = neural_lengths.clone()\n",
    "\n",
    "    # optional mask for neural time dimension\n",
    "    if make_masks:\n",
    "        T_max = neural_features.size(1)\n",
    "        out['neural_mask'] = torch.arange(T_max).unsqueeze(0).repeat(B, 1)\n",
    "        out['neural_mask'] = (out['neural_mask'] < neural_lengths.unsqueeze(1)).to(torch.bool)  # (B, T_max)\n",
    "\n",
    "    # ---------- Label sequence (seq_class_ids) ----------\n",
    "    if has_key('seq_class_ids'):\n",
    "        seqs = [b['seq_class_ids'] for b in batch]  # each (L_i,)\n",
    "        seqs = [s if isinstance(s, torch.Tensor) else torch.tensor(s, dtype=torch.long) for s in seqs]\n",
    "        seqs = [s.long() for s in seqs]\n",
    "\n",
    "        seq_lengths = torch.tensor([s.shape[0] for s in seqs], dtype=torch.long)\n",
    "        seq_class_ids = pad_sequence(seqs, batch_first=True, padding_value=pad_val_id)  # (B, L_max)\n",
    "\n",
    "        out['seq_class_ids'] = seq_class_ids\n",
    "        out['seq_lengths'] = seq_lengths\n",
    "\n",
    "        if make_masks:\n",
    "            L_max = seq_class_ids.size(1)\n",
    "            out['seq_mask'] = torch.arange(L_max).unsqueeze(0).repeat(B, 1)\n",
    "            out['seq_mask'] = (out['seq_mask'] < seq_lengths.unsqueeze(1)).to(torch.bool)  # (B, L_max)\n",
    "\n",
    "    # ---------- Day ----------\n",
    "    if has_key('day'):\n",
    "        # allow tensor or int in the samples\n",
    "        day_vals = [int(b['day']) if not isinstance(b['day'], torch.Tensor) else int(b['day'].item()) for b in batch]\n",
    "        out['day'] = torch.tensor(day_vals, dtype=torch.long)\n",
    "\n",
    "    # ---------- Strings and per-item metadata ----------\n",
    "    if has_key('sentence_label'):\n",
    "        out['sentence_label'] = [b['sentence_label'] for b in batch]\n",
    "    if has_key('transcriptions'):\n",
    "        out['transcriptions'] = [b['transcriptions'] for b in batch]\n",
    "    if has_key('merged_segments'):\n",
    "        # list of variable-length lists; keep as-is\n",
    "        out['merged_segments'] = [b['merged_segments'] for b in batch]\n",
    "    if has_key('assignments'):\n",
    "        out['assignments'] = [b['assignments'] for b in batch]\n",
    "    if has_key('augmentation'):\n",
    "        out['augmentation'] = [b['augmentation'] for b in batch]\n",
    "\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "badbfa6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pickle\n",
    "\n",
    "# load cleaned_train_data\n",
    "with open(\"../data/cleaned_train_data.pkl\", \"rb\") as f:\n",
    "    cleaned_train_data = pickle.load(f)\n",
    "\n",
    "# load cleaned_val_data\n",
    "with open(\"../data/cleaned_val_data.pkl\", \"rb\") as f:\n",
    "    cleaned_val_data = pickle.load(f)\n",
    "\n",
    "# load cleaned_test_data\n",
    "with open(\"../data/cleaned_test_data.pkl\", \"rb\") as f:\n",
    "    cleaned_test_data = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "350a3657",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = AugmentedNeuralTextDataset(cleaned_train_data, mix_prob=0.5, subset_len=3)\n",
    "val_dataset = BaseNeuralTextDataset(cleaned_val_data, eval_type=\"val\")\n",
    "test_dataset = BaseNeuralTextDataset(cleaned_test_data, eval_type=\"test\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "89e8f004",
   "metadata": {},
   "outputs": [],
   "source": [
    "bs = 64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57092bfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_dataset,\n",
    "    batch_size=bs,\n",
    "    shuffle=True,\n",
    "    num_workers=4,\n",
    "    collate_fn=lambda b: collate_fn_flexible(b, pad_val_feat=0.0, pad_val_id=0, make_masks=True),\n",
    ")\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(\n",
    "    val_dataset,\n",
    "    batch_size=bs,\n",
    "    shuffle=False,\n",
    "    num_workers=4,\n",
    "    collate_fn=lambda b: collate_fn_flexible(b, pad_val_feat=0.0, pad_val_id=0, make_masks=True),\n",
    ")\n",
    "\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "    test_dataset,\n",
    "    batch_size=bs,\n",
    "    shuffle=False,\n",
    "    num_workers=4,\n",
    "    collate_fn=collate_fn_flexible,  # defaults are fine\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0155a63",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'neural_features': tensor([[[-0.4676, -0.2794, -0.4091,  ..., -0.7803, -0.7143, -1.2947],\n",
       "          [-0.4676, -0.2794, -0.4091,  ...,  0.4451, -0.6258, -1.1445],\n",
       "          [-0.4676, -0.2794, -0.4091,  ..., -1.0221,  0.0195, -1.5166],\n",
       "          ...,\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],\n",
       " \n",
       "         [[-0.4579, -0.2730, -0.4068,  ..., -1.0249,  0.0639, -0.7468],\n",
       "          [ 0.9247, -0.2730, -0.4068,  ...,  0.3864, -0.8787, -0.1448],\n",
       "          [-0.4579, -0.2730, -0.4068,  ..., -0.6938,  0.8109, -0.3617],\n",
       "          ...,\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],\n",
       " \n",
       "         [[-0.4359, -0.2711, -0.4101,  ..., -0.7326, -0.8427, -1.1315],\n",
       "          [-0.4359, -0.2711, -0.4101,  ...,  1.2035,  2.2787, -0.2118],\n",
       "          [-0.4359, -0.2711, -0.4101,  ...,  0.5862, -0.8594, -1.0635],\n",
       "          ...,\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],\n",
       " \n",
       "         ...,\n",
       " \n",
       "         [[-0.4356, -0.7865, -0.6147,  ...,  1.0471, -0.6484,  1.7079],\n",
       "          [-0.4356,  0.3857, -0.6147,  ..., -0.2635, -0.7932,  0.3156],\n",
       "          [-0.4356, -0.7865, -0.6147,  ..., -0.6606,  0.9707,  0.0319],\n",
       "          ...,\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],\n",
       " \n",
       "         [[-0.4344,  0.3670,  2.1836,  ...,  0.0843,  0.2604, -0.1623],\n",
       "          [ 2.5596,  0.3670, -0.6182,  ...,  1.8439,  2.1030, -0.6256],\n",
       "          [-0.4344,  0.3670, -0.6182,  ..., -0.1718,  1.8707,  2.8790],\n",
       "          ...,\n",
       "          [-0.4344, -0.7928, -0.6182,  ..., -0.7170,  0.6601,  0.1123],\n",
       "          [-0.4344, -0.7928, -0.6182,  ..., -0.1331, -1.0689, -1.0023],\n",
       "          [-0.4344, -0.7928, -0.6182,  ...,  0.6749, -0.8338, -0.0528]],\n",
       " \n",
       "         [[-0.4461, -0.7867, -0.6158,  ...,  0.2418, -0.4137,  3.1183],\n",
       "          [-0.4461,  1.5422, -0.6158,  ...,  0.1368, -0.7299, -0.9804],\n",
       "          [-0.4461, -0.7867, -0.6158,  ...,  0.2993, -0.6494,  0.2463],\n",
       "          ...,\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],\n",
       "          [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]]),\n",
       " 'neural_lengths': tensor([ 936,  340,  638,  757,  716,  466,  585,  590,  845,  612,  508,  970,\n",
       "          358,  501,  570,  744,  845,  492,  833,  678,  471,  839,  702,  873,\n",
       "          480,  643,  629,  937,  937,  604,  635,  624,  739,  783,  620,  553,\n",
       "          611,  687,  857,  637,  573,  851,  979,  957,  512,  523,  931,  873,\n",
       "          915,  735, 1024,  557,  666,  828, 1013,  604,  849,  517,  931,  606,\n",
       "          645,  664, 1296,  874]),\n",
       " 'n_time_steps': tensor([ 936,  340,  638,  757,  716,  466,  585,  590,  845,  612,  508,  970,\n",
       "          358,  501,  570,  744,  845,  492,  833,  678,  471,  839,  702,  873,\n",
       "          480,  643,  629,  937,  937,  604,  635,  624,  739,  783,  620,  553,\n",
       "          611,  687,  857,  637,  573,  851,  979,  957,  512,  523,  931,  873,\n",
       "          915,  735, 1024,  557,  666,  828, 1013,  604,  849,  517,  931,  606,\n",
       "          645,  664, 1296,  874]),\n",
       " 'neural_mask': tensor([[ True,  True,  True,  ..., False, False, False],\n",
       "         [ True,  True,  True,  ..., False, False, False],\n",
       "         [ True,  True,  True,  ..., False, False, False],\n",
       "         ...,\n",
       "         [ True,  True,  True,  ..., False, False, False],\n",
       "         [ True,  True,  True,  ...,  True,  True,  True],\n",
       "         [ True,  True,  True,  ..., False, False, False]]),\n",
       " 'day': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n",
       "         2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])}"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "item = next(iter(test_loader))\n",
    "item"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5432ddc6",
   "metadata": {},
   "outputs": [
    {
     "ename": "KeyError",
     "evalue": "'neural_lenghts'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyError\u001b[0m                                  Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[57], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mitem\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mneural_lenghts\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n",
      "\u001b[0;31mKeyError\u001b[0m: 'neural_lenghts'"
     ]
    }
   ],
   "source": [
    "item[\"neural_lenghts\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56fd3e8c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 936,  340,  638,  757,  716,  466,  585,  590,  845,  612,  508,  970,\n",
       "         358,  501,  570,  744,  845,  492,  833,  678,  471,  839,  702,  873,\n",
       "         480,  643,  629,  937,  937,  604,  635,  624,  739,  783,  620,  553,\n",
       "         611,  687,  857,  637,  573,  851,  979,  957,  512,  523,  931,  873,\n",
       "         915,  735, 1024,  557,  666,  828, 1013,  604,  849,  517,  931,  606,\n",
       "         645,  664, 1296,  874])"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "item[\"neural_lengths\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd8597e8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1450"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(test_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "befbf783",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.23"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
