{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f1013317",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package cmudict to /home/XXXXXX/nltk_data...\n",
      "[nltk_data]   Package cmudict is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "from dataset import SpeechSentenceDataset, idsToPhonemes, getDatasetLoaders,getDatasetLoaders_V3, PHONE_DEF, PHONE_DEF_SIL\n",
    "import re \n",
    "from g2p_en import G2p\n",
    "import numpy as np\n",
    "from model.ctc_modelling import LightningGRUDecoder\n",
    "import time\n",
    "import numpy as np\n",
    "from edit_distance import SequenceMatcher\n",
    "import tqdm\n",
    "import pytorch_lightning as pl\n",
    "import jiwer\n",
    "import nltk\n",
    "from nltk.corpus import cmudict\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "import wandb\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import copy\n",
    "from difflib import get_close_matches\n",
    "from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer\n",
    "import pandas as pd\n",
    "from torchaudio.models.decoder import ctc_decoder\n",
    "import string\n",
    "from config import DATASET_SM_ROBUST, DATASET_SM_ZSCORE, DATASET_FULL_TRIALS_ZSCORE\n",
    "from model.ctc_modelling import LightningGRUDecoder_MFCC_v3\n",
    "# from model.ctc_modelling import Light\n",
    "import os\n",
    "\n",
    "# Download CMU Pronouncing Dictionary (First-time use)\n",
    "nltk.download(\"cmudict\")\n",
    "\n",
    "# Load CMUdict\n",
    "cmu_dict = cmudict.dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "57b661f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "channels = pickle.load(open(\"encoding/significant_channels.pkl\", \"rb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "174a06b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# channels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c9719402",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of trials:  10020\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  8800\n",
      "Number of trials:  880\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  880\n"
     ]
    }
   ],
   "source": [
    "train_loader, test_loader,_, loadedData = getDatasetLoaders_V3(DATASET_FULL_TRIALS_ZSCORE, 64, include_prego=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e90284c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# idx  = 10\n",
    "# print(\"mfcc shape\", batch[\"mfcc\"][idx].shape)\n",
    "# print(\"neural_time_bins\",batch[\"neural_time_bins\"][idx])\n",
    "# print(\"neural_feats shape\", batch[\"neural_feats\"][idx].shape)\n",
    "# print(\"go onset\", batch[\"go_onset\"][idx])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "8dff0426",
   "metadata": {},
   "outputs": [],
   "source": [
    "# batch[\"neural_feats\"][idx][392]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "97cd89af",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import matplotlib.pyplot as plt\n",
    "# plt.imshow(batch[\"mfcc\"][idx].T, aspect='auto', origin='lower')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a2dc1e06",
   "metadata": {},
   "outputs": [],
   "source": [
    "nInputFeatures = 256 #channels \n",
    "nClasses = 40 \n",
    "dropout = 0.4 \n",
    "hidden_dim = 1024\n",
    "nlayers = 5\n",
    "stride_len = 4\n",
    "kernel_len =32\n",
    "gaussian_smooth_width = 2\n",
    "bidirectional = True\n",
    "\n",
    "white_noise_SD = 0.8\n",
    "constant_offset_SD = 0.2\n",
    "seq_len = 150\n",
    "max_time_series_len = 12000\n",
    "\n",
    "lr_start = 1e-4\n",
    "lr_end = 1e-5\n",
    "l2_decay = 1e-5\n",
    "\n",
    "\n",
    "warmup_epoch = 5\n",
    "steps_per_epoch = len(train_loader)\n",
    "warmup_steps = warmup_epoch * steps_per_epoch\n",
    "\n",
    "target_epoch = 60\n",
    "total_steps = target_epoch * steps_per_epoch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "3a8cf659",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_name = \"mfcc_sm_gru_ctc_channels\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b83a12fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Resetting neural_dim based on channels\n",
      "neural_dim 158 158\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "You are using a CUDA device ('NVIDIA H100 80GB HBM3') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mXXXXXXXXXXXXXX\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.19.4"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20250514_095916-y7pba5g2</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/y7pba5g2' target=\"_blank\">GRU_CTC_MFCC_channel</a></strong> to <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset' target=\"_blank\">https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/y7pba5g2' target=\"_blank\">https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/y7pba5g2</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]\n",
      "\n",
      "  | Name                   | Type              | Params | Mode \n",
      "---------------------------------------------------------------------\n",
      "0 | inputLayerNonlinearity | Softsign          | 0      | train\n",
      "1 | unfolder               | Unfold            | 0      | train\n",
      "2 | mfcc_unfolder          | Unfold            | 0      | train\n",
      "3 | gaussianSmoother       | GaussianSmoothing | 0      | train\n",
      "4 | gru_decoder            | GRU               | 112 M  | train\n",
      "5 | fc_decoder_out         | Linear            | 84.0 K | train\n",
      "6 | mfcc_decoder           | Linear            | 114 K  | train\n",
      "7 | ctc_loss               | CTCLoss           | 0      | train\n",
      "8 | l1oss                  | L1Loss            | 0      | train\n",
      "  | other params           | n/a               | 602 K  | n/a  \n",
      "---------------------------------------------------------------------\n",
      "113 M     Trainable params\n",
      "0         Non-trainable params\n",
      "113 M     Total params\n",
      "454.864   Total estimated model params size (MB)\n",
      "9         Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "96cd2a8bf29242ccae9829c559ad38bb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n",
      "/data/XXXXXX/speech_decoding_BCI/augmentations.py:91: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at ../aten/src/ATen/native/Convolution.cpp:1036.)\n",
      "  return self.conv(input, weight=self.weight, groups=self.groups, padding=\"same\")\n",
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 64. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9d5b464aa6d04b28af815294e8fad74f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 32. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a9762d2214844cd18dc23e04d259d2f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 48. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
      "Metric val_CER improved. New best score: 0.741\n",
      "Epoch 0, global step 138: 'val_loss' reached 3.34063 (best 3.34063), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2173a5a06025444a82d643cfd6e46308",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.256 >= min_delta = 0.0. New best score: 0.484\n",
      "Epoch 1, global step 276: 'val_loss' reached 2.38310 (best 2.38310), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "040787cbde7a4423bc01e3f6dd92f8ac",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.065 >= min_delta = 0.0. New best score: 0.419\n",
      "Epoch 2, global step 414: 'val_loss' reached 2.05939 (best 2.05939), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cd774e2951594839b9965e9ed1dbe091",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.048 >= min_delta = 0.0. New best score: 0.372\n",
      "Epoch 3, global step 552: 'val_loss' reached 1.85241 (best 1.85241), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "58a78f027f0a45d69af51813a231b277",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.030 >= min_delta = 0.0. New best score: 0.341\n",
      "Epoch 4, global step 690: 'val_loss' reached 1.70977 (best 1.70977), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "079e08d2aacc4856a686487d2239cde9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.024 >= min_delta = 0.0. New best score: 0.317\n",
      "Epoch 5, global step 828: 'val_loss' reached 1.61653 (best 1.61653), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "401f8f93946c4ba9b6df53737bc7b1ea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.016 >= min_delta = 0.0. New best score: 0.301\n",
      "Epoch 6, global step 966: 'val_loss' reached 1.54142 (best 1.54142), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6d7c12cc005e4bf38d57d4423747ea63",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.019 >= min_delta = 0.0. New best score: 0.282\n",
      "Epoch 7, global step 1104: 'val_loss' reached 1.48349 (best 1.48349), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c95c0c1bfea14903a6b4c6989877c6ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.014 >= min_delta = 0.0. New best score: 0.268\n",
      "Epoch 8, global step 1242: 'val_loss' reached 1.42047 (best 1.42047), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "126f154a9cc949838a8ae5b057acae2f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.008 >= min_delta = 0.0. New best score: 0.260\n",
      "Epoch 9, global step 1380: 'val_loss' reached 1.37993 (best 1.37993), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd07647fff0147f985a7cf3b7df7e78a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.018 >= min_delta = 0.0. New best score: 0.241\n",
      "Epoch 10, global step 1518: 'val_loss' reached 1.32238 (best 1.32238), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d426641d7fa74e0589d98c37f47cfc7a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.002 >= min_delta = 0.0. New best score: 0.239\n",
      "Epoch 11, global step 1656: 'val_loss' reached 1.30159 (best 1.30159), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ea04aaa7c06b475683b2f4c78572aa1e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.007 >= min_delta = 0.0. New best score: 0.232\n",
      "Epoch 12, global step 1794: 'val_loss' reached 1.28447 (best 1.28447), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b11fdd04e7744ee1b517c646865f66a3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.007 >= min_delta = 0.0. New best score: 0.225\n",
      "Epoch 13, global step 1932: 'val_loss' reached 1.25799 (best 1.25799), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bdff8b2c26ad4a1195af1329288bb37d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.006 >= min_delta = 0.0. New best score: 0.220\n",
      "Epoch 14, global step 2070: 'val_loss' reached 1.24241 (best 1.24241), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c74da3a6e96449d181c83ddf103ee0a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.005 >= min_delta = 0.0. New best score: 0.214\n",
      "Epoch 15, global step 2208: 'val_loss' reached 1.22790 (best 1.22790), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "da2a4bd7a30f40cd87c9888ed6427214",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.004 >= min_delta = 0.0. New best score: 0.210\n",
      "Epoch 16, global step 2346: 'val_loss' reached 1.22343 (best 1.22343), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3c8f327f897141d4a3743f728228b6ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17, global step 2484: 'val_loss' reached 1.20909 (best 1.20909), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "80335b86a4e34dd9877a58ba55472308",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.006 >= min_delta = 0.0. New best score: 0.204\n",
      "Epoch 18, global step 2622: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "27e40743365349d5ad885b3ca227988e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.003 >= min_delta = 0.0. New best score: 0.201\n",
      "Epoch 19, global step 2760: 'val_loss' reached 1.19364 (best 1.19364), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9281e18f2c2f47058b340fd03416497f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.004 >= min_delta = 0.0. New best score: 0.197\n",
      "Epoch 20, global step 2898: 'val_loss' reached 1.17709 (best 1.17709), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "924eccd0810a48978f0b0dd742c6b918",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.001 >= min_delta = 0.0. New best score: 0.196\n",
      "Epoch 21, global step 3036: 'val_loss' reached 1.17616 (best 1.17616), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "41d03646f1ef4d0abcacbe1c50842971",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22, global step 3174: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b9a57e4fb52466d90db6f71df24c2f1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.002 >= min_delta = 0.0. New best score: 0.194\n",
      "Epoch 23, global step 3312: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c7ea882fc25b4055a76ab81a0430035e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.006 >= min_delta = 0.0. New best score: 0.189\n",
      "Epoch 24, global step 3450: 'val_loss' reached 1.16626 (best 1.16626), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "30ad6d5ead7e498e9fcd253b2e7306ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25, global step 3588: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb58f88c5ea54072bd4270e13a74b5fc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.003 >= min_delta = 0.0. New best score: 0.185\n",
      "Epoch 26, global step 3726: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6b8457000670418f854f715fa5d1f680",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27, global step 3864: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cd56943704204443ae9e7ac723cc7311",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.006 >= min_delta = 0.0. New best score: 0.179\n",
      "Epoch 28, global step 4002: 'val_loss' reached 1.14541 (best 1.14541), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c37cbdf685af4d78a509ad5ac9d7722a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29, global step 4140: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b940840a531642c986e34c3c5d832c57",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.001 >= min_delta = 0.0. New best score: 0.179\n",
      "Epoch 30, global step 4278: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f0dca1f0e68e4033be84e2171d4d741e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.003 >= min_delta = 0.0. New best score: 0.175\n",
      "Epoch 31, global step 4416: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f412a40e46804882b7c4a6cac3d04d2b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.002 >= min_delta = 0.0. New best score: 0.174\n",
      "Epoch 32, global step 4554: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c79b4b09c8bc496caea18dc24c7f6fe0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.000 >= min_delta = 0.0. New best score: 0.173\n",
      "Epoch 33, global step 4692: 'val_loss' reached 1.14503 (best 1.14503), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/mfcc_sm_gru_ctc_channels/best_model-v1.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b87d0ea77a24f51a11eff472c83d2b6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.002 >= min_delta = 0.0. New best score: 0.172\n",
      "Epoch 34, global step 4830: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6dbb76fcaef44220bb1bb9939f02596d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35, global step 4968: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5d04a83ac856493ca6980c0a3e0d646a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36, global step 5106: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fa6d0a298fe745188b94118fb5746e8c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.000 >= min_delta = 0.0. New best score: 0.172\n",
      "Epoch 37, global step 5244: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "16a7f1d700e04c8a8df26783be2a29d2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.000 >= min_delta = 0.0. New best score: 0.171\n",
      "Epoch 38, global step 5382: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "255910bf78f8435aa6f785cadef06265",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.002 >= min_delta = 0.0. New best score: 0.170\n",
      "Epoch 39, global step 5520: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "47470ed1161a41a29dc3ca9840d68a67",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40, global step 5658: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a8738413b3d64169aede1fe5031a3615",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.001 >= min_delta = 0.0. New best score: 0.169\n",
      "Epoch 41, global step 5796: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "36ae8b8b4e4c4dd79c8bb6d0d9fdd749",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42, global step 5934: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "39d5174e4356451cb5e2d18be5c852ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43, global step 6072: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9c47caee886341a0b333320af47fceea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.001 >= min_delta = 0.0. New best score: 0.168\n",
      "Epoch 44, global step 6210: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9e27de24dd874c43a3ecdbd40cedefb4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45, global step 6348: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9ed8987c451d47c597a51361eb447add",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46, global step 6486: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c5d56c5d05754068bab1104f340515ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47, global step 6624: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "db7e4869d9d542dfbbccd2d333942195",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48, global step 6762: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2fbde9c1f9244108b01975f9fb912832",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Metric val_CER improved by 0.000 >= min_delta = 0.0. New best score: 0.168\n",
      "Epoch 49, global step 6900: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e75ec73d05fc43228156b23946ac4f81",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50, global step 7038: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "55ceec80af124abd8f4b6b0397531ff2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 51, global step 7176: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "553b70d55f3f4271848e5970546369d1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 52, global step 7314: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a232470c378c4ed9b06b96127facaab0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 53, global step 7452: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "71c588294c2b432a8ef55ba39c431d60",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Monitored metric val_CER did not improve in the last 5 records. Best score: 0.168. Signaling Trainer to stop.\n",
      "Epoch 54, global step 7590: 'val_loss' was not in top 1\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<br>    <style><br>        .wandb-row {<br>            display: flex;<br>            flex-direction: row;<br>            flex-wrap: wrap;<br>            justify-content: flex-start;<br>            width: 100%;<br>        }<br>        .wandb-col {<br>            display: flex;<br>            flex-direction: column;<br>            flex-basis: 100%;<br>            flex: 1;<br>            padding: 10px;<br>        }<br>    </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▁▂▂▂▂▂▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇▇██</td></tr><tr><td>train_loss_epoch</td><td>█▆▆▅▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>train_loss_step</td><td>█▇▅▄▄▃▃▂▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>trainer/global_step</td><td>▁▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇██</td></tr><tr><td>val_CER</td><td>█▅▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_loss</td><td>█▅▃▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>54</td></tr><tr><td>train_loss_epoch</td><td>0.43856</td></tr><tr><td>train_loss_step</td><td>0.48641</td></tr><tr><td>trainer/global_step</td><td>7589</td></tr><tr><td>val_CER</td><td>0.16843</td></tr><tr><td>val_loss</td><td>1.14797</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">GRU_CTC_MFCC_channel</strong> at: <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/y7pba5g2' target=\"_blank\">https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/y7pba5g2</a><br> View project at: <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset' target=\"_blank\">https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset</a><br>Synced 8 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20250514_095916-y7pba5g2/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "model = LightningGRUDecoder_MFCC_v3(\n",
    "            neural_dim=nInputFeatures,\n",
    "            n_classes=nClasses,\n",
    "            hidden_dim=hidden_dim,\n",
    "            layer_dim=nlayers,\n",
    "            strideLen=stride_len,\n",
    "            kernelLen=kernel_len,\n",
    "            gaussianSmoothWidth=gaussian_smooth_width,\n",
    "            bidirectional=bidirectional,\n",
    "            dropout=dropout,\n",
    "            white_noise_SD=white_noise_SD,\n",
    "            constant_offset_SD=constant_offset_SD,\n",
    "            weight_decay=l2_decay,\n",
    "            learning_rate=lr_start,\n",
    "            channels=channels[\"encoding_mfcc\"],)\n",
    "\n",
    "wandb_logger = WandbLogger(project=\"ECOG_Sentence_dataset\", name=f\"GRU_CTC_MFCC_channel\",\n",
    "                            reinit=True)\n",
    "\n",
    "# Define ModelCheckpoint to save the best model based on validation loss\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    monitor=\"val_loss\",  # Ensure your validation step logs \"val_loss\"\n",
    "    mode=\"min\",          # Save the model with the lowest validation loss\n",
    "    save_top_k=1,        # Keep only the best model\n",
    "    dirpath=f\".checkpoints/{output_name}/\",  # Directory to save checkpoints\n",
    "    filename=f\"best_model\",  # Model filename\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "# Define EarlyStopping callback with patience of 3 epochs\n",
    "early_stopping_callback = EarlyStopping(\n",
    "    monitor=\"val_CER\",\n",
    "    patience=5,   # Stop training if no improvement in 3 epochs\n",
    "    mode=\"min\",\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "\n",
    "# Train model\n",
    "trainer = pl.Trainer(max_epochs=60,devices =[0], callbacks=[checkpoint_callback, early_stopping_callback], logger=wandb_logger)\n",
    "\n",
    "trainer.fit(model, train_loader, test_loader)\n",
    "\n",
    "#reload state_dict of best model\n",
    "# model.load_state_dict(torch.load(f\".checkpoints/mfcc_sm_gru_ctc/best_model.ckpt\")[\"state_dict\"])\n",
    "\n",
    "\n",
    "# close wandb logger\n",
    "wandb.finish()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6ddd9bd0",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "a43de8ab",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LightningGRUDecoder_MFCC_v3(\n",
       "  (inputLayerNonlinearity): Softsign()\n",
       "  (unfolder): Unfold(kernel_size=(32, 1), dilation=1, padding=0, stride=4)\n",
       "  (mfcc_unfolder): Unfold(kernel_size=(4, 1), dilation=1, padding=0, stride=4)\n",
       "  (gaussianSmoother): GaussianSmoothing()\n",
       "  (gru_decoder): GRU(5056, 1024, num_layers=5, batch_first=True, dropout=0.4, bidirectional=True)\n",
       "  (fc_decoder_out): Linear(in_features=2048, out_features=41, bias=True)\n",
       "  (mfcc_decoder): Linear(in_features=2048, out_features=56, bias=True)\n",
       "  (ctc_loss): CTCLoss()\n",
       "  (l1oss): L1Loss()\n",
       ")"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = \"cuda:1\"\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "15e3dd71",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = [\"<blank>\"] + PHONE_DEF + [\" \"]\n",
    "decoder = ctc_decoder(tokens= tokens,   \n",
    "                      lexicon=None,  \n",
    "                      blank_token = '<blank>', \n",
    "                      sil_token = ' ',\n",
    "                      )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "55953441",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_ctc_output(logits):\n",
    "    \"\"\"\n",
    "    Converts model logits to predicted phoneme sequences.\n",
    "    - Removes repeated phonemes.\n",
    "    - Removes blank tokens (0).\n",
    "    \"\"\"\n",
    "\n",
    "    predictions = torch.argmax(logits, dim=-1)  # Get most probable phoneme indices\n",
    "    predictions = [torch.unique_consecutive(seq[seq != 0]).cpu().numpy() for seq in predictions]  # Remove blanks\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "68244d09",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 14/14 [00:40<00:00,  2.87s/it]\n"
     ]
    }
   ],
   "source": [
    "## predit all teh test set \n",
    "pred_phonemes = []\n",
    "pred_logits = []\n",
    "true_phonemes = []\n",
    "true_sentences = []\n",
    "day_indices = []\n",
    "cer_list = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm.tqdm(test_loader):\n",
    "        X = batch[\"neural_feats\"]\n",
    "        y = batch[\"phone_seq\"]\n",
    "        X_len = batch[\"neural_time_bins\"]\n",
    "        y_len = batch[\"phone_seq_len\"]\n",
    "        days = batch[\"day\"]\n",
    "        transcriptions = batch[\"sentence\"]\n",
    "        \n",
    "        # Move data to device\n",
    "        X = X.to(device)\n",
    "        y = y.to(device)\n",
    "\n",
    "        days = days.to(device)\n",
    "        X_len = X_len.to(device)\n",
    "        y_len = y_len.to(device)\n",
    "\n",
    "        logits, _ = model(X,days)\n",
    "        pred = torch.nn.functional.log_softmax(logits, dim=-1).cpu()\n",
    "        decoded = decoder(pred)\n",
    "        pred_logits.append(pred)\n",
    "\n",
    "        total_edit_distance, total_seq_length = 0, 0\n",
    "\n",
    "        for i in range(pred.shape[0]):\n",
    "            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / model.strideLen), :], dim=-1)\n",
    "            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "\n",
    "            trueSeq = y[i][:y_len[i]].cpu().numpy()\n",
    "            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "            total_edit_distance += matcher.distance()\n",
    "            total_seq_length += len(trueSeq)\n",
    "\n",
    "            cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0\n",
    "            cer_list.append(cer)\n",
    "            \n",
    "        pp = decode_ctc_output(pred)\n",
    "\n",
    "        pred_phonemes.extend(pp)\n",
    "        true_phonemes.extend([y[i][:y_len[i]].cpu().numpy() for i in range(len(y))])\n",
    "        # true_phonemes.extend(y.cpu().numpy())\n",
    "        true_sentences.extend(transcriptions)\n",
    "        day_indices.extend(days.cpu().numpy())\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "2e6dab3b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([20, 28, 11, 29, 20, 29, 40, 28, 17, 20,  3, 23, 29, 17, 31,  3,  9,\n",
       "       40])"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_phonemes[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "80bd7295",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.17828339792200779"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(cer_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "5e3f4b1a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted Phonemes: ['D', 'UW', 'SIL', 'Y', 'UW', 'SIL', 'HH', 'AE', 'V', 'SIL', 'Y', 'AO', 'R', 'SIL', 'B', 'AE', 'G', 'SIL']\n",
      "True Phonemes: ['D', 'UW', 'SIL', 'Y', 'UW', 'SIL', 'HH', 'AE', 'V', 'SIL', 'Y', 'AO', 'R', 'SIL', 'B', 'AE', 'G', 'SIL']\n",
      "True Sentence: Do you have your bag?\n"
     ]
    }
   ],
   "source": [
    "idx = 121\n",
    "print(\"Predicted Phonemes:\", idsToPhonemes(pred_phonemes[idx]))\n",
    "print(\"True Phonemes:\", idsToPhonemes(true_phonemes[idx]))\n",
    "print(\"True Sentence:\", true_sentences[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "d8788dc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(preds, targets):\n",
    "    \n",
    "\n",
    "    accs= []\n",
    "    for pred, target in zip(preds, targets):\n",
    "        \n",
    "        #truncate to the length of the shortest sequence\n",
    "        min_len = min(len(pred), len(target))\n",
    "\n",
    "\n",
    "        pred = pred[:min_len]\n",
    "        target = target[:min_len]\n",
    "\n",
    "        equal_inference = (pred == target)\n",
    "        acc = np.sum(equal_inference)/ len(pred)\n",
    "        accs.append(acc)\n",
    "\n",
    "    return np.mean(accs)\n",
    "   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "faafbebe",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "overall_acc 0.5582719614801291\n"
     ]
    }
   ],
   "source": [
    "overall_acc = compute_accuracy(pred_phonemes, true_phonemes)\n",
    "print(\"overall_acc\", overall_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "41787ddb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# day_indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "e965deaf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range of Accuracy per day 0.2394975658501473 0.6905553906343884\n"
     ]
    }
   ],
   "source": [
    "day_indices_flat = day_indices\n",
    "\n",
    "#compute accuracy per day by selecting indices of the same day\n",
    "day_accs = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    acc = compute_accuracy([pred_phonemes[idx] for idx in indices], [true_phonemes[idx] for idx in indices])\n",
    "    day_accs.append(acc)\n",
    "\n",
    "day_accs\n",
    "print(\"Range of Accuracy per day\", min(day_accs), max(day_accs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "996916a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range of CER per day 0.1276270310550197 0.371347870431733\n"
     ]
    }
   ],
   "source": [
    "cer_list_per_day = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    cer_list_per_day.append(np.array(cer_list)[indices].mean())\n",
    "\n",
    "cer_list_per_day\n",
    "\n",
    "print(\"Range of CER per day\", min(cer_list_per_day), max(cer_list_per_day))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "2b2acef8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average lenght diff: 0.4863636363636364 +- 1.767714357464683\n"
     ]
    }
   ],
   "source": [
    "diffs = []\n",
    "for i in range(len(true_phonemes)):\n",
    "    true = true_phonemes[i]\n",
    "    pred = pred_phonemes[i]\n",
    "    diffs.append(np.array(len(true)) - np.array(len(pred)))\n",
    "\n",
    "print(f\"Average lenght diff: {np.mean(diffs)} +- {np.std(diffs)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "15f1b044",
   "metadata": {},
   "outputs": [],
   "source": [
    "# true_phonemes[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "c5db68ce",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range of diff lenghts per day: -0.125 - 2.2\n"
     ]
    }
   ],
   "source": [
    "## compute diff lenghts per day\n",
    "diffs_per_day = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    diffs_per_day.append(np.array(diffs)[indices].mean())\n",
    "\n",
    "diffs_per_day\n",
    "print(f\"Range of diff lenghts per day: {min(diffs_per_day)} - {max(diffs_per_day)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "dc0b80c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "results_dir = f\"results/{output_name}/\"\n",
    "os.makedirs(results_dir, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "738a217d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#create a dataframe with the results\n",
    "df = pd.DataFrame({\n",
    "    'True Phonemes': [idsToPhonemes(p) for p in true_phonemes],\n",
    "    'Predicted Phonemes': [idsToPhonemes(p) for p in pred_phonemes],\n",
    "    'True Sentence': true_sentences,\n",
    "    'Day Index': day_indices_flat,\n",
    "    'CER': cer_list\n",
    "})\n",
    "\n",
    "#save it \n",
    "df.to_csv(os.path.join(results_dir, \"results.csv\"), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "3553ea32",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>True Phonemes</th>\n",
       "      <th>Predicted Phonemes</th>\n",
       "      <th>True Sentence</th>\n",
       "      <th>Day Index</th>\n",
       "      <th>CER</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>257</th>\n",
       "      <td>[SH, AH, K, AA, G, OW, SIL, AH, N, D, SIL, F, ...</td>\n",
       "      <td>[R, IY, K, AA, K, ER, Z, SIL, AH, N, D, SIL, F...</td>\n",
       "      <td>Chicago and Philadelphia.</td>\n",
       "      <td>8</td>\n",
       "      <td>0.406250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>[W, AY, L, D, F, AY, ER, SIL, N, IH, R, SIL, S...</td>\n",
       "      <td>[W, AY, L, D, F, AO, R, SIL, N, UW, SIL, K, AY...</td>\n",
       "      <td>Wildfire near Sunshine forces park closures.</td>\n",
       "      <td>0</td>\n",
       "      <td>0.406452</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[R, IH, CH, SIL, P, ER, CH, AH, S, T, SIL, S, ...</td>\n",
       "      <td>[EY, K, S, SIL, P, ER, R, CH, AH, N, S, SIL, E...</td>\n",
       "      <td>Rich purchased several signed lithographs.</td>\n",
       "      <td>0</td>\n",
       "      <td>0.433962</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>384</th>\n",
       "      <td>[K, L, IH, K, SIL, HH, IY, R, SIL, T, UW, SIL,...</td>\n",
       "      <td>[L, UH, K, SIL, HH, IY, R, SIL, T, UW, SIL, D,...</td>\n",
       "      <td>Click here to join freelancer.</td>\n",
       "      <td>11</td>\n",
       "      <td>0.440000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[TH, IY, AA, K, R, AH, S, IY, SIL, R, IY, K, A...</td>\n",
       "      <td>[K, R, EH, S, K, S, SIL, R, IH, K, AH, N, S, I...</td>\n",
       "      <td>Theocracy reconsidered.</td>\n",
       "      <td>0</td>\n",
       "      <td>0.450000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         True Phonemes  \\\n",
       "257  [SH, AH, K, AA, G, OW, SIL, AH, N, D, SIL, F, ...   \n",
       "7    [W, AY, L, D, F, AY, ER, SIL, N, IH, R, SIL, S...   \n",
       "1    [R, IH, CH, SIL, P, ER, CH, AH, S, T, SIL, S, ...   \n",
       "384  [K, L, IH, K, SIL, HH, IY, R, SIL, T, UW, SIL,...   \n",
       "0    [TH, IY, AA, K, R, AH, S, IY, SIL, R, IY, K, A...   \n",
       "\n",
       "                                    Predicted Phonemes  \\\n",
       "257  [R, IY, K, AA, K, ER, Z, SIL, AH, N, D, SIL, F...   \n",
       "7    [W, AY, L, D, F, AO, R, SIL, N, UW, SIL, K, AY...   \n",
       "1    [EY, K, S, SIL, P, ER, R, CH, AH, N, S, SIL, E...   \n",
       "384  [L, UH, K, SIL, HH, IY, R, SIL, T, UW, SIL, D,...   \n",
       "0    [K, R, EH, S, K, S, SIL, R, IH, K, AH, N, S, I...   \n",
       "\n",
       "                                    True Sentence  Day Index       CER  \n",
       "257                     Chicago and Philadelphia.          8  0.406250  \n",
       "7    Wildfire near Sunshine forces park closures.          0  0.406452  \n",
       "1      Rich purchased several signed lithographs.          0  0.433962  \n",
       "384                Click here to join freelancer.         11  0.440000  \n",
       "0                         Theocracy reconsidered.          0  0.450000  "
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.sort_values(by=[\"CER\"], ascending=True).iloc[-5:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "2ef8f0bb",
   "metadata": {},
   "outputs": [],
   "source": [
    "#create a dataframe with all the metrics\n",
    "df_metrics = pd.DataFrame({\n",
    "    'Overall Accuracy': [overall_acc],\n",
    "    'Range of Accuracy per day': f\"{[min(day_accs), max(day_accs)]}\",\n",
    "    'Range of CER per day': f\"{[min(cer_list_per_day), max(cer_list_per_day)]}\",\n",
    "    'Average length diff': f\"{[np.mean(diffs), np.std(diffs)]}\",\n",
    "    'Range of length diff per day': f\"{[min(diffs_per_day), max(diffs_per_day)]}\"\n",
    "})\n",
    "df_metrics.to_csv(os.path.join(results_dir,\"metrics.csv\"), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "46cc7a8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "## save pred_logits\n",
    "with open(os.path.join(results_dir, \"pred_logits.pkl\"), \"wb\") as f:\n",
    "    pickle.dump(pred_logits, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
