{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Idea: couple a pre-trained GRU with a LM head to improve phoneme sequence prediction accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package cmudict to /home/XXXXXX/nltk_data...\n",
      "[nltk_data]   Package cmudict is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "from dataset import SpeechSentenceDataset, idsToPhonemes, getDatasetLoaders,getDatasetLoaders_V3, PHONE_DEF, PHONE_DEF_SIL\n",
    "import re \n",
    "from g2p_en import G2p\n",
    "import numpy as np\n",
    "from model.ctc_modelling import LightningGRUDecoder,LightningGRUDecoder_MFCC_v3\n",
    "from model.rnnt_modelling import LightningRNNTDecoder, LightningPretrainedRNNTDecoder\n",
    "import time\n",
    "import numpy as np\n",
    "from edit_distance import SequenceMatcher\n",
    "import tqdm\n",
    "import pytorch_lightning as pl\n",
    "import jiwer\n",
    "import nltk\n",
    "from nltk.corpus import cmudict\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "import wandb\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import copy\n",
    "from difflib import get_close_matches\n",
    "from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer\n",
    "import pandas as pd\n",
    "from torchaudio.models.decoder import ctc_decoder\n",
    "import string\n",
    "from config import DATASET_SM_ROBUST, DATASET_SM_ZSCORE, DATASET_FULL_TRIALS_ZSCORE\n",
    "import os\n",
    "# from model.ctc_modelling import Light\n",
    "\n",
    "# Download CMU Pronouncing Dictionary (First-time use)\n",
    "nltk.download(\"cmudict\")\n",
    "\n",
    "# Load CMUdict\n",
    "cmu_dict = cmudict.dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of trials:  10020\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  8800\n",
      "Number of trials:  880\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  880\n"
     ]
    }
   ],
   "source": [
    "train_loader, test_loader,_, loadedData = getDatasetLoaders_V3(DATASET_FULL_TRIALS_ZSCORE, 32, include_prego=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(train_loader))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load a pretrained neural encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "neural_encoder_model_weights_path = \".checkpoints/mfcc_sm_gru_ctc/best_model.ckpt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Resetting neural_dim based on channels\n",
      "neural_dim 256 256\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n",
      "/tmp/ipykernel_1308448/3666902854.py:36: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  neural_encoder.load_state_dict(torch.load(neural_encoder_model_weights_path)[\"state_dict\"])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "nInputFeatures = 256 #channels \n",
    "nClasses = 40 \n",
    "dropout = 0.4 \n",
    "hidden_dim = 1024\n",
    "nlayers = 5\n",
    "stride_len = 4\n",
    "kernel_len =32\n",
    "gaussian_smooth_width = 2\n",
    "bidirectional = True\n",
    "\n",
    "white_noise_SD = 0.8\n",
    "constant_offset_SD = 0.2\n",
    "seq_len = 150\n",
    "max_time_series_len = 12000\n",
    "\n",
    "lr_start = 3e-4\n",
    "lr_end = 0.02\n",
    "l2_decay = 1e-5\n",
    "\n",
    "\n",
    "neural_encoder = LightningGRUDecoder_MFCC_v3(\n",
    "            neural_dim=nInputFeatures,\n",
    "            n_classes=nClasses,\n",
    "            hidden_dim=hidden_dim,\n",
    "            layer_dim=nlayers,\n",
    "            strideLen=stride_len,\n",
    "            kernelLen=kernel_len,\n",
    "            gaussianSmoothWidth=gaussian_smooth_width,\n",
    "            bidirectional=bidirectional,\n",
    "            dropout=dropout,\n",
    "            white_noise_SD=white_noise_SD,\n",
    "            constant_offset_SD=constant_offset_SD,\n",
    "            weight_decay=l2_decay,\n",
    "            learning_rate=lr_start)\n",
    "\n",
    "neural_encoder.load_state_dict(torch.load(neural_encoder_model_weights_path)[\"state_dict\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/torch/nn/modules/rnn.py:123: UserWarning: dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "model = LightningPretrainedRNNTDecoder(\n",
    "    neural_dim=256,\n",
    "    n_classes=40,\n",
    "    hidden_dim=1024,\n",
    "    encoder = neural_encoder,\n",
    "    dropout = 0.2,\n",
    "    bidirectional=bidirectional,\n",
    "    learning_rate=1e-4,\n",
    "    white_noise_SD=white_noise_SD,\n",
    "    constant_offset_SD=constant_offset_SD,\n",
    "    weight_decay=1e-4,\n",
    "    smoothing=True,\n",
    "    day_transforms=True,\n",
    "    num_LM_layers=1,\n",
    "    prediction_type=\"rnn\",\n",
    "    freeze_encoder=False)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1024"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "neural_encoder.hidden_dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "output_name = \"rnnt_pretrained\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1308448/2866452482.py:30: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(\"optimization/RNNT_MFCC/easy-sweep-6/best_model.ckpt\")[\"state_dict\"])\n"
     ]
    }
   ],
   "source": [
    "TRAIN = False \n",
    "\n",
    "if TRAIN:\n",
    "    wandb_logger = WandbLogger(project=\"ECOG_Sentence_dataset\", name=f\"{output_name}\")\n",
    "    # Define ModelCheckpoint to save the best model based on validation loss\n",
    "    checkpoint_callback = ModelCheckpoint(\n",
    "        monitor=\"val_CER\",  # Ensure your validation step logs \"val_loss\"\n",
    "        mode=\"min\",          # Save the model with the lowest validation loss\n",
    "        save_top_k=1,        # Keep only the best model\n",
    "        dirpath=f\".checkpoints/{output_name}/\",  # Directory to save checkpoints\n",
    "        filename=\"best_model\",  # Model filename\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "    # Define EarlyStopping callback with patience of 3 epochs\n",
    "    early_stopping_callback = EarlyStopping(\n",
    "        monitor=\"val_CER\",\n",
    "        patience=10,   # Stop training if no improvement in 3 epochs\n",
    "        mode=\"min\",\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "\n",
    "    # Train model\n",
    "    trainer = pl.Trainer(max_epochs=100,devices =[2], callbacks=[checkpoint_callback, early_stopping_callback], logger=wandb_logger)\n",
    "\n",
    "    trainer.fit(model, train_loader, test_loader)\n",
    "\n",
    "else:\n",
    "    model.load_state_dict(torch.load(\"optimization/RNNT_MFCC/easy-sweep-6/best_model.ckpt\")[\"state_dict\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<bound method LightningGRUDecoder_MFCC_v3.get_neural_embedding of LightningGRUDecoder_MFCC_v3(\n",
       "  (inputLayerNonlinearity): Softsign()\n",
       "  (unfolder): Unfold(kernel_size=(32, 1), dilation=1, padding=0, stride=4)\n",
       "  (mfcc_unfolder): Unfold(kernel_size=(4, 1), dilation=1, padding=0, stride=4)\n",
       "  (gaussianSmoother): GaussianSmoothing()\n",
       "  (gru_decoder): GRU(8192, 1024, num_layers=5, batch_first=True, dropout=0.4, bidirectional=True)\n",
       "  (fc_decoder_out): Linear(in_features=2048, out_features=41, bias=True)\n",
       "  (mfcc_decoder): Linear(in_features=2048, out_features=56, bias=True)\n",
       "  (ctc_loss): CTCLoss()\n",
       "  (l1oss): L1Loss()\n",
       ")>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "neural_encoder.get_neural_embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trained\n"
     ]
    }
   ],
   "source": [
    "print(\"trained\")\n",
    "# break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LightningPretrainedRNNTDecoder(\n",
       "  (encoder): LightningGRUDecoder_MFCC_v3(\n",
       "    (inputLayerNonlinearity): Softsign()\n",
       "    (unfolder): Unfold(kernel_size=(32, 1), dilation=1, padding=0, stride=4)\n",
       "    (mfcc_unfolder): Unfold(kernel_size=(4, 1), dilation=1, padding=0, stride=4)\n",
       "    (gaussianSmoother): GaussianSmoothing()\n",
       "    (gru_decoder): GRU(8192, 1024, num_layers=5, batch_first=True, dropout=0.4, bidirectional=True)\n",
       "    (fc_decoder_out): Linear(in_features=2048, out_features=41, bias=True)\n",
       "    (mfcc_decoder): Linear(in_features=2048, out_features=56, bias=True)\n",
       "    (ctc_loss): CTCLoss()\n",
       "    (l1oss): L1Loss()\n",
       "  )\n",
       "  (inputLayerNonlinearity): Softsign()\n",
       "  (unfolder): Unfold(kernel_size=(32, 1), dilation=1, padding=0, stride=4)\n",
       "  (gaussianSmoother): GaussianSmoothing()\n",
       "  (mfcc_decoder): Linear(in_features=2048, out_features=56, bias=True)\n",
       "  (prediction_network): RNNPredictionNetwork(\n",
       "    (embed): Embedding(41, 1024)\n",
       "    (rnn): LSTM(1024, 1024, batch_first=True, dropout=0.2)\n",
       "    (proj): Linear(in_features=1024, out_features=1024, bias=True)\n",
       "  )\n",
       "  (joiner): RNNTJoiner(\n",
       "    (proj_enc): Linear(in_features=2048, out_features=1024, bias=True)\n",
       "    (fc): Linear(in_features=1024, out_features=41, bias=True)\n",
       "    (activation): ReLU()\n",
       "  )\n",
       "  (rnnt_loss): RNNTLoss()\n",
       ")"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = \"cuda:0\"\n",
    "best_model = model\n",
    "best_model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/32 [00:00<?, ?it/s]/data/XXXXXX/speech_decoding_BCI/augmentations.py:91: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at ../aten/src/ATen/native/Convolution.cpp:1036.)\n",
      "  return self.conv(input, weight=self.weight, groups=self.groups, padding=\"same\")\n",
      "100%|██████████| 32/32 [00:07<00:00,  4.23it/s]\n"
     ]
    }
   ],
   "source": [
    "batch = next(iter(test_loader))\n",
    "\n",
    "\n",
    "batch.keys()\n",
    "\n",
    "\n",
    "for i in tqdm.trange(len(batch[\"neural_feats\"])):\n",
    "    enc_out = model.infer(batch[\"neural_feats\"].to(device)[i:i+1], batch[\"day\"].to(device)[i:i+1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['DH', 'AH', 'SIL', 'S', 'AH', 'M', 'B', 'ER', 'L', 'IY', 'SIL', 'R', 'IH', 'S', 'AH', 'P', 'S', 'SIL', 'DH', 'AE', 'T', 'SIL', 'F', 'EY', 'V', 'ER', 'Z', 'SIL', 'B', 'AY', 'SIL', 'AH', 'M', 'EH', 'N', 'AH', 'T', 'SIL']\n",
      "['DH', 'AH', 'SIL', 'F', 'AE', 'M', 'AH', 'L', 'IY', 'SIL', 'R', 'IH', 'K', 'W', 'EH', 'S', 'T', 'S', 'SIL', 'DH', 'AE', 'T', 'SIL', 'F', 'L', 'AW', 'ER', 'Z', 'SIL', 'B', 'IY', 'SIL', 'OW', 'M', 'IH', 'T', 'AH', 'D', 'SIL']\n"
     ]
    }
   ],
   "source": [
    "print(idsToPhonemes(enc_out))\n",
    "print(idsToPhonemes(batch[\"phone_seq\"][-1].numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['neural_feats', 'phone_seq', 'neural_time_bins', 'phone_seq_len', 'day', 'sentence', 'audio_file', 'mfcc', 'go_onset', 'speech_label'])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_cer(preds, targets):\n",
    "    \"\"\"\n",
    "    Computes Character Error Rate (CER) using edit distance.\n",
    "    \"\"\"\n",
    "    total_edit_distance, total_chars = 0, 0\n",
    "    for pred, target in zip(preds, targets):\n",
    "        matcher = SequenceMatcher(a=target, b=pred)\n",
    "        total_edit_distance += matcher.distance()\n",
    "        total_chars += len(target)\n",
    "    return total_edit_distance / total_chars"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_cer_per_sample(preds, targets):\n",
    "    \"\"\"\n",
    "    Computes Character Error Rate (CER) using edit distance per sample.\n",
    "    \"\"\"\n",
    "    cer_sample = []\n",
    "    for pred, target in zip(preds, targets):\n",
    "        matcher = SequenceMatcher(a=target, b=pred)\n",
    "        cer_sample.append(matcher.distance() / len(target))\n",
    "    return cer_sample\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 28/28 [01:46<00:00,  3.79s/it]\n"
     ]
    },
    {
     "ename": "TypeError",
     "evalue": "unsupported format string passed to list.__format__",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[21], line 41\u001b[0m\n\u001b[1;32m     38\u001b[0m cer \u001b[38;5;241m=\u001b[39m compute_cer_per_sample(all_pred_phonemes, all_true_phonemes)\n\u001b[1;32m     40\u001b[0m \u001b[38;5;66;03m# Print final results\u001b[39;00m\n\u001b[0;32m---> 41\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCharacter Error Rate (CER): \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mcer\u001b[38;5;250m \u001b[39m\u001b[38;5;241m*\u001b[39m\u001b[38;5;250m \u001b[39m\u001b[38;5;241m100\u001b[39m\u001b[38;5;132;01m:\u001b[39;00m\u001b[38;5;124m.2f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m%\u001b[39m\u001b[38;5;124m\"\u001b[39m) \n",
      "\u001b[0;31mTypeError\u001b[0m: unsupported format string passed to list.__format__"
     ]
    }
   ],
   "source": [
    "device = \"cuda:0\"\n",
    "model.eval()\n",
    "model.to(device)\n",
    "\n",
    "# Run inference on test set\n",
    "all_pred_phonemes = []\n",
    "all_true_phonemes = []\n",
    "all_pred_texts = []\n",
    "all_true_texts = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm.tqdm(test_loader):\n",
    "        \n",
    "        X = batch[\"neural_feats\"]\n",
    "        y = batch[\"phone_seq\"]\n",
    "        days = batch[\"day\"]\n",
    "        X_len = batch[\"neural_time_bins\"]\n",
    "        y_len = batch[\"phone_seq_len\"]\n",
    "        transcriptions = batch[\"sentence\"] \n",
    "        \n",
    "        # Move data to device\n",
    "        X = X.to(device)\n",
    "        y = y.to(device)\n",
    "        days = days.to(device)\n",
    "        X_len = X_len.to(device)\n",
    "        y_len = y_len.to(device)\n",
    "\n",
    "\n",
    "        for i in range(len(X)):\n",
    "            enc_out = model.infer(X[i:i+1], days[i:i+1])\n",
    "            all_pred_phonemes.append(enc_out)\n",
    "            all_true_phonemes.append([y[i][:y_len[i]].cpu().numpy()])\n",
    "\n",
    "            all_true_texts.append(transcriptions[i])\n",
    "\n",
    "all_true_phonemes = [target[0].tolist() for target in all_true_phonemes]\n",
    "# Compute CER & WER\n",
    "cer = compute_cer_per_sample(all_pred_phonemes, all_true_phonemes)\n",
    "\n",
    "# Print final results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 28/28 [01:47<00:00,  3.84s/it]\n"
     ]
    }
   ],
   "source": [
    "## predit all teh test set \n",
    "pred_phonemes = []\n",
    "pred_logits = []\n",
    "true_phonemes = []\n",
    "true_sentences = []\n",
    "day_indices = []\n",
    "cer_list = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm.tqdm(test_loader):\n",
    "        X = batch[\"neural_feats\"]\n",
    "        y = batch[\"phone_seq\"]\n",
    "        X_len = batch[\"neural_time_bins\"]\n",
    "        y_len = batch[\"phone_seq_len\"]\n",
    "        days = batch[\"day\"]\n",
    "        transcriptions = batch[\"sentence\"]\n",
    "        \n",
    "        # Move data to device\n",
    "        X = X.to(device)\n",
    "        y = y.to(device)\n",
    "\n",
    "        days = days.to(device)\n",
    "        X_len = X_len.to(device)\n",
    "        y_len = y_len.to(device)\n",
    "\n",
    "\n",
    "        pred = []\n",
    "        for i in range(len(X)):\n",
    "            enc_out = model.infer(X[i:i+1], days[i:i+1])\n",
    "            pred.append(enc_out)\n",
    "        # pred = torch.cat(pred, dim=0)\n",
    "\n",
    "        # pred = torch.nn.functional.log_softmax(logits, dim=-1).cpu()\n",
    "        # decoded = decoder(pred)\n",
    "        pred_logits.append(pred)\n",
    "\n",
    "        total_edit_distance, total_seq_length = 0, 0\n",
    "\n",
    "        for i in range(len(pred)):\n",
    "            decodedSeq = torch.tensor(pred[i])\n",
    "            # Remove padding (if any)\n",
    "            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "\n",
    "            trueSeq = y[i][:y_len[i]].cpu().numpy()\n",
    "            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "            total_edit_distance += matcher.distance()\n",
    "            total_seq_length += len(trueSeq)\n",
    "\n",
    "            cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0\n",
    "            cer_list.append(cer)\n",
    "            \n",
    "        # pp = decode_ctc_output(pred)\n",
    "\n",
    "        pred_phonemes.extend(pred)\n",
    "        true_phonemes.extend([y[i][:y_len[i]].cpu().numpy() for i in range(len(y))])\n",
    "        # true_phonemes.extend(y.cpu().numpy())\n",
    "        true_sentences.extend(transcriptions)\n",
    "        day_indices.extend(days.cpu().numpy())\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Character Error Rate (CER): 20.25%\n"
     ]
    }
   ],
   "source": [
    "print(f\"Character Error Rate (CER): {np.mean(cer_list) * 100:.2f}%\") "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted Phonemes: ['D', 'UW', 'SIL', 'Y', 'UW', 'SIL', 'HH', 'AE', 'V', 'SIL', 'W', 'ER', 'SIL', 'B', 'AE', 'G', 'SIL']\n",
      "True Phonemes: ['D', 'UW', 'SIL', 'Y', 'UW', 'SIL', 'HH', 'AE', 'V', 'SIL', 'Y', 'AO', 'R', 'SIL', 'B', 'AE', 'G', 'SIL']\n",
      "True Sentence: Do you have your bag?\n"
     ]
    }
   ],
   "source": [
    "idx = 121\n",
    "print(\"Predicted Phonemes:\", idsToPhonemes(all_pred_phonemes[idx]))\n",
    "print(\"True Phonemes:\", idsToPhonemes(all_true_phonemes[idx]))\n",
    "print(\"True Sentence:\", all_true_texts[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(preds, targets):\n",
    "    \n",
    "\n",
    "    accs= []\n",
    "    for pred, target in zip(preds, targets):\n",
    "        \n",
    "        #truncate to the length of the shortest sequence\n",
    "        min_len = min(len(pred), len(target))\n",
    "\n",
    "\n",
    "        pred = pred[:min_len]\n",
    "        target = target[:min_len]\n",
    "\n",
    "        equal_inference = (pred == target)\n",
    "        acc = np.sum(equal_inference)/ len(pred)\n",
    "        accs.append(acc)\n",
    "\n",
    "    return np.mean(accs)\n",
    "   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "overall_acc 0.5616673276836803\n"
     ]
    }
   ],
   "source": [
    "overall_acc = compute_accuracy(pred_phonemes, true_phonemes)\n",
    "print(\"overall_acc\", overall_acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range of Accuracy per day 0.2039772518693356 0.6813619459950637\n"
     ]
    }
   ],
   "source": [
    "day_indices_flat = day_indices\n",
    "\n",
    "#compute accuracy per day by selecting indices of the same day\n",
    "day_accs = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    acc = compute_accuracy([pred_phonemes[idx] for idx in indices], [true_phonemes[idx] for idx in indices])\n",
    "    day_accs.append(acc)\n",
    "\n",
    "day_accs\n",
    "print(\"Range of Accuracy per day\", min(day_accs), max(day_accs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range of CER per day 0.13098882598752953 0.4271792961959867\n"
     ]
    }
   ],
   "source": [
    "cer_list_per_day = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    cer_list_per_day.append(np.array(cer_list)[indices].mean())\n",
    "\n",
    "cer_list_per_day\n",
    "\n",
    "print(\"Range of CER per day\", min(cer_list_per_day), max(cer_list_per_day))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average lenght diff: 0.634090909090909 +- 2.027699455699643\n"
     ]
    }
   ],
   "source": [
    "diffs = []\n",
    "for i in range(len(true_phonemes)):\n",
    "    true = true_phonemes[i]\n",
    "    pred = pred_phonemes[i]\n",
    "    diffs.append(np.array(len(true)) - np.array(len(pred)))\n",
    "\n",
    "print(f\"Average lenght diff: {np.mean(diffs)} +- {np.std(diffs)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Range of diff lenghts per day: -0.2 - 1.75\n"
     ]
    }
   ],
   "source": [
    "## compute diff lenghts per day\n",
    "diffs_per_day = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    diffs_per_day.append(np.array(diffs)[indices].mean())\n",
    "\n",
    "diffs_per_day\n",
    "print(f\"Range of diff lenghts per day: {min(diffs_per_day)} - {max(diffs_per_day)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "# output_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "results_dir = f\"results/{output_name}/\"\n",
    "os.makedirs(results_dir, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "#create a dataframe with the results\n",
    "df = pd.DataFrame({\n",
    "    'True Phonemes': [idsToPhonemes(p) for p in true_phonemes],\n",
    "    'Predicted Phonemes': [idsToPhonemes(p) for p in pred_phonemes],\n",
    "    'True Sentence': true_sentences,\n",
    "    'Day Index': day_indices_flat,\n",
    "    'CER': cer_list\n",
    "})\n",
    "\n",
    "#save it \n",
    "df.to_csv(os.path.join(results_dir, \"results.csv\"), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>True Phonemes</th>\n",
       "      <th>Predicted Phonemes</th>\n",
       "      <th>True Sentence</th>\n",
       "      <th>Day Index</th>\n",
       "      <th>CER</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[R, IH, CH, SIL, P, ER, CH, AH, S, T, SIL, S, ...</td>\n",
       "      <td>[G, EH, S, SIL, P, OY, SH, AH, N, S, SIL, EH, ...</td>\n",
       "      <td>Rich purchased several signed lithographs.</td>\n",
       "      <td>0</td>\n",
       "      <td>0.490566</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[TH, IY, AA, K, R, AH, S, IY, SIL, R, IY, K, A...</td>\n",
       "      <td>[DH, IH, S, W, EH, S, TH, SIL, R, IH, K, AH, N...</td>\n",
       "      <td>Theocracy reconsidered.</td>\n",
       "      <td>0</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>256</th>\n",
       "      <td>[DH, EY, SIL, D, OW, N, T, SIL, IY, V, IH, N, ...</td>\n",
       "      <td>[DH, EY, SIL, D, OW, N, T, SIL, EH, V, ER, IY,...</td>\n",
       "      <td>They don't even check my social security number.</td>\n",
       "      <td>8</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>64</th>\n",
       "      <td>[D, UW, SIL, Y, UW, SIL, HH, IY, R, SIL, DH, A...</td>\n",
       "      <td>[D, UW, SIL, Y, UW, SIL, HH, IY, R, SIL, DH, A...</td>\n",
       "      <td>Do you hear the sleigh bells ringing?</td>\n",
       "      <td>3</td>\n",
       "      <td>0.571429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>416</th>\n",
       "      <td>[F, AO, R, M, ER, SIL, EH, M, P, L, OY, ER, Z,...</td>\n",
       "      <td>[P, R, AH, M, EH, ZH, ER, SIL, B, IH, K, AO, Z...</td>\n",
       "      <td>Former employers.</td>\n",
       "      <td>12</td>\n",
       "      <td>0.785714</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         True Phonemes  \\\n",
       "1    [R, IH, CH, SIL, P, ER, CH, AH, S, T, SIL, S, ...   \n",
       "0    [TH, IY, AA, K, R, AH, S, IY, SIL, R, IY, K, A...   \n",
       "256  [DH, EY, SIL, D, OW, N, T, SIL, IY, V, IH, N, ...   \n",
       "64   [D, UW, SIL, Y, UW, SIL, HH, IY, R, SIL, DH, A...   \n",
       "416  [F, AO, R, M, ER, SIL, EH, M, P, L, OY, ER, Z,...   \n",
       "\n",
       "                                    Predicted Phonemes  \\\n",
       "1    [G, EH, S, SIL, P, OY, SH, AH, N, S, SIL, EH, ...   \n",
       "0    [DH, IH, S, W, EH, S, TH, SIL, R, IH, K, AH, N...   \n",
       "256  [DH, EY, SIL, D, OW, N, T, SIL, EH, V, ER, IY,...   \n",
       "64   [D, UW, SIL, Y, UW, SIL, HH, IY, R, SIL, DH, A...   \n",
       "416  [P, R, AH, M, EH, ZH, ER, SIL, B, IH, K, AO, Z...   \n",
       "\n",
       "                                        True Sentence  Day Index       CER  \n",
       "1          Rich purchased several signed lithographs.          0  0.490566  \n",
       "0                             Theocracy reconsidered.          0  0.500000  \n",
       "256  They don't even check my social security number.          8  0.500000  \n",
       "64              Do you hear the sleigh bells ringing?          3  0.571429  \n",
       "416                                 Former employers.         12  0.785714  "
      ]
     },
     "execution_count": 57,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df.sort_values(by=[\"CER\"], ascending=True).iloc[-5:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [],
   "source": [
    "#create a dataframe with all the metrics\n",
    "df_metrics = pd.DataFrame({\n",
    "    'Overall Accuracy': [overall_acc],\n",
    "    'Range of Accuracy per day': f\"{[min(day_accs), max(day_accs)]}\",\n",
    "    'Range of CER per day': f\"{[min(cer_list_per_day), max(cer_list_per_day)]}\",\n",
    "    'Average length diff': f\"{[np.mean(diffs), np.std(diffs)]}\",\n",
    "    'Range of length diff per day': f\"{[min(diffs_per_day), max(diffs_per_day)]}\",\n",
    "    'average CER': [np.mean(cer_list)],\n",
    "})\n",
    "df_metrics.to_csv(os.path.join(results_dir,\"metrics.csv\"), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "# pred_logits[0][0].__len__()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "## save pred_logits\n",
    "with open(os.path.join(results_dir, \"pred_logits.pkl\"), \"wb\") as f:\n",
    "    pickle.dump(pred_logits, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Old stuff\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def split_phoneme_sequence(sequence, sil_token='SIL'):\n",
    "    \"\"\"\n",
    "    Splits a phoneme sequence into potential words based on the 'SIL' token.\n",
    "    \n",
    "    Args:\n",
    "        sequence (list): List of phoneme tokens.\n",
    "        sil_token (str): Token representing silence (used as word boundary).\n",
    "    \n",
    "    Returns:\n",
    "        list: List of phoneme sequences representing words.\n",
    "    \"\"\"\n",
    "    # Initialize the list to hold split sequences\n",
    "    words = []\n",
    "    current_word = []\n",
    "    \n",
    "    # Iterate through the phoneme sequence\n",
    "    for phoneme in sequence:\n",
    "        if phoneme == sil_token:\n",
    "            if current_word:\n",
    "                words.append(current_word)\n",
    "                current_word = []\n",
    "        else:\n",
    "            current_word.append(phoneme)\n",
    "    \n",
    "    # Append the last word if the sequence doesn't end with SIL\n",
    "    if current_word:\n",
    "        words.append(current_word)\n",
    "    \n",
    "    return words\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Try with direct CMU word decoding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[['R', 'IY', 'K', 'AH', 'N', 'S', 'IH', 'D', 'ER', 'D']]\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def adapt_cmu_dict(cmu_dict):\n",
    "    \"\"\"\n",
    "    Adapts the CMU Pronouncing Dictionary to match the reduced phoneme set\n",
    "    by removing stress markers.\n",
    "    \n",
    "    Returns:\n",
    "        dict: A dictionary mapping words to simplified phoneme sequences.\n",
    "    \"\"\"\n",
    "    adapted_dict = {}\n",
    "    \n",
    "    for word, phoneme_lists in cmu_dict.items():\n",
    "        simplified_phonemes = []\n",
    "        for phoneme_seq in phoneme_lists:\n",
    "            # Remove stress markers (e.g., AH0 -> AH)\n",
    "            simplified_seq = [re.sub(r'[0-9]', '', p) for p in phoneme_seq]\n",
    "            simplified_phonemes.append(simplified_seq)\n",
    "        \n",
    "        # Store the simplified sequences\n",
    "        adapted_dict[word.lower()] = simplified_phonemes\n",
    "    \n",
    "    return adapted_dict\n",
    "\n",
    "# Adapted CMU dictionary\n",
    "adapted_cmu = adapt_cmu_dict(cmu_dict)\n",
    "\n",
    "# Example: Check how 'reconsidered' looks after adaptation\n",
    "print(adapted_cmu.get('reconsidered'))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_closest_word(decoded_phonemes, adapted_cmu, max_distance=3):\n",
    "    \"\"\"\n",
    "    Finds the closest word from the adapted CMU dictionary using Levenshtein distance.\n",
    "    \n",
    "    Args:\n",
    "        decoded_phonemes (list): Noisy phoneme sequence.\n",
    "        adapted_cmu (dict): Simplified CMU dictionary.\n",
    "        max_distance (int): Maximum allowed phoneme distance for a match.\n",
    "    \n",
    "    Returns:\n",
    "        str: Closest matching word or None if no match found.\n",
    "    \"\"\"\n",
    "    closest_word = \"<UNK>\"\n",
    "    min_distance = float('inf')\n",
    "    \n",
    "    for word, phoneme_lists in adapted_cmu.items():\n",
    "        for phoneme_seq in phoneme_lists:\n",
    "            dist = levenshtein_distance(' '.join(decoded_phonemes), ' '.join(phoneme_seq))\n",
    "            if dist < min_distance and dist <= max_distance:\n",
    "                min_distance = dist\n",
    "                closest_word = word\n",
    "    \n",
    "    return closest_word\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 880/880 [01:25<00:00, 10.24it/s]\n"
     ]
    }
   ],
   "source": [
    "all_pred_texts_fast = []\n",
    "\n",
    "for idx in tqdm.trange(len(all_pred_phonemes)):\n",
    "    phoneme_seq = idsToPhonemes(all_pred_phonemes[idx])\n",
    "    phoneme_seq_splitted = split_phoneme_sequence(phoneme_seq)\n",
    "    sentence = []\n",
    "    for word_phonemes in phoneme_seq_splitted:\n",
    "        try:\n",
    "            word = [word for word, pron in adapted_cmu.items() if word_phonemes in pron][0]\n",
    "        except IndexError:\n",
    "            word = find_closest_word(word_phonemes, adapted_cmu)\n",
    "\n",
    "        sentence.append(word)\n",
    "    all_pred_texts_fast.append(\" \".join(sentence))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Function to compute Word Error Rate (WER)\n",
    "def compute_wer(pred_texts, ref_texts):\n",
    "    \"\"\"\n",
    "    Computes Word Error Rate (WER) using jiwer.\n",
    "    \"\"\"\n",
    "    wer = jiwer.wer(ref_texts, pred_texts)\n",
    "    return wer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WER fast: 0.6966721222040371\n"
     ]
    }
   ],
   "source": [
    "fast_wer = compute_wer(all_pred_texts_fast, all_true_texts)\n",
    "print(\"WER fast:\" , fast_wer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WER: 0.6967\n",
      "BLEU: 10.2671\n",
      "ROUGE-1: 0.4620\n",
      "ROUGE-2: 0.2405\n",
      "ROUGE-L: 0.4610\n",
      "METEOR: 0.3062\n",
      "BERTScore_Precision: 0.0666\n",
      "BERTScore_Recall: 0.3210\n",
      "BERTScore_F1: 0.1910\n"
     ]
    }
   ],
   "source": [
    "import jiwer  # For WER\n",
    "import sacrebleu  # For BLEU\n",
    "from rouge_score import rouge_scorer  # For ROUGE\n",
    "from nltk.translate.meteor_score import meteor_score  # For METEOR\n",
    "import bert_score  # For BERTScore\n",
    "import numpy as np\n",
    "\n",
    "def compute_metrics(text_transcriptions, gpt_decoded):\n",
    "    \"\"\"\n",
    "    Compute various NLP evaluation metrics for text generation.\n",
    "\n",
    "    Args:\n",
    "        text_transcriptions (list): List of ground-truth reference sentences.\n",
    "        gpt_decoded (list): List of model-generated sentences.\n",
    "\n",
    "    Returns:\n",
    "        dict: Dictionary containing all computed metrics.\n",
    "    \"\"\"\n",
    "\n",
    "    results = {}\n",
    "\n",
    "    # WER (Word Error Rate)\n",
    "    wer = jiwer.wer(text_transcriptions, gpt_decoded)\n",
    "    results[\"WER\"] = wer\n",
    "\n",
    "    # BLEU Score\n",
    "    bleu = sacrebleu.corpus_bleu(gpt_decoded, [text_transcriptions]).score\n",
    "    results[\"BLEU\"] = bleu\n",
    "\n",
    "    # ROUGE Scores\n",
    "    rouge = rouge_scorer.RougeScorer([\"rouge1\", \"rouge2\", \"rougeL\"], use_stemmer=True)\n",
    "    rouge_scores = [rouge.score(ref, pred) for ref, pred in zip(text_transcriptions, gpt_decoded)]\n",
    "    results[\"ROUGE-1\"] = np.mean([score[\"rouge1\"].fmeasure for score in rouge_scores])\n",
    "    results[\"ROUGE-2\"] = np.mean([score[\"rouge2\"].fmeasure for score in rouge_scores])\n",
    "    results[\"ROUGE-L\"] = np.mean([score[\"rougeL\"].fmeasure for score in rouge_scores])\n",
    "\n",
    "    ##METEOR\n",
    "    tokenized_references = [ref.split() for ref in text_transcriptions]  # Tokenize reference sentences\n",
    "    tokenized_hypotheses = [pred.split() for pred in gpt_decoded]  # Tokenize predicted sentences\n",
    "\n",
    "    meteor_scores = [meteor_score([ref], pred) for ref, pred in zip(tokenized_references, tokenized_hypotheses)]\n",
    "    results[\"METEOR\"] = np.mean(meteor_scores)\n",
    "    # BERTScore (Semantic Similarity)\n",
    "    P, R, F1 = bert_score.score(gpt_decoded, text_transcriptions, lang=\"en\", rescale_with_baseline=True)\n",
    "    results[\"BERTScore_Precision\"] = P.mean().item()\n",
    "    results[\"BERTScore_Recall\"] = R.mean().item()\n",
    "    results[\"BERTScore_F1\"] = F1.mean().item()\n",
    "\n",
    "    ## save also all values without recomputing when possible\n",
    "    results[\"METEOR_scores\"] = meteor_scores\n",
    "    results[\"ROUGE_scores\"] = rouge_scores\n",
    "\n",
    "    results[\"WER_scores\"] = [jiwer.wer([ref], [pred]) for ref, pred in zip(text_transcriptions, gpt_decoded)]\n",
    "    results[\"BERTScore_F1_scores\"] = F1.cpu().numpy().tolist()\n",
    "    return results\n",
    "\n",
    "metrics = compute_metrics(all_true_texts, all_pred_texts_fast)\n",
    "for metric, score in metrics.items():\n",
    "    if \"scores\" not in metric:\n",
    "        print(f\"{metric}: {score:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create a dataframe with the results, with target and pred phonemes and sentences\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "#map ids to phonenems\n",
    "all_true_phonemes_mapped = [idsToPhonemes(target) for target in all_true_phonemes]\n",
    "all_pred_phonemes_mapped = [idsToPhonemes(pred) for pred in all_pred_phonemes]\n",
    "\n",
    "results_df = pd.DataFrame({\n",
    "    \"target_phonemes\": all_true_phonemes_mapped,\n",
    "    \"pred_phonemes\": all_pred_phonemes_mapped,\n",
    "    \"target_sentence\": all_true_texts,\n",
    "    \"pred_sentence\": all_pred_texts_fast,\n",
    "})\n",
    "\n",
    "results_df[\"WER_scores\"] = metrics[\"WER_scores\"]\n",
    "results_df[\"METEOR_scores\"] = metrics[\"METEOR_scores\"]\n",
    "results_df[\"ROUGE_scores\"] = metrics[\"ROUGE_scores\"]\n",
    "results_df[\"BERTScore_F1_scores\"] = metrics[\"BERTScore_F1_scores\"]\n",
    "results_df[\"CER\"] = cer_per_sample\n",
    "\n",
    "\n",
    "overall_metrics = {k:v for k,v in metrics.items() if \"scores\" not in k}\n",
    "\n",
    "metrics_df = pd.DataFrame(overall_metrics, index=[0])\n",
    "\n",
    "results_df.to_csv(\"results/RNNT_results_cmu.csv\", index=False)\n",
    "metrics_df.to_csv(\"results/RNNT_metrics_cmu.csv\", index=False)\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>target_phonemes</th>\n",
       "      <th>pred_phonemes</th>\n",
       "      <th>target_sentence</th>\n",
       "      <th>pred_sentence</th>\n",
       "      <th>WER_scores</th>\n",
       "      <th>METEOR_scores</th>\n",
       "      <th>ROUGE_scores</th>\n",
       "      <th>BERTScore_F1_scores</th>\n",
       "      <th>CER</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>843</th>\n",
       "      <td>[Y, UW, SIL, K, UH, D, SIL, D, UW, SIL, IH, T,...</td>\n",
       "      <td>[Y, UW, SIL, K, UH, D, SIL, D, UW, SIL, IH, T,...</td>\n",
       "      <td>You could do it.</td>\n",
       "      <td>ewe could deux it</td>\n",
       "      <td>0.750000</td>\n",
       "      <td>0.125000</td>\n",
       "      <td>{'rouge1': (0.5, 0.5, 0.5), 'rouge2': (0.0, 0....</td>\n",
       "      <td>0.259605</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>835</th>\n",
       "      <td>[W, EH, L, SIL, DH, AE, T, S, SIL, AH, B, AW, ...</td>\n",
       "      <td>[W, EH, L, SIL, DH, AE, T, S, SIL, AH, B, AW, ...</td>\n",
       "      <td>Well that's about it really.</td>\n",
       "      <td>well that's about it really</td>\n",
       "      <td>0.400000</td>\n",
       "      <td>0.793750</td>\n",
       "      <td>{'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....</td>\n",
       "      <td>0.917149</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>446</th>\n",
       "      <td>[DH, AH, SIL, IH, K, S, P, IH, R, IY, AH, N, S...</td>\n",
       "      <td>[DH, AH, SIL, IH, K, S, P, IH, R, IY, AH, N, S...</td>\n",
       "      <td>The experience.</td>\n",
       "      <td>the experience</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.250000</td>\n",
       "      <td>{'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....</td>\n",
       "      <td>0.871264</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>449</th>\n",
       "      <td>[AY, SIL, TH, IH, NG, K, SIL, IH, T, SIL, IH, ...</td>\n",
       "      <td>[AY, SIL, TH, IH, NG, K, SIL, IH, T, SIL, IH, ...</td>\n",
       "      <td>I think it is getting worse.</td>\n",
       "      <td>ai think it is getting worse</td>\n",
       "      <td>0.333333</td>\n",
       "      <td>0.661458</td>\n",
       "      <td>{'rouge1': (0.8333333333333334, 0.833333333333...</td>\n",
       "      <td>0.780689</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>80</th>\n",
       "      <td>[IH, T, S, SIL, R, IH, L, IY, SIL, HH, AA, R, ...</td>\n",
       "      <td>[IH, T, S, SIL, R, IH, L, IY, SIL, HH, AA, R, ...</td>\n",
       "      <td>It's really hard to find something that works.</td>\n",
       "      <td>it's really hard tew find something that work's</td>\n",
       "      <td>0.375000</td>\n",
       "      <td>0.736111</td>\n",
       "      <td>{'rouge1': (0.8, 0.8888888888888888, 0.8421052...</td>\n",
       "      <td>0.684343</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>366</th>\n",
       "      <td>[V, EH, R, IY, SIL, W, EH, L, SIL, P, ER, S, W...</td>\n",
       "      <td>[W, IY, SIL, W, IH, L, SIL, S, P, AO, R, T, ER...</td>\n",
       "      <td>Very well persuaded.</td>\n",
       "      <td>oui we'll snorters just deux deux</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-0.077393</td>\n",
       "      <td>0.944444</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>342</th>\n",
       "      <td>[Y, UW, SIL, K, AE, N, T, SIL, B, L, EY, M, SI...</td>\n",
       "      <td>[Y, UW, SIL, K, AE, N, SIL, B, IH, L, D, SIL, ...</td>\n",
       "      <td>You can't blame them.</td>\n",
       "      <td>ewe caen bild maerz deux c deux c</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-0.272716</td>\n",
       "      <td>1.058824</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>369</th>\n",
       "      <td>[AY, M, SIL, W, EH, R, IH, NG, SIL, SH, AO, R,...</td>\n",
       "      <td>[S, AH, M, SIL, R, IY, Z, AH, N, SIL, CH, OY, ...</td>\n",
       "      <td>I'm wearing shorts.</td>\n",
       "      <td>some reason choice just deux</td>\n",
       "      <td>1.666667</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>0.039787</td>\n",
       "      <td>1.133333</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>332</th>\n",
       "      <td>[N, AA, T, SIL, L, EY, T, L, IY, SIL, DH, OW, ...</td>\n",
       "      <td>[N, AA, T, SIL, TH, R, IY, SIL, DH, OW, SIL, D...</td>\n",
       "      <td>Not lately though.</td>\n",
       "      <td>knot three tho deux deux deux deux</td>\n",
       "      <td>2.333333</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-0.137574</td>\n",
       "      <td>1.230769</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>376</th>\n",
       "      <td>[AY, SIL, HH, AE, V, SIL, S, AH, M, SIL, G, EY...</td>\n",
       "      <td>[AY, SIL, HH, AE, V, SIL, S, AH, M, SIL, T, AY...</td>\n",
       "      <td>I have some games.</td>\n",
       "      <td>ai halve some time's just deux just did</td>\n",
       "      <td>1.750000</td>\n",
       "      <td>0.113636</td>\n",
       "      <td>{'rouge1': (0.1111111111111111, 0.25, 0.153846...</td>\n",
       "      <td>0.125846</td>\n",
       "      <td>1.266667</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>880 rows × 9 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       target_phonemes  \\\n",
       "843  [Y, UW, SIL, K, UH, D, SIL, D, UW, SIL, IH, T,...   \n",
       "835  [W, EH, L, SIL, DH, AE, T, S, SIL, AH, B, AW, ...   \n",
       "446  [DH, AH, SIL, IH, K, S, P, IH, R, IY, AH, N, S...   \n",
       "449  [AY, SIL, TH, IH, NG, K, SIL, IH, T, SIL, IH, ...   \n",
       "80   [IH, T, S, SIL, R, IH, L, IY, SIL, HH, AA, R, ...   \n",
       "..                                                 ...   \n",
       "366  [V, EH, R, IY, SIL, W, EH, L, SIL, P, ER, S, W...   \n",
       "342  [Y, UW, SIL, K, AE, N, T, SIL, B, L, EY, M, SI...   \n",
       "369  [AY, M, SIL, W, EH, R, IH, NG, SIL, SH, AO, R,...   \n",
       "332  [N, AA, T, SIL, L, EY, T, L, IY, SIL, DH, OW, ...   \n",
       "376  [AY, SIL, HH, AE, V, SIL, S, AH, M, SIL, G, EY...   \n",
       "\n",
       "                                         pred_phonemes  \\\n",
       "843  [Y, UW, SIL, K, UH, D, SIL, D, UW, SIL, IH, T,...   \n",
       "835  [W, EH, L, SIL, DH, AE, T, S, SIL, AH, B, AW, ...   \n",
       "446  [DH, AH, SIL, IH, K, S, P, IH, R, IY, AH, N, S...   \n",
       "449  [AY, SIL, TH, IH, NG, K, SIL, IH, T, SIL, IH, ...   \n",
       "80   [IH, T, S, SIL, R, IH, L, IY, SIL, HH, AA, R, ...   \n",
       "..                                                 ...   \n",
       "366  [W, IY, SIL, W, IH, L, SIL, S, P, AO, R, T, ER...   \n",
       "342  [Y, UW, SIL, K, AE, N, SIL, B, IH, L, D, SIL, ...   \n",
       "369  [S, AH, M, SIL, R, IY, Z, AH, N, SIL, CH, OY, ...   \n",
       "332  [N, AA, T, SIL, TH, R, IY, SIL, DH, OW, SIL, D...   \n",
       "376  [AY, SIL, HH, AE, V, SIL, S, AH, M, SIL, T, AY...   \n",
       "\n",
       "                                    target_sentence  \\\n",
       "843                                You could do it.   \n",
       "835                    Well that's about it really.   \n",
       "446                                 The experience.   \n",
       "449                    I think it is getting worse.   \n",
       "80   It's really hard to find something that works.   \n",
       "..                                              ...   \n",
       "366                            Very well persuaded.   \n",
       "342                           You can't blame them.   \n",
       "369                             I'm wearing shorts.   \n",
       "332                              Not lately though.   \n",
       "376                              I have some games.   \n",
       "\n",
       "                                       pred_sentence  WER_scores  \\\n",
       "843                                ewe could deux it    0.750000   \n",
       "835                      well that's about it really    0.400000   \n",
       "446                                   the experience    1.000000   \n",
       "449                     ai think it is getting worse    0.333333   \n",
       "80   it's really hard tew find something that work's    0.375000   \n",
       "..                                               ...         ...   \n",
       "366                oui we'll snorters just deux deux    2.000000   \n",
       "342                ewe caen bild maerz deux c deux c    2.000000   \n",
       "369                     some reason choice just deux    1.666667   \n",
       "332               knot three tho deux deux deux deux    2.333333   \n",
       "376          ai halve some time's just deux just did    1.750000   \n",
       "\n",
       "     METEOR_scores                                       ROUGE_scores  \\\n",
       "843       0.125000  {'rouge1': (0.5, 0.5, 0.5), 'rouge2': (0.0, 0....   \n",
       "835       0.793750  {'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....   \n",
       "446       0.250000  {'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....   \n",
       "449       0.661458  {'rouge1': (0.8333333333333334, 0.833333333333...   \n",
       "80        0.736111  {'rouge1': (0.8, 0.8888888888888888, 0.8421052...   \n",
       "..             ...                                                ...   \n",
       "366       0.000000  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....   \n",
       "342       0.000000  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....   \n",
       "369       0.000000  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....   \n",
       "332       0.000000  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....   \n",
       "376       0.113636  {'rouge1': (0.1111111111111111, 0.25, 0.153846...   \n",
       "\n",
       "     BERTScore_F1_scores       CER  \n",
       "843             0.259605  0.000000  \n",
       "835             0.917149  0.000000  \n",
       "446             0.871264  0.000000  \n",
       "449             0.780689  0.000000  \n",
       "80              0.684343  0.000000  \n",
       "..                   ...       ...  \n",
       "366            -0.077393  0.944444  \n",
       "342            -0.272716  1.058824  \n",
       "369             0.039787  1.133333  \n",
       "332            -0.137574  1.230769  \n",
       "376             0.125846  1.266667  \n",
       "\n",
       "[880 rows x 9 columns]"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Try with GPT4 Api"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reconstructed Sentence: Hello world.\n"
     ]
    }
   ],
   "source": [
    "from openai import OpenAI\n",
    "from config import OPENAI_API_KEY\n",
    "\n",
    "# Set your OpenAI API key here\n",
    "os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY\n",
    "client = OpenAI(\n",
    "  api_key=os.environ['OPENAI_API_KEY'],  # this is also the default, it can be omitted\n",
    ")\n",
    "\n",
    "def phonemes_to_sentence(phoneme_sequence):\n",
    "    \"\"\"\n",
    "    Uses OpenAI's GPT API to reconstruct a sentence from a sequence of phonemes.\n",
    "    \n",
    "    Args:\n",
    "        phoneme_sequence (list): List of phonemes, with 'SIL' representing silence.\n",
    "    \n",
    "    Returns:\n",
    "        str: Reconstructed sentence.\n",
    "    \"\"\"\n",
    "    \n",
    "    # Prepare the phoneme sequence for the prompt\n",
    "    phoneme_string = ' '.join(phoneme_sequence)\n",
    "    \n",
    "    # Construct the system and user messages for the chat completion\n",
    "    messages = [\n",
    "        {\"role\": \"system\", \"content\": \"You are an expert in speech recognition and language modeling.\"},\n",
    "        {\n",
    "            \"role\": \"user\",\n",
    "            \"content\": (\n",
    "                \"Given the following sequence of phonemes, reconstruct the most probable sentence. \"\n",
    "                \"The phoneme sequence may contain some noise or errors, and 'SIL' represents silence between words. \"\n",
    "                \"Use language statistics and context to infer the most natural sentence that makes sense.\\n\\n\"\n",
    "                f\"Phoneme sequence: {phoneme_string}\\n\\n\"\n",
    "                \"Reconstructed sentence:\"\n",
    "            )\n",
    "        }\n",
    "    ]\n",
    "\n",
    "    # Call the OpenAI Chat Completion API\n",
    "    completion = client.chat.completions.create(\n",
    "        model=\"gpt-4\",  # Use 'gpt-3.5-turbo' if needed\n",
    "        messages=messages,\n",
    "        max_tokens=100,\n",
    "        temperature=0.7,  # Add creativity to handle noisy phonemes\n",
    "        n=1  # Number of responses\n",
    "    )\n",
    "\n",
    "    # Extract the reconstructed sentence\n",
    "    reconstructed_sentence = completion.choices[0].message.content.strip()\n",
    "    \n",
    "    # Print token usage info (optional)\n",
    "    # print(dict(completion).get('usage'))\n",
    "    \n",
    "    # Optional: Pretty print the full API response for debugging\n",
    "    # print(completion.model_dump_json(indent=2))\n",
    "\n",
    "    return reconstructed_sentence\n",
    "\n",
    "\n",
    "# Example usage\n",
    "phoneme_sequence = ['HH', 'AH', 'L', 'OW', 'SIL', 'W', 'ER', 'L', 'D']  # Example phonemes for \"Hello world\"\n",
    "reconstructed_sentence = phonemes_to_sentence(phoneme_sequence)\n",
    "\n",
    "print(\"Reconstructed Sentence:\", reconstructed_sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 880/880 [18:47<00:00,  1.28s/it]\n"
     ]
    }
   ],
   "source": [
    "gpt4_decoded = []\n",
    "for phoneme_seq in tqdm.tqdm(all_pred_phonemes):\n",
    "    sentence = phonemes_to_sentence(idsToPhonemes(phoneme_seq))\n",
    "    gpt4_decoded.append(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WER: 0.5679\n",
      "BLEU: 32.8568\n",
      "ROUGE-1: 0.6579\n",
      "ROUGE-2: 0.4983\n",
      "ROUGE-L: 0.6572\n",
      "METEOR: 0.4352\n",
      "BERTScore_Precision: 0.3651\n",
      "BERTScore_Recall: 0.4738\n",
      "BERTScore_F1: 0.4193\n"
     ]
    }
   ],
   "source": [
    "import jiwer  # For WER\n",
    "import sacrebleu  # For BLEU\n",
    "from rouge_score import rouge_scorer  # For ROUGE\n",
    "from nltk.translate.meteor_score import meteor_score  # For METEOR\n",
    "import bert_score  # For BERTScore\n",
    "import numpy as np\n",
    "\n",
    "def compute_metrics(text_transcriptions, gpt_decoded):\n",
    "    \"\"\"\n",
    "    Compute various NLP evaluation metrics for text generation.\n",
    "\n",
    "    Args:\n",
    "        text_transcriptions (list): List of ground-truth reference sentences.\n",
    "        gpt_decoded (list): List of model-generated sentences.\n",
    "\n",
    "    Returns:\n",
    "        dict: Dictionary containing all computed metrics.\n",
    "    \"\"\"\n",
    "\n",
    "    results = {}\n",
    "\n",
    "    # WER (Word Error Rate)\n",
    "    wer = jiwer.wer(text_transcriptions, gpt_decoded)\n",
    "    results[\"WER\"] = wer\n",
    "\n",
    "    # BLEU Score\n",
    "    bleu = sacrebleu.corpus_bleu(gpt_decoded, [text_transcriptions]).score\n",
    "    results[\"BLEU\"] = bleu\n",
    "\n",
    "    # ROUGE Scores\n",
    "    rouge = rouge_scorer.RougeScorer([\"rouge1\", \"rouge2\", \"rougeL\"], use_stemmer=True)\n",
    "    rouge_scores = [rouge.score(ref, pred) for ref, pred in zip(text_transcriptions, gpt_decoded)]\n",
    "    results[\"ROUGE-1\"] = np.mean([score[\"rouge1\"].fmeasure for score in rouge_scores])\n",
    "    results[\"ROUGE-2\"] = np.mean([score[\"rouge2\"].fmeasure for score in rouge_scores])\n",
    "    results[\"ROUGE-L\"] = np.mean([score[\"rougeL\"].fmeasure for score in rouge_scores])\n",
    "\n",
    "    ##METEOR\n",
    "    tokenized_references = [ref.split() for ref in text_transcriptions]  # Tokenize reference sentences\n",
    "    tokenized_hypotheses = [pred.split() for pred in gpt_decoded]  # Tokenize predicted sentences\n",
    "\n",
    "    meteor_scores = [meteor_score([ref], pred) for ref, pred in zip(tokenized_references, tokenized_hypotheses)]\n",
    "    results[\"METEOR\"] = np.mean(meteor_scores)\n",
    "    # BERTScore (Semantic Similarity)\n",
    "    P, R, F1 = bert_score.score(gpt_decoded, text_transcriptions, lang=\"en\", rescale_with_baseline=True)\n",
    "    results[\"BERTScore_Precision\"] = P.mean().item()\n",
    "    results[\"BERTScore_Recall\"] = R.mean().item()\n",
    "    results[\"BERTScore_F1\"] = F1.mean().item()\n",
    "\n",
    "    ## save also all values without recomputing when possible\n",
    "    results[\"METEOR_scores\"] = meteor_scores\n",
    "    results[\"ROUGE_scores\"] = rouge_scores\n",
    "\n",
    "    results[\"WER_scores\"] = [jiwer.wer([ref], [pred]) for ref, pred in zip(text_transcriptions, gpt_decoded)]\n",
    "    results[\"BERTScore_F1_scores\"] = F1.cpu().numpy().tolist()\n",
    "    return results\n",
    "\n",
    "metrics_gpt4 = compute_metrics(all_true_texts, gpt4_decoded)\n",
    "for metric, score in metrics_gpt4.items():\n",
    "    if \"scores\" not in metric:\n",
    "        print(f\"{metric}: {score:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## create a dataframe with the results, with target and pred phonemes and sentences\n",
    "import pandas as pd\n",
    "\n",
    "\n",
    "results_df_gpt4 = pd.DataFrame({\n",
    "    \"target_phonemes\": all_true_phonemes_mapped,\n",
    "    \"pred_phonemes\": all_pred_phonemes_mapped,\n",
    "    \"target_sentence\": all_true_texts,\n",
    "    \"pred_sentence\": gpt4_decoded,\n",
    "})\n",
    "\n",
    "results_df_gpt4[\"WER_scores\"] = metrics[\"WER_scores\"]\n",
    "results_df_gpt4[\"METEOR_scores\"] = metrics[\"METEOR_scores\"]\n",
    "results_df_gpt4[\"ROUGE_scores\"] = metrics[\"ROUGE_scores\"]\n",
    "results_df_gpt4[\"BERTScore_F1_scores\"] = metrics[\"BERTScore_F1_scores\"]\n",
    "results_df_gpt4[\"CER\"] = cer_per_sample\n",
    "\n",
    "\n",
    "overall_metrics = {k:v for k,v in metrics.items() if \"scores\" not in k}\n",
    "\n",
    "metrics_df_gpt4 = pd.DataFrame(overall_metrics, index=[0])\n",
    "\n",
    "results_df_gpt4.to_csv(\"results/RNNT_results_gpt4.csv\", index=False)\n",
    "metrics_df_gpt4.to_csv(\"results/RNNT_metrics_gpt4.csv\", index=False)\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>target_phonemes</th>\n",
       "      <th>pred_phonemes</th>\n",
       "      <th>target_sentence</th>\n",
       "      <th>pred_sentence</th>\n",
       "      <th>WER_scores</th>\n",
       "      <th>METEOR_scores</th>\n",
       "      <th>ROUGE_scores</th>\n",
       "      <th>BERTScore_F1_scores</th>\n",
       "      <th>CER</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[TH, IY, AA, K, R, AH, S, IY, SIL, R, IY, K, A...</td>\n",
       "      <td>[DH, EH, R, K, S, ER, Z, SIL, R, IY, K, AH, N,...</td>\n",
       "      <td>Theocracy reconsidered.</td>\n",
       "      <td>The exercise is reconsidered.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.5, 0.5, 0.5), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-0.081067</td>\n",
       "      <td>0.400000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[R, IH, CH, SIL, P, ER, CH, AH, S, T, SIL, S, ...</td>\n",
       "      <td>[G, EH, S, SIL, P, ER, CH, AH, N, S, SIL, EH, ...</td>\n",
       "      <td>Rich purchased several signed lithographs.</td>\n",
       "      <td>\"Guess perchance ever sound list off.\"</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>0.076523</td>\n",
       "      <td>0.484848</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[S, OW, SIL, R, UW, L, Z, SIL, W, IY, SIL, M, ...</td>\n",
       "      <td>[S, OW, SIL, R, IH, L, Z, SIL, M, IY, SIL, M, ...</td>\n",
       "      <td>So rules we made, in unabashed collusion.</td>\n",
       "      <td>\"So, reels me made in an adverse conclusion.\"</td>\n",
       "      <td>0.857143</td>\n",
       "      <td>0.071429</td>\n",
       "      <td>{'rouge1': (0.375, 0.42857142857142855, 0.3999...</td>\n",
       "      <td>0.184476</td>\n",
       "      <td>0.147059</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[L, AO, R, IY, Z, SIL, K, AA, S, T, UW, M, SIL...</td>\n",
       "      <td>[T, R, OY, IH, NG, SIL, K, AW, M, SIL, N, IY, ...</td>\n",
       "      <td>Lori's costume needed black gloves to be compl...</td>\n",
       "      <td>\"Trying calm needed black lives so be humbly t...</td>\n",
       "      <td>0.888889</td>\n",
       "      <td>0.111111</td>\n",
       "      <td>{'rouge1': (0.2, 0.2, 0.20000000000000004), 'r...</td>\n",
       "      <td>-0.106656</td>\n",
       "      <td>0.388889</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[DH, AH, SIL, T, UW, TH, SIL, F, EH, R, IY, SI...</td>\n",
       "      <td>[DH, AH, SIL, S, ER, TH, SIL, V, EH, R, IY, SI...</td>\n",
       "      <td>The tooth fairy forgot to come when Roger's to...</td>\n",
       "      <td>\"The Earth very requite to hum friend, where's...</td>\n",
       "      <td>0.909091</td>\n",
       "      <td>0.090909</td>\n",
       "      <td>{'rouge1': (0.2727272727272727, 0.25, 0.260869...</td>\n",
       "      <td>0.011441</td>\n",
       "      <td>0.369565</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>875</th>\n",
       "      <td>[Y, AO, R, SIL, T, Y, UW, IH, SH, AH, N, SIL, ...</td>\n",
       "      <td>[Y, UH, R, SIL, M, EH, ZH, ER, IH, NG, SIL, IH...</td>\n",
       "      <td>Your tuition reimbursement.</td>\n",
       "      <td>\"You're measuring him as meant.\"</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>{'rouge1': (0.3333333333333333, 0.333333333333...</td>\n",
       "      <td>0.307121</td>\n",
       "      <td>0.500000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>876</th>\n",
       "      <td>[G, EH, T, IH, NG, SIL, P, L, EH, JH, SIL, SH,...</td>\n",
       "      <td>[K, AE, N, D, SIL, P, L, AE, K, SIL, CH, IY, C...</td>\n",
       "      <td>Getting pledge sheets for the Boy Scouts.</td>\n",
       "      <td>\"Can the black cheese for the boy stay tight?\"</td>\n",
       "      <td>0.857143</td>\n",
       "      <td>0.267857</td>\n",
       "      <td>{'rouge1': (0.2857142857142857, 0.285714285714...</td>\n",
       "      <td>0.073714</td>\n",
       "      <td>0.406250</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>877</th>\n",
       "      <td>[IH, F, SIL, Y, UW, SIL, HH, AE, V, SIL, EH, N...</td>\n",
       "      <td>[IH, P, SIL, Y, UW, SIL, HH, AE, V, SIL, EH, N...</td>\n",
       "      <td>If you have any doubts whatsoever.</td>\n",
       "      <td>\"I've you have any doubt whatsoever.\"</td>\n",
       "      <td>0.833333</td>\n",
       "      <td>0.083333</td>\n",
       "      <td>{'rouge1': (0.16666666666666666, 0.16666666666...</td>\n",
       "      <td>-0.034383</td>\n",
       "      <td>0.285714</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>878</th>\n",
       "      <td>[M, IH, S, T, ER, IY, SIL, M, UW, V, IY, Z, SIL]</td>\n",
       "      <td>[T, EH, CH, ER, SIL, M, UW, V, IY, Z, SIL]</td>\n",
       "      <td>Mystery movies.</td>\n",
       "      <td>\"Teacher moves.\"</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.3333333333333333, 0.5, 0.4), 'ro...</td>\n",
       "      <td>0.302895</td>\n",
       "      <td>0.384615</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>879</th>\n",
       "      <td>[AY, SIL, K, AE, N, T, SIL, R, IH, M, EH, M, B...</td>\n",
       "      <td>[AY, SIL, K, AE, N, T, SIL, R, IH, M, EH, M, B...</td>\n",
       "      <td>I can't remember where.</td>\n",
       "      <td>I can't remember her.</td>\n",
       "      <td>0.500000</td>\n",
       "      <td>0.468750</td>\n",
       "      <td>{'rouge1': (0.6, 0.6, 0.6), 'rouge2': (0.5, 0....</td>\n",
       "      <td>0.434112</td>\n",
       "      <td>0.157895</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>880 rows × 9 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       target_phonemes  \\\n",
       "0    [TH, IY, AA, K, R, AH, S, IY, SIL, R, IY, K, A...   \n",
       "1    [R, IH, CH, SIL, P, ER, CH, AH, S, T, SIL, S, ...   \n",
       "2    [S, OW, SIL, R, UW, L, Z, SIL, W, IY, SIL, M, ...   \n",
       "3    [L, AO, R, IY, Z, SIL, K, AA, S, T, UW, M, SIL...   \n",
       "4    [DH, AH, SIL, T, UW, TH, SIL, F, EH, R, IY, SI...   \n",
       "..                                                 ...   \n",
       "875  [Y, AO, R, SIL, T, Y, UW, IH, SH, AH, N, SIL, ...   \n",
       "876  [G, EH, T, IH, NG, SIL, P, L, EH, JH, SIL, SH,...   \n",
       "877  [IH, F, SIL, Y, UW, SIL, HH, AE, V, SIL, EH, N...   \n",
       "878   [M, IH, S, T, ER, IY, SIL, M, UW, V, IY, Z, SIL]   \n",
       "879  [AY, SIL, K, AE, N, T, SIL, R, IH, M, EH, M, B...   \n",
       "\n",
       "                                         pred_phonemes  \\\n",
       "0    [DH, EH, R, K, S, ER, Z, SIL, R, IY, K, AH, N,...   \n",
       "1    [G, EH, S, SIL, P, ER, CH, AH, N, S, SIL, EH, ...   \n",
       "2    [S, OW, SIL, R, IH, L, Z, SIL, M, IY, SIL, M, ...   \n",
       "3    [T, R, OY, IH, NG, SIL, K, AW, M, SIL, N, IY, ...   \n",
       "4    [DH, AH, SIL, S, ER, TH, SIL, V, EH, R, IY, SI...   \n",
       "..                                                 ...   \n",
       "875  [Y, UH, R, SIL, M, EH, ZH, ER, IH, NG, SIL, IH...   \n",
       "876  [K, AE, N, D, SIL, P, L, AE, K, SIL, CH, IY, C...   \n",
       "877  [IH, P, SIL, Y, UW, SIL, HH, AE, V, SIL, EH, N...   \n",
       "878         [T, EH, CH, ER, SIL, M, UW, V, IY, Z, SIL]   \n",
       "879  [AY, SIL, K, AE, N, T, SIL, R, IH, M, EH, M, B...   \n",
       "\n",
       "                                       target_sentence  \\\n",
       "0                              Theocracy reconsidered.   \n",
       "1           Rich purchased several signed lithographs.   \n",
       "2            So rules we made, in unabashed collusion.   \n",
       "3    Lori's costume needed black gloves to be compl...   \n",
       "4    The tooth fairy forgot to come when Roger's to...   \n",
       "..                                                 ...   \n",
       "875                        Your tuition reimbursement.   \n",
       "876          Getting pledge sheets for the Boy Scouts.   \n",
       "877                 If you have any doubts whatsoever.   \n",
       "878                                    Mystery movies.   \n",
       "879                            I can't remember where.   \n",
       "\n",
       "                                         pred_sentence  WER_scores  \\\n",
       "0                        The exercise is reconsidered.    1.000000   \n",
       "1               \"Guess perchance ever sound list off.\"    1.000000   \n",
       "2        \"So, reels me made in an adverse conclusion.\"    0.857143   \n",
       "3    \"Trying calm needed black lives so be humbly t...    0.888889   \n",
       "4    \"The Earth very requite to hum friend, where's...    0.909091   \n",
       "..                                                 ...         ...   \n",
       "875                   \"You're measuring him as meant.\"    1.000000   \n",
       "876     \"Can the black cheese for the boy stay tight?\"    0.857143   \n",
       "877              \"I've you have any doubt whatsoever.\"    0.833333   \n",
       "878                                   \"Teacher moves.\"    1.000000   \n",
       "879                              I can't remember her.    0.500000   \n",
       "\n",
       "     METEOR_scores                                       ROUGE_scores  \\\n",
       "0         0.000000  {'rouge1': (0.5, 0.5, 0.5), 'rouge2': (0.0, 0....   \n",
       "1         0.000000  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....   \n",
       "2         0.071429  {'rouge1': (0.375, 0.42857142857142855, 0.3999...   \n",
       "3         0.111111  {'rouge1': (0.2, 0.2, 0.20000000000000004), 'r...   \n",
       "4         0.090909  {'rouge1': (0.2727272727272727, 0.25, 0.260869...   \n",
       "..             ...                                                ...   \n",
       "875       0.166667  {'rouge1': (0.3333333333333333, 0.333333333333...   \n",
       "876       0.267857  {'rouge1': (0.2857142857142857, 0.285714285714...   \n",
       "877       0.083333  {'rouge1': (0.16666666666666666, 0.16666666666...   \n",
       "878       0.000000  {'rouge1': (0.3333333333333333, 0.5, 0.4), 'ro...   \n",
       "879       0.468750  {'rouge1': (0.6, 0.6, 0.6), 'rouge2': (0.5, 0....   \n",
       "\n",
       "     BERTScore_F1_scores       CER  \n",
       "0              -0.081067  0.400000  \n",
       "1               0.076523  0.484848  \n",
       "2               0.184476  0.147059  \n",
       "3              -0.106656  0.388889  \n",
       "4               0.011441  0.369565  \n",
       "..                   ...       ...  \n",
       "875             0.307121  0.500000  \n",
       "876             0.073714  0.406250  \n",
       "877            -0.034383  0.285714  \n",
       "878             0.302895  0.384615  \n",
       "879             0.434112  0.157895  \n",
       "\n",
       "[880 rows x 9 columns]"
      ]
     },
     "execution_count": 126,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df_gpt4"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
