{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "dce49b53",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package cmudict to /home/matteo/nltk_data...\n",
      "[nltk_data]   Package cmudict is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "import sys\n",
    "import os\n",
    "os.environ['CUDA_LAUNCH_BLOCKING'] = '1'\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "import re \n",
    "from g2p_en import G2p\n",
    "import numpy as np\n",
    "from model_training.dataset import getDatasetLoaders\n",
    "from model.ctc_modelling import LightningGRUDecoder, LightningGRUDecoder_V2\n",
    "import time\n",
    "import numpy as np\n",
    "from edit_distance import SequenceMatcher\n",
    "import tqdm\n",
    "import pytorch_lightning as pl\n",
    "import jiwer\n",
    "import nltk\n",
    "from nltk.corpus import cmudict\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "import wandb\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import copy\n",
    "from difflib import get_close_matches\n",
    "from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer\n",
    "import pandas as pd\n",
    "from torchaudio.models.decoder import ctc_decoder\n",
    "import string\n",
    "import os\n",
    "from torch.utils.data import Subset, DataLoader\n",
    "\n",
    "#import seed_everything\n",
    "from pytorch_lightning import seed_everything\n",
    "\n",
    "# from model.ctc_modelling import Light\n",
    "\n",
    "# Download CMU Pronouncing Dictionary (First-time use)\n",
    "nltk.download(\"cmudict\")\n",
    "\n",
    "# Load CMUdict\n",
    "cmu_dict = cmudict.dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5fc4aeb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, val_loader, test_loader = getDatasetLoaders(BATCH_SIZE=32, SHUFFLE_TRAIN=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9c04486b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([14, 14, 14, 14, 14, 14, 14, 14,  1,  1,  1,  1,  1,  1,  1,  1, 33, 33,\n",
       "        33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 34, 34, 34, 34])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch =next(iter(train_loader))\n",
    "\n",
    "batch[\"day\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3f152209",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_loader.dataset.__dict__[\"days\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5f5a30b2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "31de4ae9",
   "metadata": {},
   "outputs": [],
   "source": [
    "nInputFeatures = 512 #channels \n",
    "nClasses = 40 \n",
    "dropout = 0.4 \n",
    "hidden_dim = 1024\n",
    "nlayers = 5\n",
    "stride_len = 4\n",
    "kernel_len = 16\n",
    "gaussian_smooth_width = 2\n",
    "bidirectional = True\n",
    "\n",
    "white_noise_SD = 0.8\n",
    "constant_offset_SD = 0.2\n",
    "seq_len = 150\n",
    "max_time_series_len = 12000\n",
    "\n",
    "lr_start = 1e-4\n",
    "lr_end = 1e-5\n",
    "l2_decay = 1e-5\n",
    "\n",
    "\n",
    "warmup_epoch = 5\n",
    "steps_per_epoch = len(train_loader)\n",
    "warmup_steps = warmup_epoch * steps_per_epoch\n",
    "\n",
    "target_epoch = 60\n",
    "total_steps = target_epoch * steps_per_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "93e50045",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matteo/anaconda3/envs/evo/lib/python3.9/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n"
     ]
    }
   ],
   "source": [
    "model = LightningGRUDecoder_V2(\n",
    "            neural_dim=nInputFeatures,\n",
    "            n_classes=nClasses,\n",
    "            hidden_dim=hidden_dim,\n",
    "            layer_dim=nlayers,\n",
    "            nDays=45,\n",
    "            strideLen=stride_len,\n",
    "            kernelLen=kernel_len,\n",
    "            gaussianSmoothWidth=gaussian_smooth_width,\n",
    "            bidirectional=bidirectional,\n",
    "            dropout=dropout,\n",
    "            white_noise_SD=white_noise_SD,\n",
    "            constant_offset_SD=constant_offset_SD,\n",
    "            weight_decay=l2_decay,\n",
    "            learning_rate=lr_start)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "60bb27ca",
   "metadata": {},
   "outputs": [],
   "source": [
    "PHONE_DEF = [\n",
    "    'AA', 'AE', 'AH', 'AO', 'AW',\n",
    "    'AY', 'B',  'CH', 'D', 'DH',\n",
    "    'EH', 'ER', 'EY', 'F', 'G',\n",
    "    'HH', 'IH', 'IY', 'JH', 'K',\n",
    "    'L', 'M', 'N', 'NG', 'OW',\n",
    "    'OY', 'P', 'R', 'S', 'SH',\n",
    "    'T', 'TH', 'UH', 'UW', 'V',\n",
    "    'W', 'Y', 'Z', 'ZH'\n",
    "]\n",
    "PHONE_DEF_SIL = PHONE_DEF + ['SIL']\n",
    "\n",
    "def phoneToId(p):\n",
    "    return PHONE_DEF_SIL.index(p)\n",
    "\n",
    "phoneToIdDict = {p:phoneToId(p) for p in PHONE_DEF_SIL}\n",
    "idToPhone = {v: k for k, v in phoneToIdDict.items()}\n",
    "\n",
    "def idsToPhonemes(seqClassIDs, idToPhone = idToPhone):\n",
    "    \"\"\"\n",
    "    Converts a sequence of phoneme IDs back to their phoneme representations.\n",
    "    \n",
    "    Args:\n",
    "        seqClassIDs (numpy array): The numerical sequence of phoneme IDs.\n",
    "        idToPhone (dict): A dictionary mapping phoneme IDs back to phonemes.\n",
    "        \n",
    "    Returns:\n",
    "        list: The corresponding phoneme sequence.\n",
    "    \"\"\"\n",
    "    phonemeSeq = [idToPhone[id - 1] for id in seqClassIDs if id > 0]  # -1 because IDs were stored with +1\n",
    "    return phonemeSeq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a8d12443",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_name = \"gru_ctc\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "38d2ed99",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = [\"<blank>\"] + PHONE_DEF + [\" \"]\n",
    "decoder = ctc_decoder(tokens= tokens,   \n",
    "                      lexicon=None,  \n",
    "                      blank_token = '<blank>', \n",
    "                      sil_token = ' ',\n",
    "                      )\n",
    "\n",
    "def decode_ctc_output(logits):\n",
    "    \"\"\"\n",
    "    Converts model logits to predicted phoneme sequences.\n",
    "    - Removes repeated phonemes.\n",
    "    - Removes blank tokens (0).\n",
    "    \"\"\"\n",
    "\n",
    "    predictions = torch.argmax(logits, dim=-1)  # Get most probable phoneme indices\n",
    "    predictions = [torch.unique_consecutive(seq[seq != 0]).cpu().numpy() for seq in predictions]  # Remove blanks\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "06950199",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3564984/3414911032.py:39: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  torch.load(f\".checkpoints/{output_name}/best_model.ckpt\")[\"state_dict\"]\n",
      "  0%|          | 0/64 [00:00<?, ?it/s]/data/matteo/nejm-brain-to-text/augmentations.py:91: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at ../aten/src/ATen/native/Convolution.cpp:1036.)\n",
      "  return self.conv(input, weight=self.weight, groups=self.groups, padding=\"same\")\n",
      "100%|██████████| 64/64 [00:18<00:00,  3.40it/s]\n"
     ]
    }
   ],
   "source": [
    "# Create a directory to save models\n",
    "run_folder = f\"{output_name}\"\n",
    "os.makedirs(run_folder, exist_ok=True)\n",
    "\n",
    "TRAIN = False\n",
    "\n",
    "if TRAIN:\n",
    "\n",
    "    wandb_logger = WandbLogger(project=\"Brain2Text2025\", name=f\"{output_name}\",\n",
    "                                reinit=True)\n",
    "\n",
    "    # Define ModelCheckpoint to save the best model based on validation loss\n",
    "    checkpoint_callback = ModelCheckpoint(\n",
    "        monitor=\"val_CER\",  # Ensure your validation step logs \"val_loss\"\n",
    "        mode=\"min\",          # Save the model with the lowest validation loss\n",
    "        save_top_k=1,        # Keep only the best model\n",
    "        dirpath=f\".checkpoints/{run_folder}/\",  # Directory to save checkpoints\n",
    "        filename=f\"best_model\",  # Model filename\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "    # Define EarlyStopping callback with patience of 3 epochs\n",
    "    early_stopping_callback = EarlyStopping(\n",
    "        monitor=\"val_loss\",\n",
    "        patience=10,   # Stop training if no improvement in 3 epochs\n",
    "        mode=\"min\",\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "\n",
    "    # Train model\n",
    "    trainer = pl.Trainer(max_epochs=120,devices =[6], callbacks=[checkpoint_callback, early_stopping_callback], logger=wandb_logger)\n",
    "\n",
    "    trainer.fit(model, train_loader, test_loader)\n",
    "\n",
    "else:\n",
    "    #load the best model\n",
    "    model.load_state_dict(\n",
    "        torch.load(f\".checkpoints/{output_name}/best_model.ckpt\")[\"state_dict\"]\n",
    "    )\n",
    "\n",
    "## EVALUATION\n",
    "device = \"cuda:5\"\n",
    "model.to(device)\n",
    "model.eval()\n",
    "\n",
    "## predit all teh test set \n",
    "pred_phonemes = []\n",
    "pred_logits = []\n",
    "true_phonemes = []\n",
    "true_sentences = []\n",
    "day_indices = []\n",
    "cer_list = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm.tqdm(val_loader):\n",
    "        X = batch[\"neural_feats\"]\n",
    "        y = batch[\"phone_seq\"]\n",
    "        X_len = batch[\"neural_time_bins\"]\n",
    "        y_len = batch[\"phone_seq_len\"]\n",
    "        days = batch[\"day\"]\n",
    "        transcriptions = batch[\"sentence\"]\n",
    "        \n",
    "        # Move data to device\n",
    "        X = X.to(device)\n",
    "        y = y.to(device)\n",
    "\n",
    "        days = days.to(device)\n",
    "        X_len = X_len.to(device)\n",
    "        y_len = y_len.to(device)\n",
    "\n",
    "        logits = model(X,days)\n",
    "        pred = torch.nn.functional.log_softmax(logits, dim=-1).cpu()\n",
    "        # decoded = decoder(pred)\n",
    "        pred_logits.append(pred)\n",
    "\n",
    "        total_edit_distance, total_seq_length = 0, 0\n",
    "\n",
    "        for i in range(pred.shape[0]):\n",
    "            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / model.strideLen), :], dim=-1)\n",
    "            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "\n",
    "            trueSeq = y[i][:y_len[i]].cpu().numpy()\n",
    "            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "            total_edit_distance += matcher.distance()\n",
    "            total_seq_length += len(trueSeq)\n",
    "\n",
    "            cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0\n",
    "            cer_list.append(cer)\n",
    "            \n",
    "        pp = decode_ctc_output(pred)\n",
    "\n",
    "        pred_phonemes.extend(pp)\n",
    "        true_phonemes.extend([y[i][:y_len[i]].cpu().numpy() for i in range(len(y))])\n",
    "        # true_phonemes.extend(y.cpu().numpy())\n",
    "        true_sentences.extend(transcriptions)\n",
    "        day_indices.extend(days.cpu().numpy())\n",
    "\n",
    "avg_cer = np.mean(cer_list)\n",
    "\n",
    "    #create a dataframe with the results\n",
    "df = pd.DataFrame({\n",
    "    'True Phonemes': [idsToPhonemes(p) for p in true_phonemes],\n",
    "    'Predicted Phonemes': [idsToPhonemes(p) for p in pred_phonemes],\n",
    "    'True Sentence': true_sentences,\n",
    "    'Day Index': day_indices,\n",
    "    'CER': cer_list\n",
    "})\n",
    "\n",
    "#save it \n",
    "df.to_csv(os.path.join(run_folder, \"results.csv\"), index=False)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b47a5d3",
   "metadata": {},
   "source": [
    "## Evaluation Phase"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "93e7e55a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted Phonemes: ['HH', 'AE', 'T', 'SIL', 'DH', 'AH', 'SIL', 'T', 'AY', 'M', 'SIL', 'Y', 'UW', 'SIL', 'HH', 'AY', 'ER', 'AH', 'D', 'SIL', 'AA', 'N', 'SIL', 'W', 'IH', 'DH', 'SIL', 'DH', 'IH', 'S', 'SIL', 'K', 'AH', 'M', 'P', 'AH', 'N', 'IY', 'SIL']\n",
      "True Phonemes: ['AE', 'T', 'SIL', 'DH', 'AH', 'SIL', 'T', 'AY', 'M', 'SIL', 'Y', 'UW', 'SIL', 'HH', 'AY', 'ER', 'D', 'SIL', 'AA', 'N', 'SIL', 'W', 'IH', 'DH', 'SIL', 'DH', 'IH', 'S', 'SIL', 'K', 'AH', 'M', 'P', 'AH', 'N', 'IY', 'SIL']\n",
      "True Sentence: At the time you hired on with this company.\n"
     ]
    }
   ],
   "source": [
    "idx = 121\n",
    "print(\"Predicted Phonemes:\", idsToPhonemes(pred_phonemes[idx]))\n",
    "print(\"True Phonemes:\", idsToPhonemes(true_phonemes[idx]))\n",
    "print(\"True Sentence:\", true_sentences[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6eaba555",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(preds, targets):\n",
    "    \n",
    "\n",
    "    accs= []\n",
    "    for pred, target in zip(preds, targets):\n",
    "        \n",
    "        #truncate to the length of the shortest sequence\n",
    "        min_len = min(len(pred), len(target))\n",
    "\n",
    "\n",
    "        pred = pred[:min_len]\n",
    "        target = target[:min_len]\n",
    "\n",
    "        equal_inference = (pred == target)\n",
    "        acc = np.sum(equal_inference)/ len(pred)\n",
    "        accs.append(acc)\n",
    "\n",
    "    return np.mean(accs)\n",
    "   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5fb753c4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "overall_acc 0.5868959623103708\n"
     ]
    }
   ],
   "source": [
    "overall_acc = compute_accuracy(pred_phonemes, true_phonemes)\n",
    "print(\"overall_acc\", overall_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "36fd730c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "def compute_f1_score(preds, targets):\n",
    "    \"\"\"\n",
    "    Computes the macro F1 score over a batch of predicted and target sequences.\n",
    "    Each element of preds and targets should be a list of integers or strings.\n",
    "    \"\"\"\n",
    "\n",
    "    f1s = []\n",
    "    for pred, target in zip(preds, targets):\n",
    "        # Truncate to the length of the shortest sequence\n",
    "        min_len = min(len(pred), len(target))\n",
    "        pred = pred[:min_len]\n",
    "        target = target[:min_len]\n",
    "\n",
    "        # Compute macro F1 score for the sequence\n",
    "        try:\n",
    "            score = f1_score(target, pred, average='macro', zero_division=0)\n",
    "        except ValueError:\n",
    "            score = 0.0\n",
    "        f1s.append(score)\n",
    "\n",
    "    return np.mean(f1s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "cacfe8eb",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "F1 Score: 0.5286955630800393\n"
     ]
    }
   ],
   "source": [
    "f1 = compute_f1_score(pred_phonemes, true_phonemes)\n",
    "print(\"F1 Score:\", f1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "ecf5d150",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range of Accuracy per day 0.3376334524780828 0.9041285210820575\n"
     ]
    }
   ],
   "source": [
    "day_indices_flat = day_indices\n",
    "\n",
    "#compute accuracy per day by selecting indices of the same day\n",
    "day_accs = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    acc = compute_accuracy([pred_phonemes[idx] for idx in indices], [true_phonemes[idx] for idx in indices])\n",
    "    day_accs.append(acc)\n",
    "\n",
    "day_accs\n",
    "print(\"Range of Accuracy per day\", min(day_accs), max(day_accs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ae336524",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range of CER per day 0.02023066063285903 0.38880449832059344\n"
     ]
    }
   ],
   "source": [
    "cer_list_per_day = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    cer_list_per_day.append(np.array(cer_list)[indices].mean())\n",
    "\n",
    "cer_list_per_day\n",
    "\n",
    "print(\"Range of CER per day\", min(cer_list_per_day), max(cer_list_per_day))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "dcdb1522",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average lenght diff: -19.199158485273493 +- 22.831528835565795\n"
     ]
    }
   ],
   "source": [
    "diffs = []\n",
    "for i in range(len(true_phonemes)):\n",
    "    true = true_phonemes[i]\n",
    "    pred = pred_phonemes[i]\n",
    "    diffs.append(np.array(len(true)) - np.array(len(pred)))\n",
    "\n",
    "print(f\"Average lenght diff: {np.mean(diffs)} +- {np.std(diffs)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8d4d4fe1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range of diff lenghts per day: -88.13333333333334 - 0.8333333333333334\n"
     ]
    }
   ],
   "source": [
    "## compute diff lenghts per day\n",
    "diffs_per_day = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    diffs_per_day.append(np.array(diffs)[indices].mean())\n",
    "\n",
    "diffs_per_day\n",
    "print(f\"Range of diff lenghts per day: {min(diffs_per_day)} - {max(diffs_per_day)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "26c4f0b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "results_dir = f\"results/{output_name}/\"\n",
    "os.makedirs(results_dir, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "0cb7b593",
   "metadata": {},
   "outputs": [],
   "source": [
    "#create a dataframe with the results\n",
    "df = pd.DataFrame({\n",
    "    'True Phonemes': [idsToPhonemes(p) for p in true_phonemes],\n",
    "    'Predicted Phonemes': [idsToPhonemes(p) for p in pred_phonemes],\n",
    "    'True Sentence': true_sentences,\n",
    "    'Day Index': day_indices_flat,\n",
    "    'CER': cer_list\n",
    "})\n",
    "\n",
    "#save it \n",
    "df.to_csv(os.path.join(results_dir, \"results.csv\"), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "b2118a02",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>True Phonemes</th>\n",
       "      <th>Predicted Phonemes</th>\n",
       "      <th>True Sentence</th>\n",
       "      <th>Day Index</th>\n",
       "      <th>CER</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1254</th>\n",
       "      <td>[IH, F, SIL, IH, T, SIL, D, AH, Z, SIL, DH, AH...</td>\n",
       "      <td>[IH, T, SIL, IH, T, SIL, D, AH, Z, SIL, DH, AH...</td>\n",
       "      <td>If it does the job.</td>\n",
       "      <td>39</td>\n",
       "      <td>0.489796</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1325</th>\n",
       "      <td>[DH, AH, SIL, T, AW, N, SIL, IH, Z, SIL, JH, A...</td>\n",
       "      <td>[DH, AH, SIL, B, EY, SIL, IH, Z, SIL, D, IH, Z...</td>\n",
       "      <td>The town is just over the hill here.</td>\n",
       "      <td>41</td>\n",
       "      <td>0.505747</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1323</th>\n",
       "      <td>[AY, SIL, HH, OW, P, SIL, EH, V, R, IY, TH, IH...</td>\n",
       "      <td>[AH, SIL, IH, N, SIL, IY, Z, IY, TH, IH, NG, S...</td>\n",
       "      <td>I hope everything works out up there.</td>\n",
       "      <td>41</td>\n",
       "      <td>0.517241</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1324</th>\n",
       "      <td>[AA, R, SIL, Y, UW, SIL, P, L, IY, Z, D, SIL, ...</td>\n",
       "      <td>[IY, SIL, AH, B, L, IY, Z, SIL, DH, AH, DH, SI...</td>\n",
       "      <td>Are you pleased with this decision?</td>\n",
       "      <td>41</td>\n",
       "      <td>0.578947</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1253</th>\n",
       "      <td>[DH, EH, R, SIL, D, UW, IH, NG, SIL, AH, SIL, ...</td>\n",
       "      <td>[DH, AW, SIL, N, UW, B, ER, SIL, AH, N, D, SIL...</td>\n",
       "      <td>They're doing a lot of good research here.</td>\n",
       "      <td>39</td>\n",
       "      <td>0.656250</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                          True Phonemes  \\\n",
       "1254  [IH, F, SIL, IH, T, SIL, D, AH, Z, SIL, DH, AH...   \n",
       "1325  [DH, AH, SIL, T, AW, N, SIL, IH, Z, SIL, JH, A...   \n",
       "1323  [AY, SIL, HH, OW, P, SIL, EH, V, R, IY, TH, IH...   \n",
       "1324  [AA, R, SIL, Y, UW, SIL, P, L, IY, Z, D, SIL, ...   \n",
       "1253  [DH, EH, R, SIL, D, UW, IH, NG, SIL, AH, SIL, ...   \n",
       "\n",
       "                                     Predicted Phonemes  \\\n",
       "1254  [IH, T, SIL, IH, T, SIL, D, AH, Z, SIL, DH, AH...   \n",
       "1325  [DH, AH, SIL, B, EY, SIL, IH, Z, SIL, D, IH, Z...   \n",
       "1323  [AH, SIL, IH, N, SIL, IY, Z, IY, TH, IH, NG, S...   \n",
       "1324  [IY, SIL, AH, B, L, IY, Z, SIL, DH, AH, DH, SI...   \n",
       "1253  [DH, AW, SIL, N, UW, B, ER, SIL, AH, N, D, SIL...   \n",
       "\n",
       "                                   True Sentence  Day Index       CER  \n",
       "1254                         If it does the job.         39  0.489796  \n",
       "1325        The town is just over the hill here.         41  0.505747  \n",
       "1323       I hope everything works out up there.         41  0.517241  \n",
       "1324         Are you pleased with this decision?         41  0.578947  \n",
       "1253  They're doing a lot of good research here.         39  0.656250  "
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.sort_values(by=[\"CER\"], ascending=True).iloc[-5:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "52f3e77d",
   "metadata": {},
   "outputs": [],
   "source": [
    "#create a dataframe with all the metrics\n",
    "df_metrics = pd.DataFrame({\n",
    "    'Overall Accuracy': [overall_acc],\n",
    "    'Range of Accuracy per day': f\"{[min(day_accs), max(day_accs)]}\",\n",
    "    'Range of CER per day': f\"{[min(cer_list_per_day), max(cer_list_per_day)]}\",\n",
    "    'Average length diff': f\"{[np.mean(diffs), np.std(diffs)]}\",\n",
    "    'Range of length diff per day': f\"{[min(diffs_per_day), max(diffs_per_day)]}\",\n",
    "    'average CER': [np.mean(cer_list)],\n",
    "})\n",
    "df_metrics.to_csv(os.path.join(results_dir,\"metrics.csv\"), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "5d1a6af9",
   "metadata": {},
   "outputs": [],
   "source": [
    "## save pred_logits\n",
    "with open(os.path.join(results_dir, \"pred_logits.pkl\"), \"wb\") as f:\n",
    "    pickle.dump(pred_logits, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef3126f1",
   "metadata": {},
   "source": [
    "## Test Predictions\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "4c50a203",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 64/64 [00:20<00:00,  3.11it/s]\n",
      "/home/matteo/anaconda3/envs/evo/lib/python3.9/site-packages/numpy/_core/fromnumeric.py:3596: RuntimeWarning: Mean of empty slice.\n",
      "  return _methods._mean(a, axis=axis, dtype=dtype,\n",
      "/home/matteo/anaconda3/envs/evo/lib/python3.9/site-packages/numpy/_core/_methods.py:138: RuntimeWarning: invalid value encountered in scalar divide\n",
      "  ret = ret.dtype.type(ret / rcount)\n"
     ]
    }
   ],
   "source": [
    "\n",
    "## EVALUATION\n",
    "model.to(device)\n",
    "model.eval()\n",
    "\n",
    "## predit all teh test set \n",
    "pred_phonemes = []\n",
    "pred_logits = []\n",
    "true_phonemes = []\n",
    "true_sentences = []\n",
    "day_indices = []\n",
    "cer_list = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm.tqdm(test_loader):\n",
    "        X = batch[\"neural_feats\"]\n",
    "        X_len = batch[\"neural_time_bins\"]\n",
    "        days = batch[\"day\"]\n",
    "        \n",
    "        # Move data to device\n",
    "        X = X.to(device)\n",
    "\n",
    "        days = days.to(device)\n",
    "        X_len = X_len.to(device)\n",
    "\n",
    "        logits = model(X,days)\n",
    "        pred = torch.nn.functional.log_softmax(logits, dim=-1).cpu()\n",
    "        # decoded = decoder(pred)\n",
    "        pred_logits.append(pred)\n",
    "        pp = decode_ctc_output(pred)\n",
    "\n",
    "        pred_phonemes.extend(pp)\n",
    "        day_indices.extend(days.cpu().numpy())\n",
    "\n",
    "avg_cer = np.mean(cer_list)\n",
    "\n",
    "    #create a dataframe with the results\n",
    "df = pd.DataFrame({\n",
    "    'Predicted Phonemes': [idsToPhonemes(p) for p in pred_phonemes],\n",
    "    'Day Index': day_indices,\n",
    "})\n",
    "\n",
    "#save it \n",
    "df.to_csv(os.path.join(run_folder, \"test_results.csv\"), index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "76cc9dca",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Predicted Phonemes</th>\n",
       "      <th>Day Index</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[AY, SIL, G, EH, T, SIL, T, AY, ER, D, SIL, W,...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[IH, M, R, JH, AH, S, IY, SIL, K, IY, R, SIL, ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[Y, UW, SIL, K, R, IY, EY, T, SIL, AH, SIL, P,...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[AY, SIL, TH, IH, NG, K, SIL, M, EY, B, IY, SI...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[SH, OW, SIL, DH, AE, T, SIL, DH, EY, SIL, D, ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1445</th>\n",
       "      <td>[SH, IH, S, SIL, DH, EY, SIL, D, OW, N, T, SIL...</td>\n",
       "      <td>44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1446</th>\n",
       "      <td>[DH, EH, R, S, SIL, AH, SIL, L, AA, T, SIL, AH...</td>\n",
       "      <td>44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1447</th>\n",
       "      <td>[AE, N, SIL, AW, ER, SIL, HH, AE, D, SIL, AH, ...</td>\n",
       "      <td>44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1448</th>\n",
       "      <td>[SH, IY, SIL, P, OY, R, T, S, SIL, AE, T, SIL,...</td>\n",
       "      <td>44</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1449</th>\n",
       "      <td>[W, IY, L, SIL, P, EY, P, IY, SIL, DH, AE, T, ...</td>\n",
       "      <td>44</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>1450 rows × 2 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                     Predicted Phonemes  Day Index\n",
       "0     [AY, SIL, G, EH, T, SIL, T, AY, ER, D, SIL, W,...          1\n",
       "1     [IH, M, R, JH, AH, S, IY, SIL, K, IY, R, SIL, ...          1\n",
       "2     [Y, UW, SIL, K, R, IY, EY, T, SIL, AH, SIL, P,...          1\n",
       "3     [AY, SIL, TH, IH, NG, K, SIL, M, EY, B, IY, SI...          1\n",
       "4     [SH, OW, SIL, DH, AE, T, SIL, DH, EY, SIL, D, ...          1\n",
       "...                                                 ...        ...\n",
       "1445  [SH, IH, S, SIL, DH, EY, SIL, D, OW, N, T, SIL...         44\n",
       "1446  [DH, EH, R, S, SIL, AH, SIL, L, AA, T, SIL, AH...         44\n",
       "1447  [AE, N, SIL, AW, ER, SIL, HH, AE, D, SIL, AH, ...         44\n",
       "1448  [SH, IY, SIL, P, OY, R, T, S, SIL, AE, T, SIL,...         44\n",
       "1449  [W, IY, L, SIL, P, EY, P, IY, SIL, DH, AE, T, ...         44\n",
       "\n",
       "[1450 rows x 2 columns]"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "4604f685",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total number of trials in test pred_logits: 1450\n"
     ]
    }
   ],
   "source": [
    "tot_len = 0\n",
    "for i in pred_logits:\n",
    "    tot_len += i.shape[0]\n",
    "print(f\"Total number of trials in test pred_logits: {tot_len}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "cba21c17",
   "metadata": {},
   "outputs": [],
   "source": [
    "## save pred_logits\n",
    "with open(os.path.join(results_dir, \"test_pred_logits.pkl\"), \"wb\") as f:\n",
    "    pickle.dump(pred_logits, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
