{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "59debae9",
   "metadata": {},
   "source": [
    "## Here we target phoneme and diphoneme predictions as two distict tasks\n",
    "\n",
    "In this case, I'll also sum the probabilities from the single phoneme head and marginalized diphonemes head"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e2787057",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package cmudict to /home/XXXXXX/nltk_data...\n",
      "[nltk_data]   Package cmudict is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "from dataset import SpeechSentenceDataset, idsToPhonemes, getDatasetLoaders,getDatasetLoaders_V3, PHONE_DEF, PHONE_DEF_SIL\n",
    "import re \n",
    "from g2p_en import G2p\n",
    "import numpy as np\n",
    "from model.ctc_modelling import LightningGRUDecoder\n",
    "import time\n",
    "import numpy as np\n",
    "from edit_distance import SequenceMatcher\n",
    "import tqdm\n",
    "import pytorch_lightning as pl\n",
    "import jiwer\n",
    "import nltk\n",
    "from nltk.corpus import cmudict\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "import wandb\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import copy\n",
    "from difflib import get_close_matches\n",
    "from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer\n",
    "import pandas as pd\n",
    "from torchaudio.models.decoder import ctc_decoder\n",
    "import string\n",
    "from config import DATASET_SM_ROBUST, DATASET_SM_ZSCORE, DATASET_FULL_TRIALS_ZSCORE\n",
    "import torchaudio  \n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "# from model.ctc_modelling import Light\n",
    "\n",
    "# Download CMU Pronouncing Dictionary (First-time use)\n",
    "nltk.download(\"cmudict\")\n",
    "\n",
    "# Load CMUdict\n",
    "cmu_dict = cmudict.dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b672ec04",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of trials:  10020\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  8800\n",
      "Number of trials:  880\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  880\n"
     ]
    }
   ],
   "source": [
    "train_loader, test_loader,_, loadedData = getDatasetLoaders_V3(DATASET_FULL_TRIALS_ZSCORE, 64, include_prego=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "99ce5c68",
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product\n",
    "from collections import defaultdict\n",
    "\n",
    "def build_diphone_vocab(phone_def_sil):\n",
    "    diphones = [f\"{a}→{b}\" for a, b in product(phone_def_sil, repeat=2)]\n",
    "    diphone_to_idx = {d: i+1 for i, d in enumerate(sorted(diphones))}  # +1 for CTC blank at idx 0\n",
    "    idx_to_diphone = {i: d for d, i in diphone_to_idx.items()}\n",
    "    return diphone_to_idx, idx_to_diphone\n",
    "\n",
    "diphone_to_idx, idx_to_diphone = build_diphone_vocab(PHONE_DEF_SIL)\n",
    "\n",
    "\n",
    "\n",
    "def build_phoneme_to_diphone_ids(PHONE_DEF_SIL, diphone_to_idx):\n",
    "    phoneme_to_diphone_ids = defaultdict(list)\n",
    "\n",
    "    for diphone_str, idx in diphone_to_idx.items():\n",
    "        if idx == 0:\n",
    "            continue  # skip CTC blank\n",
    "        left, right = diphone_str.split(\"→\")\n",
    "        phoneme_to_diphone_ids[left].append(idx)   # p is source: p→X\n",
    "        phoneme_to_diphone_ids[right].append(idx)  # p is target: X→p\n",
    "\n",
    "    return phoneme_to_diphone_ids\n",
    "\n",
    "\n",
    "def convert_phoneme_seq_to_diphone_ids(batch_phone_seq, idsToPhonemes, diphone_to_idx):\n",
    "    \"\"\"\n",
    "    Converts batch of phoneme index sequences to DCoND-style diphone ID sequences.\n",
    "    \n",
    "    Args:\n",
    "        batch_phone_seq: (B, T) tensor of phoneme indices\n",
    "        idsToPhonemes: callable to convert phoneme ids to phoneme strings\n",
    "        diphone_to_idx: dict mapping 'A→B' to int\n",
    "    \n",
    "    Returns:\n",
    "        diphone_seqs: list of 1D tensors with diphone IDs\n",
    "        diphone_lengths: list of ints (lengths per sequence)\n",
    "    \"\"\"\n",
    "    diphone_seqs = []\n",
    "    diphone_lengths = []\n",
    "\n",
    "    batch_size = batch_phone_seq.shape[0]\n",
    "\n",
    "    for i in range(batch_size):\n",
    "        phoneme_ids = batch_phone_seq[i].cpu().numpy()\n",
    "        phonemes = idsToPhonemes(phoneme_ids)\n",
    "        phonemes = [p for p in phonemes if p != \"<pad>\" and p != \"\"]\n",
    "\n",
    "        diphones = []\n",
    "        for j in range(len(phonemes)):\n",
    "            if j > 0:\n",
    "                diphones.append(f\"{phonemes[j-1]}→{phonemes[j]}\")\n",
    "            diphones.append(f\"{phonemes[j]}→{phonemes[j]}\")\n",
    "            if j < len(phonemes) - 1:\n",
    "                diphones.append(f\"{phonemes[j]}→{phonemes[j+1]}\")\n",
    "\n",
    "        # Convert to indices\n",
    "        diphone_ids = [diphone_to_idx[d] for d in diphones if d in diphone_to_idx]\n",
    "        diphone_ids_tensor = torch.tensor(diphone_ids, dtype=torch.int32)\n",
    "\n",
    "        # Deduplicate consecutive transitions\n",
    "        diphone_ids_tensor = torch.unique_consecutive(diphone_ids_tensor)\n",
    "\n",
    "        diphone_seqs.append(diphone_ids_tensor)\n",
    "        diphone_lengths.append(len(diphone_ids_tensor))\n",
    "\n",
    "    return diphone_seqs, diphone_lengths\n",
    "\n",
    "\n",
    "def idsToDiphonemes(id_sequence, idx_to_diphone):\n",
    "    \"\"\"\n",
    "    Converts a sequence of diphone IDs to a list of diphone strings.\n",
    "    \n",
    "    Args:\n",
    "        id_sequence (Union[Tensor, np.ndarray, List[int]]): 1D sequence of diphone IDs.\n",
    "        idx_to_diphone (dict): Mapping from index to diphone string.\n",
    "    \n",
    "    Returns:\n",
    "        List[str]: List of diphone strings like ['SIL→H', 'H→H', 'H→OW', ...]\n",
    "    \"\"\"\n",
    "    if isinstance(id_sequence, torch.Tensor):\n",
    "        id_sequence = id_sequence.cpu().numpy()\n",
    "    \n",
    "    diphones = []\n",
    "    for i in id_sequence:\n",
    "        if i == 0:\n",
    "            continue  # skip blank/padding\n",
    "        diphone = idx_to_diphone.get(i, \"<unk>\")\n",
    "        diphones.append(diphone)\n",
    "    \n",
    "    return diphones\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cb4c24d",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "075b78a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from augmentations import GaussianSmoothing\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pytorch_lightning as pl\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from edit_distance import SequenceMatcher\n",
    "import math\n",
    "from transformers import AutoProcessor, ClapModel, AutoModel, AutoTokenizer\n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3863a145",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LightningGRUDecoder_Diphonemes_DualHead(pl.LightningModule):\n",
    "    def __init__(\n",
    "        self,\n",
    "        neural_dim,\n",
    "        n_classes,\n",
    "        hidden_dim,\n",
    "        layer_dim,\n",
    "        nDays=24,\n",
    "        dropout=0.1,\n",
    "        strideLen=4,\n",
    "        kernelLen=14,\n",
    "        gaussianSmoothWidth=0,\n",
    "        bidirectional=False,\n",
    "        learning_rate=1e-3,\n",
    "        white_noise_SD=0.01,\n",
    "        constant_offset_SD=0.01,\n",
    "        weight_decay=1e-5,\n",
    "        mfcc_dim = 14,\n",
    "        mfcc_loss_weight = 1.,\n",
    "        channels = None,\n",
    "        phone_def_sil = PHONE_DEF_SIL,\n",
    "        total_steps = 100000,\n",
    "        alpha = 0.5,\n",
    "\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.total_steps = total_steps\n",
    "        self.layer_dim = layer_dim\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.neural_dim = neural_dim\n",
    "        self.n_classes = n_classes\n",
    "        self.nDays = nDays\n",
    "        self.strideLen = strideLen\n",
    "        self.kernelLen = kernelLen\n",
    "        self.gaussianSmoothWidth = gaussianSmoothWidth\n",
    "        self.bidirectional = bidirectional\n",
    "        self.learning_rate = learning_rate\n",
    "        self.white_noise_SD = white_noise_SD\n",
    "        self.constant_offset_SD = constant_offset_SD\n",
    "        self.weight_decay = weight_decay\n",
    "        self.mfcc_loss_weight = mfcc_loss_weight\n",
    "        self.channels = channels\n",
    "        self.PHONE_DEF_SIL = phone_def_sil\n",
    "        self.alpha = alpha\n",
    "\n",
    "        if channels is None:\n",
    "            self.channels = np.arange(0, neural_dim)\n",
    "\n",
    "        print(\"Resetting neural_dim based on channels\")\n",
    "        self.neural_dim = len(self.channels)\n",
    "        neural_dim = self.neural_dim\n",
    "        print(\"neural_dim\", neural_dim, self.neural_dim)\n",
    "\n",
    "        self.phoneme_to_diphone_ids = build_phoneme_to_diphone_ids(PHONE_DEF_SIL, diphone_to_idx)\n",
    "\n",
    "        for ph, ids in self.phoneme_to_diphone_ids.items():\n",
    "            self.phoneme_to_diphone_ids[ph] = sorted(set(ids))\n",
    "\n",
    "        self.inputLayerNonlinearity = nn.Softsign()\n",
    "        self.unfolder = nn.Unfold((self.kernelLen, 1), dilation=1, padding=0, stride=self.strideLen)\n",
    "        self.mfcc_unfolder = nn.Unfold((self.strideLen, 1), dilation=1, padding=0, stride=self.strideLen)\n",
    "        \n",
    "        self.gaussianSmoother = GaussianSmoothing(neural_dim, 20, self.gaussianSmoothWidth, dim=1)\n",
    "\n",
    "        # Per-day transformation weights\n",
    "        self.dayWeights = nn.Parameter(torch.randn(nDays, neural_dim, neural_dim))\n",
    "        self.dayBias = nn.Parameter(torch.zeros(nDays, 1, neural_dim))\n",
    "\n",
    "        for x in range(nDays):\n",
    "            self.dayWeights.data[x, :, :] = torch.eye(neural_dim)\n",
    "\n",
    "        # GRU layer\n",
    "        self.gru_decoder = nn.GRU(\n",
    "            (neural_dim) * self.kernelLen,\n",
    "            hidden_dim,\n",
    "            layer_dim,\n",
    "            batch_first=True,\n",
    "            dropout=dropout,\n",
    "            bidirectional=bidirectional,\n",
    "        )\n",
    "\n",
    "        for name, param in self.gru_decoder.named_parameters():\n",
    "            if \"weight_hh\" in name:\n",
    "                nn.init.orthogonal_(param)\n",
    "            if \"weight_ih\" in name:\n",
    "                nn.init.xavier_uniform_(param)\n",
    "\n",
    "        # Fully connected output layer\n",
    "        self.fc_decoder_diphone_out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, n_classes**2 + 1)  # +1 for CTC blank\n",
    "        \n",
    "        self.fc_decoder_out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, n_classes + 1)  # +1 for CTC blank\n",
    "        \n",
    "        self.mfcc_decoder = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, mfcc_dim*self.strideLen) \n",
    "        \n",
    "        \n",
    "        # Loss function\n",
    "        self.ctc_loss = nn.CTCLoss(blank=0, reduction=\"mean\", zero_infinity=True)\n",
    "        self.diphone_ctc_loss = nn.CTCLoss(blank=0, reduction=\"mean\", zero_infinity=True)\n",
    "        self.l1oss = nn.L1Loss()\n",
    "\n",
    "    def get_neural_embedding(self, neuralInput, dayIdx):\n",
    "        \"\"\"\n",
    "        Forward pass of the model.\n",
    "        neuralInput: (batch, time, features)\n",
    "        dayIdx: Session index\n",
    "        \"\"\"\n",
    "\n",
    "        #channel selection\n",
    "        neuralInput = neuralInput[:, :, self.channels].contiguous()\n",
    "        neuralInput = torch.permute(neuralInput, (0, 2, 1))\n",
    "        neuralInput = self.gaussianSmoother(neuralInput)\n",
    "        neuralInput = torch.permute(neuralInput, (0, 2, 1))\n",
    "\n",
    "        # Apply day-specific transformations\n",
    "        dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)\n",
    "        transformedNeural = torch.einsum(\"btd,bdk->btk\", neuralInput, dayWeights) + torch.index_select(self.dayBias, 0, dayIdx)\n",
    "        transformedNeural = self.inputLayerNonlinearity(transformedNeural)\n",
    "\n",
    "        # Apply unfolding (sliding window)\n",
    "        stridedInputs = torch.permute(\n",
    "            self.unfolder(torch.unsqueeze(torch.permute(transformedNeural, (0, 2, 1)), 3)), (0, 2, 1)\n",
    "        )\n",
    "\n",
    "        # Initialize GRU hidden state\n",
    "        h0 = torch.zeros(\n",
    "            self.layer_dim * (2 if self.bidirectional else 1),\n",
    "            transformedNeural.size(0),\n",
    "            self.hidden_dim,\n",
    "            device=self.device\n",
    "        ).requires_grad_()\n",
    "\n",
    "        # Apply GRU\n",
    "        hid, _ = self.gru_decoder(stridedInputs, h0.detach())\n",
    "        return hid\n",
    "    def forward(self, neuralInput, dayIdx):\n",
    "       \n",
    "        hid = self.get_neural_embedding(neuralInput, dayIdx)\n",
    "        # Final output layer\n",
    "        diphone_logits = self.fc_decoder_diphone_out(hid)\n",
    "        phoneme_logits = self.fc_decoder_out(hid)\n",
    "        mfcc_pred = self.mfcc_decoder(hid)\n",
    "        return  diphone_logits,phoneme_logits, mfcc_pred\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        \"\"\"\n",
    "        Forward-pass one minibatch, compute diphone-CTC + phoneme-CTC\n",
    "        (optionally MFCC L1) and return the total loss for back-prop.\n",
    "        \"\"\"\n",
    "\n",
    "        # ─── unpack & move to device ──────────────────────────────────────────\n",
    "        X          = batch[\"neural_feats\"].to(self.device)          # (B,T,Fraw)\n",
    "        y          = batch[\"phone_seq\"].to(self.device)             # (B, Sph)\n",
    "        X_len      = batch[\"neural_time_bins\"].to(self.device)      # (B,)\n",
    "        y_len      = batch[\"phone_seq_len\"].to(self.device)         # (B,)\n",
    "        dayIdx     = batch[\"day\"].to(self.device)                   # (B,)\n",
    "        MFCC_list  = batch[\"mfcc\"]                                  # list(np.ndarray)\n",
    "\n",
    "        # ─── optional MFCC target preprocessing ──────────────────────────────\n",
    "        MFCC = pad_sequence([torch.tensor(m) for m in MFCC_list],\n",
    "                            batch_first=True).to(self.device)       # (B,Tm,Dim)\n",
    "        MFCC = torch.permute(self.mfcc_unfolder(\n",
    "                torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)\n",
    "            ), (0, 2, 1))                                        # (B,T',Dim·stride)\n",
    "\n",
    "        # ─── simple noise augmentation ───────────────────────────────────────\n",
    "        if self.white_noise_SD > 0:\n",
    "            X = X + torch.randn_like(X) * self.white_noise_SD\n",
    "        if self.constant_offset_SD > 0:\n",
    "            X = X + torch.randn(X.size(0), 1, X.size(2),\n",
    "                                device=self.device) * self.constant_offset_SD\n",
    "\n",
    "        # ─── forward through encoder/decoder ─────────────────────────────────\n",
    "        diphone_logits,phoneme_logits, mfcc_pred = self.forward(X, dayIdx)         # (B,T',D) & (B,T',Dim·stride)\n",
    "\n",
    "        # ─── build target diphone sequences ─────────────────────────────────\n",
    "        diphone_seqs, diphone_len = convert_phoneme_seq_to_diphone_ids(\n",
    "            y, idsToPhonemes, diphone_to_idx\n",
    "        )\n",
    "        y_diphone   = pad_sequence(diphone_seqs, batch_first=True, padding_value=0) \\\n",
    "                    .to(self.device)                              # (B,Sdi)\n",
    "        diphone_len = torch.tensor(diphone_len,\n",
    "                                dtype=torch.int32,\n",
    "                                device=self.device)              # (B,)\n",
    "\n",
    "        # ─── input lengths after unfolding (integer division!) ───────────────\n",
    "        input_lengths = ((X_len - self.kernelLen) // self.strideLen) \\\n",
    "                        .to(torch.int32)                            # (B,)\n",
    "\n",
    "        # ─── CTC loss on diphone logits ──────────────────────────────────────\n",
    "        diphone_logp = diphone_logits.log_softmax(-1)               # (B,T',D)\n",
    "        loss_diphone = self.diphone_ctc_loss(\n",
    "            diphone_logp.permute(1, 0, 2),  # (T',B,D)\n",
    "            y_diphone,\n",
    "            input_lengths,\n",
    "            diphone_len\n",
    "        )\n",
    "\n",
    "        # ─── diphone → phoneme marginalisation in log-space ─────────────────\n",
    "        B, T, _ = diphone_logp.shape\n",
    "        P = len(self.PHONE_DEF_SIL)\n",
    "\n",
    "\n",
    "\n",
    "        phoneme_logp = diphone_logp.new_full((B, T, P), -float(\"inf\"))\n",
    "\n",
    "        for p_idx, ph in enumerate(self.PHONE_DEF_SIL):\n",
    "            ids = self.phoneme_to_diphone_ids[ph]      # must be a **disjoint** list\n",
    "            assert len(set(ids)) == len(ids)           # sanity check\n",
    "            if ids:\n",
    "                logsumexp = torch.logsumexp(diphone_logp[:, :, ids], dim=-1)\n",
    "                phoneme_logp[:, :, p_idx] = logsumexp\n",
    "            else:\n",
    "                # optionally set to very low log-prob\n",
    "                phoneme_logp[:, :, p_idx] = -1e8\n",
    "\n",
    "        # prepend the *same* blank that lives at diphone index 0\n",
    "        blank_logp = diphone_logp[:, :, 0:1]\n",
    "        phoneme_logp = torch.cat([blank_logp, phoneme_logp], dim=-1)\n",
    "        \n",
    "        \n",
    "        # with torch.no_grad():\n",
    "        #     if torch.isnan(phoneme_logp).any():\n",
    "        #         print(\"NaN detected in phoneme_logp\")\n",
    "\n",
    "        # ─── CTC loss on phoneme distribution ───────────────────────────────\n",
    "\n",
    "        # phoneme_logp = torch.logsumexp(torch.stack([phoneme_logits.log_softmax(-1), phoneme_logp], dim=0),dim=0)        \n",
    "        phoneme_logp = phoneme_logits.log_softmax(-1) + phoneme_logp\n",
    "        loss_phoneme = self.ctc_loss(\n",
    "            phoneme_logp.permute(1, 0, 2),  # (T',B,P+1)\n",
    "            y,\n",
    "            input_lengths,\n",
    "            y_len\n",
    "        )\n",
    "\n",
    "        # ─── optional MFCC reconstruction loss (commented out) ──────────────\n",
    "        min_len = min(MFCC.size(1), mfcc_pred.size(1))\n",
    "        l1_mfcc = self.l1oss(mfcc_pred[:, :min_len], MFCC[:, :min_len])\n",
    "        mfcc_term = self.mfcc_loss_weight * l1_mfcc\n",
    "        # --------------------------------------------------------------------\n",
    "\n",
    "        # ─── dynamic α schedule (clamped to >=0) ────────────────────────────\n",
    "        # alpha = max(0.6, 1.0 - self.global_step / self.total_steps)\n",
    "\n",
    "        loss = self.alpha * loss_diphone + (1 - self.alpha) * loss_phoneme   + mfcc_term\n",
    "\n",
    "        # ─── logging ────────────────────────────────────────────────────────\n",
    "        self.log_dict({\n",
    "            \"train_loss\":          loss,\n",
    "            \"train_loss_diphone\":  loss_diphone,\n",
    "            \"train_loss_phoneme\":  loss_phoneme,\n",
    "            # \"train_l1_mfcc\":     l1_mfcc if use_mfcc else torch.tensor(0.)\n",
    "        }, prog_bar=True, on_step=True, on_epoch=True)\n",
    "\n",
    "        return loss\n",
    "\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        \"\"\"\n",
    "        Run one minibatch in eval mode, compute the diphone-CTC +\n",
    "        phoneme-CTC losses and the phoneme error-rate (CER).\n",
    "        \"\"\"\n",
    "        # ─── unpack & move to device ─────────────────────────────────────────\n",
    "        X          = batch[\"neural_feats\"].to(self.device)          # (B,T,Fraw)\n",
    "        y          = batch[\"phone_seq\"].to(self.device)             # (B, Sph)\n",
    "        X_len      = batch[\"neural_time_bins\"].to(self.device)      # (B,)\n",
    "        y_len      = batch[\"phone_seq_len\"].to(self.device)         # (B,)\n",
    "        dayIdx     = batch[\"day\"].to(self.device)                   # (B,)\n",
    "        MFCC_list  = batch[\"mfcc\"]                                  # list(np.ndarray)\n",
    "\n",
    "        # ─── (optional) MFCC target pre-processing ──────────────────────────\n",
    "        MFCC = pad_sequence([torch.tensor(m) for m in MFCC_list],\n",
    "                            batch_first=True).to(self.device)       # (B,Tm,Dim)\n",
    "        MFCC = torch.permute(self.mfcc_unfolder(\n",
    "                torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)\n",
    "            ), (0, 2, 1))                                        # (B,T',Dim·stride)\n",
    "\n",
    "                # ─── forward through encoder/decoder ─────────────────────────────────\n",
    "        diphone_logits,phoneme_logits, mfcc_pred = self.forward(X, dayIdx)         # (B,T',D) & (B,T',Dim·stride)\n",
    "\n",
    "        # ─── build target diphone sequences ─────────────────────────────────\n",
    "        diphone_seqs, diphone_len = convert_phoneme_seq_to_diphone_ids(\n",
    "            y, idsToPhonemes, diphone_to_idx\n",
    "        )\n",
    "        y_diphone   = pad_sequence(diphone_seqs, batch_first=True, padding_value=0) \\\n",
    "                    .to(self.device)                              # (B,Sdi)\n",
    "        diphone_len = torch.tensor(diphone_len,\n",
    "                                dtype=torch.int32,\n",
    "                                device=self.device)              # (B,)\n",
    "\n",
    "        # ─── input lengths after unfolding (integer division!) ───────────────\n",
    "        input_lengths = ((X_len - self.kernelLen) // self.strideLen) \\\n",
    "                        .to(torch.int32)                            # (B,)\n",
    "\n",
    "        # ─── CTC loss on diphone logits ──────────────────────────────────────\n",
    "        diphone_logp = diphone_logits.log_softmax(-1)               # (B,T',D)\n",
    "        loss_diphone = self.diphone_ctc_loss(\n",
    "            diphone_logp.permute(1, 0, 2),  # (T',B,D)\n",
    "            y_diphone,\n",
    "            input_lengths,\n",
    "            diphone_len\n",
    "        )\n",
    "\n",
    "        # ─── diphone → phoneme marginalisation in log-space ─────────────────\n",
    "        B, T, _ = diphone_logp.shape\n",
    "        P = len(self.PHONE_DEF_SIL)\n",
    "\n",
    "\n",
    "\n",
    "        phoneme_logp = diphone_logp.new_full((B, T, P), -float(\"inf\"))\n",
    "\n",
    "        for p_idx, ph in enumerate(self.PHONE_DEF_SIL):\n",
    "            ids = self.phoneme_to_diphone_ids[ph]      # must be a **disjoint** list\n",
    "            assert len(set(ids)) == len(ids)           # sanity check\n",
    "            if ids:\n",
    "                logsumexp = torch.logsumexp(diphone_logp[:, :, ids], dim=-1)\n",
    "                phoneme_logp[:, :, p_idx] = logsumexp\n",
    "            else:\n",
    "                # optionally set to very low log-prob\n",
    "                phoneme_logp[:, :, p_idx] = -1e8\n",
    "\n",
    "        # prepend the *same* blank that lives at diphone index 0\n",
    "        blank_logp = diphone_logp[:, :, 0:1]\n",
    "        phoneme_logp = torch.cat([blank_logp, phoneme_logp], dim=-1)\n",
    "        \n",
    "        # with torch.no_grad():\n",
    "        #     if torch.isnan(phoneme_logp).any():\n",
    "        #         print(\"NaN detected in phoneme_logp\")\n",
    "\n",
    "        # ─── CTC loss on phoneme distribution ───────────────────────────────\n",
    "        phoneme_logp =phoneme_logits.log_softmax(-1) + phoneme_logp\n",
    "        # phoneme_logp = torch.logsumexp(\n",
    "        #     torch.stack([phoneme_logits.log_softmax(-1), phoneme_logp], dim=0),\n",
    "        #     dim=0\n",
    "        # )\n",
    "\n",
    "        loss_phoneme = self.ctc_loss(\n",
    "            phoneme_logp.permute(1, 0, 2),  # (T',B,P+1)\n",
    "            y,\n",
    "            input_lengths,\n",
    "            y_len\n",
    "        )\n",
    "\n",
    "        # ─── optional MFCC reconstruction loss (commented out) ──────────────\n",
    "        min_len = min(MFCC.size(1), mfcc_pred.size(1))\n",
    "        l1_mfcc = self.l1oss(mfcc_pred[:, :min_len], MFCC[:, :min_len])\n",
    "        mfcc_term = self.mfcc_loss_weight * l1_mfcc\n",
    "        # --------------------------------------------------------------------\n",
    "\n",
    "        # ─── dynamic α schedule (clamped to >=0) ────────────────────────────\n",
    "        # alpha = max(0.6, 1.0 - self.global_step / self.total_steps)\n",
    "\n",
    "        loss = self.alpha * loss_diphone + (1 - self.alpha) * loss_phoneme   + mfcc_term\n",
    "\n",
    "\n",
    "        # ─── phoneme error-rate (CER) ──────────────────────────────────────\n",
    "        ### CHANGED: decode from log-probs (argmax over last dim)\n",
    "        decoded = phoneme_logp.argmax(-1)                           # (B,T')\n",
    "        total_edits, total_len = 0, 0\n",
    "\n",
    "        for i in range(decoded.size(0)):\n",
    "            seq = decoded[i, : input_lengths[i]].cpu()\n",
    "            seq = torch.unique_consecutive(seq)        # CTC collapse\n",
    "            seq = seq[seq != 0].tolist()               # remove blank\n",
    "\n",
    "            target = y[i, : y_len[i]].cpu().tolist()\n",
    "            total_edits += SequenceMatcher(a=target, b=seq).distance()\n",
    "            total_len   += len(target)\n",
    "\n",
    "        cer = total_edits / total_len if total_len else 1.0\n",
    "\n",
    "        # ─── logging ────────────────────────────────────────────────────────\n",
    "        self.log_dict({\n",
    "            \"val_loss\":          loss,\n",
    "            \"val_loss_diphone\":  loss_diphone,\n",
    "            \"val_loss_phoneme\":  loss_phoneme,\n",
    "            \"val_CER\":           cer,\n",
    "        }, prog_bar=True, on_step=False, on_epoch=True)\n",
    "\n",
    "        ### NEW: return loss so Lightning can aggregate (optional)\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        \"\"\"\n",
    "        Configures the optimizer and learning rate scheduler.\n",
    "        \"\"\"\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, betas=(0.9, 0.999),\n",
    "                                      eps=1e-8)\n",
    "\n",
    "        return optimizer\n",
    "\n",
    "        # scheduler = ReduceLROnPlateau(optimizer, mode=\"min\", factor=0.5, patience=2)\n",
    "        # return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler, \"monitor\": \"val_loss\"}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a94f832d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4140"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nInputFeatures = 256 #channels \n",
    "nClasses = 40 \n",
    "dropout = 0.4 \n",
    "hidden_dim = 1024\n",
    "nlayers = 5\n",
    "stride_len = 4\n",
    "kernel_len = 14 #best 32\n",
    "gaussian_smooth_width = 2\n",
    "bidirectional = True\n",
    "\n",
    "white_noise_SD = 0.8\n",
    "constant_offset_SD = 0.2\n",
    "seq_len = 150\n",
    "max_time_series_len = 12000\n",
    "\n",
    "lr_start = 5e-5\n",
    "lr_end = 5e-5\n",
    "l2_decay = 1e-5\n",
    "\n",
    "\n",
    "warmup_epoch = 5\n",
    "steps_per_epoch = len(train_loader)\n",
    "warmup_steps = warmup_epoch * steps_per_epoch\n",
    "\n",
    "target_epoch = 30\n",
    "total_steps = target_epoch * steps_per_epoch\n",
    "total_steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "42eece3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_name = \"gru_ctc_diphones_dualhead_marginalization\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2d072a00",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Resetting neural_dim based on channels\n",
      "neural_dim 256 256\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "You are using a CUDA device ('NVIDIA H100 80GB HBM3') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mXXXXXXXXXXXXXX\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.19.4"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20250612_082039-nwkhw644</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/nwkhw644' target=\"_blank\">gru_ctc_diphones_dualhead_marginalization</a></strong> to <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset' target=\"_blank\">https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/nwkhw644' target=\"_blank\">https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/nwkhw644</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6]\n",
      "\n",
      "   | Name                   | Type              | Params | Mode \n",
      "----------------------------------------------------------------------\n",
      "0  | inputLayerNonlinearity | Softsign          | 0      | train\n",
      "1  | unfolder               | Unfold            | 0      | train\n",
      "2  | mfcc_unfolder          | Unfold            | 0      | train\n",
      "3  | gaussianSmoother       | GaussianSmoothing | 0      | train\n",
      "4  | gru_decoder            | GRU               | 103 M  | train\n",
      "5  | fc_decoder_diphone_out | Linear            | 3.3 M  | train\n",
      "6  | fc_decoder_out         | Linear            | 84.0 K | train\n",
      "7  | mfcc_decoder           | Linear            | 114 K  | train\n",
      "8  | ctc_loss               | CTCLoss           | 0      | train\n",
      "9  | diphone_ctc_loss       | CTCLoss           | 0      | train\n",
      "10 | l1oss                  | L1Loss            | 0      | train\n",
      "   | other params           | n/a               | 1.6 M  | n/a  \n",
      "----------------------------------------------------------------------\n",
      "108 M     Trainable params\n",
      "0         Non-trainable params\n",
      "108 M     Total params\n",
      "435.715   Total estimated model params size (MB)\n",
      "11        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "be2f665d3240403c9b774074b8c3281c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n",
      "/data/XXXXXX/speech_decoding_BCI/augmentations.py:91: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at ../aten/src/ATen/native/Convolution.cpp:1036.)\n",
      "  return self.conv(input, weight=self.weight, groups=self.groups, padding=\"same\")\n",
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 64. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab39e6e78f954568ab6824cae81dbcb9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 32. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4193b0ca282f4179b59e21296c92e1ca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 48. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
      "Metric val_CER improved. New best score: 0.963\n",
      "Epoch 0, global step 138: 'val_loss' reached 6.46010 (best 6.46010), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "18938df01e8342d1ae7d56ecd408682d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.016 >= min_delta = 0.0. New best score: 0.947\n",
      "Epoch 1, global step 276: 'val_loss' reached 6.36846 (best 6.36846), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62ab87322d1b41dcb45f84a00e419c7a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.004 >= min_delta = 0.0. New best score: 0.943\n",
      "Epoch 2, global step 414: 'val_loss' reached 6.20939 (best 6.20939), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b988b269f8b4ee49defd8547a306a15",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.007 >= min_delta = 0.0. New best score: 0.936\n",
      "Epoch 3, global step 552: 'val_loss' reached 6.07583 (best 6.07583), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b2550c3f5868491e8d85848b04a49707",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.107 >= min_delta = 0.0. New best score: 0.829\n",
      "Epoch 4, global step 690: 'val_loss' reached 5.67275 (best 5.67275), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8f20585966564b608163451f2fb90694",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.309 >= min_delta = 0.0. New best score: 0.520\n",
      "Epoch 5, global step 828: 'val_loss' reached 4.62504 (best 4.62504), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "49b2d666b0d34722b769ccad1dec1bdc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.062 >= min_delta = 0.0. New best score: 0.458\n",
      "Epoch 6, global step 966: 'val_loss' reached 4.09092 (best 4.09092), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "15735f456c2c4d03aaa05611e6b54090",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.048 >= min_delta = 0.0. New best score: 0.410\n",
      "Epoch 7, global step 1104: 'val_loss' reached 3.78622 (best 3.78622), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "919da8cc604443139d3a2cb409a93578",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.034 >= min_delta = 0.0. New best score: 0.376\n",
      "Epoch 8, global step 1242: 'val_loss' reached 3.54748 (best 3.54748), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c3256e9862af4185ae8a9d908b47b9bc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.026 >= min_delta = 0.0. New best score: 0.351\n",
      "Epoch 9, global step 1380: 'val_loss' reached 3.35237 (best 3.35237), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "801ea967ebc1442d9aa0c4a84cc19adc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.021 >= min_delta = 0.0. New best score: 0.330\n",
      "Epoch 10, global step 1518: 'val_loss' reached 3.22029 (best 3.22029), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1182b2b578d24738898b13e672da8463",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.016 >= min_delta = 0.0. New best score: 0.314\n",
      "Epoch 11, global step 1656: 'val_loss' reached 3.11441 (best 3.11441), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "95c2c9ff49b74a87943091d6ee0d07fa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.016 >= min_delta = 0.0. New best score: 0.298\n",
      "Epoch 12, global step 1794: 'val_loss' reached 2.97819 (best 2.97819), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d416b323b71a4bd7b2271296e4df5651",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.009 >= min_delta = 0.0. New best score: 0.289\n",
      "Epoch 13, global step 1932: 'val_loss' reached 2.90998 (best 2.90998), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "60053b13e1f6489a819f1930d4319ddf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.012 >= min_delta = 0.0. New best score: 0.278\n",
      "Epoch 14, global step 2070: 'val_loss' reached 2.81851 (best 2.81851), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2aa3f6cafbf44f0cb3baec6378f00e61",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.008 >= min_delta = 0.0. New best score: 0.270\n",
      "Epoch 15, global step 2208: 'val_loss' reached 2.75254 (best 2.75254), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "57f683f7f3514c1fb514a4e78523d235",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.008 >= min_delta = 0.0. New best score: 0.262\n",
      "Epoch 16, global step 2346: 'val_loss' reached 2.69527 (best 2.69527), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c7cfa444155549da907bfb4c0e3f43ce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.007 >= min_delta = 0.0. New best score: 0.255\n",
      "Epoch 17, global step 2484: 'val_loss' reached 2.63909 (best 2.63909), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86ae38753c5b4d8b83b28ff8885b0826",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.002 >= min_delta = 0.0. New best score: 0.254\n",
      "Epoch 18, global step 2622: 'val_loss' reached 2.61060 (best 2.61060), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0f3889e9b38647f1ad4931a2f0363cfb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.003 >= min_delta = 0.0. New best score: 0.250\n",
      "Epoch 19, global step 2760: 'val_loss' reached 2.58946 (best 2.58946), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0c894c3f3d2b430a89a187f022a523e8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.010 >= min_delta = 0.0. New best score: 0.241\n",
      "Epoch 20, global step 2898: 'val_loss' reached 2.52325 (best 2.52325), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "172e47941b3e4fb1abd0ab6bf4259e1b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.006 >= min_delta = 0.0. New best score: 0.235\n",
      "Epoch 21, global step 3036: 'val_loss' reached 2.49132 (best 2.49132), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59930740e4e143408079f2f7180a193a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22, global step 3174: 'val_loss' reached 2.47053 (best 2.47053), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a4fe6189bb784349a99d52f2f7e2dbec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.002 >= min_delta = 0.0. New best score: 0.233\n",
      "Epoch 23, global step 3312: 'val_loss' reached 2.45437 (best 2.45437), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "659c32a27d1b4f7f8239d931bf361153",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.005 >= min_delta = 0.0. New best score: 0.228\n",
      "Epoch 24, global step 3450: 'val_loss' reached 2.41186 (best 2.41186), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "22e8e719bb7b43d8b977f32749fd94f2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25, global step 3588: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cf71e4bfc90a42cba76e679c1b813211",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.002 >= min_delta = 0.0. New best score: 0.226\n",
      "Epoch 26, global step 3726: 'val_loss' reached 2.40709 (best 2.40709), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fa9d54380a5a4ad3a062c90897a9ef0e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.003 >= min_delta = 0.0. New best score: 0.223\n",
      "Epoch 27, global step 3864: 'val_loss' reached 2.39942 (best 2.39942), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "043109f751804a0aa8088ab627b66ecf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.006 >= min_delta = 0.0. New best score: 0.217\n",
      "Epoch 28, global step 4002: 'val_loss' reached 2.37339 (best 2.37339), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ad27e4fc18914aea85368048e1dc14c2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.001 >= min_delta = 0.0. New best score: 0.216\n",
      "Epoch 29, global step 4140: 'val_loss' reached 2.34337 (best 2.34337), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "241daa2bccf74f09a82a36cb5557a7e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30, global step 4278: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e233b67540d64d33b3c8441042c5301a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.004 >= min_delta = 0.0. New best score: 0.213\n",
      "Epoch 31, global step 4416: 'val_loss' reached 2.33041 (best 2.33041), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d8f1f73557af41f08d517f9b4bfcb7fd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.004 >= min_delta = 0.0. New best score: 0.209\n",
      "Epoch 32, global step 4554: 'val_loss' reached 2.31314 (best 2.31314), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b3538e3036cc443994783a14d1372f8c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.003 >= min_delta = 0.0. New best score: 0.206\n",
      "Epoch 33, global step 4692: 'val_loss' reached 2.30895 (best 2.30895), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ca5e105db17749e19f32950d682237bc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34, global step 4830: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bdc5c0c877a943a6a1fc949205537fca",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35, global step 4968: 'val_loss' reached 2.30630 (best 2.30630), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "39de0dd2298c4ec49147dca773f14b4f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.004 >= min_delta = 0.0. New best score: 0.203\n",
      "Epoch 36, global step 5106: 'val_loss' reached 2.28115 (best 2.28115), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5c273de99c3d4224921c4506f1504fb2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37, global step 5244: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7fa1e4ca4032447bbaec12ca3bdb91bd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.005 >= min_delta = 0.0. New best score: 0.198\n",
      "Epoch 38, global step 5382: 'val_loss' reached 2.27145 (best 2.27145), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1ca537b4e3734c9a8644adac27a732c2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39, global step 5520: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ee15ed4e125646308a279937ecc78d8f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40, global step 5658: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5b250e1d37f34206b4cc560c0000f425",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.002 >= min_delta = 0.0. New best score: 0.196\n",
      "Epoch 41, global step 5796: 'val_loss' reached 2.25919 (best 2.25919), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_diphones_dualhead_marginalization/best_model-v3.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "127d2369a7984ffd98ff19e8d86f1686",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42, global step 5934: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "48582afd6119466fbee6930b781a88ff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43, global step 6072: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd74127e024f4945a837d63e5625a6c5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44, global step 6210: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "065c7b3822904e7fb3449c58e0de19d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45, global step 6348: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d8c8c52267534e43bc6f991a2d576081",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.002 >= min_delta = 0.0. New best score: 0.193\n",
      "Epoch 46, global step 6486: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "258ed5227dab4536a50c0bf069b9545d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47, global step 6624: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c5d7b26645334300a296775d83c0a668",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.001 >= min_delta = 0.0. New best score: 0.192\n",
      "Epoch 48, global step 6762: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2c7bd258b32548e898a86c84c9769716",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49, global step 6900: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1719c61121634ee494f109d052cecd22",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.000 >= min_delta = 0.0. New best score: 0.192\n",
      "Epoch 50, global step 7038: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "985cb269d0754fa0962a38b06947ff8e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 51, global step 7176: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f407415c8de405daaf1c12ce03b668a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 52, global step 7314: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "190cb1accd78499b863d29c98c793b02",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.005 >= min_delta = 0.0. New best score: 0.186\n",
      "Epoch 53, global step 7452: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86fcc62e43c54ef892770e9911093ce7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 54, global step 7590: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e8e91deb587f44108b4394a53f01e829",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 55, global step 7728: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "34e3bf4c55414ac79585a3bed4684aa0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 56, global step 7866: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2ff893f7378d418eba0d6b6bae653ad3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 57, global step 8004: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a64860d929414606b6a3ac374415c669",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 58, global step 8142: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "103e621e893b46aa9989552178f79068",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 59, global step 8280: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c915f9d1f1924e879f2c4d2b546b3ca0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 60, global step 8418: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c9203128cf76424887e6a5d16d0cda32",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Monitored metric val_CER did not improve in the last 8 records. Best score: 0.186. Signaling Trainer to stop.\n",
      "Epoch 61, global step 8556: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<br>    <style><br>        .wandb-row {<br>            display: flex;<br>            flex-direction: row;<br>            flex-wrap: wrap;<br>            justify-content: flex-start;<br>            width: 100%;<br>        }<br>        .wandb-col {<br>            display: flex;<br>            flex-direction: column;<br>            flex-basis: 100%;<br>            flex: 1;<br>            padding: 10px;<br>        }<br>    </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇████</td></tr><tr><td>train_loss_diphone_epoch</td><td>█▇▇▆▅▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>train_loss_diphone_step</td><td>████▇▆▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>train_loss_epoch</td><td>█▇▇▆▅▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>train_loss_phoneme_epoch</td><td>█▇▇▇▆▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>train_loss_phoneme_step</td><td>███▅▅▄▄▃▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>train_loss_step</td><td>█▇▇▇▇▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>trainer/global_step</td><td>▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇▇▇▇▇██</td></tr><tr><td>val_CER</td><td>████▇▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_loss</td><td>███▇▇▄▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_loss_diphone</td><td>██▇▅▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_loss_phoneme</td><td>██▇▇▅▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>61</td></tr><tr><td>train_loss_diphone_epoch</td><td>0.32442</td></tr><tr><td>train_loss_diphone_step</td><td>0.26821</td></tr><tr><td>train_loss_epoch</td><td>0.86218</td></tr><tr><td>train_loss_phoneme_epoch</td><td>0.4</td></tr><tr><td>train_loss_phoneme_step</td><td>0.3625</td></tr><tr><td>train_loss_step</td><td>0.77999</td></tr><tr><td>trainer/global_step</td><td>8555</td></tr><tr><td>val_CER</td><td>0.18896</td></tr><tr><td>val_loss</td><td>2.34982</td></tr><tr><td>val_loss_diphone</td><td>1.52291</td></tr><tr><td>val_loss_phoneme</td><td>1.86027</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">gru_ctc_diphones_dualhead_marginalization</strong> at: <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/nwkhw644' target=\"_blank\">https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/nwkhw644</a><br> View project at: <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset' target=\"_blank\">https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset</a><br>Synced 8 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20250612_082039-nwkhw644/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "TRAIN = True\n",
    "\n",
    "\n",
    "model = LightningGRUDecoder_Diphonemes_DualHead(\n",
    "            neural_dim=nInputFeatures,\n",
    "            n_classes=nClasses,\n",
    "            hidden_dim=hidden_dim,\n",
    "            layer_dim=nlayers,\n",
    "            strideLen=stride_len,\n",
    "            kernelLen=kernel_len,\n",
    "            gaussianSmoothWidth=gaussian_smooth_width,\n",
    "            bidirectional=bidirectional,\n",
    "            dropout=dropout,\n",
    "            white_noise_SD=white_noise_SD,\n",
    "            constant_offset_SD=constant_offset_SD,\n",
    "            weight_decay=l2_decay,\n",
    "            learning_rate=lr_start,\n",
    "            total_steps=total_steps,)\n",
    "if TRAIN:\n",
    "\n",
    "    wandb_logger = WandbLogger(project=\"ECOG_Sentence_dataset\", name=f\"{output_name}\",\n",
    "                                reinit=True)\n",
    "\n",
    "    # Define ModelCheckpoint to save the best model based on validation loss\n",
    "    checkpoint_callback = ModelCheckpoint(\n",
    "        monitor=\"val_loss\",  # Ensure your validation step logs \"val_loss\"\n",
    "        mode=\"min\",          # Save the model with the lowest validation loss\n",
    "        save_top_k=1,        # Keep only the best model\n",
    "        dirpath=f\".checkpoints/{output_name}/\",  # Directory to save checkpoints\n",
    "        filename=f\"best_model\",  # Model filename\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "    # Define EarlyStopping callback with patience of 3 epochs\n",
    "    early_stopping_callback = EarlyStopping(\n",
    "        monitor=\"val_CER\",\n",
    "        patience=8,   # Stop training if no improvement in 3 epochs\n",
    "        mode=\"min\",\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "\n",
    "    # Train model\n",
    "    trainer = pl.Trainer(max_epochs=100,devices =[1], callbacks=[checkpoint_callback, early_stopping_callback], logger=wandb_logger)\n",
    "\n",
    "    trainer.fit(model, train_loader, test_loader)\n",
    "\n",
    "\n",
    "\n",
    "    # close wandb logger\n",
    "    wandb.finish()\n",
    "\n",
    "else:\n",
    "    #reload state_dict of best model\n",
    "    model.load_state_dict(torch.load(f\".checkpoints/{output_name}/best_model.ckpt\")[\"state_dict\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "21a234cd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LightningGRUDecoder_Diphonemes_DualHead(\n",
       "  (inputLayerNonlinearity): Softsign()\n",
       "  (unfolder): Unfold(kernel_size=(14, 1), dilation=1, padding=0, stride=4)\n",
       "  (mfcc_unfolder): Unfold(kernel_size=(4, 1), dilation=1, padding=0, stride=4)\n",
       "  (gaussianSmoother): GaussianSmoothing()\n",
       "  (gru_decoder): GRU(3584, 1024, num_layers=5, batch_first=True, dropout=0.4, bidirectional=True)\n",
       "  (fc_decoder_diphone_out): Linear(in_features=2048, out_features=1601, bias=True)\n",
       "  (fc_decoder_out): Linear(in_features=2048, out_features=41, bias=True)\n",
       "  (mfcc_decoder): Linear(in_features=2048, out_features=56, bias=True)\n",
       "  (ctc_loss): CTCLoss()\n",
       "  (diphone_ctc_loss): CTCLoss()\n",
       "  (l1oss): L1Loss()\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = \"cuda:1\"\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "73ab00d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_ctc_output(logits):\n",
    "    \"\"\"\n",
    "    Converts model logits to predicted phoneme sequences.\n",
    "    - Removes repeated phonemes.\n",
    "    - Removes blank tokens (0).\n",
    "    \"\"\"\n",
    "\n",
    "    predictions = torch.argmax(logits, dim=-1)  # Get most probable phoneme indices\n",
    "    predictions = [torch.unique_consecutive(seq[seq != 0]).cpu().numpy() for seq in predictions]  # Remove blanks\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bc5d35c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 14/14 [00:14<00:00,  1.07s/it]\n"
     ]
    }
   ],
   "source": [
    "## predit all teh test set \n",
    "pred_phonemes = []\n",
    "pred_logits = []\n",
    "true_phonemes = []\n",
    "true_sentences = []\n",
    "day_indices = []\n",
    "cer_list = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm.tqdm(test_loader):\n",
    "        X = batch[\"neural_feats\"]\n",
    "        y = batch[\"phone_seq\"]\n",
    "        X_len = batch[\"neural_time_bins\"]\n",
    "        y_len = batch[\"phone_seq_len\"]\n",
    "        days = batch[\"day\"]\n",
    "        transcriptions = batch[\"sentence\"]\n",
    "        \n",
    "        # Move data to device\n",
    "        X = X.to(device)\n",
    "        y = y.to(device)\n",
    "\n",
    "        days = days.to(device)\n",
    "        X_len = X_len.to(device)\n",
    "        y_len = y_len.to(device)\n",
    "\n",
    "        diphone_logits,pred, mfcc_pred  = model(X,days)\n",
    "        # pred = torch.nn.functional.log_softmax(logits, dim=-1).cpu()\n",
    "        # print(logits.shape)\n",
    "        # decoded = decoder(pred)\n",
    "        pred_logits.append(pred)\n",
    "\n",
    "        total_edit_distance, total_seq_length = 0, 0\n",
    "\n",
    "        for i in range(pred.shape[0]):\n",
    "            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / model.strideLen), :], dim=-1)\n",
    "            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "\n",
    "            trueSeq = y[i][:y_len[i]].cpu().numpy()\n",
    "            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "            total_edit_distance += matcher.distance()\n",
    "            total_seq_length += len(trueSeq)\n",
    "\n",
    "            cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0\n",
    "            cer_list.append(cer)\n",
    "            \n",
    "        pp = decode_ctc_output(pred)\n",
    "\n",
    "        pred_phonemes.extend(pp)\n",
    "        true_phonemes.extend([y[i][:y_len[i]].cpu().numpy() for i in range(len(y))])\n",
    "        # true_phonemes.extend(y.cpu().numpy())\n",
    "        true_sentences.extend(transcriptions)\n",
    "        day_indices.extend(days.cpu().numpy())\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d56e425e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(0.20568881245144693)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(cer_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "93227fc1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted Phonemes: ['R', 'EY', 'S', 'SIL', 'P', 'AA', 'R', 'CH', 'IH', 'T', 'SIL', 'S', 'EH', 'V', 'ER', 'SIL', 'S', 'AE', 'N', 'SIL', 'G', 'L', 'S', 'ER', 'G', 'EH', 'F', 'SIL']\n",
      "True Phonemes: ['R', 'IH', 'CH', 'SIL', 'P', 'ER', 'CH', 'AH', 'S', 'T', 'SIL', 'S', 'EH', 'V', 'R', 'AH', 'L', 'SIL', 'S', 'AY', 'N', 'D', 'SIL', 'L', 'IH', 'TH', 'AH', 'G', 'R', 'AE', 'F', 'S', 'SIL']\n",
      "True Sentence: Rich purchased several signed lithographs.\n"
     ]
    }
   ],
   "source": [
    "idx = 1\n",
    "print(\"Predicted Phonemes:\", idsToPhonemes(pred_phonemes[idx]))\n",
    "print(\"True Phonemes:\", idsToPhonemes(true_phonemes[idx]))\n",
    "print(\"True Sentence:\", true_sentences[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "eb7c4eae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(preds, targets):\n",
    "    \n",
    "\n",
    "    accs= []\n",
    "    for pred, target in zip(preds, targets):\n",
    "        \n",
    "        #truncate to the length of the shortest sequence\n",
    "        min_len = min(len(pred), len(target))\n",
    "\n",
    "\n",
    "        pred = pred[:min_len]\n",
    "        target = target[:min_len]\n",
    "\n",
    "        equal_inference = (pred == target)\n",
    "        acc = np.sum(equal_inference)/ len(pred)\n",
    "        accs.append(acc)\n",
    "\n",
    "    return np.mean(accs)\n",
    "   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7c0fa6d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "overall_acc 0.5017722954302081\n",
      "Range of Accuracy per day 0.16475021683035201 0.6336412711404553\n",
      "Range of CER per day 0.15161924023880138 0.3892295557431734\n",
      "Average lenght diff: -0.029545454545454545 +- 2.2154498358587613\n",
      "Range of diff lenghts per day: -2.1 - 1.0\n"
     ]
    }
   ],
   "source": [
    "overall_acc = compute_accuracy(pred_phonemes, true_phonemes)\n",
    "print(\"overall_acc\", overall_acc)\n",
    "# day_indices\n",
    "day_indices_flat = day_indices\n",
    "\n",
    "#compute accuracy per day by selecting indices of the same day\n",
    "day_accs = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    acc = compute_accuracy([pred_phonemes[idx] for idx in indices], [true_phonemes[idx] for idx in indices])\n",
    "    day_accs.append(acc)\n",
    "\n",
    "day_accs\n",
    "print(\"Range of Accuracy per day\", min(day_accs), max(day_accs))\n",
    "cer_list_per_day = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    cer_list_per_day.append(np.array(cer_list)[indices].mean())\n",
    "\n",
    "cer_list_per_day\n",
    "\n",
    "print(\"Range of CER per day\", min(cer_list_per_day), max(cer_list_per_day))\n",
    "diffs = []\n",
    "for i in range(len(true_phonemes)):\n",
    "    true = true_phonemes[i]\n",
    "    pred = pred_phonemes[i]\n",
    "    diffs.append(np.array(len(true)) - np.array(len(pred)))\n",
    "\n",
    "print(f\"Average lenght diff: {np.mean(diffs)} +- {np.std(diffs)}\")\n",
    "# true_phonemes[0]\n",
    "## compute diff lenghts per day\n",
    "diffs_per_day = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    diffs_per_day.append(np.array(diffs)[indices].mean())\n",
    "\n",
    "diffs_per_day\n",
    "print(f\"Range of diff lenghts per day: {min(diffs_per_day)} - {max(diffs_per_day)}\")\n",
    "import os\n",
    "\n",
    "results_dir = f\"results/{output_name}/\"\n",
    "os.makedirs(results_dir, exist_ok=True)\n",
    "\n",
    "#create a dataframe with the results\n",
    "df = pd.DataFrame({\n",
    "    'True Phonemes': [idsToPhonemes(p) for p in true_phonemes],\n",
    "    'Predicted Phonemes': [idsToPhonemes(p) for p in pred_phonemes],\n",
    "    'True Sentence': true_sentences,\n",
    "    'Day Index': day_indices_flat,\n",
    "    'CER': cer_list\n",
    "})\n",
    "\n",
    "#save it \n",
    "df.to_csv(os.path.join(results_dir, \"results.csv\"), index=False)\n",
    "df.sort_values(by=[\"CER\"], ascending=True).iloc[-5:]\n",
    "#create a dataframe with all the metrics\n",
    "df_metrics = pd.DataFrame({\n",
    "    'Overall Accuracy': [overall_acc],\n",
    "    'Range of Accuracy per day': f\"{[min(day_accs), max(day_accs)]}\",\n",
    "    'Range of CER per day': f\"{[min(cer_list_per_day), max(cer_list_per_day)]}\",\n",
    "    'Average length diff': f\"{[np.mean(diffs), np.std(diffs)]}\",\n",
    "    'Range of length diff per day': f\"{[min(diffs_per_day), max(diffs_per_day)]}\"\n",
    "})\n",
    "df_metrics.to_csv(os.path.join(results_dir,\"metrics.csv\"), index=False)\n",
    "## save pred_logits\n",
    "with open(os.path.join(results_dir, \"pred_logits.pkl\"), \"wb\") as f:\n",
    "    pickle.dump(pred_logits, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
