{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b0a9b79f",
   "metadata": {},
   "source": [
    "## Ensembling of various GRU+ CTC+MFCC models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bf333104",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package cmudict to /home/XXXXXX/nltk_data...\n",
      "[nltk_data]   Package cmudict is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import pickle\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "from dataset import SpeechSentenceDataset, idsToPhonemes, getDatasetLoaders,getDatasetLoaders_V3, PHONE_DEF, PHONE_DEF_SIL\n",
    "import re \n",
    "from g2p_en import G2p\n",
    "import numpy as np\n",
    "from model.ctc_modelling import LightningGRUDecoder, LightningGRUDecoder_MFCC_v3\n",
    "from model.hybrid_modelling import HybridCausalLMOutput, HybridGRUDecoder\n",
    "import time\n",
    "import numpy as np\n",
    "from edit_distance import SequenceMatcher\n",
    "import tqdm\n",
    "import pytorch_lightning as pl\n",
    "import jiwer\n",
    "import nltk\n",
    "from nltk.corpus import cmudict\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "import wandb\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import copy\n",
    "from difflib import get_close_matches\n",
    "from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer\n",
    "import pandas as pd\n",
    "from torchaudio.models.decoder import ctc_decoder\n",
    "import string\n",
    "from config import DATASET_SM_ROBUST, DATASET_SM_ZSCORE, DATASET_FULL_TRIALS_ZSCORE\n",
    "# from model.ctc_modelling import Light\n",
    "import os\n",
    "# Download CMU Pronouncing Dictionary (First-time use)\n",
    "nltk.download(\"cmudict\")\n",
    "\n",
    "# Load CMUdict\n",
    "cmu_dict = cmudict.dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "057097db",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of trials:  10020\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  8800\n",
      "Number of trials:  880\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  880\n"
     ]
    }
   ],
   "source": [
    "train_loader, test_loader,_, loadedData = getDatasetLoaders_V3(DATASET_FULL_TRIALS_ZSCORE, 64, include_prego=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c7e4200a",
   "metadata": {},
   "outputs": [],
   "source": [
    "nInputFeatures = 256 #channels \n",
    "nClasses = 40 \n",
    "dropout = 0.4 \n",
    "hidden_dim = 1024\n",
    "nlayers = 5\n",
    "stride_len = 4\n",
    "kernel_len =32\n",
    "gaussian_smooth_width = 2\n",
    "bidirectional = True\n",
    "\n",
    "white_noise_SD = 0.8\n",
    "constant_offset_SD = 0.2\n",
    "seq_len = 150\n",
    "max_time_series_len = 12000\n",
    "\n",
    "lr_start = 1e-4\n",
    "lr_end = 1e-5\n",
    "l2_decay = 1e-5\n",
    "\n",
    "\n",
    "warmup_epoch = 5\n",
    "steps_per_epoch = len(train_loader)\n",
    "warmup_steps = warmup_epoch * steps_per_epoch\n",
    "\n",
    "target_epoch = 60\n",
    "total_steps = target_epoch * steps_per_epoch\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "260fcc14",
   "metadata": {},
   "outputs": [],
   "source": [
    "ensemble_model_weights = \"optimization/.checkpoints/mfcc_sm_gru_ctc\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c355ab2d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Resetting neural_dim based on channels\n",
      "neural_dim 256 256\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n",
      "/tmp/ipykernel_4007828/3159069469.py:25: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(model_path, map_location='cpu')['state_dict'])\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading model from optimization/.checkpoints/mfcc_sm_gru_ctc/best_model-v5.ckpt\n",
      "Resetting neural_dim based on channels\n",
      "neural_dim 256 256\n",
      "Loading model from optimization/.checkpoints/mfcc_sm_gru_ctc/best_model-v3.ckpt\n",
      "Resetting neural_dim based on channels\n",
      "neural_dim 256 256\n",
      "Loading model from optimization/.checkpoints/mfcc_sm_gru_ctc/best_model-v2.ckpt\n",
      "Resetting neural_dim based on channels\n",
      "neural_dim 256 256\n",
      "Loading model from optimization/.checkpoints/mfcc_sm_gru_ctc/best_model.ckpt\n",
      "Resetting neural_dim based on channels\n",
      "neural_dim 256 256\n",
      "Loading model from optimization/.checkpoints/mfcc_sm_gru_ctc/best_model-v1.ckpt\n",
      "Resetting neural_dim based on channels\n",
      "neural_dim 256 256\n",
      "Loading model from optimization/.checkpoints/mfcc_sm_gru_ctc/best_model-v4.ckpt\n"
     ]
    }
   ],
   "source": [
    "models = []\n",
    "\n",
    "\n",
    "for weight_path in os.listdir(ensemble_model_weights):\n",
    "    if weight_path.endswith(\".ckpt\"):\n",
    "        model_path = os.path.join(ensemble_model_weights, weight_path)\n",
    "\n",
    "                # Define model\n",
    "        model = LightningGRUDecoder_MFCC_v3(\n",
    "                    neural_dim=nInputFeatures,\n",
    "                    n_classes=nClasses,\n",
    "                    hidden_dim=hidden_dim,\n",
    "                    layer_dim=nlayers,\n",
    "                    strideLen=stride_len,\n",
    "                    kernelLen=kernel_len,\n",
    "                    gaussianSmoothWidth=gaussian_smooth_width,\n",
    "                    bidirectional=bidirectional,\n",
    "                    dropout=dropout,\n",
    "                    white_noise_SD=white_noise_SD,\n",
    "                    constant_offset_SD=constant_offset_SD,\n",
    "                    weight_decay=l2_decay,\n",
    "                    learning_rate=lr_start,\n",
    "                    mfcc_loss_weight=1.,)\n",
    "\n",
    "        model.load_state_dict(torch.load(model_path, map_location='cpu')['state_dict'])\n",
    "\n",
    "\n",
    "        print(f\"Loading model from {model_path}\")\n",
    "        models.append(model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "cde7a0de",
   "metadata": {},
   "outputs": [],
   "source": [
    "tokens = [\"<blank>\"] + PHONE_DEF + [\" \"]\n",
    "decoder = ctc_decoder(tokens= tokens,   \n",
    "                      lexicon=None,  \n",
    "                      blank_token = '<blank>', \n",
    "                      sil_token = ' ',\n",
    "                      )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "adaaeff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_ctc_output(logits):\n",
    "    \"\"\"\n",
    "    Converts model logits to predicted phoneme sequences.\n",
    "    - Removes repeated phonemes.\n",
    "    - Removes blank tokens (0).\n",
    "    \"\"\"\n",
    "\n",
    "    predictions = torch.argmax(logits, dim=-1)  # Get most probable phoneme indices\n",
    "    predictions = [torch.unique_consecutive(seq[seq != 0]).cpu().numpy() for seq in predictions]  # Remove blanks\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c448cb84",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f9ef0640",
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "from difflib import SequenceMatcher\n",
    "\n",
    "def align_and_vote(seqs):\n",
    "    \"\"\"\n",
    "    Aligns phoneme sequences and returns a hard-voted sequence.\n",
    "    `seqs` is a list of lists (decoded phoneme sequences from models).\n",
    "    \"\"\"\n",
    "    # Pick a reference (longest sequence)\n",
    "    ref = max(seqs, key=len)\n",
    "\n",
    "    # Build alignment matrix\n",
    "    aligned = [[] for _ in range(len(ref))]\n",
    "    \n",
    "    for seq in seqs:\n",
    "        matcher = SequenceMatcher(a=ref, b=seq)\n",
    "        ref_idx = 0\n",
    "        for op, i1, i2, j1, j2 in matcher.get_opcodes():\n",
    "            if op == 'equal':\n",
    "                for r, s in zip(range(i1, i2), range(j1, j2)):\n",
    "                    aligned[r].append(seq[s])\n",
    "                    ref_idx += 1\n",
    "            elif op == 'replace':\n",
    "                for r in range(i1, i2):\n",
    "                    aligned[r].append(seq[j1] if j1 < len(seq) else None)\n",
    "                    ref_idx += 1\n",
    "            elif op == 'insert':\n",
    "                continue  # skip insertions\n",
    "            elif op == 'delete':\n",
    "                for r in range(i1, i2):\n",
    "                    aligned[r].append(None)\n",
    "                    ref_idx += 1\n",
    "\n",
    "    # Majority vote\n",
    "    voted_seq = []\n",
    "    for col in aligned:\n",
    "        if col:\n",
    "            counter = Counter([p for p in col if p is not None])\n",
    "            if counter:\n",
    "                voted_seq.append(counter.most_common(1)[0][0])\n",
    "\n",
    "    return voted_seq\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c649f0d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 14/14 [00:43<00:00,  3.09s/it]\n"
     ]
    }
   ],
   "source": [
    "from collections import defaultdict\n",
    "\n",
    "pred_phonemes = []\n",
    "pred_logits = []\n",
    "true_phonemes = []\n",
    "true_sentences = []\n",
    "day_indices = []\n",
    "cer_list = []\n",
    "cer_list_per_model = defaultdict(list)\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm.tqdm(test_loader):\n",
    "        X = batch[\"neural_feats\"].to(device)\n",
    "        y = batch[\"phone_seq\"].to(device)\n",
    "        X_len = batch[\"neural_time_bins\"].to(device)\n",
    "        y_len = batch[\"phone_seq_len\"].to(device)\n",
    "        days = batch[\"day\"].to(device)\n",
    "        transcriptions = batch[\"sentence\"]\n",
    "\n",
    "        logits_ensemble = []\n",
    "\n",
    "        # -------- Per-model predictions and CER --------\n",
    "        for j, model in enumerate(models):\n",
    "            model.eval().to(device)\n",
    "            logits, _ = model(X, days)\n",
    "            logits_ensemble.append(logits.cpu())  # accumulate raw logits\n",
    "            model.cpu()\n",
    "\n",
    "            pred = torch.nn.functional.log_softmax(logits, dim=-1).cpu()\n",
    "\n",
    "            for i in range(pred.shape[0]):\n",
    "                decodedSeq = torch.argmax(pred[i, : int(X_len[i] / model.strideLen), :], dim=-1)\n",
    "                decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "                decodedSeq = decodedSeq[decodedSeq != 0].numpy()\n",
    "\n",
    "                trueSeq = y[i][:y_len[i]].cpu().numpy()\n",
    "                matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "                cer = matcher.distance() / len(trueSeq) if len(trueSeq) > 0 else 1.0\n",
    "                cer_list_per_model[j].append(cer)\n",
    "\n",
    "        # -------- Ensemble prediction and CER --------\n",
    "        logits = torch.stack(logits_ensemble).mean(dim=0)  # average raw logits\n",
    "        pred = torch.nn.functional.log_softmax(logits, dim=-1)\n",
    "        pred_logits.append(pred)\n",
    "\n",
    "        batch_pred_phonemes = decode_ctc_output(pred)\n",
    "\n",
    "\n",
    "        \n",
    "\n",
    "        for i in range(pred.shape[0]):\n",
    "            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / models[0].strideLen), :], dim=-1)\n",
    "            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "\n",
    "            trueSeq = y[i][:y_len[i]].cpu().numpy()\n",
    "            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "            cer = matcher.distance() / len(trueSeq) if len(trueSeq) > 0 else 1.0\n",
    "            cer_list.append(cer)\n",
    "\n",
    "        # -------- Store outputs --------\n",
    "        pred_phonemes.extend(batch_pred_phonemes)\n",
    "        true_phonemes.extend([y[i][:y_len[i]].cpu().numpy() for i in range(len(y))])\n",
    "        true_sentences.extend(transcriptions)\n",
    "        day_indices.extend(days.cpu().numpy())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "86841be4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model 0 CER: 0.1710\n",
      "Model 1 CER: 0.1747\n",
      "Model 2 CER: 0.1799\n",
      "Model 3 CER: 0.1773\n",
      "Model 4 CER: 0.1747\n",
      "Model 5 CER: 0.1747\n",
      "Ensemble CER: 0.2127\n"
     ]
    }
   ],
   "source": [
    "for k, cer_vals in cer_list_per_model.items():\n",
    "    print(f\"Model {k} CER: {np.mean(cer_vals):.4f}\")\n",
    "\n",
    "print(f\"Ensemble CER: {np.mean(cer_list):.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb887787",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/14 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 14/14 [02:08<00:00,  9.20s/it]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model 0 CER: 0.1710\n",
      "Model 1 CER: 0.1747\n",
      "Model 2 CER: 0.1799\n",
      "Model 3 CER: 0.1773\n",
      "Model 4 CER: 0.1747\n",
      "Model 5 CER: 0.1747\n",
      "Hard-voted Ensemble CER: 0.1742\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "ename": "",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n",
      "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n",
      "\u001b[1;31mClick <a href='https://aka.ms/vscodeJupyterKernelCrash'>here</a> for more info. \n",
      "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
     ]
    }
   ],
   "source": [
    "from collections import defaultdict, Counter\n",
    "from edit_distance import SequenceMatcher\n",
    "import torch\n",
    "import tqdm\n",
    "\n",
    "def align_and_vote(seqs):\n",
    "    \"\"\"\n",
    "    Aligns phoneme sequences and returns a hard-voted sequence.\n",
    "    Each element in `seqs` is a list of predicted phonemes from one model.\n",
    "    \"\"\"\n",
    "    ref = max(seqs, key=len)\n",
    "    aligned = [[] for _ in range(len(ref))]\n",
    "\n",
    "    for seq in seqs:\n",
    "        matcher = SequenceMatcher(a=ref, b=seq)\n",
    "        for op, i1, i2, j1, j2 in matcher.get_opcodes():\n",
    "            if op == 'equal':\n",
    "                for r, s in zip(range(i1, i2), range(j1, j2)):\n",
    "                    aligned[r].append(seq[s])\n",
    "            elif op == 'replace':\n",
    "                for r in range(i1, i2):\n",
    "                    aligned[r].append(seq[j1] if j1 < len(seq) else None)\n",
    "            elif op == 'delete':\n",
    "                for r in range(i1, i2):\n",
    "                    aligned[r].append(None)\n",
    "            # skip insertions (ref doesn't move)\n",
    "\n",
    "    voted_seq = []\n",
    "    for col in aligned:\n",
    "        col = [p for p in col if p is not None]\n",
    "        if col:\n",
    "            voted_seq.append(Counter(col).most_common(1)[0][0])\n",
    "    return voted_seq\n",
    "\n",
    "\n",
    "# Init output containers\n",
    "pred_phonemes = []\n",
    "true_phonemes = []\n",
    "true_sentences = []\n",
    "day_indices = []\n",
    "cer_list = []\n",
    "cer_list_per_model = defaultdict(list)\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm.tqdm(test_loader):\n",
    "        X = batch[\"neural_feats\"].to(device)\n",
    "        y = batch[\"phone_seq\"].to(device)\n",
    "        X_len = batch[\"neural_time_bins\"].to(device)\n",
    "        y_len = batch[\"phone_seq_len\"].to(device)\n",
    "        days = batch[\"day\"].to(device)\n",
    "        transcriptions = batch[\"sentence\"]\n",
    "\n",
    "        decoded_seqs_per_model = []\n",
    "\n",
    "        for j, model in enumerate(models):\n",
    "            model.eval().to(device)\n",
    "            logits, _ = model(X, days)\n",
    "            model.cpu()\n",
    "\n",
    "            pred = torch.nn.functional.log_softmax(logits, dim=-1).cpu()\n",
    "            decoded = decode_ctc_output(pred)  # returns list of decoded phoneme sequences\n",
    "            decoded_seqs_per_model.append(decoded)\n",
    "\n",
    "            for i in range(pred.shape[0]):\n",
    "                decodedSeq = torch.argmax(pred[i, : int(X_len[i] / model.strideLen), :], dim=-1)\n",
    "                decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "                decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "\n",
    "                trueSeq = y[i][:y_len[i]].cpu().numpy()\n",
    "                matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "                cer = matcher.distance() / len(trueSeq) if len(trueSeq) > 0 else 1.0\n",
    "                cer_list_per_model[j].append(cer)\n",
    "\n",
    "        # Hard voting across model predictions\n",
    "        for i in range(len(decoded_seqs_per_model[0])):\n",
    "            seqs = [decoded_seqs_per_model[m][i] for m in range(len(models))]\n",
    "            voted_seq = align_and_vote(seqs)\n",
    "\n",
    "            trueSeq = y[i][:y_len[i]].cpu().numpy()\n",
    "            matcher = SequenceMatcher(a=trueSeq.tolist(), b=voted_seq)\n",
    "            cer = matcher.distance() / len(trueSeq) if len(trueSeq) > 0 else 1.0\n",
    "            cer_list.append(cer)\n",
    "\n",
    "            pred_phonemes.append(voted_seq)\n",
    "\n",
    "        true_phonemes.extend([y[i][:y_len[i]].cpu().numpy() for i in range(len(y))])\n",
    "        true_sentences.extend(transcriptions)\n",
    "        day_indices.extend(days.cpu().numpy())\n",
    "\n",
    "# Print results\n",
    "for k, v in cer_list_per_model.items():\n",
    "    print(f\"Model {k} CER: {np.mean(v):.4f}\")\n",
    "print(f\"Hard-voted Ensemble CER: {np.mean(cer_list):.4f}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "3f401831",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(0.19438203749107374)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(cer_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6457cffd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted Phonemes: ['SIL', 'Y', 'UW', 'SIL', 'HH', 'AE', 'V', 'SIL', 'Y', 'AO', 'R', 'SIL', 'B', 'AE', 'G', 'SIL']\n",
      "True Phonemes: ['D', 'UW', 'SIL', 'Y', 'UW', 'SIL', 'HH', 'AE', 'V', 'SIL', 'Y', 'AO', 'R', 'SIL', 'B', 'AE', 'G', 'SIL']\n",
      "True Sentence: Do you have your bag?\n"
     ]
    }
   ],
   "source": [
    "idx = 121\n",
    "print(\"Predicted Phonemes:\", idsToPhonemes(pred_phonemes[idx]))\n",
    "print(\"True Phonemes:\", idsToPhonemes(true_phonemes[idx]))\n",
    "print(\"True Sentence:\", true_sentences[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "953cbfa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(preds, targets):\n",
    "    \n",
    "\n",
    "    accs= []\n",
    "    for pred, target in zip(preds, targets):\n",
    "        \n",
    "        #truncate to the length of the shortest sequence\n",
    "        min_len = min(len(pred), len(target))\n",
    "\n",
    "\n",
    "        pred = pred[:min_len]\n",
    "        target = target[:min_len]\n",
    "\n",
    "        equal_inference = (pred == target)\n",
    "        acc = np.sum(equal_inference)/ len(pred)\n",
    "        accs.append(acc)\n",
    "\n",
    "    return np.mean(accs)\n",
    "   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cce1db4d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "overall_acc 0.23116372039573507\n"
     ]
    }
   ],
   "source": [
    "overall_acc = compute_accuracy(pred_phonemes, true_phonemes)\n",
    "print(\"overall_acc\", overall_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "337ef12a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range of Accuracy per day 0.11697775745870842 0.30675206358476315\n"
     ]
    }
   ],
   "source": [
    "day_indices_flat = day_indices\n",
    "\n",
    "#compute accuracy per day by selecting indices of the same day\n",
    "day_accs = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    acc = compute_accuracy([pred_phonemes[idx] for idx in indices], [true_phonemes[idx] for idx in indices])\n",
    "    day_accs.append(acc)\n",
    "\n",
    "day_accs\n",
    "print(\"Range of Accuracy per day\", min(day_accs), max(day_accs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f0d2c658",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range of CER per day 0.16521852599319856 0.3932086158781628\n"
     ]
    }
   ],
   "source": [
    "cer_list_per_day = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    cer_list_per_day.append(np.array(cer_list)[indices].mean())\n",
    "\n",
    "cer_list_per_day\n",
    "\n",
    "print(\"Range of CER per day\", min(cer_list_per_day), max(cer_list_per_day))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61fcfa51",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average lenght diff: 1.3647727272727272 +- 2.239599181505845\n"
     ]
    }
   ],
   "source": [
    "diffs = []\n",
    "for i in range(len(true_phonemes)):\n",
    "    true = true_phonemes[i]\n",
    "    pred = pred_phonemes[i]\n",
    "    diffs.append(np.array(len(true)) - np.array(len(pred)))\n",
    "\n",
    "print(f\"Average lenght diff: {np.mean(diffs)} +- {np.std(diffs)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62197392",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range of diff lenghts per day: 0.075 - 4.05\n"
     ]
    }
   ],
   "source": [
    "## compute diff lenghts per day\n",
    "diffs_per_day = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    diffs_per_day.append(np.array(diffs)[indices].mean())\n",
    "\n",
    "diffs_per_day\n",
    "print(f\"Range of diff lenghts per day: {min(diffs_per_day)} - {max(diffs_per_day)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "96cc9343",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "output_name = \"ensemble_mfcc_sm_gru_ctc\"\n",
    "\n",
    "results_dir = f\"results/{output_name}/\"\n",
    "os.makedirs(results_dir, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d0c8af4",
   "metadata": {},
   "outputs": [],
   "source": [
    "#create a dataframe with the results\n",
    "df = pd.DataFrame({\n",
    "    'True Phonemes': [idsToPhonemes(p) for p in true_phonemes],\n",
    "    'Predicted Phonemes': [idsToPhonemes(p) for p in pred_phonemes],\n",
    "    'True Sentence': true_sentences,\n",
    "    'Day Index': day_indices_flat,\n",
    "    'CER': cer_list\n",
    "})\n",
    "\n",
    "#save it \n",
    "df.to_csv(os.path.join(results_dir, \"results.csv\"), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55cd768d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>True Phonemes</th>\n",
       "      <th>Predicted Phonemes</th>\n",
       "      <th>True Sentence</th>\n",
       "      <th>Day Index</th>\n",
       "      <th>CER</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>384</th>\n",
       "      <td>[K, L, IH, K, SIL, HH, IY, R, SIL, T, UW, SIL,...</td>\n",
       "      <td>[SIL, L, SIL, K, SIL, HH, IY, R, SIL, T, UW, S...</td>\n",
       "      <td>Click here to join freelancer.</td>\n",
       "      <td>11</td>\n",
       "      <td>0.440000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>256</th>\n",
       "      <td>[DH, EY, SIL, D, OW, N, T, SIL, IY, V, IH, N, ...</td>\n",
       "      <td>[SIL, DH, SIL, D, OW, N, T, SIL, IY, V, IH, N,...</td>\n",
       "      <td>They don't even check my social security number.</td>\n",
       "      <td>8</td>\n",
       "      <td>0.452381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>257</th>\n",
       "      <td>[SH, AH, K, AA, G, OW, SIL, AH, N, D, SIL, F, ...</td>\n",
       "      <td>[SIL, K, ER, SIL, AH, N, D, SIL, P, R, EH, L, ...</td>\n",
       "      <td>Chicago and Philadelphia.</td>\n",
       "      <td>8</td>\n",
       "      <td>0.468750</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[R, IH, CH, SIL, P, ER, CH, AH, S, T, SIL, S, ...</td>\n",
       "      <td>[SIL, R, EY, SIL, P, AA, CH, AH, T, SIL, EH, V...</td>\n",
       "      <td>Rich purchased several signed lithographs.</td>\n",
       "      <td>0</td>\n",
       "      <td>0.528302</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[TH, IY, AA, K, R, AH, S, IY, SIL, R, IY, K, A...</td>\n",
       "      <td>[SIL, K, AH, SIL, R, IH, K, EH, N, T, D, SIL]</td>\n",
       "      <td>Theocracy reconsidered.</td>\n",
       "      <td>0</td>\n",
       "      <td>0.600000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         True Phonemes  \\\n",
       "384  [K, L, IH, K, SIL, HH, IY, R, SIL, T, UW, SIL,...   \n",
       "256  [DH, EY, SIL, D, OW, N, T, SIL, IY, V, IH, N, ...   \n",
       "257  [SH, AH, K, AA, G, OW, SIL, AH, N, D, SIL, F, ...   \n",
       "1    [R, IH, CH, SIL, P, ER, CH, AH, S, T, SIL, S, ...   \n",
       "0    [TH, IY, AA, K, R, AH, S, IY, SIL, R, IY, K, A...   \n",
       "\n",
       "                                    Predicted Phonemes  \\\n",
       "384  [SIL, L, SIL, K, SIL, HH, IY, R, SIL, T, UW, S...   \n",
       "256  [SIL, DH, SIL, D, OW, N, T, SIL, IY, V, IH, N,...   \n",
       "257  [SIL, K, ER, SIL, AH, N, D, SIL, P, R, EH, L, ...   \n",
       "1    [SIL, R, EY, SIL, P, AA, CH, AH, T, SIL, EH, V...   \n",
       "0        [SIL, K, AH, SIL, R, IH, K, EH, N, T, D, SIL]   \n",
       "\n",
       "                                        True Sentence  Day Index       CER  \n",
       "384                    Click here to join freelancer.         11  0.440000  \n",
       "256  They don't even check my social security number.          8  0.452381  \n",
       "257                         Chicago and Philadelphia.          8  0.468750  \n",
       "1          Rich purchased several signed lithographs.          0  0.528302  \n",
       "0                             Theocracy reconsidered.          0  0.600000  "
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.sort_values(by=[\"CER\"], ascending=True).iloc[-5:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "95fe666b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>True Phonemes</th>\n",
       "      <th>Predicted Phonemes</th>\n",
       "      <th>True Sentence</th>\n",
       "      <th>Day Index</th>\n",
       "      <th>CER</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>128</th>\n",
       "      <td>[AY, SIL, AE, M, SIL, N, AA, T, SIL, IY, V, IH...</td>\n",
       "      <td>[SIL, AE, M, SIL, N, AA, T, SIL, IY, V, IH, N,...</td>\n",
       "      <td>I am not even aware that I could have seen it.</td>\n",
       "      <td>5</td>\n",
       "      <td>0.050000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>320</th>\n",
       "      <td>[HH, IY, SIL, R, IH, L, IY, SIL, L, AY, K, S, ...</td>\n",
       "      <td>[SIL, HH, IY, SIL, R, IH, L, SIL, L, AY, K, SI...</td>\n",
       "      <td>He really likes it and I don't.</td>\n",
       "      <td>10</td>\n",
       "      <td>0.111111</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>455</th>\n",
       "      <td>[IH, N, SIL, DH, AE, T, SIL, T, AY, M, SIL, F,...</td>\n",
       "      <td>[IH, N, SIL, DH, AE, T, SIL, T, AY, M, SIL, F,...</td>\n",
       "      <td>In that time frame.</td>\n",
       "      <td>13</td>\n",
       "      <td>0.117021</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>456</th>\n",
       "      <td>[T, R, AY, IH, NG, SIL, T, UW, SIL, M, EY, K, ...</td>\n",
       "      <td>[SIL, T, R, SIL, IH, SIL, T, UW, SIL, M, EY, K...</td>\n",
       "      <td>Trying to make it through that.</td>\n",
       "      <td>13</td>\n",
       "      <td>0.117925</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>129</th>\n",
       "      <td>[DH, AH, SIL, R, AE, NG, K, S, SIL, AH, V, SIL...</td>\n",
       "      <td>[SIL, DH, AH, SIL, R, EY, NG, K, S, SIL, AH, V...</td>\n",
       "      <td>The ranks of Asian riders are swelling too.</td>\n",
       "      <td>5</td>\n",
       "      <td>0.118421</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>384</th>\n",
       "      <td>[K, L, IH, K, SIL, HH, IY, R, SIL, T, UW, SIL,...</td>\n",
       "      <td>[SIL, L, SIL, K, SIL, HH, IY, R, SIL, T, UW, S...</td>\n",
       "      <td>Click here to join freelancer.</td>\n",
       "      <td>11</td>\n",
       "      <td>0.440000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>256</th>\n",
       "      <td>[DH, EY, SIL, D, OW, N, T, SIL, IY, V, IH, N, ...</td>\n",
       "      <td>[SIL, DH, SIL, D, OW, N, T, SIL, IY, V, IH, N,...</td>\n",
       "      <td>They don't even check my social security number.</td>\n",
       "      <td>8</td>\n",
       "      <td>0.452381</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>257</th>\n",
       "      <td>[SH, AH, K, AA, G, OW, SIL, AH, N, D, SIL, F, ...</td>\n",
       "      <td>[SIL, K, ER, SIL, AH, N, D, SIL, P, R, EH, L, ...</td>\n",
       "      <td>Chicago and Philadelphia.</td>\n",
       "      <td>8</td>\n",
       "      <td>0.468750</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[R, IH, CH, SIL, P, ER, CH, AH, S, T, SIL, S, ...</td>\n",
       "      <td>[SIL, R, EY, SIL, P, AA, CH, AH, T, SIL, EH, V...</td>\n",
       "      <td>Rich purchased several signed lithographs.</td>\n",
       "      <td>0</td>\n",
       "      <td>0.528302</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[TH, IY, AA, K, R, AH, S, IY, SIL, R, IY, K, A...</td>\n",
       "      <td>[SIL, K, AH, SIL, R, IH, K, EH, N, T, D, SIL]</td>\n",
       "      <td>Theocracy reconsidered.</td>\n",
       "      <td>0</td>\n",
       "      <td>0.600000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>880 rows × 5 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         True Phonemes  \\\n",
       "128  [AY, SIL, AE, M, SIL, N, AA, T, SIL, IY, V, IH...   \n",
       "320  [HH, IY, SIL, R, IH, L, IY, SIL, L, AY, K, S, ...   \n",
       "455  [IH, N, SIL, DH, AE, T, SIL, T, AY, M, SIL, F,...   \n",
       "456  [T, R, AY, IH, NG, SIL, T, UW, SIL, M, EY, K, ...   \n",
       "129  [DH, AH, SIL, R, AE, NG, K, S, SIL, AH, V, SIL...   \n",
       "..                                                 ...   \n",
       "384  [K, L, IH, K, SIL, HH, IY, R, SIL, T, UW, SIL,...   \n",
       "256  [DH, EY, SIL, D, OW, N, T, SIL, IY, V, IH, N, ...   \n",
       "257  [SH, AH, K, AA, G, OW, SIL, AH, N, D, SIL, F, ...   \n",
       "1    [R, IH, CH, SIL, P, ER, CH, AH, S, T, SIL, S, ...   \n",
       "0    [TH, IY, AA, K, R, AH, S, IY, SIL, R, IY, K, A...   \n",
       "\n",
       "                                    Predicted Phonemes  \\\n",
       "128  [SIL, AE, M, SIL, N, AA, T, SIL, IY, V, IH, N,...   \n",
       "320  [SIL, HH, IY, SIL, R, IH, L, SIL, L, AY, K, SI...   \n",
       "455  [IH, N, SIL, DH, AE, T, SIL, T, AY, M, SIL, F,...   \n",
       "456  [SIL, T, R, SIL, IH, SIL, T, UW, SIL, M, EY, K...   \n",
       "129  [SIL, DH, AH, SIL, R, EY, NG, K, S, SIL, AH, V...   \n",
       "..                                                 ...   \n",
       "384  [SIL, L, SIL, K, SIL, HH, IY, R, SIL, T, UW, S...   \n",
       "256  [SIL, DH, SIL, D, OW, N, T, SIL, IY, V, IH, N,...   \n",
       "257  [SIL, K, ER, SIL, AH, N, D, SIL, P, R, EH, L, ...   \n",
       "1    [SIL, R, EY, SIL, P, AA, CH, AH, T, SIL, EH, V...   \n",
       "0        [SIL, K, AH, SIL, R, IH, K, EH, N, T, D, SIL]   \n",
       "\n",
       "                                        True Sentence  Day Index       CER  \n",
       "128    I am not even aware that I could have seen it.          5  0.050000  \n",
       "320                   He really likes it and I don't.         10  0.111111  \n",
       "455                               In that time frame.         13  0.117021  \n",
       "456                   Trying to make it through that.         13  0.117925  \n",
       "129       The ranks of Asian riders are swelling too.          5  0.118421  \n",
       "..                                                ...        ...       ...  \n",
       "384                    Click here to join freelancer.         11  0.440000  \n",
       "256  They don't even check my social security number.          8  0.452381  \n",
       "257                         Chicago and Philadelphia.          8  0.468750  \n",
       "1          Rich purchased several signed lithographs.          0  0.528302  \n",
       "0                             Theocracy reconsidered.          0  0.600000  \n",
       "\n",
       "[880 rows x 5 columns]"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.sort_values(by=[\"CER\"], ascending=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "57082ac1",
   "metadata": {},
   "outputs": [],
   "source": [
    "#create a dataframe with all the metrics\n",
    "df_metrics = pd.DataFrame({\n",
    "    'Overall Accuracy': [overall_acc],\n",
    "    'Range of Accuracy per day': f\"{[min(day_accs), max(day_accs)]}\",\n",
    "    'Range of CER per day': f\"{[min(cer_list_per_day), max(cer_list_per_day)]}\",\n",
    "    'Average length diff': f\"{[np.mean(diffs), np.std(diffs)]}\",\n",
    "    'Range of length diff per day': f\"{[min(diffs_per_day), max(diffs_per_day)]}\",\n",
    "    'average CER': [np.mean(cer_list)],\n",
    "})\n",
    "df_metrics.to_csv(os.path.join(results_dir,\"metrics.csv\"), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e0504fa2",
   "metadata": {},
   "outputs": [],
   "source": [
    "## save pred_logits\n",
    "with open(os.path.join(results_dir, \"pred_logits.pkl\"), \"wb\") as f:\n",
    "    pickle.dump(pred_logits, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
