{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "95b55462",
   "metadata": {},
   "source": [
    "## Idea, use a pretrained model to slice the dataset into smaller pieces"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "86ee8a68",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bda7bfa9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import time\n",
    "from tqdm import tqdm\n",
    "import argparse\n",
    "\n",
    "from rnn_model import GRUDecoder\n",
    "from evaluate_model_helpers import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e91f7eca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda:1 for model inference.\n"
     ]
    }
   ],
   "source": [
    "# Variable declarations (instead of argparse)\n",
    "model_path = '../data/t15_pretrained_rnn_baseline'\n",
    "data_dir = '../data/hdf5_data_final'\n",
    "eval_type = 'val'  # or 'test'\n",
    "csv_path = '../data/t15_copyTaskData_description.csv'\n",
    "gpu_number = 1  # Set to -1 to use CPU\n",
    "\n",
    "# Load CSV file with metadata\n",
    "import pandas as pd\n",
    "import os\n",
    "import torch\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "b2txt_csv_df = pd.read_csv(csv_path)\n",
    "\n",
    "# Load model args\n",
    "model_args = OmegaConf.load(os.path.join(model_path, 'checkpoint/args.yaml'))\n",
    "\n",
    "# Set up GPU device\n",
    "if torch.cuda.is_available() and gpu_number >= 0:\n",
    "    if gpu_number >= torch.cuda.device_count():\n",
    "        raise ValueError(f'GPU number {gpu_number} is out of range. Available GPUs: {torch.cuda.device_count()}')\n",
    "    device = torch.device(f'cuda:{gpu_number}')\n",
    "    print(f'Using {device} for model inference.')\n",
    "else:\n",
    "    if gpu_number >= 0:\n",
    "        print(f'GPU number {gpu_number} requested but not available.')\n",
    "    print('Using CPU for model inference.')\n",
    "    device = torch.device('cpu')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "0a7d778c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install omegaconf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "549a4aa1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 35 val trials for session t15.2023.08.13.\n",
      "Loaded 49 val trials for session t15.2023.08.18.\n",
      "Loaded 48 val trials for session t15.2023.08.20.\n",
      "Loaded 25 val trials for session t15.2023.08.25.\n",
      "Loaded 25 val trials for session t15.2023.08.27.\n",
      "Loaded 49 val trials for session t15.2023.09.01.\n",
      "Loaded 34 val trials for session t15.2023.09.03.\n",
      "Loaded 35 val trials for session t15.2023.09.24.\n",
      "Loaded 48 val trials for session t15.2023.09.29.\n",
      "Loaded 44 val trials for session t15.2023.10.01.\n",
      "Loaded 36 val trials for session t15.2023.10.06.\n",
      "Loaded 17 val trials for session t15.2023.10.08.\n",
      "Loaded 44 val trials for session t15.2023.10.13.\n",
      "Loaded 44 val trials for session t15.2023.10.15.\n",
      "Loaded 9 val trials for session t15.2023.10.20.\n",
      "Loaded 33 val trials for session t15.2023.10.22.\n",
      "Loaded 50 val trials for session t15.2023.11.03.\n",
      "Loaded 15 val trials for session t15.2023.11.04.\n",
      "Loaded 25 val trials for session t15.2023.11.17.\n",
      "Loaded 20 val trials for session t15.2023.11.19.\n",
      "Loaded 44 val trials for session t15.2023.11.26.\n",
      "Loaded 34 val trials for session t15.2023.12.03.\n",
      "Loaded 50 val trials for session t15.2023.12.08.\n",
      "Loaded 25 val trials for session t15.2023.12.10.\n",
      "Loaded 30 val trials for session t15.2023.12.17.\n",
      "Loaded 50 val trials for session t15.2023.12.29.\n",
      "Loaded 23 val trials for session t15.2024.02.25.\n",
      "Loaded 24 val trials for session t15.2024.03.08.\n",
      "Loaded 48 val trials for session t15.2024.03.15.\n",
      "Loaded 48 val trials for session t15.2024.03.17.\n",
      "Loaded 25 val trials for session t15.2024.05.10.\n",
      "Loaded 25 val trials for session t15.2024.06.14.\n",
      "Loaded 48 val trials for session t15.2024.07.19.\n",
      "Loaded 46 val trials for session t15.2024.07.21.\n",
      "Loaded 48 val trials for session t15.2024.07.28.\n",
      "Loaded 23 val trials for session t15.2025.01.10.\n",
      "Loaded 47 val trials for session t15.2025.01.12.\n",
      "Loaded 24 val trials for session t15.2025.03.14.\n",
      "Loaded 24 val trials for session t15.2025.03.16.\n",
      "Loaded 30 val trials for session t15.2025.03.30.\n",
      "Loaded 25 val trials for session t15.2025.04.13.\n",
      "Total number of val trials: 1426\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# define model\n",
    "model = GRUDecoder(\n",
    "    neural_dim = model_args['model']['n_input_features'],\n",
    "    n_units = model_args['model']['n_units'], \n",
    "    n_days = len(model_args['dataset']['sessions']),\n",
    "    n_classes = model_args['dataset']['n_classes'],\n",
    "    rnn_dropout = model_args['model']['rnn_dropout'],\n",
    "    input_dropout = model_args['model']['input_network']['input_layer_dropout'],\n",
    "    n_layers = model_args['model']['n_layers'],\n",
    "    patch_size = model_args['model']['patch_size'],\n",
    "    patch_stride = model_args['model']['patch_stride'],\n",
    ")\n",
    "\n",
    "# load model weights\n",
    "checkpoint = torch.load(os.path.join(model_path, 'checkpoint/best_checkpoint'), weights_only=False)\n",
    "# rename keys to not start with \"module.\" (happens if model was saved with DataParallel)\n",
    "for key in list(checkpoint['model_state_dict'].keys()):\n",
    "    checkpoint['model_state_dict'][key.replace(\"module.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
    "    checkpoint['model_state_dict'][key.replace(\"_orig_mod.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
    "model.load_state_dict(checkpoint['model_state_dict'])  \n",
    "\n",
    "# add model to device\n",
    "model.to(device) \n",
    "\n",
    "# set model to eval mode\n",
    "model.eval()\n",
    "\n",
    "# load data for each session\n",
    "test_data = {}\n",
    "total_test_trials = 0\n",
    "for session in model_args['dataset']['sessions']:\n",
    "    files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]\n",
    "    if f'data_{eval_type}.hdf5' in files:\n",
    "        eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')\n",
    "\n",
    "        data = load_h5py_file_twoargs(eval_file, b2txt_csv_df)\n",
    "        test_data[session] = data\n",
    "\n",
    "        total_test_trials += len(test_data[session][\"neural_features\"])\n",
    "        print(f'Loaded {len(test_data[session][\"neural_features\"])} {eval_type} trials for session {session}.')\n",
    "print(f'Total number of {eval_type} trials: {total_test_trials}')\n",
    "print()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "3090e83b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 35 test trials for session t15.2023.08.13.\n",
      "Loaded 50 test trials for session t15.2023.08.18.\n",
      "Loaded 49 test trials for session t15.2023.08.20.\n",
      "Loaded 25 test trials for session t15.2023.08.25.\n",
      "Loaded 25 test trials for session t15.2023.08.27.\n",
      "Loaded 50 test trials for session t15.2023.09.01.\n",
      "Loaded 35 test trials for session t15.2023.09.03.\n",
      "Loaded 35 test trials for session t15.2023.09.24.\n",
      "Loaded 49 test trials for session t15.2023.09.29.\n",
      "Loaded 45 test trials for session t15.2023.10.01.\n",
      "Loaded 37 test trials for session t15.2023.10.06.\n",
      "Loaded 18 test trials for session t15.2023.10.08.\n",
      "Loaded 45 test trials for session t15.2023.10.13.\n",
      "Loaded 45 test trials for session t15.2023.10.15.\n",
      "Loaded 9 test trials for session t15.2023.10.20.\n",
      "Loaded 35 test trials for session t15.2023.10.22.\n",
      "Loaded 50 test trials for session t15.2023.11.03.\n",
      "Loaded 15 test trials for session t15.2023.11.04.\n",
      "Loaded 25 test trials for session t15.2023.11.17.\n",
      "Loaded 20 test trials for session t15.2023.11.19.\n",
      "Loaded 45 test trials for session t15.2023.11.26.\n",
      "Loaded 34 test trials for session t15.2023.12.03.\n",
      "Loaded 50 test trials for session t15.2023.12.08.\n",
      "Loaded 25 test trials for session t15.2023.12.10.\n",
      "Loaded 30 test trials for session t15.2023.12.17.\n",
      "Loaded 50 test trials for session t15.2023.12.29.\n",
      "Loaded 24 test trials for session t15.2024.02.25.\n",
      "Loaded 25 test trials for session t15.2024.03.08.\n",
      "Loaded 50 test trials for session t15.2024.03.15.\n",
      "Loaded 49 test trials for session t15.2024.03.17.\n",
      "Loaded 25 test trials for session t15.2024.05.10.\n",
      "Loaded 25 test trials for session t15.2024.06.14.\n",
      "Loaded 49 test trials for session t15.2024.07.19.\n",
      "Loaded 47 test trials for session t15.2024.07.21.\n",
      "Loaded 49 test trials for session t15.2024.07.28.\n",
      "Loaded 24 test trials for session t15.2025.01.10.\n",
      "Loaded 47 test trials for session t15.2025.01.12.\n",
      "Loaded 26 test trials for session t15.2025.03.14.\n",
      "Loaded 24 test trials for session t15.2025.03.16.\n",
      "Loaded 30 test trials for session t15.2025.03.30.\n",
      "Loaded 25 test trials for session t15.2025.04.13.\n",
      "Total number of test trials: 1450\n",
      "\n"
     ]
    }
   ],
   "source": [
    "eval_type = \"test\"\n",
    "# load data for each session\n",
    "competition_data = {}\n",
    "total_competition_trials = 0\n",
    "for session in model_args['dataset']['sessions']:\n",
    "    files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]\n",
    "    if f'data_{eval_type}.hdf5' in files:\n",
    "        eval_file = os.path.join(data_dir, session, f'data_{eval_type}.hdf5')\n",
    "\n",
    "        data = load_h5py_file_twoargs(eval_file, b2txt_csv_df)\n",
    "        competition_data[session] = data\n",
    "\n",
    "        total_competition_trials += len(competition_data[session][\"neural_features\"])\n",
    "        print(f'Loaded {len(competition_data[session][\"neural_features\"])} {eval_type} trials for session {session}.')\n",
    "print(f'Total number of {eval_type} trials: {total_competition_trials}')\n",
    "print()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "27117135",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded 288 train trials for session t15.2023.08.11.\n",
      "Loaded 348 train trials for session t15.2023.08.13.\n",
      "Loaded 197 train trials for session t15.2023.08.18.\n",
      "Loaded 278 train trials for session t15.2023.08.20.\n",
      "Loaded 88 train trials for session t15.2023.08.25.\n",
      "Loaded 150 train trials for session t15.2023.08.27.\n",
      "Loaded 297 train trials for session t15.2023.09.01.\n",
      "Loaded 322 train trials for session t15.2023.09.03.\n",
      "Loaded 245 train trials for session t15.2023.09.24.\n",
      "Loaded 153 train trials for session t15.2023.09.29.\n",
      "Loaded 218 train trials for session t15.2023.10.01.\n",
      "Loaded 174 train trials for session t15.2023.10.06.\n",
      "Loaded 284 train trials for session t15.2023.10.08.\n",
      "Loaded 155 train trials for session t15.2023.10.13.\n",
      "Loaded 239 train trials for session t15.2023.10.15.\n",
      "Loaded 98 train trials for session t15.2023.10.20.\n",
      "Loaded 134 train trials for session t15.2023.10.22.\n",
      "Loaded 149 train trials for session t15.2023.11.03.\n",
      "Loaded 80 train trials for session t15.2023.11.04.\n",
      "Loaded 100 train trials for session t15.2023.11.17.\n",
      "Loaded 60 train trials for session t15.2023.11.19.\n",
      "Loaded 198 train trials for session t15.2023.11.26.\n",
      "Loaded 228 train trials for session t15.2023.12.03.\n",
      "Loaded 198 train trials for session t15.2023.12.08.\n",
      "Loaded 131 train trials for session t15.2023.12.10.\n",
      "Loaded 135 train trials for session t15.2023.12.17.\n",
      "Loaded 198 train trials for session t15.2023.12.29.\n",
      "Loaded 193 train trials for session t15.2024.02.25.\n",
      "Loaded 219 train trials for session t15.2024.03.03.\n",
      "Loaded 163 train trials for session t15.2024.03.08.\n",
      "Loaded 239 train trials for session t15.2024.03.15.\n",
      "Loaded 246 train trials for session t15.2024.03.17.\n",
      "Loaded 364 train trials for session t15.2024.04.25.\n",
      "Loaded 150 train trials for session t15.2024.04.28.\n",
      "Loaded 110 train trials for session t15.2024.05.10.\n",
      "Loaded 90 train trials for session t15.2024.06.14.\n",
      "Loaded 169 train trials for session t15.2024.07.19.\n",
      "Loaded 160 train trials for session t15.2024.07.21.\n",
      "Loaded 161 train trials for session t15.2024.07.28.\n",
      "Loaded 106 train trials for session t15.2025.01.10.\n",
      "Loaded 163 train trials for session t15.2025.01.12.\n",
      "Loaded 59 train trials for session t15.2025.03.14.\n",
      "Loaded 101 train trials for session t15.2025.03.16.\n",
      "Loaded 165 train trials for session t15.2025.03.30.\n",
      "Loaded 69 train trials for session t15.2025.04.13.\n",
      "Total number of train trials: 8072\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# load data for each session\n",
    "train_data = {}\n",
    "total_train_trials = 0\n",
    "for session in model_args['dataset']['sessions']:\n",
    "    files = [f for f in os.listdir(os.path.join(data_dir, session)) if f.endswith('.hdf5')]\n",
    "    if f'data_train.hdf5' in files:\n",
    "        eval_file = os.path.join(data_dir, session, f'data_train.hdf5')\n",
    "\n",
    "        data = load_h5py_file_twoargs(eval_file, b2txt_csv_df)\n",
    "        train_data[session] = data\n",
    "\n",
    "        total_train_trials += len(train_data[session][\"neural_features\"])\n",
    "        print(f'Loaded {len(train_data[session][\"neural_features\"])} train trials for session {session}.')\n",
    "print(f'Total number of train trials: {total_train_trials}')\n",
    "print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "7dd718d6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Predicting phoneme sequences: 8072trial [01:13, 109.89trial/s]                       \n"
     ]
    }
   ],
   "source": [
    "# put neural data through the pretrained model to get phoneme predictions (logits)\n",
    "with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:\n",
    "    for session, data in train_data.items():\n",
    "\n",
    "        data['logits'] = []\n",
    "        data['pred_seq'] = []\n",
    "        input_layer = model_args['dataset']['sessions'].index(session)\n",
    "        \n",
    "        for trial in range(len(data['neural_features'])):\n",
    "            # get neural input for the trial\n",
    "            neural_input = data['neural_features'][trial]\n",
    "\n",
    "            # add batch dimension\n",
    "            neural_input = np.expand_dims(neural_input, axis=0)\n",
    "\n",
    "            # convert to torch tensor\n",
    "            neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)\n",
    "\n",
    "            # run decoding step\n",
    "            logits = runSingleDecodingStep(neural_input, input_layer, model, model_args, device)\n",
    "            data['logits'].append(logits)\n",
    "\n",
    "            pbar.update(1)\n",
    "pbar.close()\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d18d1110",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Predicting phoneme sequences: 100%|██████████| 1426/1426 [00:13<00:00, 108.72trial/s]\n"
     ]
    }
   ],
   "source": [
    "# put neural data through the pretrained model to get phoneme predictions (logits)\n",
    "with tqdm(total=total_test_trials, desc='Predicting phoneme sequences', unit='trial') as pbar:\n",
    "    for session, data in test_data.items():\n",
    "\n",
    "        data['logits'] = []\n",
    "        data['pred_seq'] = []\n",
    "        input_layer = model_args['dataset']['sessions'].index(session)\n",
    "        \n",
    "        for trial in range(len(data['neural_features'])):\n",
    "            # get neural input for the trial\n",
    "            neural_input = data['neural_features'][trial]\n",
    "\n",
    "            # add batch dimension\n",
    "            neural_input = np.expand_dims(neural_input, axis=0)\n",
    "\n",
    "            # convert to torch tensor\n",
    "            neural_input = torch.tensor(neural_input, device=device, dtype=torch.bfloat16)\n",
    "\n",
    "            # run decoding step\n",
    "            logits = runSingleDecodingStep(neural_input, input_layer, model, model_args, device)\n",
    "            data['logits'].append(logits)\n",
    "\n",
    "            pbar.update(1)\n",
    "pbar.close()\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "0b72f098",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['t15.2023.08.11', 't15.2023.08.13', 't15.2023.08.18', 't15.2023.08.20', 't15.2023.08.25', 't15.2023.08.27', 't15.2023.09.01', 't15.2023.09.03', 't15.2023.09.24', 't15.2023.09.29', 't15.2023.10.01', 't15.2023.10.06', 't15.2023.10.08', 't15.2023.10.13', 't15.2023.10.15', 't15.2023.10.20', 't15.2023.10.22', 't15.2023.11.03', 't15.2023.11.04', 't15.2023.11.17', 't15.2023.11.19', 't15.2023.11.26', 't15.2023.12.03', 't15.2023.12.08', 't15.2023.12.10', 't15.2023.12.17', 't15.2023.12.29', 't15.2024.02.25', 't15.2024.03.03', 't15.2024.03.08', 't15.2024.03.15', 't15.2024.03.17', 't15.2024.04.25', 't15.2024.04.28', 't15.2024.05.10', 't15.2024.06.14', 't15.2024.07.19', 't15.2024.07.21', 't15.2024.07.28', 't15.2025.01.10', 't15.2025.01.12', 't15.2025.03.14', 't15.2025.03.16', 't15.2025.03.30', 't15.2025.04.13'])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a39f8a91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# STRIDE = 4  # time per index\n",
    "# SIL_INDEX = len(LOGIT_TO_PHONEME) - 1  # last class = silence\n",
    "\n",
    "# for session, data in test_data.items():\n",
    "#     data['speech_segments'] = []  # list of (onset_time, offset_time) tuples\n",
    "\n",
    "#     for trial in range(len(data['logits'])):\n",
    "#         logits = data['logits'][trial][0]  # shape: (T, vocab)\n",
    "#         pred_ids = np.argmax(logits, axis=-1)  # predicted class at each time step\n",
    "\n",
    "#         segments = []\n",
    "#         in_segment = False\n",
    "#         start_idx = None\n",
    "\n",
    "#         for t, p in enumerate(pred_ids):\n",
    "#             if p != SIL_INDEX and not in_segment:\n",
    "#                 # Start of speech segment\n",
    "#                 in_segment = True\n",
    "#                 start_idx = t\n",
    "#             elif p == SIL_INDEX and in_segment:\n",
    "#                 # End of speech segment\n",
    "#                 in_segment = False\n",
    "#                 end_idx = t\n",
    "#                 segments.append((start_idx, end_idx))\n",
    "\n",
    "#         # If speech continues till the end\n",
    "#         if in_segment:\n",
    "#             segments.append((start_idx, len(pred_ids)))\n",
    "\n",
    "#         # Convert from index to time\n",
    "#         min_len = 14  # bins (i.e. 100ms)\n",
    "#         segments_in_time = [(s*STRIDE, e*STRIDE) for s, e in segments if (e - s) >= min_len]\n",
    "#         data['speech_segments'].append(segments_in_time)\n",
    "\n",
    "#         # print(f\"Session: {session}, Trial: {trial}, Speech Segments (ms): {segments_in_time}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e3e0c540",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "STRIDE = 4  # time per index\n",
    "SIL_INDEX = len(LOGIT_TO_PHONEME) - 1  # last class = silence\n",
    "BLANK_INDEX = 0 \n",
    "MIN_SILENCE_LEN = 0 # must see at least 3 consecutive silence steps to break a segment\n",
    "MIN_SEGMENT_LEN = 14  # minimum speech segment duration (in steps), e.g. 100 ms\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "dca6e5c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for session, data in train_data.items():\n",
    "    data['speech_segments'] = []\n",
    "\n",
    "    for trial in range(len(data['logits'])):\n",
    "        logits = data['logits'][trial][0]  # shape: (T, vocab)\n",
    "        pred_ids = np.argmax(logits, axis=-1)  # shape: (T,)\n",
    "\n",
    "        segments = []\n",
    "        in_segment = False\n",
    "        start_idx = None\n",
    "        sil_count = 0\n",
    "\n",
    "        for t, p in enumerate(pred_ids):\n",
    "            if p != SIL_INDEX:\n",
    "                if not in_segment:\n",
    "                    # Start of speech\n",
    "                    in_segment = True\n",
    "                    start_idx = t\n",
    "                sil_count = 0  # reset silence count\n",
    "            else:\n",
    "                if in_segment:\n",
    "                    sil_count += 1\n",
    "                    if sil_count >= MIN_SILENCE_LEN:\n",
    "                        # End segment if enough silence\n",
    "                        end_idx = t - sil_count + 1\n",
    "                        if (end_idx - start_idx) >= MIN_SEGMENT_LEN:\n",
    "                            segments.append((start_idx, end_idx))\n",
    "                        in_segment = False\n",
    "                        start_idx = None\n",
    "                        sil_count = 0\n",
    "\n",
    "        # Final segment\n",
    "        if in_segment:\n",
    "            end_idx = len(pred_ids)\n",
    "            if (end_idx - start_idx) >= MIN_SEGMENT_LEN:\n",
    "                segments.append((start_idx, end_idx))\n",
    "\n",
    "        # Convert from step to ms\n",
    "        segments_in_time = [(s * STRIDE, e * STRIDE) for s, e in segments]\n",
    "        data['speech_segments'].append(segments_in_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "ad7baa29",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "for session, data in test_data.items():\n",
    "    data['speech_segments'] = []\n",
    "\n",
    "    for trial in range(len(data['logits'])):\n",
    "        logits = data['logits'][trial][0]  # shape: (T, vocab)\n",
    "        pred_ids = np.argmax(logits, axis=-1)  # shape: (T,)\n",
    "\n",
    "        segments = []\n",
    "        in_segment = False\n",
    "        start_idx = None\n",
    "        sil_count = 0\n",
    "\n",
    "        for t, p in enumerate(pred_ids):\n",
    "            if p != SIL_INDEX:\n",
    "                if not in_segment:\n",
    "                    # Start of speech\n",
    "                    in_segment = True\n",
    "                    start_idx = t\n",
    "                sil_count = 0  # reset silence count\n",
    "            else:\n",
    "                if in_segment:\n",
    "                    sil_count += 1\n",
    "                    if sil_count >= MIN_SILENCE_LEN:\n",
    "                        # End segment if enough silence\n",
    "                        end_idx = t - sil_count + 1\n",
    "                        if (end_idx - start_idx) >= MIN_SEGMENT_LEN:\n",
    "                            segments.append((start_idx, end_idx))\n",
    "                        in_segment = False\n",
    "                        start_idx = None\n",
    "                        sil_count = 0\n",
    "\n",
    "        # Final segment\n",
    "        # Final segment\n",
    "        if in_segment:\n",
    "            end_idx = len(pred_ids)\n",
    "            if (end_idx - start_idx) >= MIN_SEGMENT_LEN:\n",
    "                segments.append((start_idx, end_idx))\n",
    "\n",
    "        # Convert from step to ms\n",
    "        segments_in_time = [(s * STRIDE, e * STRIDE) for s, e in segments]\n",
    "        data['speech_segments'].append(segments_in_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "52e7cf67",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = 0\n",
    "test_data[\"t15.2023.08.13\"][\"speech_segments\"].__len__()== test_data[\"t15.2023.08.13\"][\"sentence_label\"].__len__()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3e18875c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['t15.2023.08.13', 't15.2023.08.18', 't15.2023.08.20', 't15.2023.08.25', 't15.2023.08.27', 't15.2023.09.01', 't15.2023.09.03', 't15.2023.09.24', 't15.2023.09.29', 't15.2023.10.01', 't15.2023.10.06', 't15.2023.10.08', 't15.2023.10.13', 't15.2023.10.15', 't15.2023.10.20', 't15.2023.10.22', 't15.2023.11.03', 't15.2023.11.04', 't15.2023.11.17', 't15.2023.11.19', 't15.2023.11.26', 't15.2023.12.03', 't15.2023.12.08', 't15.2023.12.10', 't15.2023.12.17', 't15.2023.12.29', 't15.2024.02.25', 't15.2024.03.08', 't15.2024.03.15', 't15.2024.03.17', 't15.2024.05.10', 't15.2024.06.14', 't15.2024.07.19', 't15.2024.07.21', 't15.2024.07.28', 't15.2025.01.10', 't15.2025.01.12', 't15.2025.03.14', 't15.2025.03.16', 't15.2025.03.30', 't15.2025.04.13'])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "test_data.keys()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9fe7db9f",
   "metadata": {},
   "source": [
    "## Lavorare qui creare un euristica"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6f0060d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3 2\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([(0, 160), (172, 256), (268, 380)], 'Woodworking mastery.')"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "i=9 \n",
    "print(len(test_data[\"t15.2023.08.13\"][\"speech_segments\"][i]),len(test_data[\"t15.2023.08.13\"][\"sentence_label\"][i].split(\" \")))\n",
    "test_data[\"t15.2023.08.13\"][\"speech_segments\"][i], test_data[\"t15.2023.08.13\"][\"sentence_label\"][i]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa6e89c1",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08285abd",
   "metadata": {},
   "source": [
    "## Lavorare qui: Controllare cosa accade se predico un silenzio in meno"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "b64ec1c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "## This code handles very well when the model predict the right number of silence of more.\n",
    "#I should investigate more for cases where the model predicts less silence than expected.\n",
    "\n",
    "# ---------- config ----------\n",
    "STRIDE = 4                 # ms per step\n",
    "BLANK_LABEL = \"BLANK\"\n",
    "SIL_LABEL = \"SIL\"\n",
    "MAX_MERGE = 3              # allow merging up to 3 adjacent segments per word (tweak)\n",
    "\n",
    "# ---------- helpers ----------\n",
    "\n",
    "def collapse_ctc(symbols, blank_labels=(BLANK_LABEL, SIL_LABEL)):\n",
    "    \"\"\"Collapse repeats and remove blanks/silences.\"\"\"\n",
    "    out = []\n",
    "    prev = None\n",
    "    for s in symbols:\n",
    "        if s != prev and s not in blank_labels:\n",
    "            out.append(s)\n",
    "        prev = s\n",
    "    return out\n",
    "\n",
    "\n",
    "\n",
    "def _as_int(x):\n",
    "    try:\n",
    "        return int(x)\n",
    "    except Exception:\n",
    "        return int(x.item())  # torch / numpy scalar\n",
    "\n",
    "def _sym(idx2phoneme, p, default=\"UNK\"):\n",
    "    \"\"\"Unified accessor: works if idx2phoneme is a list or a dict.\"\"\"\n",
    "    p = _as_int(p)\n",
    "    if isinstance(idx2phoneme, dict):\n",
    "        return idx2phoneme.get(p, default)\n",
    "    # list/tuple/array-like\n",
    "    if 0 <= p < len(idx2phoneme):\n",
    "        return idx2phoneme[p]\n",
    "    return default\n",
    "\n",
    "def segment_pred_phonemes(logits_arr, seg, idx2phoneme, stride=STRIDE):\n",
    "    \"\"\"\n",
    "    logits_arr: (T, V) or (1, T, V)\n",
    "    seg: (start_ms, end_ms)\n",
    "    \"\"\"\n",
    "    arr = logits_arr\n",
    "    if getattr(arr, \"ndim\", None) == 3:\n",
    "        arr = arr[0]  # (T, V)\n",
    "    start_ms, end_ms = seg\n",
    "    t0 = start_ms // stride\n",
    "    t1 = end_ms   // stride\n",
    "    # guard bounds\n",
    "    t0 = max(0, min(t0, arr.shape[0]))\n",
    "    t1 = max(0, min(t1, arr.shape[0]))\n",
    "    if t1 <= t0:\n",
    "        return []\n",
    "    pred_ids = arr[t0:t1].argmax(axis=-1)\n",
    "    symbols = [_sym(idx2phoneme, p) for p in pred_ids]\n",
    "    return collapse_ctc(symbols)\n",
    "\n",
    "def ids_to_phonemes(id_list, idx2phoneme):\n",
    "    return [_sym(idx2phoneme, p) for p in id_list]\n",
    "\n",
    "def levenshtein(a, b):\n",
    "    \"\"\"Levenshtein distance for sequences (lists of strings).\"\"\"\n",
    "    # classic Wagner–Fischer\n",
    "    n, m = len(a), len(b)\n",
    "    if n == 0: return m\n",
    "    if m == 0: return n\n",
    "    dp = list(range(m+1))\n",
    "    for i in range(1, n+1):\n",
    "        prev, dp[0] = dp[0], i\n",
    "        for j in range(1, m+1):\n",
    "            ins = dp[j-1] + 1\n",
    "            dele = dp[j] + 1\n",
    "            sub = prev + (a[i-1] != b[j-1])\n",
    "            prev, dp[j] = dp[j], min(ins, dele, sub)\n",
    "    return dp[m]\n",
    "\n",
    "def norm_lev(a, b):\n",
    "    \"\"\"Normalized Levenshtein (0..1).\"\"\"\n",
    "    if not a and not b:\n",
    "        return 0.0\n",
    "    return levenshtein(a, b) / max(len(a), len(b), 1)\n",
    "\n",
    "# ---------- core: match with optional merges ----------\n",
    "\n",
    "def match_segments_to_words_with_merges(\n",
    "    logits_arr,\n",
    "    segments_ms,                 # list[(start_ms, end_ms)]\n",
    "    gt_word_phone_ids,           # list[list[int]] (already split by 40)\n",
    "    idx2phoneme,\n",
    "    max_merge=MAX_MERGE,\n",
    "    stride=STRIDE\n",
    "):\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "      assignments: list of tuples (word_idx, [seg_idx,...], cost, pred_seq, gt_seq)\n",
    "      merged_segments_ms: list of (start_ms, end_ms) after merging per assignment order\n",
    "      total_cost: sum of normalized Levenshtein costs\n",
    "    \"\"\"\n",
    "    # Build predicted phonemes per segment\n",
    "    seg_pred_ph = [\n",
    "        segment_pred_phonemes(logits_arr, seg, idx2phoneme, stride=stride)\n",
    "        for seg in segments_ms\n",
    "    ]\n",
    "    # Build GT phonemes per word\n",
    "    gt_word_ph = [ids_to_phonemes(w_ids, idx2phoneme) for w_ids in gt_word_phone_ids]\n",
    "\n",
    "    S = len(segments_ms)\n",
    "    W = len(gt_word_ph)\n",
    "\n",
    "    # DP over segments->words with merges (many-to-one). dp[i][j] = (cost, back_k)\n",
    "    # i = #segments consumed (0..S), j = #words assigned (0..W)\n",
    "    INF = 1e9\n",
    "    dp = [[(INF, None) for _ in range(W+1)] for __ in range(S+1)]\n",
    "    dp[0][0] = (0.0, None)\n",
    "\n",
    "    # Precompute concatenations & costs for speed:\n",
    "    # concat_pred[i:j] = collapsed phonemes of segments i..j-1 concatenated\n",
    "    concat_pred = [[None]*(S+1) for _ in range(S+1)]\n",
    "    for i in range(S):\n",
    "        accum = []\n",
    "        for j in range(i, S):\n",
    "            # concatenate while avoiding extra collapse across segment boundary:\n",
    "            # we simply extend then collapse once at the end for stability\n",
    "            accum = accum + seg_pred_ph[j]\n",
    "            concat_pred[i][j+1] = collapse_ctc(accum)\n",
    "\n",
    "    # DP transitions: assign k segments (1..max_merge) to next word\n",
    "    for i in range(S+1):       # segments used\n",
    "        for j in range(W+1):   # words used\n",
    "            if dp[i][j][0] >= INF:\n",
    "                continue\n",
    "            if j == W:\n",
    "                # No more words -> remaining segments must be \"garbage\"; we can either forbid\n",
    "                # or allow assigning them to an implicit NULL with some cost. Here we keep as-is.\n",
    "                continue\n",
    "            # try merging k segments for word j\n",
    "            for k in range(1, min(max_merge, S - i) + 1):\n",
    "                pred_seq = concat_pred[i][i+k]  # segments i..i+k-1\n",
    "                gt_seq = gt_word_ph[j]\n",
    "                cost = norm_lev(pred_seq, gt_seq)\n",
    "                new_cost = dp[i][j][0] + cost\n",
    "                if new_cost < dp[i+k][j+1][0]:\n",
    "                    dp[i+k][j+1] = (new_cost, k)\n",
    "\n",
    "    # Find best terminal state: we must have consumed all words, while segments can be > words.\n",
    "    # If some trailing segments remain, we can either:\n",
    "    #   (a) penalize them (e.g., add their length as cost),\n",
    "    #   (b) or force exact consumption by increasing max_merge or padding words.\n",
    "    # Here we pick the minimal dp[i][W] over i = W..S and treat leftovers as ignored.\n",
    "    end_i = None\n",
    "    best_total = INF\n",
    "    for i in range(W, S+1):\n",
    "        if dp[i][W][0] < best_total:\n",
    "            best_total = dp[i][W][0]\n",
    "            end_i = i\n",
    "\n",
    "    if end_i is None:\n",
    "        # Fallback: no path found\n",
    "        return [], segments_ms, float('inf')\n",
    "\n",
    "    # Backtrack\n",
    "    assignments = []\n",
    "    i, j = end_i, W\n",
    "    while j > 0:\n",
    "        cost_ij, k = dp[i][j]\n",
    "        assert k is not None\n",
    "        seg_group = list(range(i - k, i))\n",
    "        pred_seq = concat_pred[i - k][i]\n",
    "        gt_seq = gt_word_ph[j - 1]\n",
    "        cost = norm_lev(pred_seq, gt_seq)\n",
    "        assignments.append((j - 1, seg_group, cost, pred_seq, gt_seq))\n",
    "        i -= k\n",
    "        j -= 1\n",
    "    assignments.reverse()\n",
    "\n",
    "    # Build merged segments for the chosen grouping\n",
    "    merged_segments_ms = []\n",
    "    for (_, seg_group, _, _, _) in assignments:\n",
    "        s0 = segments_ms[seg_group[0]][0]\n",
    "        s1 = segments_ms[seg_group[-1]][1]\n",
    "        merged_segments_ms.append((s0, s1))\n",
    "\n",
    "    return assignments, merged_segments_ms, best_total"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d8cb02bc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from scipy.ndimage import gaussian_filter1d\n",
    "\n",
    "# --- Softmax ---\n",
    "def _softmax(x, axis=-1):\n",
    "    x = x - np.max(x, axis=axis, keepdims=True)\n",
    "    e = np.exp(x)\n",
    "    return e / np.sum(e, axis=axis, keepdims=True)\n",
    "\n",
    "# --- Get (silence + blank) probability curve for the whole trial ---\n",
    "def silence_prob_curve(logits_arr, sil_index, blank_index=None):\n",
    "    \"\"\"\n",
    "    logits_arr: (T, V) or (1, T, V). Returns p_sil_blank of shape (T,)\n",
    "    \"\"\"\n",
    "    arr = logits_arr[0] if getattr(logits_arr, \"ndim\", None) == 3 else logits_arr\n",
    "    P = _softmax(arr, axis=-1)\n",
    "    p_sil = P[:, sil_index]\n",
    "    if blank_index is not None:\n",
    "        p_sil = p_sil + P[:, blank_index]  # speech pause proxy\n",
    "        p_sil = np.clip(p_sil, 0, 1)\n",
    "    return p_sil\n",
    "\n",
    "# --- Peak picking for candidate split points inside a segment ---\n",
    "def candidate_splits_for_segment(p_curve, seg, stride_ms=4,\n",
    "                                 th=0.35, smooth_sigma=2.0,\n",
    "                                 min_gap_ms=40, edge_margin_ms=24, top_k=3):\n",
    "    \"\"\"\n",
    "    p_curve: (T,) silence-prob curve\n",
    "    seg: (start_ms, end_ms)\n",
    "    Returns list of candidate split times in ms (inside the segment)\n",
    "    \"\"\"\n",
    "    s_ms, e_ms = seg\n",
    "    t0 = max(0, s_ms // stride_ms)\n",
    "    t1 = max(t0, e_ms // stride_ms)\n",
    "\n",
    "    if t1 - t0 < 3:\n",
    "        return []\n",
    "\n",
    "    window = p_curve[t0:t1].copy()\n",
    "    if smooth_sigma and smooth_sigma > 0:\n",
    "        window = gaussian_filter1d(window, sigma=smooth_sigma)\n",
    "    # simple peak detection: local maxima above threshold\n",
    "    peaks = []\n",
    "    for i in range(1, len(window) - 1):\n",
    "        if window[i] > window[i-1] and window[i] > window[i+1] and window[i] >= th:\n",
    "            peaks.append((i, window[i]))\n",
    "\n",
    "    # map to ms\n",
    "    cand_ms = [s_ms + idx * stride_ms for idx, _ in peaks]\n",
    "\n",
    "    # keep away from edges\n",
    "    cand_ms = [t for t in cand_ms if (t - s_ms) >= edge_margin_ms and (e_ms - t) >= edge_margin_ms]\n",
    "\n",
    "    # enforce min distance between split points (within this segment)\n",
    "    cand_ms.sort()\n",
    "    filtered = []\n",
    "    min_gap_steps = max(1, min_gap_ms // stride_ms)\n",
    "    last = None\n",
    "    for t in cand_ms:\n",
    "        if last is None or (t - last) >= min_gap_ms:\n",
    "            filtered.append(t)\n",
    "            last = t\n",
    "\n",
    "    # keep top_k by probability\n",
    "    if top_k is not None and len(filtered) > top_k:\n",
    "        # re-pick by probability weights\n",
    "        # rebuild dict prob at each t\n",
    "        prob_by_t = { (s_ms + i*stride_ms): window[i] for i,_ in peaks }\n",
    "        filtered = sorted(filtered, key=lambda t: prob_by_t.get(t, 0.0), reverse=True)[:top_k]\n",
    "        filtered.sort()\n",
    "    return filtered"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "ac867b59",
   "metadata": {},
   "outputs": [],
   "source": [
    "def recover_and_match_with_splits(\n",
    "    logits_arr,\n",
    "    segments_ms,\n",
    "    gt_word_phone_ids,\n",
    "    idx2phoneme,\n",
    "    *,\n",
    "    sil_index,\n",
    "    blank_index=None,\n",
    "    stride_ms=4,\n",
    "    # candidate split detection params\n",
    "    sil_th=0.35,\n",
    "    smooth_sigma=2.0,\n",
    "    min_gap_ms=40,\n",
    "    edge_margin_ms=24,\n",
    "    top_k_per_segment=3,\n",
    "    # matching params\n",
    "    max_merge_after_recovery=1,   # enforce 1:1 after splitting\n",
    "    verbose=False\n",
    "):\n",
    "    \"\"\"\n",
    "    Returns:\n",
    "      final_assignments, final_segments_ms, final_total_cost\n",
    "    \"\"\"\n",
    "\n",
    "    # ---- shapes & helpers ----\n",
    "    arr = logits_arr[0] if getattr(logits_arr, \"ndim\", None) == 3 else logits_arr\n",
    "    T = int(arr.shape[0])  # time steps\n",
    "    total_ms = T * stride_ms\n",
    "\n",
    "    def eval_match(curr_segments, max_merge):\n",
    "        return match_segments_to_words_with_merges(\n",
    "            logits_arr=logits_arr,\n",
    "            segments_ms=curr_segments,\n",
    "            gt_word_phone_ids=gt_word_phone_ids,\n",
    "            idx2phoneme=idx2phoneme,\n",
    "            max_merge=max_merge,\n",
    "            stride=stride_ms\n",
    "        )\n",
    "\n",
    "    # ---- quick exits ----\n",
    "    num_words = len(gt_word_phone_ids)\n",
    "    # No words → nothing to match\n",
    "    if num_words == 0:\n",
    "        return [], [], 0.0\n",
    "\n",
    "    # Clean segments: drop zero/negative-length segments\n",
    "    segs = [(s, e) for (s, e) in segments_ms if (e - s) > 0]\n",
    "\n",
    "    # If no detected segments but we have words, bootstrap a single segment over the whole trial\n",
    "    if len(segs) == 0:\n",
    "        segs = [(0, total_ms)]\n",
    "\n",
    "    # ---- build silence prob curve once ----\n",
    "    p_sil = silence_prob_curve(logits_arr, sil_index, blank_index=blank_index)\n",
    "\n",
    "    # initial match; allow merges if we have more segments than words\n",
    "    base_allow_merge = 3 if len(segs) > num_words else 1\n",
    "    _, _, base_cost = eval_match(segs, max_merge=base_allow_merge)\n",
    "    if verbose:\n",
    "        print(f\"[init] segs={len(segs)} words={num_words} cost={base_cost:.3f}\")\n",
    "\n",
    "    # If we already have enough (or more) segments, skip recovery splits\n",
    "    if len(segs) >= num_words:\n",
    "        return eval_match(segs, max_merge=max_merge_after_recovery)\n",
    "\n",
    "    # ---- candidate splits per segment ----\n",
    "    def per_seg_candidates(curr_segs):\n",
    "        cands_all = []\n",
    "        for s in curr_segs:\n",
    "            cands = candidate_splits_for_segment(\n",
    "                p_sil, s,\n",
    "                stride_ms=stride_ms,\n",
    "                th=sil_th,\n",
    "                smooth_sigma=smooth_sigma,\n",
    "                min_gap_ms=min_gap_ms,\n",
    "                edge_margin_ms=edge_margin_ms,\n",
    "                top_k=top_k_per_segment\n",
    "            )\n",
    "            cands_all.append(cands)\n",
    "        return cands_all\n",
    "\n",
    "    per_seg_cands = per_seg_candidates(segs)\n",
    "\n",
    "    # ---- greedy splitting loop ----\n",
    "    while len(segs) < num_words:\n",
    "        best_new_cost = None\n",
    "        best_new_segments = None\n",
    "\n",
    "        # try splitting each segment at each candidate\n",
    "        for seg_idx, (s_ms, e_ms) in enumerate(segs):\n",
    "            cands = per_seg_cands[seg_idx] if seg_idx < len(per_seg_cands) else []\n",
    "            for t_ms in cands:\n",
    "                if not (s_ms < t_ms < e_ms):\n",
    "                    continue\n",
    "                trial_segments = segs[:seg_idx] + [(s_ms, t_ms), (t_ms, e_ms)] + segs[seg_idx+1:]\n",
    "                # guard: drop any zero-length segments that could appear due to rounding\n",
    "                trial_segments = [(a, b) for (a, b) in trial_segments if (b - a) > 0]\n",
    "                _, _, trial_cost = eval_match(trial_segments, max_merge=3)\n",
    "                if (best_new_cost is None) or (trial_cost < best_new_cost):\n",
    "                    best_new_cost = trial_cost\n",
    "                    best_new_segments = trial_segments\n",
    "\n",
    "        # fallback if no candidate helped or existed: split the longest segment at midpoint\n",
    "        if best_new_segments is None:\n",
    "            # pick longest\n",
    "            lengths = [e - s for (s, e) in segs]\n",
    "            if len(lengths) == 0:\n",
    "                # extremely rare: nothing left, bootstrap one segment\n",
    "                segs = [(0, total_ms)]\n",
    "                per_seg_cands = per_seg_candidates(segs)\n",
    "                continue\n",
    "            seg_idx = int(np.argmax(lengths))\n",
    "            s_ms, e_ms = segs[seg_idx]\n",
    "            if e_ms - s_ms <= 2 * edge_margin_ms:\n",
    "                # cannot safely split; break\n",
    "                if verbose:\n",
    "                    print(\"[recover] longest segment too short to split; stopping.\")\n",
    "                break\n",
    "            t_ms = s_ms + (e_ms - s_ms) // 2\n",
    "            best_new_segments = segs[:seg_idx] + [(s_ms, t_ms), (t_ms, e_ms)] + segs[seg_idx+1:]\n",
    "            best_new_segments = [(a, b) for (a, b) in best_new_segments if (b - a) > 0]\n",
    "\n",
    "        # commit\n",
    "        segs = best_new_segments\n",
    "        per_seg_cands = per_seg_candidates(segs)\n",
    "\n",
    "        if verbose:\n",
    "            _, _, curr_cost = eval_match(segs, max_merge=3)\n",
    "            print(f\"[recover] split -> segs={len(segs)} curr_cost={curr_cost:.3f}\")\n",
    "\n",
    "        # safety: avoid runaway\n",
    "        if len(segs) > num_words + 8:\n",
    "            if verbose:\n",
    "                print(\"[recover] too many segments created; stopping.\")\n",
    "            break\n",
    "\n",
    "    #add the right word as part of assignments\n",
    "\n",
    "    \n",
    "\n",
    "    # final pass: enforce 1:1 (or allow small merges, if you prefer)\n",
    "    final_assignments, final_merged, final_cost = eval_match(segs, max_merge=max_merge_after_recovery)\n",
    "    return final_assignments, final_merged, final_cost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "dc19326e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total trials in test_data: 1426\n"
     ]
    }
   ],
   "source": [
    "#count total trials in test_data\n",
    "trials = 0\n",
    "for session, data in test_data.items():\n",
    "    trials += len(data['sentence_label'])\n",
    "\n",
    "print(f'Total trials in test_data: {trials}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "d5f1a5f3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Matches: 1426/1426 (100.00%)\n"
     ]
    }
   ],
   "source": [
    "matches = 0\n",
    "total = 0\n",
    "for session, data in test_data.items():\n",
    "    data['merged_segments'] = []\n",
    "    data['total_cost'] = 0.0\n",
    "    data['assignments'] = []\n",
    "\n",
    "    for trial in range(len(data['sentence_label'])):\n",
    "        segments_ms = data['speech_segments'][trial]\n",
    "        logits_arr = data['logits'][trial]\n",
    "        gnd_truth_phones =data[\"seq_class_ids\"][trial]\n",
    "        # remove padding (class 0)\n",
    "        gnd_truth_phones = [p for p in gnd_truth_phones if p != 0]\n",
    "\n",
    "        #split into a list of word level phoneme list by using space (last class)\n",
    "        # Find positions of delimiter (40)\n",
    "        idx = np.where(np.array(gnd_truth_phones) == 40)[0]\n",
    "\n",
    "        # Split at those positions (+1 to split after the delimiter)\n",
    "        sublists = np.split(gnd_truth_phones, idx + 1)\n",
    "\n",
    "        # Remove the delimiter itself\n",
    "        sublists = [sub[sub != 40].tolist() for sub in sublists if np.any(sub != 40)]\n",
    "\n",
    "        # assignments, merged_segments_ms, total_cost = match_segments_to_words_with_merges(\n",
    "        #     logits_arr, segments_ms, sublists, LOGIT_TO_PHONEME\n",
    "        # )\n",
    "\n",
    "        assignments, merged_segments_ms, total_cost = recover_and_match_with_splits(\n",
    "            logits_arr=logits_arr,\n",
    "            segments_ms=segments_ms,\n",
    "            gt_word_phone_ids=sublists,\n",
    "            idx2phoneme=LOGIT_TO_PHONEME,\n",
    "            sil_index=40,            # your SIL class index\n",
    "            blank_index=0,           # your BLANK / CTC blank index, or None\n",
    "            stride_ms=4,\n",
    "            sil_th=0.35,             # ↑ if too many false splits; ↓ if missing silences\n",
    "            smooth_sigma=2.0,        # smoothing over p_sil curve (in steps)\n",
    "            min_gap_ms=40,           # avoid super-close splits\n",
    "            edge_margin_ms=10,       # avoid splitting near segment edges\n",
    "            top_k_per_segment=3,     # keep best K candidate splits per segment\n",
    "            max_merge_after_recovery=1,  # enforce 1:1 after splitting\n",
    "            verbose=False\n",
    "        )\n",
    "\n",
    "        sentence = data['sentence_label'][trial]\n",
    "        sentence_in_words = sentence.split(\" \")\n",
    "\n",
    "        assignments_with_words = [\n",
    "            (*a, sentence_in_words[a[0]])  # a[0] is word_idx\n",
    "            for a in assignments\n",
    "        ]\n",
    "\n",
    "        #count number of words\n",
    "        num_words = len(sublists)\n",
    "        if len(merged_segments_ms) != num_words:\n",
    "            print(f\"Warning: Number of merged segments ({len(merged_segments_ms)}) does not match number of words ({num_words}) for session {session}, trial {trial}. Original number of segments: {len(segments_ms)}\")\n",
    "        else:\n",
    "            matches += 1\n",
    "        total += 1\n",
    "\n",
    "\n",
    "\n",
    "        data['merged_segments'].append(merged_segments_ms)\n",
    "        data['total_cost'] += total_cost\n",
    "        data['assignments'].append(assignments_with_words)\n",
    "        \n",
    "\n",
    "print(f\"Matches: {matches}/{total} ({matches/total:.2%})\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54c99da8",
   "metadata": {},
   "source": [
    "### same for train_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "c5b03716",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warning: Number of merged segments (4) does not match number of words (6) for session t15.2023.10.06, trial 110. Original number of segments: 1\n",
      "Sentence label: It's really more of a toy.\n",
      "trial, 110, session t15.2023.10.06,\n",
      "Warning: Number of merged segments (8) does not match number of words (10) for session t15.2024.02.25, trial 15. Original number of segments: 1\n",
      "Sentence label: We have to get him back to that point again.\n",
      "trial, 15, session t15.2024.02.25,\n",
      "Warning: Number of merged segments (8) does not match number of words (10) for session t15.2024.02.25, trial 18. Original number of segments: 1\n",
      "Sentence label: Try to think about it from their point of view.\n",
      "trial, 18, session t15.2024.02.25,\n",
      "Matches: 8069/8072 (99.96%)\n"
     ]
    }
   ],
   "source": [
    "matches = 0\n",
    "total = 0\n",
    "\n",
    "lower_silence_trials = []\n",
    "\n",
    "for session, data in train_data.items():\n",
    "    data['merged_segments'] = []\n",
    "    data['total_cost'] = 0.0\n",
    "    data['assignments'] = []\n",
    "\n",
    "    for trial in range(len(data['sentence_label'])):\n",
    "        segments_ms = data['speech_segments'][trial]\n",
    "        logits_arr = data['logits'][trial]\n",
    "        gnd_truth_phones =data[\"seq_class_ids\"][trial]\n",
    "        # remove padding (class 0)\n",
    "        gnd_truth_phones = [p for p in gnd_truth_phones if p != 0]\n",
    "\n",
    "        #split into a list of word level phoneme list by using space (last class)\n",
    "        # Find positions of delimiter (40)\n",
    "        idx = np.where(np.array(gnd_truth_phones) == 40)[0]\n",
    "\n",
    "        # Split at those positions (+1 to split after the delimiter)\n",
    "        sublists = np.split(gnd_truth_phones, idx + 1)\n",
    "\n",
    "        # Remove the delimiter itself\n",
    "        sublists = [sub[sub != 40].tolist() for sub in sublists if np.any(sub != 40)]\n",
    "\n",
    "        # assignments, merged_segments_ms, total_cost = match_segments_to_words_with_merges(\n",
    "        #     logits_arr, segments_ms, sublists, LOGIT_TO_PHONEME\n",
    "        # )\n",
    "\n",
    "        #if number of segments is less than number of words, use recovery and matching with splits\n",
    "        if len(segments_ms) < len(sublists):\n",
    "            lower_silence_trials.append((session, trial, len(segments_ms), len(sublists)))\n",
    "\n",
    "        assignments, merged_segments_ms, total_cost = recover_and_match_with_splits(\n",
    "            logits_arr=logits_arr,\n",
    "            segments_ms=segments_ms,\n",
    "            gt_word_phone_ids=sublists,\n",
    "            idx2phoneme=LOGIT_TO_PHONEME,\n",
    "            sil_index=40,            # your SIL class index\n",
    "            blank_index=0,           # your BLANK / CTC blank index, or None\n",
    "            stride_ms=4,\n",
    "            sil_th=0.30,             # ↑ if too many false splits; ↓ if missing silences\n",
    "            smooth_sigma=2.0,        # smoothing over p_sil curve (in steps)\n",
    "            min_gap_ms=30,           # avoid super-close splits\n",
    "            edge_margin_ms=10,       # avoid splitting near segment edges\n",
    "            top_k_per_segment=5,     # keep best K candidate splits per segment\n",
    "            max_merge_after_recovery=1,  # enforce 1:1 after splitting\n",
    "            verbose=False\n",
    "        )\n",
    "        sentence = data['sentence_label'][trial]\n",
    "        sentence_in_words = sentence.split(\" \")\n",
    "\n",
    "        assignments_with_words = [\n",
    "            (*a, sentence_in_words[a[0]])  # a[0] is word_idx\n",
    "            for a in assignments\n",
    "        ]\n",
    "\n",
    "        #count number of words\n",
    "        num_words = len(sublists)\n",
    "        if len(merged_segments_ms) != num_words:\n",
    "            print(f\"Warning: Number of merged segments ({len(merged_segments_ms)}) does not match number of words ({num_words}) for session {session}, trial {trial}. Original number of segments: {len(segments_ms)}\")\n",
    "            #print the sentence label\n",
    "            print(f\"Sentence label: {data['sentence_label'][trial]}\")\n",
    "            print(f\"trial, {trial}, session {session},\")\n",
    "        else:\n",
    "            matches += 1\n",
    "        total += 1\n",
    "        data['merged_segments'].append(merged_segments_ms)\n",
    "        data['total_cost'] += total_cost\n",
    "        data['assignments'].append(assignments_with_words)\n",
    "\n",
    "print(f\"Matches: {matches}/{total} ({matches/total:.2%})\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "e2eeb84b",
   "metadata": {},
   "outputs": [],
   "source": [
    "ses_name_to_remove = [(\"t15.2023.10.06\",110), (\"t15.2024.02.25\", 15), (\"t15.2024.02.25\", 18) ]\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "069a9e36",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Removed session t15.2023.10.06, trial 110 due to known issues.\n",
      "Sentence label: It's really more of a toy.\n",
      "Removed session t15.2024.02.25, trial 15 due to known issues.\n",
      "Sentence label: We have to get him back to that point again.\n",
      "Removed session t15.2024.02.25, trial 18 due to known issues.\n",
      "Sentence label: Try to think about it from their point of view.\n"
     ]
    }
   ],
   "source": [
    "cleaned_train_data = {}\n",
    "\n",
    "# initialize all keys as empty lists\n",
    "for session, data in train_data.items():\n",
    "    for key in data.keys():\n",
    "        if key not in ['logits', 'pred_seq', 'total_cost']:\n",
    "            if key not in cleaned_train_data:\n",
    "                cleaned_train_data[key] = []\n",
    "\n",
    "# now stack all trials\n",
    "for session, data in train_data.items():\n",
    "    for trial in range(len(data['neural_features'])):\n",
    "        if (session, trial) not in ses_name_to_remove:\n",
    "            for key in data.keys():\n",
    "                if key in ['logits', 'pred_seq', 'total_cost']:\n",
    "                    continue\n",
    "                cleaned_train_data[key].append(data[key][trial])\n",
    "        else:\n",
    "            print(f\"Removed session {session}, trial {trial} due to known issues.\")\n",
    "            print(f\"Sentence label: {data['sentence_label'][trial]}\")\n",
    "\n",
    "#add another column dayIdx to cleaned_train_data each session name becomes a day index\n",
    "day_to_idx = {session: idx for idx, session in enumerate(model_args['dataset']['sessions'])}\n",
    "cleaned_train_data['dayIdx'] = [day_to_idx[session] for session in cleaned_train_data['session']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "d643d2f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "cleaned_val_data = {}\n",
    "# initialize all keys as empty lists\n",
    "for session, data in test_data.items():\n",
    "    for key in data.keys():\n",
    "        if key not in ['logits', 'pred_seq', 'total_cost']:\n",
    "            if key not in cleaned_val_data:\n",
    "                cleaned_val_data[key] = []\n",
    "# now stack all trials\n",
    "for session, data in test_data.items():\n",
    "    for trial in range(len(data['neural_features'])):\n",
    "        for key in data.keys():\n",
    "            if key in ['logits', 'pred_seq', 'total_cost']:\n",
    "                continue\n",
    "            cleaned_val_data[key].append(data[key][trial])\n",
    "\n",
    "\n",
    "day_to_idx = {session: idx for idx, session in enumerate(model_args['dataset']['sessions'])}\n",
    "cleaned_val_data['dayIdx'] = [day_to_idx[session] for session in cleaned_val_data['session']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "5fe38631",
   "metadata": {},
   "outputs": [],
   "source": [
    "cleaned_test_data = {}\n",
    "# initialize all keys as empty lists\n",
    "for session, data in competition_data.items():\n",
    "    for key in data.keys():\n",
    "        if key not in ['logits', 'pred_seq', 'total_cost']:\n",
    "            if key not in cleaned_test_data:\n",
    "                cleaned_test_data[key] = []\n",
    "# now stack all trials\n",
    "for session, data in competition_data.items():\n",
    "    for trial in range(len(data['neural_features'])):\n",
    "        for key in data.keys():\n",
    "            if key in ['logits', 'pred_seq', 'total_cost']:\n",
    "                continue\n",
    "            cleaned_test_data[key].append(data[key][trial])\n",
    "\n",
    "day_to_idx = {session: idx for idx, session in enumerate(model_args['dataset']['sessions'])}\n",
    "cleaned_test_data['dayIdx'] = [day_to_idx[session] for session in cleaned_test_data['session']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "3df87d75",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['neural_features', 'n_time_steps', 'seq_class_ids', 'seq_len', 'transcriptions', 'sentence_label', 'session', 'block_num', 'trial_num', 'corpus', 'speech_segments', 'merged_segments', 'assignments', 'dayIdx'])"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cleaned_train_data.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "f128a0fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "PHONE_DEF = [\n",
    "    'AA', 'AE', 'AH', 'AO', 'AW',\n",
    "    'AY', 'B',  'CH', 'D', 'DH',\n",
    "    'EH', 'ER', 'EY', 'F', 'G',\n",
    "    'HH', 'IH', 'IY', 'JH', 'K',\n",
    "    'L', 'M', 'N', 'NG', 'OW',\n",
    "    'OY', 'P', 'R', 'S', 'SH',\n",
    "    'T', 'TH', 'UH', 'UW', 'V',\n",
    "    'W', 'Y', 'Z', 'ZH'\n",
    "]\n",
    "PHONE_DEF_SIL = PHONE_DEF + ['SIL']\n",
    "\n",
    "def phoneToId(p):\n",
    "    return PHONE_DEF_SIL.index(p)\n",
    "\n",
    "phoneToIdDict = {p:phoneToId(p) for p in PHONE_DEF_SIL}\n",
    "\n",
    "class AugmentedNeuralTextDataset(torch.utils.data.Dataset):\n",
    "    \"\"\"augmentations ideas:\n",
    "        1. truncate the signal\n",
    "        2. mixup two samples from the same day replacing a word\n",
    "        3. add a word from another sample in the same day at beginning or end\n",
    "\n",
    "        #for the future, check only valid sentences\n",
    "\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, data_dict, mix_prob = 0.5, subset_len = 3):\n",
    "        self.neural_features = data_dict['neural_features']\n",
    "        self.n_time_steps = data_dict['n_time_steps']\n",
    "        self.seq_class_ids = data_dict['seq_class_ids']\n",
    "        self.seq_len = data_dict['seq_len']\n",
    "        self.sentence_label = data_dict['sentence_label']\n",
    "        self.merged_segments = data_dict['merged_segments']\n",
    "        self.assignments = data_dict['assignments']\n",
    "        self.transcriptions = data_dict[\"sentence_label\"]\n",
    "        self.days = torch.tensor(data_dict[\"dayIdx\"], dtype=torch.long)\n",
    "        self.mix_prob = mix_prob\n",
    "        self.subset_len = subset_len\n",
    "\n",
    "        self.aug_methods = [\"truncate\", \"mixup\"]\n",
    "        self.aug_p = [0.50, 0.50]# probabilities for each augmentation method\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.neural_features)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "\n",
    "        seq_class_ids = self.seq_class_ids[idx]\n",
    "        #remove padding (class 0)\n",
    "        seq_class_ids = [p.item() for p in seq_class_ids if p != 0]\n",
    "\n",
    "        item =  {\n",
    "            'neural_features': torch.tensor(self.neural_features[idx], dtype=torch.float32),\n",
    "            'n_time_steps': self.n_time_steps[idx],\n",
    "            'seq_class_ids': torch.tensor(seq_class_ids, dtype=torch.long),\n",
    "            'seq_len': torch.tensor(self.seq_len[idx], dtype=torch.long),\n",
    "            'sentence_label': self.sentence_label[idx],\n",
    "            'merged_segments': self.merged_segments[idx],\n",
    "            'assignments': self.assignments[idx],\n",
    "            'transcriptions': self.transcriptions[idx],\n",
    "            'day': self.days[idx],\n",
    "            'augmentation': \"none\"\n",
    "        }\n",
    "\n",
    "        r = np.random.rand()\n",
    "        if r < self.mix_prob:\n",
    "            #do mixup \n",
    "\n",
    "            #1. look for another random index from the same day\n",
    "            # same_day_indices = [i for i, day in enumerate(self.days) if day == self.days[idx] and i != idx]\n",
    "            # if len(same_day_indices) == 0:\n",
    "            #     return item\n",
    "            # rand_idx = np.random.choice(same_day_indices)\n",
    "\n",
    "            # other_item = {\n",
    "            #     'neural_features': torch.tensor(self.neural_features[rand_idx], dtype=torch.float32),\n",
    "            #     'n_time_steps': self.n_time_steps[rand_idx],\n",
    "            #     'seq_class_ids': torch.tensor(self.seq_class_ids[rand_idx], dtype=torch.long),  \n",
    "            #     'seq_len': torch.tensor(self.seq_len[rand_idx], dtype=torch.long),\n",
    "            #     'sentence_label': self.sentence_label[rand_idx],\n",
    "            #     'merged_segments': self.merged_segments[rand_idx],\n",
    "            #     'assignments': self.assignments[rand_idx],\n",
    "            #     'transcriptions': self.transcriptions[rand_idx],\n",
    "            #     'day': self.days[rand_idx]\n",
    "            # }\n",
    "            print(\"doing augmentation\")\n",
    "\n",
    "            aug_method = np.random.choice(self.aug_methods, p=self.aug_p)\n",
    "            print(f\"augmentation method: {aug_method}\")\n",
    "            if aug_method == \"truncate\":\n",
    "\n",
    "                print(\"doing truncate\")\n",
    "                other_item = item.copy()\n",
    "\n",
    "                assignments_full = other_item['assignments']\n",
    "                merged_full = other_item['merged_segments']\n",
    "                n_words = len(assignments_full)\n",
    "\n",
    "                if n_words <= self.subset_len:\n",
    "                    # keep your existing first/last drop logic here (unchanged)\n",
    "                    if n_words == 1:\n",
    "                        return other_item\n",
    "                    if np.random.rand() < 0.5:\n",
    "                        # ---- drop FIRST word (your existing code) ----\n",
    "                        merged_segms = other_item['merged_segments'][0]\n",
    "                        start_idx = merged_segms[-1]\n",
    "                        end_idx = len(other_item['neural_features'])\n",
    "\n",
    "                        other_item['neural_features'] = other_item['neural_features'][start_idx:end_idx]\n",
    "                        other_item['n_time_steps'] = end_idx - start_idx\n",
    "                        other_item['transcriptions'] = \" \".join([a[-1] for a in assignments_full[1:]])\n",
    "                        other_item['sentence_label'] = other_item['transcriptions']\n",
    "\n",
    "                        new_seq_class_ids = []\n",
    "                        for ass in assignments_full[1:]:\n",
    "                            new_seq_class_ids.extend([phoneToId(p) + 1 for p in ass[4]])\n",
    "                            new_seq_class_ids.append(phoneToId(\"SIL\") + 1)\n",
    "                        other_item[\"seq_class_ids\"] = torch.tensor(new_seq_class_ids, dtype=torch.long)\n",
    "                        other_item[\"seq_len\"] = len(new_seq_class_ids)\n",
    "\n",
    "                        # rebase merged segments & assignments to the new slice\n",
    "                        kept_assignments = assignments_full[1:]\n",
    "                        kept_merged = []\n",
    "                        for seg in merged_full[1:]:\n",
    "                            kept_merged.append([ix - start_idx for ix in seg])\n",
    "                        other_item['assignments'] = kept_assignments\n",
    "                        other_item['merged_segments'] = kept_merged\n",
    "\n",
    "                        other_item[\"augmentation\"] = \"drop_first\"\n",
    "\n",
    "                    else:\n",
    "                        # ---- drop LAST word (mirror path) ----\n",
    "                        last_seg = other_item['merged_segments'][-1]\n",
    "                        end_idx = last_seg[0]  # start of last word\n",
    "                        start_idx = 0\n",
    "\n",
    "                        other_item['neural_features'] = other_item['neural_features'][start_idx:end_idx+4] # include a bit of silence after last word\n",
    "                        other_item['n_time_steps'] = end_idx - start_idx\n",
    "                        other_item['transcriptions'] = \" \".join([a[-1] for a in assignments_full[:-1]])\n",
    "                        other_item['sentence_label'] = other_item['transcriptions']\n",
    "\n",
    "                        new_seq_class_ids = []\n",
    "                        for ass in assignments_full[:-1]:\n",
    "                            new_seq_class_ids.extend([phoneToId(p) + 1 for p in ass[4]])\n",
    "                            new_seq_class_ids.append(phoneToId(\"SIL\") + 1)\n",
    "                        other_item[\"seq_class_ids\"] = torch.tensor(new_seq_class_ids, dtype=torch.long)\n",
    "                        other_item[\"seq_len\"] = len(new_seq_class_ids)\n",
    "\n",
    "                        # merged segments/assignments (no rebase needed when starting at 0)\n",
    "                        other_item['assignments'] = assignments_full[:-1]\n",
    "                        other_item['merged_segments'] = merged_full[:-1]\n",
    "\n",
    "                        print(\"dropped last word\")\n",
    "                        other_item[\"augmentation\"] = \"drop_last\"\n",
    "\n",
    "                else:\n",
    "                    win_len = self.subset_len\n",
    "                    start_word = np.random.randint(0, n_words - win_len + 1)\n",
    "                    end_word = start_word + win_len  # exclusive\n",
    "\n",
    "                    # slice words/segments\n",
    "                    kept_assignments = assignments_full[start_word:end_word]\n",
    "                    kept_merged_abs = merged_full[start_word:end_word]\n",
    "\n",
    "                    # absolute indices to cut the neural signal\n",
    "                    # start at the first kept word's start\n",
    "                    start_idx = int(kept_merged_abs[0][0])\n",
    "\n",
    "                    # end at the NEXT word's start if it exists, else end of the original signal\n",
    "                    if end_word < n_words:\n",
    "                        end_idx = int(merged_full[end_word][0])  # start of the next word (exclusive)\n",
    "                    else:\n",
    "                        # use original (pre-slice) length to avoid compounding\n",
    "                        end_idx = int(len(other_item['neural_features']))\n",
    "\n",
    "                    # safety\n",
    "                    end_idx = max(end_idx, start_idx + 1)\n",
    "\n",
    "                    # slice neural features\n",
    "                    other_item['neural_features'] = other_item['neural_features'][start_idx:end_idx]\n",
    "                    other_item['n_time_steps'] = end_idx - start_idx\n",
    "\n",
    "                    # update text\n",
    "                    other_item['transcriptions'] = \" \".join([a[-1] for a in kept_assignments])\n",
    "                    other_item['sentence_label'] = other_item['transcriptions']\n",
    "\n",
    "                    # rebuild seq_class_ids (insert SIL between words; no trailing SIL)\n",
    "                    new_seq_class_ids = []\n",
    "                    for i, ass in enumerate(kept_assignments):\n",
    "                        new_seq_class_ids.extend([phoneToId(p) + 1 for p in ass[4]])\n",
    "                        if i < len(kept_assignments) - 1:\n",
    "                            new_seq_class_ids.append(phoneToId(\"SIL\") + 1)\n",
    "\n",
    "                    other_item[\"seq_class_ids\"] = torch.tensor(new_seq_class_ids, dtype=torch.long)\n",
    "                    other_item[\"seq_len\"] = len(new_seq_class_ids)\n",
    "\n",
    "                    # rebase merged segments to the new window and CLIP inside [0, end_idx-start_idx)\n",
    "                    L = end_idx - start_idx\n",
    "                    kept_merged_rebased = []\n",
    "                    for seg in kept_merged_abs:\n",
    "                        rebased = [int(ix - start_idx) for ix in seg]\n",
    "                        # clip to be safe (handles inclusive vs exclusive ambiguity)\n",
    "                        rebased = [min(max(0, x), L - 1) for x in rebased]\n",
    "                        # ensure sorted and unique\n",
    "                        rebased = sorted(set(rebased))\n",
    "                        kept_merged_rebased.append(rebased)\n",
    "\n",
    "                    other_item['assignments'] = kept_assignments\n",
    "                    other_item['merged_segments'] = kept_merged_rebased\n",
    "                    other_item[\"augmentation\"] = \"subset\"\n",
    "        \n",
    "                return other_item\n",
    "        \n",
    "            elif aug_method == \"mixup\":\n",
    "                \n",
    "                # ---- MIXUP (variable-length splice) ----\n",
    "                # 1) find donor trial from same day\n",
    "                same_day_indices = [i for i, d in enumerate(self.days) if d == self.days[idx] and i != idx]\n",
    "                if not same_day_indices:\n",
    "                    return item, item  # no donor, bail out safely\n",
    "                rand_idx = int(np.random.choice(same_day_indices))\n",
    "\n",
    "                donor_feats = torch.tensor(self.neural_features[rand_idx], dtype=torch.float32)\n",
    "                donor_assign = self.assignments[rand_idx]\n",
    "                donor_merged = self.merged_segments[rand_idx]\n",
    "\n",
    "                # 2) pick random word indices\n",
    "                orig_n_words = len(item['assignments'])\n",
    "                donor_n_words = len(donor_assign)\n",
    "                if orig_n_words == 0 or donor_n_words == 0:\n",
    "                    return item, item\n",
    "\n",
    "                i_orig = np.random.randint(0, orig_n_words)\n",
    "                j_donor = np.random.randint(0, donor_n_words)\n",
    "\n",
    "                # 3) get original and donor segment boundaries (absolute frame indices)\n",
    "                orig_seg_abs = item['merged_segments'][i_orig]\n",
    "                donor_seg_abs = donor_merged[j_donor]\n",
    "\n",
    "                orig_start = int(orig_seg_abs[0])\n",
    "                orig_end   = int(orig_seg_abs[-1])\n",
    "                donor_start = int(donor_seg_abs[0])\n",
    "                donor_end   = int(donor_seg_abs[-1])\n",
    "\n",
    "                orig_len = max(0, orig_end - orig_start)\n",
    "                donor_len = max(0, donor_end - donor_start)\n",
    "                if orig_len == 0 or donor_len == 0:\n",
    "                    return item, item\n",
    "\n",
    "                # 4) deep(ish) copy current example to edit\n",
    "                other_item = {\n",
    "                    'neural_features': item['neural_features'].clone(),  # (T, C) tensor\n",
    "                    'n_time_steps': item['n_time_steps'],\n",
    "                    'seq_class_ids': None,\n",
    "                    'seq_len': None,\n",
    "                    'sentence_label': item['sentence_label'],\n",
    "                    'merged_segments': [list(seg) for seg in item['merged_segments']],\n",
    "                    'assignments': [list(a) for a in item['assignments']],  # make mutable\n",
    "                    'transcriptions': item['transcriptions'],\n",
    "                    'day': item['day'],\n",
    "                }\n",
    "\n",
    "                # 5) splice donor frames (variable length): new = before + donor + after\n",
    "                before = other_item['neural_features'][:orig_start]\n",
    "                donor_slice = donor_feats[donor_start:donor_end]         # (donor_len, C)\n",
    "                after  = other_item['neural_features'][orig_end:]\n",
    "\n",
    "                new_feats = torch.cat([before, donor_slice, after], dim=0)\n",
    "                other_item['neural_features'] = new_feats\n",
    "                other_item['n_time_steps'] = new_feats.shape[0]\n",
    "\n",
    "                # length delta: how much timeline shifts for words AFTER i_orig\n",
    "                delta = donor_len - orig_len\n",
    "\n",
    "                # 6) update merged_segments:\n",
    "                #    - replaced segment gets donor's internal indices, rebased to orig_start\n",
    "                #    - all following segments shift by +delta\n",
    "                #    - prior segments unchanged\n",
    "                new_merged = []\n",
    "                for k, seg in enumerate(item['merged_segments']):\n",
    "                    if k < i_orig:\n",
    "                        new_merged.append(list(seg))\n",
    "                    elif k == i_orig:\n",
    "                        # rebase donor indices to align at orig_start\n",
    "                        # donor indices relative -> ix' = orig_start + (ix - donor_start)\n",
    "                        rebased = [int(orig_start + (ix - donor_start)) for ix in donor_seg_abs]\n",
    "                        # ensure sorted and within new length\n",
    "                        rebased = sorted(rebased)\n",
    "                        new_merged.append(rebased)\n",
    "                    else:\n",
    "                        # shift everything by delta\n",
    "                        shifted = [int(ix + delta) for ix in seg]\n",
    "                        new_merged.append(shifted)\n",
    "                other_item['merged_segments'] = new_merged\n",
    "\n",
    "                # 7) update assignments content at i_orig using donor's word + phonemes\n",
    "                #    Keep the *structure* of the assignment tuple; replace only word & phoneme list.\n",
    "                a_list = other_item['assignments'][i_orig]\n",
    "                a_list[4] = list(donor_assign[j_donor][4])   # phoneme strings\n",
    "                a_list[-1] = donor_assign[j_donor][-1]       # word text\n",
    "                other_item['assignments'][i_orig] = a_list\n",
    "\n",
    "                # (Optional) If your assignment stores a word_idx at [0], ensure it still matches:\n",
    "                # other_item['assignments'][i_orig][0] = i_orig\n",
    "\n",
    "                # 8) update transcription text (swap the word string)\n",
    "                words = other_item['transcriptions'].split()\n",
    "                if len(words) == len(other_item['assignments']):\n",
    "                    words[i_orig] = donor_assign[j_donor][-1]\n",
    "                    new_text = \" \".join(words)\n",
    "                else:\n",
    "                    # fallback: reconstruct from assignments to be safe\n",
    "                    new_text = \" \".join([a[-1] for a in other_item['assignments']])\n",
    "                other_item['transcriptions'] = new_text\n",
    "                other_item['sentence_label'] = new_text\n",
    "\n",
    "                # 9) rebuild seq_class_ids from assignments with SIL between words (no trailing SIL)\n",
    "                new_seq_class_ids = []\n",
    "                for wi, ass in enumerate(other_item['assignments']):\n",
    "                    # ass[4] is list of phoneme strings\n",
    "                    new_seq_class_ids.extend([phoneToId(p) + 1 for p in ass[4]])\n",
    "                    if wi < len(other_item['assignments']) - 1:\n",
    "                        new_seq_class_ids.append(phoneToId(\"SIL\") + 1)\n",
    "\n",
    "                other_item[\"seq_class_ids\"] = torch.tensor(new_seq_class_ids, dtype=torch.long)\n",
    "                other_item[\"seq_len\"] = len(new_seq_class_ids)\n",
    "                other_item[\"augmentation\"] = \"mixup\"\n",
    "\n",
    "                return other_item\n",
    "\n",
    "        return item\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9122c48",
   "metadata": {},
   "source": [
    "## Remove Inf trials (3 on training data) -> Check all the others, especially then tricky ones where the number of initial split was lower. Do they make sense? If yes proceed and save the dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "39d8c1c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['trans_args.yaml',\n",
       " 'evaluate_model.py',\n",
       " 'rnn_baseline_submission_file_valsplit.csv',\n",
       " 'train_model_base.py',\n",
       " 'B2TXT25_Whisper',\n",
       " 'train_model_ensembling.py',\n",
       " 'RNN_Whisper_augmentations.ipynb',\n",
       " 'README.md',\n",
       " 'rnn_trainer.py',\n",
       " 'understand_data.ipynb',\n",
       " 'data_augmentations.py',\n",
       " 'rnn_model.py',\n",
       " 'masked_transformer.ipynb',\n",
       " '__pycache__',\n",
       " 'Wav2Vec.ipynb',\n",
       " 'rnn_args_masking.yaml',\n",
       " 'slicing_dataset.ipynb',\n",
       " 'transformer_trainer_masked.py',\n",
       " 'wandb',\n",
       " 'rnn_args.yaml',\n",
       " 'dataset.py',\n",
       " 'RNN_Whisper.ipynb',\n",
       " 'train_model.py',\n",
       " 'evaluate_model_helpers.py',\n",
       " 'evaluate_ensembling.ipynb',\n",
       " 'train_model_transformer.py',\n",
       " 'lightning_logs',\n",
       " 'trained_models',\n",
       " 'rnn_trainer_masked.py',\n",
       " 'transformer_model.py']"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import os\n",
    "os.listdir()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc8a41e7",
   "metadata": {},
   "source": [
    "## Save the adjusted trials for both validation and test set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "858a7cf7",
   "metadata": {},
   "outputs": [],
   "source": [
    "## save cleaned_train_data as pickle file\n",
    "\n",
    "import pickle \n",
    "\n",
    "with open(\"../data/cleaned_train_data.pkl\", \"wb\") as f:\n",
    "    pickle.dump(cleaned_train_data, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "50009699",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save cleaned_val_data and cleaned_test_data\n",
    "\n",
    "with open(\"../data/cleaned_val_data.pkl\", \"wb\") as f:\n",
    "    pickle.dump(cleaned_val_data, f)\n",
    "\n",
    "with open(\"../data/cleaned_test_data.pkl\", \"wb\") as f:\n",
    "    pickle.dump(cleaned_test_data, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "7adcb460",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'truncate'"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.random.choice([\"truncate\", \"mixup\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "b1321d0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = AugmentedNeuralTextDataset(cleaned_train_data, mix_prob = 0.8, subset_len = 5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "21e2519e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "doing augmentation\n",
      "augmentation method: mixup\n"
     ]
    },
    {
     "ename": "ValueError",
     "evalue": "too many values to unpack (expected 2)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[32], line 4\u001b[0m\n\u001b[1;32m      1\u001b[0m idx \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mrandom\u001b[38;5;241m.\u001b[39mrandint(\u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m200\u001b[39m)\n\u001b[1;32m      3\u001b[0m a \u001b[38;5;241m=\u001b[39m train_dataset\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__getitem__\u001b[39m(idx)\n\u001b[0;32m----> 4\u001b[0m b,c \u001b[38;5;241m=\u001b[39m a\n\u001b[1;32m      6\u001b[0m b[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtranscriptions\u001b[39m\u001b[38;5;124m\"\u001b[39m], c[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtranscriptions\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
      "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)"
     ]
    }
   ],
   "source": [
    "idx = np.random.randint(0, 200)\n",
    "\n",
    "a = train_dataset.__getitem__(idx)\n",
    "b,c = a\n",
    "\n",
    "b[\"transcriptions\"], c[\"transcriptions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31f9bff5",
   "metadata": {},
   "outputs": [],
   "source": [
    "idToPhone = {v: k for k, v in phoneToIdDict.items()}\n",
    "\n",
    "def idsToPhonemes(seqClassIDs, idToPhone = idToPhone):\n",
    "    \"\"\"\n",
    "    Converts a sequence of phoneme IDs back to their phoneme representations.\n",
    "    \n",
    "    Args:\n",
    "        seqClassIDs (numpy array): The numerical sequence of phoneme IDs.\n",
    "        idToPhone (dict): A dictionary mapping phoneme IDs back to phonemes.\n",
    "        \n",
    "    Returns:\n",
    "        list: The corresponding phoneme sequence.\n",
    "    \"\"\"\n",
    "    phonemeSeq = [idToPhone[id - 1] for id in seqClassIDs if id > 0]  # -1 because IDs were stored with +1\n",
    "    return phonemeSeq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38325034",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['N', 'AA', 'T', 'SIL', 'V', 'EH', 'R', 'IY']"
      ]
     },
     "execution_count": 515,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idsToPhonemes(c[\"seq_class_ids\"].numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cdbfa383",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['DH', 'EY', 'SIL', 'AA', 'R', 'SIL', 'N', 'AA', 'T']\n"
     ]
    }
   ],
   "source": [
    "logits = model(c[\"neural_features\"].unsqueeze(0).to(device), c[\"day\"].unsqueeze(0).to(device))\n",
    "\n",
    "# do argmax and collapse\n",
    "predicted_ids = logits.argmax(dim=-1)\n",
    "\n",
    "# collapse\n",
    "predicted_ids = torch.unique_consecutive(predicted_ids).detach().cpu().numpy()\n",
    "\n",
    "#convert into phoneme ids\n",
    "predicted_phonemes = idsToPhonemes(predicted_ids)\n",
    "print(predicted_phonemes)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4260314",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0, 22,  6, 40,  0, 23, 12, 29, 40,  0, 17, 38, 40,  0, 23,  1, 31,\n",
       "       40,  0, 16,  3, 24, 15, 28, 18, 40,  0])"
      ]
     },
     "execution_count": 403,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predicted_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d8005aa4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([424, 512]), torch.Size([512, 512]))"
      ]
     },
     "execution_count": 380,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b[\"neural_features\"].shape, c[\"neural_features\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1988d358",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([37, 11, 29, 40,  6, 40,  9, 34, 40]),\n",
       " tensor([37, 11, 29, 40,  6, 40,  6]))"
      ]
     },
     "execution_count": 381,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b[\"seq_class_ids\"], c[\"seq_class_ids\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23659903",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor(9), 7)"
      ]
     },
     "execution_count": 382,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b[\"seq_len\"], c[\"seq_len\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e073b90",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(424, 512)"
      ]
     },
     "execution_count": 383,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b[\"n_time_steps\"], c[\"n_time_steps\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "53440cde",
   "metadata": {},
   "outputs": [
    {
     "ename": "ValueError",
     "evalue": "too many values to unpack (expected 2)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[126], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m b,c \u001b[38;5;241m=\u001b[39m a\n",
      "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)"
     ]
    }
   ],
   "source": [
    "b,c = a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "99369e86",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Bring it closer.'"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b[\"transcriptions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7ff4018",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'You are going.'"
      ]
     },
     "execution_count": 93,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "c[\"transcriptions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3f004f7b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "96"
      ]
     },
     "execution_count": 106,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b[\"merged_segments\"][0][-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "14a3a1a5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([321, 512])"
      ]
     },
     "execution_count": 97,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b[\"neural_features\"].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7a5e347",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Bring'"
      ]
     },
     "execution_count": 110,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "b[\"assignments\"][0][-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "993d68fe",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.23"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
