{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Goal of this work is to load the same language model of the Willet paper and compute our results with that language model to see if we reach a similar word error rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/LanguageModelDecoder/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "[nltk_data] Downloading package cmudict to /home/XXXXXX/nltk_data...\n",
      "[nltk_data]   Package cmudict is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "# import torch\n",
    "# print(torch.__version__)\n",
    "\n",
    "import pickle\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.utils.data import DataLoader\n",
    "import  utils.lmDecoderUtils as lmDecoderUtils\n",
    "\n",
    "import torch\n",
    "from dataset import SpeechSentenceDataset, getDatasetLoaders\n",
    "import re \n",
    "from g2p_en import G2p\n",
    "import numpy as np\n",
    "from model import GRUDecoder, SimpleGRUDecoder, LightningGRUDecoder\n",
    "import time\n",
    "import numpy as np\n",
    "from edit_distance import SequenceMatcher\n",
    "import tqdm\n",
    "import pytorch_lightning as pl\n",
    "import jiwer\n",
    "import nltk\n",
    "from nltk.corpus import cmudict\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "import wandb\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import copy\n",
    "from difflib import get_close_matches\n",
    "from phoneme2text import Phoneme2TextModel, PhonemeTokenizer, VanillaTransformerEncoder\n",
    "from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer\n",
    "import pandas as pd\n",
    "from Levenshtein import distance as levenshtein_distance\n",
    "# Download CMU Pronouncing Dictionary (First-time use)\n",
    "nltk.download(\"cmudict\")\n",
    "\n",
    "# Load CMUdict\n",
    "cmu_dict = cmudict.dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/LanguageModelDecoder/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'1.13.1+cu117'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "torch.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader, test_loader, competition_loader, loadedData = getDatasetLoaders(\"datasets_basline\", 64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "nInputFeatures = 256 #channels \n",
    "nClasses = 40 \n",
    "dropout = 0.4 \n",
    "hidden_dim = 1024\n",
    "nlayers = 5\n",
    "stride_len = 4\n",
    "kernel_len =32\n",
    "gaussian_smooth_width = 2\n",
    "bidirectional = True\n",
    "\n",
    "white_noise_SD = 0.8\n",
    "constant_offset_SD = 0.2\n",
    "seq_len = 150\n",
    "max_time_series_len = 12000\n",
    "\n",
    "lr_start = 3e-2\n",
    "lr_end = 1e-4\n",
    "l2_decay = 1e-5\n",
    "\n",
    "\n",
    "warmup_epoch = 5\n",
    "steps_per_epoch = len(train_loader)\n",
    "warmup_steps = warmup_epoch * steps_per_epoch\n",
    "\n",
    "target_epoch = 60\n",
    "total_steps = target_epoch * steps_per_epoch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/1 [00:00<?, ?it/s]/home/XXXXXX/anaconda3/envs/LanguageModelDecoder/lib/python3.9/site-packages/torch/functional.py:504: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3190.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n",
      "/home/XXXXXX/anaconda3/envs/LanguageModelDecoder/lib/python3.9/site-packages/torch/cuda/__init__.py:155: UserWarning: \n",
      "NVIDIA H100 80GB HBM3 with CUDA capability sm_90 is not compatible with the current PyTorch installation.\n",
      "The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70 sm_75 sm_80 sm_86.\n",
      "If you want to use the NVIDIA H100 80GB HBM3 GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/\n",
      "\n",
      "  warnings.warn(incompatible_device_warn.format(device_name, capability, \" \".join(arch_list), device_name))\n",
      "100%|██████████| 1/1 [00:09<00:00,  9.59s/it]\n"
     ]
    }
   ],
   "source": [
    "N_MODELS = 1\n",
    "\n",
    "## reload best models\n",
    "best_models = []\n",
    "\n",
    "for i in tqdm.trange(N_MODELS):\n",
    "\n",
    "    model = LightningGRUDecoder(\n",
    "        neural_dim=nInputFeatures,\n",
    "        n_classes=nClasses,\n",
    "        hidden_dim=hidden_dim,\n",
    "        layer_dim=nlayers,\n",
    "        strideLen=stride_len,\n",
    "        kernelLen=kernel_len,\n",
    "        gaussianSmoothWidth=gaussian_smooth_width,\n",
    "        bidirectional=bidirectional,\n",
    "        dropout=dropout,\n",
    "        white_noise_SD=white_noise_SD,\n",
    "        constant_offset_SD=constant_offset_SD,\n",
    "        weight_decay=l2_decay,\n",
    "        learning_rate=lr_start)\n",
    "\n",
    "    model.load_state_dict(torch.load(f\"./checkpoints_ensemble_v2/best_model_{i}.ckpt\")[\"state_dict\"])\n",
    "    best_models.append(model)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING: Logging before InitGoogleLogging() is written to STDERR\n",
      "I0305 09:53:36.600047 2412527 brain_speech_decoder.h:52] Reading fst /data/speech_5gram/lang_test/TLG.fst\n",
      "I0305 09:57:17.293763 2412527 brain_speech_decoder.h:58] Reading lm fst /data/speech_5gram/lang_test/G.fst\n",
      "I0305 09:58:14.373675 2412527 brain_speech_decoder.h:70] Reading rescore fst /data/speech_5gram/lang_test/G_no_prune.fst\n",
      "I0305 10:10:58.453982 2412527 brain_speech_decoder.h:81] Reading symbol table /data/speech_5gram/lang_test/words.txt\n"
     ]
    }
   ],
   "source": [
    "\n",
    "\n",
    "# # Load OPT 6B model\n",
    "# llm, llm_tokenizer = lmDecoderUtils.build_opt(\n",
    "#     cacheDir=\"/data/\", device=\"auto\", load_in_8bit=True\n",
    "# )\n",
    "\n",
    "ngramDecoder = lmDecoderUtils.build_lm_decoder(\n",
    "    \"/data/speech_5gram/lang_test\", acoustic_scale=0.5, nbest=100, beam=18\n",
    ")\n",
    "\n",
    "# LM decoding hyperparameters\n",
    "acoustic_scale = 0.5\n",
    "blank_penalty = np.log(7)\n",
    "llm_weight = 0.5\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import lmDecoderUtils\n",
    "import importlib\n",
    "importlib.reload(lmDecoderUtils)\n",
    "import utils.lmDecoderUtils as lmDecoderUtils \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load OPT 6B model\n",
    "llm, llm_tokenizer = lmDecoderUtils.build_gpt2_torch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import accelerate \n",
    "# accelerate.__version__"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "device =\"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/14 [00:00<?, ?it/s]/data/high_speech_BCI_analysis/reproduce_baseline/augmentations.py:91: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at ../aten/src/ATen/native/Convolution.cpp:895.)\n",
      "  return self.conv(input, weight=self.weight, groups=self.groups, padding=\"same\")\n",
      "100%|██████████| 14/14 [15:59<00:00, 68.53s/it]\n"
     ]
    }
   ],
   "source": [
    "# Run inference on test set\n",
    "all_pred_phonemes = []\n",
    "all_true_phonemes = []\n",
    "all_decoded_phonemes_ensemble = []\n",
    "all_true_texts = []\n",
    "cer_list = []\n",
    "all_logits = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for X, y, X_len, y_len,days, transcriptions in tqdm.tqdm(test_loader):\n",
    "\n",
    "        # Move data to device\n",
    "        X = X.to(device)\n",
    "        y = y.to(device)\n",
    "        days = days.to(device)\n",
    "        X_len = X_len.to(device)\n",
    "        y_len = y_len.to(device)\n",
    "\n",
    "        all_decoded_phonemes_ensemble_tmp = []\n",
    "        logits_ensemble = []\n",
    "        for model in best_models:\n",
    "            model.eval()\n",
    "\n",
    "            logits = model(X,days)\n",
    "            logits_ensemble.append(logits)\n",
    "        \n",
    "        all_logits.append(logits_ensemble)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [],
   "source": [
    "sentences = test_loader.dataset.sentences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "#unfold the logits list\n",
    "logits_unfolded = [item for sublist in all_logits for item in sublist[0]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 880/880 [53:38<00:00,  3.66s/it]  \n"
     ]
    }
   ],
   "source": [
    "llm_outputs = []\n",
    "# Generate nbest outputs from 5gram LM\n",
    "start_t = time.time()\n",
    "nbest_outputs = []\n",
    "for j in tqdm.trange(len(logits_unfolded)):\n",
    "    logits = logits_unfolded[j]\n",
    "    logits = np.concatenate(\n",
    "        [logits[:, 1:], logits[:, 0:1]], axis=-1\n",
    "    )  # Blank is last token\n",
    "    logits = lmDecoderUtils.rearrange_speech_logits(logits[None, :, :], has_sil=True)\n",
    "    nbest = lmDecoderUtils.lm_decode(\n",
    "        ngramDecoder,\n",
    "        logits[0],\n",
    "        blankPenalty=blank_penalty,\n",
    "        returnNBest=True,\n",
    "        rescore=True,\n",
    "    )\n",
    "    nbest_outputs.append(nbest)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [],
   "source": [
    "## save all he nbest outputs\n",
    "with open(\"nbest_outputs.pkl\", \"wb\") as f:\n",
    "    pickle.dump(nbest_outputs, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 880/880 [1:30:43<00:00,  6.19s/it]\n"
     ]
    }
   ],
   "source": [
    "decoded_sentences = []\n",
    "confidences = []\n",
    "for i in tqdm.trange(len(nbest_outputs)):\n",
    "    nbest_output = nbest_outputs[i]\n",
    "    decoded, confidence = lmDecoderUtils.gpt2_lm_decode(llm, llm_tokenizer, nbest_output, acoustic_scale, 0, alpha=llm_weight, returnConfidence=True)\n",
    "    decoded_sentences.append(decoded)\n",
    "    confidences.append(confidence)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "import string\n",
    "def preprocess_text(text):\n",
    "    \"\"\"\n",
    "    Remove punctuation, strip, and convert text to lowercase.\n",
    "    \"\"\"\n",
    "    return text.translate(str.maketrans('', '', string.punctuation)).strip().lower()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jiwer  # For WER\n",
    "import sacrebleu  # For BLEU\n",
    "from rouge_score import rouge_scorer  # For ROUGE\n",
    "from nltk.translate.meteor_score import meteor_score  # For METEOR\n",
    "import bert_score  # For BERTScore\n",
    "import numpy as np\n",
    "\n",
    "def compute_metrics(text_transcriptions, gpt_decoded):\n",
    "    \"\"\"\n",
    "    Compute various NLP evaluation metrics for text generation.\n",
    "\n",
    "    Args:\n",
    "        text_transcriptions (list): List of ground-truth reference sentences.\n",
    "        gpt_decoded (list): List of model-generated sentences.\n",
    "\n",
    "    Returns:\n",
    "        dict: Dictionary containing all computed metrics.\n",
    "    \"\"\"\n",
    "\n",
    "    #remove punctuation, strip and lower case\n",
    "\n",
    "\n",
    "    text_transcriptions = [preprocess_text(text) for text in text_transcriptions]\n",
    "    gpt_decoded = [preprocess_text(text) for text in gpt_decoded]\n",
    "\n",
    "    results = {}\n",
    "\n",
    "    # WER (Word Error Rate)\n",
    "    wer = jiwer.wer(text_transcriptions, gpt_decoded)\n",
    "    results[\"WER\"] = wer\n",
    "\n",
    "    # BLEU Score\n",
    "    bleu = sacrebleu.corpus_bleu(gpt_decoded, [text_transcriptions]).score\n",
    "    results[\"BLEU\"] = bleu\n",
    "\n",
    "    # ROUGE Scores\n",
    "    rouge = rouge_scorer.RougeScorer([\"rouge1\", \"rouge2\", \"rougeL\"], use_stemmer=True)\n",
    "    rouge_scores = [rouge.score(ref, pred) for ref, pred in zip(text_transcriptions, gpt_decoded)]\n",
    "    results[\"ROUGE-1\"] = np.mean([score[\"rouge1\"].fmeasure for score in rouge_scores])\n",
    "    results[\"ROUGE-2\"] = np.mean([score[\"rouge2\"].fmeasure for score in rouge_scores])\n",
    "    results[\"ROUGE-L\"] = np.mean([score[\"rougeL\"].fmeasure for score in rouge_scores])\n",
    "\n",
    "    ##METEOR\n",
    "    tokenized_references = [ref.split() for ref in text_transcriptions]  # Tokenize reference sentences\n",
    "    tokenized_hypotheses = [pred.split() for pred in gpt_decoded]  # Tokenize predicted sentences\n",
    "\n",
    "    meteor_scores = [meteor_score([ref], pred) for ref, pred in zip(tokenized_references, tokenized_hypotheses)]\n",
    "    results[\"METEOR\"] = np.mean(meteor_scores)\n",
    "    # BERTScore (Semantic Similarity)\n",
    "    P, R, F1 = bert_score.score(gpt_decoded, text_transcriptions, lang=\"en\", rescale_with_baseline=True)\n",
    "    results[\"BERTScore_Precision\"] = P.mean().item()\n",
    "    results[\"BERTScore_Recall\"] = R.mean().item()\n",
    "    results[\"BERTScore_F1\"] = F1.mean().item()\n",
    "\n",
    "    ## save also all values without recomputing when possible\n",
    "    results[\"METEOR_scores\"] = meteor_scores\n",
    "    results[\"ROUGE_scores\"] = rouge_scores\n",
    "\n",
    "    results[\"WER_scores\"] = [jiwer.wer([ref], [pred]) for ref, pred in zip(text_transcriptions, gpt_decoded)]\n",
    "    results[\"BERTScore_F1_scores\"] = F1.cpu().numpy().tolist()\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install bert-score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 96,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(880, 880)"
      ]
     },
     "execution_count": 96,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(sentences), len(decoded_sentences)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 97,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Warning: Empty candidate sentence detected; setting raw BERTscores to 0.\n",
      "Warning: Empty reference sentence detected; setting raw BERTScores to 0.\n",
      "Warning: Empty candidate sentence detected; setting raw BERTscores to 0.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WER: 0.2475\n",
      "BLEU: 65.4707\n",
      "ROUGE-1: 0.7939\n",
      "ROUGE-2: 0.7011\n",
      "ROUGE-L: 0.7932\n",
      "METEOR: 0.7914\n",
      "BERTScore_Precision: nan\n",
      "BERTScore_Recall: nan\n",
      "BERTScore_F1: nan\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Warning: Empty candidate sentence detected; setting raw BERTscores to 0.\n",
      "Warning: Empty reference sentence detected; setting raw BERTScores to 0.\n",
      "Warning: Empty candidate sentence detected; setting raw BERTscores to 0.\n",
      "Warning: Empty candidate sentence detected; setting raw BERTscores to 0.\n"
     ]
    }
   ],
   "source": [
    "\n",
    "metrics = compute_metrics(sentences,decoded_sentences)\n",
    "for metric, score in metrics.items():\n",
    "    if \"scores\" not in metric:\n",
    "        print(f\"{metric}: {score:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "results_df = pd.DataFrame({\n",
    "    \"target_sentence\": sentences,\n",
    "    \"pred_sentence\": decoded_sentences,\n",
    "})\n",
    "\n",
    "#unfold cer_list\n",
    "# cer_list_unfold = [item for sublist in cer_list for item in sublist]\n",
    "\n",
    "results_df[\"WER_scores\"] = metrics[\"WER_scores\"]\n",
    "results_df[\"METEOR_scores\"] = metrics[\"METEOR_scores\"]\n",
    "results_df[\"ROUGE_scores\"] = metrics[\"ROUGE_scores\"]\n",
    "results_df[\"BERTScore_F1_scores\"] = metrics[\"BERTScore_F1_scores\"]\n",
    "\n",
    "results_df.to_csv(\"results/willet_baseline_results.csv\", index=False)\n",
    "\n",
    "overall_metrics = {k:v for k,v in metrics.items() if \"scores\" not in k}\n",
    "\n",
    "metrics_df = pd.DataFrame(overall_metrics, index=[0])\n",
    "metrics_df.to_csv(\"results/willet_baseline_metrics.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 100,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>target_sentence</th>\n",
       "      <th>pred_sentence</th>\n",
       "      <th>WER_scores</th>\n",
       "      <th>METEOR_scores</th>\n",
       "      <th>ROUGE_scores</th>\n",
       "      <th>BERTScore_F1_scores</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>621</th>\n",
       "      <td>Started investigating.</td>\n",
       "      <td>ford invested in a</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-3.011812e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>674</th>\n",
       "      <td>Through Oklahoma City.</td>\n",
       "      <td>there or come see us</td>\n",
       "      <td>1.666667</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-4.943095e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>79</th>\n",
       "      <td>Marksmanship example.</td>\n",
       "      <td>most inept example did</td>\n",
       "      <td>1.500000</td>\n",
       "      <td>0.227273</td>\n",
       "      <td>{'rouge1': (0.25, 0.5, 0.3333333333333333), 'r...</td>\n",
       "      <td>3.032638e+21</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>406</th>\n",
       "      <td>Capital punishment.</td>\n",
       "      <td>people proficient in</td>\n",
       "      <td>1.500000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-4.925075e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>612</th>\n",
       "      <td>Golden Retriever.</td>\n",
       "      <td>children of the</td>\n",
       "      <td>1.500000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-5.630964e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>356</th>\n",
       "      <td>I'm kind of out of it right now.</td>\n",
       "      <td>i'm kind of out of it right now</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.999023</td>\n",
       "      <td>{'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....</td>\n",
       "      <td>-4.504041e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>171</th>\n",
       "      <td>As any city grows up.</td>\n",
       "      <td>as any city grows up</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.996000</td>\n",
       "      <td>{'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....</td>\n",
       "      <td>-4.191945e+01</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>617</th>\n",
       "      <td>His voice was nearly drowned out by the crowd.</td>\n",
       "      <td>his voice was nearly drowned out by the crowd</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.999314</td>\n",
       "      <td>{'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....</td>\n",
       "      <td>-5.280870e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>615</th>\n",
       "      <td>I mean you can earn a lot more money.</td>\n",
       "      <td>i mean you can earn a lot more money</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.999314</td>\n",
       "      <td>{'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....</td>\n",
       "      <td>-7.886252e+00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>647</th>\n",
       "      <td>We have that here.</td>\n",
       "      <td>we have that here</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.992188</td>\n",
       "      <td>{'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....</td>\n",
       "      <td>-5.453975e+00</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>880 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                    target_sentence  \\\n",
       "621                          Started investigating.   \n",
       "674                          Through Oklahoma City.   \n",
       "79                            Marksmanship example.   \n",
       "406                             Capital punishment.   \n",
       "612                               Golden Retriever.   \n",
       "..                                              ...   \n",
       "356                I'm kind of out of it right now.   \n",
       "171                           As any city grows up.   \n",
       "617  His voice was nearly drowned out by the crowd.   \n",
       "615           I mean you can earn a lot more money.   \n",
       "647                              We have that here.   \n",
       "\n",
       "                                     pred_sentence  WER_scores  METEOR_scores  \\\n",
       "621                             ford invested in a    2.000000       0.000000   \n",
       "674                           there or come see us    1.666667       0.000000   \n",
       "79                          most inept example did    1.500000       0.227273   \n",
       "406                           people proficient in    1.500000       0.000000   \n",
       "612                                children of the    1.500000       0.000000   \n",
       "..                                             ...         ...            ...   \n",
       "356                i'm kind of out of it right now    0.000000       0.999023   \n",
       "171                           as any city grows up    0.000000       0.996000   \n",
       "617  his voice was nearly drowned out by the crowd    0.000000       0.999314   \n",
       "615           i mean you can earn a lot more money    0.000000       0.999314   \n",
       "647                              we have that here    0.000000       0.992188   \n",
       "\n",
       "                                          ROUGE_scores  BERTScore_F1_scores  \n",
       "621  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....        -3.011812e+00  \n",
       "674  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....        -4.943095e+00  \n",
       "79   {'rouge1': (0.25, 0.5, 0.3333333333333333), 'r...         3.032638e+21  \n",
       "406  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....        -4.925075e+00  \n",
       "612  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....        -5.630964e+00  \n",
       "..                                                 ...                  ...  \n",
       "356  {'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....        -4.504041e+00  \n",
       "171  {'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....        -4.191945e+01  \n",
       "617  {'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....        -5.280870e+00  \n",
       "615  {'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....        -7.886252e+00  \n",
       "647  {'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....        -5.453975e+00  \n",
       "\n",
       "[880 rows x 6 columns]"
      ]
     },
     "execution_count": 100,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df.sort_values(\"WER_scores\", ascending=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm_out = lmDecoderUtils.cer_with_gpt2_decoder(\n",
    "    llm,\n",
    "    llm_tokenizer,\n",
    "    nbest_outputs[:],\n",
    "    acoustic_scale,\n",
    "    rnn_outputs,\n",
    "    outputType=\"speech_sil\",\n",
    "    returnCI=True,\n",
    "    lengthPenalty=0,\n",
    "    alpha=llm_weight,\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "LanguageModelDecoder",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
