{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "aed3958d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package cmudict to /home/XXXXXX/nltk_data...\n",
      "[nltk_data]   Package cmudict is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import pickle\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "from dataset import SpeechSentenceDataset, idsToPhonemes, getDatasetLoaders,getDatasetLoaders_V3, PHONE_DEF, PHONE_DEF_SIL\n",
    "import re \n",
    "from g2p_en import G2p\n",
    "import numpy as np\n",
    "from model.ctc_modelling import LightningGRUDecoder, LightningGRUDecoder_MFCC_v3\n",
    "from model.hybrid_modelling import HybridCausalLMOutput, HybridGRUDecoder\n",
    "import time\n",
    "import numpy as np\n",
    "from edit_distance import SequenceMatcher\n",
    "import tqdm\n",
    "import pytorch_lightning as pl\n",
    "import jiwer\n",
    "import nltk\n",
    "from nltk.corpus import cmudict\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "import wandb\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import copy\n",
    "from difflib import get_close_matches\n",
    "from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer\n",
    "import pandas as pd\n",
    "from torchaudio.models.decoder import ctc_decoder\n",
    "import string\n",
    "from config import DATASET_SM_ROBUST, DATASET_SM_ZSCORE, DATASET_FULL_TRIALS_ZSCORE\n",
    "# from model.ctc_modelling import Light\n",
    "import os\n",
    "# Download CMU Pronouncing Dictionary (First-time use)\n",
    "nltk.download(\"cmudict\")\n",
    "\n",
    "# Load CMUdict\n",
    "cmu_dict = cmudict.dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "59283f62",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of trials:  10020\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  8800\n",
      "Number of trials:  880\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  880\n"
     ]
    }
   ],
   "source": [
    "train_loader, test_loader,_, loadedData = getDatasetLoaders_V3(DATASET_FULL_TRIALS_ZSCORE, 64, include_prego=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "73c0335c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Resetting neural_dim based on channels\n",
      "neural_dim 256 256\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n",
      "/tmp/ipykernel_1970370/2713209040.py:36: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  neural_encoder.load_state_dict(torch.load(neural_encoder_model_weights_path)[\"state_dict\"])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "nInputFeatures = 256 #channels \n",
    "nClasses = 40 \n",
    "dropout = 0.4 \n",
    "hidden_dim = 1024\n",
    "nlayers = 5\n",
    "stride_len = 4\n",
    "kernel_len =32\n",
    "gaussian_smooth_width = 2\n",
    "bidirectional = True\n",
    "\n",
    "white_noise_SD = 0.8\n",
    "constant_offset_SD = 0.2\n",
    "seq_len = 150\n",
    "max_time_series_len = 12000\n",
    "\n",
    "lr_start = 3e-4\n",
    "lr_end = 0.02\n",
    "l2_decay = 1e-5\n",
    "\n",
    "neural_encoder_model_weights_path = \".checkpoints/mfcc_sm_gru_ctc_LONGRUN/best_model.ckpt\"\n",
    "neural_encoder = LightningGRUDecoder_MFCC_v3(\n",
    "            neural_dim=nInputFeatures,\n",
    "            n_classes=nClasses,\n",
    "            hidden_dim=hidden_dim,\n",
    "            layer_dim=nlayers,\n",
    "            strideLen=stride_len,\n",
    "            kernelLen=kernel_len,\n",
    "            gaussianSmoothWidth=gaussian_smooth_width,\n",
    "            bidirectional=bidirectional,\n",
    "            dropout=dropout,\n",
    "            white_noise_SD=white_noise_SD,\n",
    "            constant_offset_SD=constant_offset_SD,\n",
    "            weight_decay=l2_decay,\n",
    "            learning_rate=lr_start)\n",
    "\n",
    "neural_encoder.load_state_dict(torch.load(neural_encoder_model_weights_path)[\"state_dict\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 116,
   "id": "245b2214",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_name = \"gru_ctc_mfcc_bart\"\n",
    "\n",
    "results_dir = \"results/\" + output_name\n",
    "os.makedirs(results_dir, exist_ok=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bd6d7100",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Encoder will be fine-tuned.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_1970370/3498522823.py:68: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(f\"{run_folder}/best_model_wer.ckpt\")[\"state_dict\"])\n"
     ]
    }
   ],
   "source": [
    "TRAIN = False\n",
    "model = HybridGRUDecoder(\n",
    "    neural_encoder=neural_encoder,\n",
    "    learning_rate=0.00005,\n",
    "    weight_decay=0.00001,\n",
    "    lm_model_dim = 768, \n",
    "    freeze_lm=False,\n",
    "    freeze_encoder=False,\n",
    "    use_lora=False,\n",
    "    lora_r=128,\n",
    "    lora_alpha=256,\n",
    "    ce_loss_weight=0.2,\n",
    "    ctc_loss_weight=0.5,\n",
    "    l1_loss_weight=1.,\n",
    "    project_from_logits=False,\n",
    ")\n",
    "\n",
    "\n",
    "if TRAIN:\n",
    "\n",
    "    \n",
    "\n",
    "    wandb_logger = WandbLogger(project=\"ECOG_Sentence_dataset\", name=f\"{output_name}\",\n",
    "                            reinit=True)\n",
    "\n",
    "    # Define ModelCheckpoint to save the best model based on validation loss\n",
    "    checkpoint_callback_wer = ModelCheckpoint(\n",
    "        monitor=\"val_WER\",  # Ensure your validation step logs \"val_loss\"\n",
    "        mode=\"min\",          # Save the model with the lowest validation loss\n",
    "        save_top_k=1,        # Keep only the best model\n",
    "        dirpath=f\".checkpoints/{output_name}/\",  # Directory to save checkpoints\n",
    "        filename=f\"best_model_wer\",  # Model filename\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "        # Define ModelCheckpoint to save the best model based on validation loss\n",
    "    checkpoint_callback_per = ModelCheckpoint(\n",
    "        monitor=\"val_CER\",  # Ensure your validation step logs \"val_loss\"\n",
    "        mode=\"min\",          # Save the model with the lowest validation loss\n",
    "        save_top_k=1,        # Keep only the best model\n",
    "        dirpath=f\".checkpoints/{output_name}/\",  # Directory to save checkpoints\n",
    "        filename=f\"best_model_per\",  # Model filename\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "    # Define EarlyStopping callback with patience of 3 epochs\n",
    "    early_stopping_callback = EarlyStopping(\n",
    "        monitor=\"val_loss\",\n",
    "        patience=25,   # Stop training if no improvement in 3 epochs\n",
    "        mode=\"min\",\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "\n",
    "    # Train model\n",
    "    trainer = pl.Trainer(max_epochs=100,devices =[0], callbacks=[checkpoint_callback_wer,checkpoint_callback_per, early_stopping_callback], logger=wandb_logger)\n",
    "\n",
    "    trainer.fit(model, train_loader, test_loader)\n",
    "\n",
    "    wandb.finish()  # Finish the WandB run\n",
    "\n",
    "else:\n",
    "\n",
    "    #load from optimized checkpoints\n",
    "\n",
    "    # Create a directory to save models\n",
    "    run_folder = f\"optimization/{output_name}/floral-sweep-5\"\n",
    "    model.load_state_dict(torch.load(f\"{run_folder}/best_model_wer.ckpt\")[\"state_dict\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "38cc233c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "HybridGRUDecoder(\n",
       "  (encoder): LightningGRUDecoder_MFCC_v3(\n",
       "    (inputLayerNonlinearity): Softsign()\n",
       "    (unfolder): Unfold(kernel_size=(32, 1), dilation=1, padding=0, stride=4)\n",
       "    (mfcc_unfolder): Unfold(kernel_size=(4, 1), dilation=1, padding=0, stride=4)\n",
       "    (gaussianSmoother): GaussianSmoothing()\n",
       "    (gru_decoder): GRU(8192, 1024, num_layers=5, batch_first=True, dropout=0.4, bidirectional=True)\n",
       "    (fc_decoder_out): Linear(in_features=2048, out_features=41, bias=True)\n",
       "    (mfcc_decoder): Linear(in_features=2048, out_features=56, bias=True)\n",
       "    (ctc_loss): CTCLoss()\n",
       "    (l1oss): L1Loss()\n",
       "  )\n",
       "  (project): Sequential(\n",
       "    (0): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)\n",
       "    (1): Linear(in_features=2048, out_features=768, bias=True)\n",
       "  )\n",
       "  (language_model): BartForConditionalGeneration(\n",
       "    (model): BartModel(\n",
       "      (shared): BartScaledWordEmbedding(50265, 768, padding_idx=1)\n",
       "      (encoder): BartEncoder(\n",
       "        (embed_tokens): BartScaledWordEmbedding(50265, 768, padding_idx=1)\n",
       "        (embed_positions): BartLearnedPositionalEmbedding(1026, 768)\n",
       "        (layers): ModuleList(\n",
       "          (0-5): 6 x BartEncoderLayer(\n",
       "            (self_attn): BartSdpaAttention(\n",
       "              (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "            )\n",
       "            (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "            (activation_fn): GELUActivation()\n",
       "            (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          )\n",
       "        )\n",
       "        (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "      (decoder): BartDecoder(\n",
       "        (embed_tokens): BartScaledWordEmbedding(50265, 768, padding_idx=1)\n",
       "        (embed_positions): BartLearnedPositionalEmbedding(1026, 768)\n",
       "        (layers): ModuleList(\n",
       "          (0-5): 6 x BartDecoderLayer(\n",
       "            (self_attn): BartSdpaAttention(\n",
       "              (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "            )\n",
       "            (activation_fn): GELUActivation()\n",
       "            (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "            (encoder_attn): BartSdpaAttention(\n",
       "              (k_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (v_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (q_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "              (out_proj): Linear(in_features=768, out_features=768, bias=True)\n",
       "            )\n",
       "            (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "            (fc1): Linear(in_features=768, out_features=3072, bias=True)\n",
       "            (fc2): Linear(in_features=3072, out_features=768, bias=True)\n",
       "            (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          )\n",
       "        )\n",
       "        (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "      )\n",
       "    )\n",
       "    (lm_head): Linear(in_features=768, out_features=50265, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = \"cuda:0\"\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "17eeae9d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/14 [00:00<?, ?it/s]/data/XXXXXX/speech_decoding_BCI/augmentations.py:91: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at ../aten/src/ATen/native/Convolution.cpp:1036.)\n",
      "  return self.conv(input, weight=self.weight, groups=self.groups, padding=\"same\")\n",
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:677: UserWarning: `num_beams` is set to 1. However, `early_stopping` is set to `True` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `early_stopping`.\n",
      "  warnings.warn(\n",
      "100%|██████████| 14/14 [00:02<00:00,  4.68it/s]\n"
     ]
    }
   ],
   "source": [
    "decoded_sentences = []\n",
    "true_sentences = []\n",
    "for batch in tqdm.tqdm(test_loader):\n",
    "    X = batch[\"neural_feats\"].to(device)\n",
    "    y = batch[\"phone_seq\"]\n",
    "    X_len = batch[\"neural_time_bins\"]\n",
    "    y_len = batch[\"phone_seq_len\"]\n",
    "    dayIdx = batch[\"day\"].to(device)\n",
    "    sentence = batch[\"sentence\"]\n",
    "\n",
    "    with torch.no_grad():\n",
    "        decoded = model.generate(X,dayIdx)\n",
    "    decoded_sentences.extend(decoded)\n",
    "    true_sentences.extend(sentence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "7a630e79",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers.modeling_outputs import BaseModelOutput, ModelOutput\n",
    "\n",
    "def beam_search_generate(model, neuralInput, dayIdx, max_length=60, num_beams=1,num_return_sequences=1):\n",
    "    \"\"\"\n",
    "    Generate text from the model.\n",
    "    neuralInput: (batch_size, seq_len, input_dim)\n",
    "    dayIdx: Session index\n",
    "    \"\"\"\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        hid = model.get_neural_embedding(neuralInput, dayIdx)\n",
    "        # encoder_outputs = model.project(hid)\n",
    "\n",
    "        if model.project_from_logits:\n",
    "            # If projecting from logits, use the predicted logits\n",
    "            logits = model.encoder.fc_decoder_out(hid)\n",
    "            encoder_outputs = model.project(logits)\n",
    "\n",
    "        else:\n",
    "            encoder_outputs = model.project(hid)\n",
    "\n",
    "\n",
    "        encoder_outputs = BaseModelOutput(last_hidden_state=encoder_outputs)\n",
    "\n",
    "        generated_ids = model.language_model.generate(encoder_outputs=encoder_outputs, max_length=max_length,num_beams=num_beams, \n",
    "                                                      num_return_sequences=num_return_sequences, \n",
    "                                                      do_sample= True,\n",
    "                                                      top_k=50,\n",
    "                                                      top_p=0.95,\n",
    "                                                      temperature=0.9,\n",
    "                                                      length_penalty=1.0,\n",
    "                                                      no_repeat_ngram_size=3)\n",
    "        generated_text = model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
    "\n",
    "    return generated_text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 101,
   "id": "00db7572",
   "metadata": {},
   "outputs": [],
   "source": [
    "num_return = 2\n",
    "check= beam_search_generate(model,X,dayIdx,max_length=20, num_beams=18,num_return_sequences=num_return)\n",
    "# rearrange check as a list of list with each sublist containing the generated sequences for each input\n",
    "check = [check[i:i + num_return] for i in range(0, len(check), num_return)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 111,
   "id": "6c867bcd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([\"It's a pretty good size switch.\", \"It's a pretty good size church.\"],\n",
       " \"It's a pretty good size switch.\",\n",
       " \"It's a pretty good size church.\")"
      ]
     },
     "execution_count": 111,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "idx = 20\n",
    "check[idx],decoded[idx],sentence[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "6c4272d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Decoded sentence: I took about an authentic school, too.\n",
      "True sentence: He talked about unauthentic storylines too.\n"
     ]
    }
   ],
   "source": [
    "idx = 10\n",
    "print(\"Decoded sentence:\", decoded_sentences[idx])\n",
    "print(\"True sentence:\", true_sentences[idx])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1fa3900e",
   "metadata": {},
   "source": [
    "## Evaluation of words"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 113,
   "id": "23491be1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import string\n",
    "import re\n",
    "\n",
    "import jiwer  # For WER\n",
    "import sacrebleu  # For BLEU\n",
    "from rouge_score import rouge_scorer  # For ROUGE\n",
    "from nltk.translate.meteor_score import meteor_score  # For METEOR\n",
    "import bert_score  # For BERTScore\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "def preprocess_text(text):\n",
    "    \"\"\"\n",
    "    Remove punctuation, strip, and convert text to lowercase.\n",
    "    \"\"\"\n",
    "    return text.translate(str.maketrans('', '', string.punctuation)).strip().lower()\n",
    "\n",
    "\n",
    "\n",
    "def compute_metrics(text_transcriptions, gpt_decoded):\n",
    "    \"\"\"\n",
    "    Compute various NLP evaluation metrics for text generation.\n",
    "\n",
    "    Args:\n",
    "        text_transcriptions (list): List of ground-truth reference sentences.\n",
    "        gpt_decoded (list): List of model-generated sentences.\n",
    "\n",
    "    Returns:\n",
    "        dict: Dictionary containing all computed metrics.\n",
    "    \"\"\"\n",
    "\n",
    "    #remove punctuation, strip and lower case\n",
    "\n",
    "\n",
    "    text_transcriptions = [preprocess_text(text) for text in text_transcriptions]\n",
    "    gpt_decoded = [preprocess_text(text) for text in gpt_decoded]\n",
    "\n",
    "    results = {}\n",
    "\n",
    "    # WER (Word Error Rate)\n",
    "    wer = jiwer.wer(text_transcriptions, gpt_decoded)\n",
    "    results[\"WER\"] = wer\n",
    "\n",
    "    # BLEU Score\n",
    "    bleu = sacrebleu.corpus_bleu(gpt_decoded, [text_transcriptions]).score\n",
    "    results[\"BLEU\"] = bleu\n",
    "\n",
    "    # ROUGE Scores\n",
    "    rouge = rouge_scorer.RougeScorer([\"rouge1\", \"rouge2\", \"rougeL\"], use_stemmer=True)\n",
    "    rouge_scores = [rouge.score(ref, pred) for ref, pred in zip(text_transcriptions, gpt_decoded)]\n",
    "    results[\"ROUGE-1\"] = np.mean([score[\"rouge1\"].fmeasure for score in rouge_scores])\n",
    "    results[\"ROUGE-2\"] = np.mean([score[\"rouge2\"].fmeasure for score in rouge_scores])\n",
    "    results[\"ROUGE-L\"] = np.mean([score[\"rougeL\"].fmeasure for score in rouge_scores])\n",
    "\n",
    "    ##METEOR\n",
    "    tokenized_references = [ref.split() for ref in text_transcriptions]  # Tokenize reference sentences\n",
    "    tokenized_hypotheses = [pred.split() for pred in gpt_decoded]  # Tokenize predicted sentences\n",
    "\n",
    "    meteor_scores = [meteor_score([ref], pred) for ref, pred in zip(tokenized_references, tokenized_hypotheses)]\n",
    "    results[\"METEOR\"] = np.mean(meteor_scores)\n",
    "    # BERTScore (Semantic Similarity)\n",
    "    P, R, F1 = bert_score.score(gpt_decoded, text_transcriptions, lang=\"en\", rescale_with_baseline=True)\n",
    "    results[\"BERTScore_Precision\"] = P.mean().item()\n",
    "    results[\"BERTScore_Recall\"] = R.mean().item()\n",
    "    results[\"BERTScore_F1\"] = F1.mean().item()\n",
    "\n",
    "    ## save also all values without recomputing when possible\n",
    "    results[\"METEOR_scores\"] = meteor_scores\n",
    "    results[\"ROUGE_scores\"] = rouge_scores\n",
    "\n",
    "    results[\"WER_scores\"] = [jiwer.wer([ref], [pred]) for ref, pred in zip(text_transcriptions, gpt_decoded)]\n",
    "    results[\"BERTScore_F1_scores\"] = F1.cpu().numpy().tolist()\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 114,
   "id": "1ea7b37a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WER: 0.2488\n",
      "BLEU: 56.7143\n",
      "ROUGE-1: 0.7579\n",
      "ROUGE-2: 0.6131\n",
      "ROUGE-L: 0.7576\n",
      "METEOR: 0.7195\n",
      "BERTScore_Precision: 0.6285\n",
      "BERTScore_Recall: 0.6369\n",
      "BERTScore_F1: 0.6330\n"
     ]
    }
   ],
   "source": [
    "\n",
    "metrics = compute_metrics(true_sentences,decoded_sentences)\n",
    "for metric, score in metrics.items():\n",
    "    if \"scores\" not in metric:\n",
    "        print(f\"{metric}: {score:.4f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 117,
   "id": "dfd609d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "results_df = pd.DataFrame({\n",
    "    \"target_sentence\": true_sentences,\n",
    "    \"pred_sentence\": decoded_sentences,\n",
    "})\n",
    "\n",
    "#unfold cer_list\n",
    "# cer_list_unfold = [item for sublist in cer_list for item in sublist]\n",
    "\n",
    "results_df[\"WER_scores\"] = metrics[\"WER_scores\"]\n",
    "results_df[\"METEOR_scores\"] = metrics[\"METEOR_scores\"]\n",
    "results_df[\"ROUGE_scores\"] = metrics[\"ROUGE_scores\"]\n",
    "results_df[\"BERTScore_F1_scores\"] = metrics[\"BERTScore_F1_scores\"]\n",
    "\n",
    "results_df.to_csv(f\"{results_dir}/language_results_BART.csv\", index=False)\n",
    "\n",
    "overall_metrics = {k:v for k,v in metrics.items() if \"scores\" not in k}\n",
    "\n",
    "metrics_df = pd.DataFrame(overall_metrics, index=[0])\n",
    "metrics_df.to_csv(f\"{results_dir}/language_metrics_BART.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 118,
   "id": "8e1930df",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>target_sentence</th>\n",
       "      <th>pred_sentence</th>\n",
       "      <th>WER_scores</th>\n",
       "      <th>METEOR_scores</th>\n",
       "      <th>ROUGE_scores</th>\n",
       "      <th>BERTScore_F1_scores</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>621</th>\n",
       "      <td>Started investigating.</td>\n",
       "      <td>Arted and Vietnamesepiring.</td>\n",
       "      <td>1.500000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-0.146762</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>780</th>\n",
       "      <td>Temperature swing in about four hours.</td>\n",
       "      <td>Trump's who are you sure in about her?</td>\n",
       "      <td>1.166667</td>\n",
       "      <td>0.302419</td>\n",
       "      <td>{'rouge1': (0.25, 0.3333333333333333, 0.285714...</td>\n",
       "      <td>-0.179631</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>Before Thursday's exam, review every formula.</td>\n",
       "      <td>Boy, fireless faith was view every vividly.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.081967</td>\n",
       "      <td>{'rouge1': (0.14285714285714285, 0.16666666666...</td>\n",
       "      <td>-0.049474</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Theocracy reconsidered.</td>\n",
       "      <td>Hispanic approximated.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-0.001805</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>878</th>\n",
       "      <td>Mystery movies.</td>\n",
       "      <td>Military money.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>0.412289</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>416</th>\n",
       "      <td>Former employers.</td>\n",
       "      <td>Friminal employees.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>0.118247</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>107</th>\n",
       "      <td>Recent legislation.</td>\n",
       "      <td>Residential education.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>0.474056</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>79</th>\n",
       "      <td>Marksmanship example.</td>\n",
       "      <td>Milkimeimeime:</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-0.238075</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>469</th>\n",
       "      <td>Were thoroughbreds.</td>\n",
       "      <td>We're three baths.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.238095</td>\n",
       "      <td>{'rouge1': (0.3333333333333333, 0.5, 0.4), 'ro...</td>\n",
       "      <td>0.160664</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>130</th>\n",
       "      <td>Einstein equation.</td>\n",
       "      <td>Understanding erosion.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-0.107104</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>111</th>\n",
       "      <td>Noise causes air pollution.</td>\n",
       "      <td>No crude cars day pollution.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.121951</td>\n",
       "      <td>{'rouge1': (0.2, 0.25, 0.22222222222222224), '...</td>\n",
       "      <td>0.271776</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>Wildfire near Sunshine forces park closures.</td>\n",
       "      <td>Wider do children first.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-0.035701</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>869</th>\n",
       "      <td>Savings and loans.</td>\n",
       "      <td>Swinging an ozone.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>0.040194</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>572</th>\n",
       "      <td>Two hour classes.</td>\n",
       "      <td>Do our glasses?</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>0.019725</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>681</th>\n",
       "      <td>State land grant colleges.</td>\n",
       "      <td>Set the rent closures.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>0.056620</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>212</th>\n",
       "      <td>Competition with Japan.</td>\n",
       "      <td>Commentational was depend.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-0.136322</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>646</th>\n",
       "      <td>You're developing your imagination.</td>\n",
       "      <td>You evaluate for immigration.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>0.075145</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>58</th>\n",
       "      <td>Those musicians harmonize marvelously.</td>\n",
       "      <td>Losipulation harms miraculously.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>0.030252</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>612</th>\n",
       "      <td>Golden Retriever.</td>\n",
       "      <td>Children require.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>-0.026966</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>287</th>\n",
       "      <td>Summary points.</td>\n",
       "      <td>Some payments.</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>{'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....</td>\n",
       "      <td>0.234216</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                   target_sentence  \\\n",
       "621                         Started investigating.   \n",
       "780         Temperature swing in about four hours.   \n",
       "6    Before Thursday's exam, review every formula.   \n",
       "0                          Theocracy reconsidered.   \n",
       "878                                Mystery movies.   \n",
       "416                              Former employers.   \n",
       "107                            Recent legislation.   \n",
       "79                           Marksmanship example.   \n",
       "469                            Were thoroughbreds.   \n",
       "130                             Einstein equation.   \n",
       "111                    Noise causes air pollution.   \n",
       "7     Wildfire near Sunshine forces park closures.   \n",
       "869                             Savings and loans.   \n",
       "572                              Two hour classes.   \n",
       "681                     State land grant colleges.   \n",
       "212                        Competition with Japan.   \n",
       "646            You're developing your imagination.   \n",
       "58          Those musicians harmonize marvelously.   \n",
       "612                              Golden Retriever.   \n",
       "287                                Summary points.   \n",
       "\n",
       "                                   pred_sentence  WER_scores  METEOR_scores  \\\n",
       "621                  Arted and Vietnamesepiring.    1.500000       0.000000   \n",
       "780       Trump's who are you sure in about her?    1.166667       0.302419   \n",
       "6    Boy, fireless faith was view every vividly.    1.000000       0.081967   \n",
       "0                         Hispanic approximated.    1.000000       0.000000   \n",
       "878                              Military money.    1.000000       0.000000   \n",
       "416                          Friminal employees.    1.000000       0.000000   \n",
       "107                       Residential education.    1.000000       0.000000   \n",
       "79                                Milkimeimeime:    1.000000       0.000000   \n",
       "469                           We're three baths.    1.000000       0.238095   \n",
       "130                       Understanding erosion.    1.000000       0.000000   \n",
       "111                 No crude cars day pollution.    1.000000       0.121951   \n",
       "7                       Wider do children first.    1.000000       0.000000   \n",
       "869                           Swinging an ozone.    1.000000       0.000000   \n",
       "572                              Do our glasses?    1.000000       0.000000   \n",
       "681                       Set the rent closures.    1.000000       0.000000   \n",
       "212                   Commentational was depend.    1.000000       0.000000   \n",
       "646                You evaluate for immigration.    1.000000       0.000000   \n",
       "58              Losipulation harms miraculously.    1.000000       0.000000   \n",
       "612                            Children require.    1.000000       0.000000   \n",
       "287                               Some payments.    1.000000       0.000000   \n",
       "\n",
       "                                          ROUGE_scores  BERTScore_F1_scores  \n",
       "621  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....            -0.146762  \n",
       "780  {'rouge1': (0.25, 0.3333333333333333, 0.285714...            -0.179631  \n",
       "6    {'rouge1': (0.14285714285714285, 0.16666666666...            -0.049474  \n",
       "0    {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....            -0.001805  \n",
       "878  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....             0.412289  \n",
       "416  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....             0.118247  \n",
       "107  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....             0.474056  \n",
       "79   {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....            -0.238075  \n",
       "469  {'rouge1': (0.3333333333333333, 0.5, 0.4), 'ro...             0.160664  \n",
       "130  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....            -0.107104  \n",
       "111  {'rouge1': (0.2, 0.25, 0.22222222222222224), '...             0.271776  \n",
       "7    {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....            -0.035701  \n",
       "869  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....             0.040194  \n",
       "572  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....             0.019725  \n",
       "681  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....             0.056620  \n",
       "212  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....            -0.136322  \n",
       "646  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....             0.075145  \n",
       "58   {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....             0.030252  \n",
       "612  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....            -0.026966  \n",
       "287  {'rouge1': (0.0, 0.0, 0.0), 'rouge2': (0.0, 0....             0.234216  "
      ]
     },
     "execution_count": 118,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df.sort_values(\"WER_scores\", ascending=False).head(20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 119,
   "id": "fcc7b615",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>target_sentence</th>\n",
       "      <th>pred_sentence</th>\n",
       "      <th>WER_scores</th>\n",
       "      <th>METEOR_scores</th>\n",
       "      <th>ROUGE_scores</th>\n",
       "      <th>BERTScore_F1_scores</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>307</th>\n",
       "      <td>That's the only way you can get to Canada from...</td>\n",
       "      <td>That's the only way you can get to here.</td>\n",
       "      <td>0.181818</td>\n",
       "      <td>0.828761</td>\n",
       "      <td>{'rouge1': (1.0, 0.8181818181818182, 0.9), 'ro...</td>\n",
       "      <td>0.716912</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>478</th>\n",
       "      <td>It might be kind of scary.</td>\n",
       "      <td>It might be kind of dangerous.</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.830000</td>\n",
       "      <td>{'rouge1': (0.8333333333333334, 0.833333333333...</td>\n",
       "      <td>0.874420</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>473</th>\n",
       "      <td>I don't mean to be knocking.</td>\n",
       "      <td>I don't mean to be talking.</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.830000</td>\n",
       "      <td>{'rouge1': (0.8333333333333334, 0.833333333333...</td>\n",
       "      <td>0.840472</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>500</th>\n",
       "      <td>It had to be November anyway.</td>\n",
       "      <td>It had to be nineteen anyway.</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.806667</td>\n",
       "      <td>{'rouge1': (0.8333333333333334, 0.833333333333...</td>\n",
       "      <td>0.552765</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>428</th>\n",
       "      <td>You had to have a supply.</td>\n",
       "      <td>You had to have a deploy.</td>\n",
       "      <td>0.166667</td>\n",
       "      <td>0.830000</td>\n",
       "      <td>{'rouge1': (0.8333333333333334, 0.833333333333...</td>\n",
       "      <td>0.729336</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>547</th>\n",
       "      <td>That's the freedom of choice.</td>\n",
       "      <td>That's the freedom of choice.</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.996000</td>\n",
       "      <td>{'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>638</th>\n",
       "      <td>Trump has said there was no collusion.</td>\n",
       "      <td>Trump has said there was no collusion.</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.998542</td>\n",
       "      <td>{'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>636</th>\n",
       "      <td>That's the way I feel.</td>\n",
       "      <td>That's the way I feel.</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.996000</td>\n",
       "      <td>{'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>634</th>\n",
       "      <td>I'll turn the radio on.</td>\n",
       "      <td>I'll turn the radio on.</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.996000</td>\n",
       "      <td>{'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>349</th>\n",
       "      <td>I don't know the details.</td>\n",
       "      <td>I don't know the details.</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.996000</td>\n",
       "      <td>{'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....</td>\n",
       "      <td>1.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>400 rows × 6 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                       target_sentence  \\\n",
       "307  That's the only way you can get to Canada from...   \n",
       "478                         It might be kind of scary.   \n",
       "473                       I don't mean to be knocking.   \n",
       "500                      It had to be November anyway.   \n",
       "428                          You had to have a supply.   \n",
       "..                                                 ...   \n",
       "547                      That's the freedom of choice.   \n",
       "638             Trump has said there was no collusion.   \n",
       "636                             That's the way I feel.   \n",
       "634                            I'll turn the radio on.   \n",
       "349                          I don't know the details.   \n",
       "\n",
       "                                pred_sentence  WER_scores  METEOR_scores  \\\n",
       "307  That's the only way you can get to here.    0.181818       0.828761   \n",
       "478            It might be kind of dangerous.    0.166667       0.830000   \n",
       "473               I don't mean to be talking.    0.166667       0.830000   \n",
       "500             It had to be nineteen anyway.    0.166667       0.806667   \n",
       "428                 You had to have a deploy.    0.166667       0.830000   \n",
       "..                                        ...         ...            ...   \n",
       "547             That's the freedom of choice.    0.000000       0.996000   \n",
       "638    Trump has said there was no collusion.    0.000000       0.998542   \n",
       "636                    That's the way I feel.    0.000000       0.996000   \n",
       "634                   I'll turn the radio on.    0.000000       0.996000   \n",
       "349                 I don't know the details.    0.000000       0.996000   \n",
       "\n",
       "                                          ROUGE_scores  BERTScore_F1_scores  \n",
       "307  {'rouge1': (1.0, 0.8181818181818182, 0.9), 'ro...             0.716912  \n",
       "478  {'rouge1': (0.8333333333333334, 0.833333333333...             0.874420  \n",
       "473  {'rouge1': (0.8333333333333334, 0.833333333333...             0.840472  \n",
       "500  {'rouge1': (0.8333333333333334, 0.833333333333...             0.552765  \n",
       "428  {'rouge1': (0.8333333333333334, 0.833333333333...             0.729336  \n",
       "..                                                 ...                  ...  \n",
       "547  {'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....             1.000000  \n",
       "638  {'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....             1.000000  \n",
       "636  {'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....             1.000000  \n",
       "634  {'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....             1.000000  \n",
       "349  {'rouge1': (1.0, 1.0, 1.0), 'rouge2': (1.0, 1....             1.000000  \n",
       "\n",
       "[400 rows x 6 columns]"
      ]
     },
     "execution_count": 119,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results_df.sort_values(\"WER_scores\", ascending=False).tail(400)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 120,
   "id": "8c6fddf7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([256.,   0.,   0.,   0.,   0.,   3.,   9.,  17.,  27.,  45.,   0.,\n",
       "         42.,   3.,  54.,  11.,   0.,  58.,   0.,   0.,  20.,   9.,   0.,\n",
       "         64.,   1.,   2.,  14.,  26.,   1.,  19.,   7.,   3.,   0.,   0.,\n",
       "         68.,   0.,   1.,   1.,   6.,   5.,   0.,  23.,   3.,   1.,   0.,\n",
       "         26.,   0.,   0.,   4.,   0.,   0.,  14.,   1.,   0.,   7.,   1.,\n",
       "          3.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
       "         23.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
       "          1.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
       "          0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,\n",
       "          1.]),\n",
       " array([0.   , 0.015, 0.03 , 0.045, 0.06 , 0.075, 0.09 , 0.105, 0.12 ,\n",
       "        0.135, 0.15 , 0.165, 0.18 , 0.195, 0.21 , 0.225, 0.24 , 0.255,\n",
       "        0.27 , 0.285, 0.3  , 0.315, 0.33 , 0.345, 0.36 , 0.375, 0.39 ,\n",
       "        0.405, 0.42 , 0.435, 0.45 , 0.465, 0.48 , 0.495, 0.51 , 0.525,\n",
       "        0.54 , 0.555, 0.57 , 0.585, 0.6  , 0.615, 0.63 , 0.645, 0.66 ,\n",
       "        0.675, 0.69 , 0.705, 0.72 , 0.735, 0.75 , 0.765, 0.78 , 0.795,\n",
       "        0.81 , 0.825, 0.84 , 0.855, 0.87 , 0.885, 0.9  , 0.915, 0.93 ,\n",
       "        0.945, 0.96 , 0.975, 0.99 , 1.005, 1.02 , 1.035, 1.05 , 1.065,\n",
       "        1.08 , 1.095, 1.11 , 1.125, 1.14 , 1.155, 1.17 , 1.185, 1.2  ,\n",
       "        1.215, 1.23 , 1.245, 1.26 , 1.275, 1.29 , 1.305, 1.32 , 1.335,\n",
       "        1.35 , 1.365, 1.38 , 1.395, 1.41 , 1.425, 1.44 , 1.455, 1.47 ,\n",
       "        1.485, 1.5  ]),\n",
       " <BarContainer object of 100 artists>)"
      ]
     },
     "execution_count": 120,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAGdCAYAAAA44ojeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAhNklEQVR4nO3de3BTZcLH8V8vNIA2rQXatGtBwOUmNxekRlBRKgW6KCM7ArKIDsKqrTPQFaGK1KprWZZRR6bK6LqgMxQQR3AFFsUioFJQK4yA2JXbgkKKyrbhsvRCz/vHvmQ2UISEpHmSfj8zZ4acPEmeh2L69eQkibIsyxIAAIBBokM9AQAAgHMRKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMExvqCfijoaFBhw8fVnx8vKKiokI9HQAAcAksy9Lx48eVlpam6OhfPkYSloFy+PBhpaenh3oaAADAD4cOHdLVV1/9i2PCMlDi4+Ml/XeBdrs9xLMBAACXwu12Kz093fN7/JeEZaCcfVnHbrcTKAAAhJlLOT2Dk2QBAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGCc2FBPwETXzFztdfnAnOwQzQQAgOaJIygAAMA4BAoAADAOgQIAAIzjU6AUFRXphhtuUHx8vJKTkzVq1ChVVFR4jRk8eLCioqK8toceeshrzMGDB5Wdna3WrVsrOTlZ06dPV319/eWvBgAARASfTpLduHGjcnJydMMNN6i+vl5PPPGEhg4dqm+++UZXXHGFZ9zkyZP1zDPPeC63bt3a8+czZ84oOztbDodDmzdv1pEjR3TfffepRYsWev755wOwJAAAEO58CpS1a9d6XV60aJGSk5NVXl6uW265xbO/devWcjgcjd7Hhx9+qG+++UYfffSRUlJS1LdvXz377LOaMWOGnn76acXFxfmxDAAAEEku6xyU6upqSVJSUpLX/sWLF6tt27bq2bOn8vPzderUKc91ZWVl6tWrl1JSUjz7srKy5Ha7tWvXrsuZDgAAiBB+fw5KQ0ODpk6dqoEDB6pnz56e/ffee686dOigtLQ0ff3115oxY4YqKir07rvvSpJcLpdXnEjyXHa5XI0+Vk1NjWpqajyX3W63v9MGAABhwO9AycnJ0c6dO/Xpp5967Z8yZYrnz7169VJqaqqGDBmivXv3qnPnzn49VlFRkQoLC/2dKgAACDN+vcSTm5urVatW6eOPP9bVV1/9i2MzMjIkSXv27JEkORwOVVZWeo05e/lC563k5+erurrasx06dMifaQMAgDDhU6BYlqXc3FytWLFC69evV8eOHS96m+3bt0uSUlNTJUlOp1M7duzQ0aNHPWPWrVsnu92uHj16NHofNptNdrvdawMAAJHLp5d4cnJyVFJSovfee0/x8fGec0YSEhLUqlUr7d27VyUlJRoxYoTatGmjr7/+WtOmTdMtt9yi3r17S5KGDh2qHj16aMKECZo7d65cLpdmzZqlnJwc2Wy2wK8QAACEHZ+OoLz66quqrq7W4MGDlZqa6tmWLVsmSYqLi9NHH32koUOHqlu3bvrjH/+o0aNH6/333/fcR0xMjFatWqWYmBg5nU79/ve/13333ef1uSkAAKB58+kIimVZv3h9enq6Nm7ceNH76dChg9asWePLQwMAgGaE7+IBAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcXwKlKKiIt1www2Kj49XcnKyRo0apYqKCq8xp0+fVk5Ojtq0aaMrr7xSo0ePVmVlpdeYgwcPKjs7W61bt1ZycrKmT5+u+vr6y18NAACICD4FysaNG5WTk6MtW7Zo3bp1qqur09ChQ3Xy5EnPmGnTpun999/X8uXLtXHjRh0+fFh333235/ozZ84oOztbtbW12rx5s958800tWrRIs2fPDtyqAABAWIuyLMvy98Y//vijkpOTtXHjRt1yyy2qrq5Wu3btVFJSot/97neSpG+//Vbdu3dXWVmZbrzxRv3jH//Qb3/7Wx0+fFgpKSmSpAULFmjGjBn68ccfFRcXd9HHdbvdSkhIUHV1tex2u7/Tv6BrZq72unxgTnbAHwMAgObGl9/fl3UOSnV1tSQpKSlJklReXq66ujplZmZ6xnTr1k3t27dXWVmZJKmsrEy9evXyxIkkZWVlye12a9euXY0+Tk1Njdxut9cGAAAil9+B0tDQoKlTp2rgwIHq2bOnJMnlcikuLk6JiYleY1NSUuRyuTxj/jdOzl5/9rrGFBUVKSEhwbOlp6f7O20AABAG/A6UnJwc7dy5U0uXLg3kfBqVn5+v6upqz3bo0KGgPyYAAAidWH9ulJubq1WrVmnTpk26+uqrPfsdDodqa2tVVVXldRSlsrJSDofDM+bzzz/3ur+z7/I5O+ZcNptNNpvNn6kCAIAw5NMRFMuylJubqxUrVmj9+vXq2LGj1/X9+vVTixYtVFpa6tlXUVGhgwcPyul0SpKcTqd27Niho0ePesasW7dOdrtdPXr0uJy1AACACOHTEZScnByVlJTovffeU3x8vOeckYSEBLVq1UoJCQmaNGmS8vLylJSUJLvdrkcffVROp1M33nijJGno0KHq0aOHJkyYoLlz58rlcmnWrFnKycnhKAkAAJDkY6C8+uqrkqTBgwd77V+4cKHuv/9+SdKLL76o6OhojR49WjU1NcrKytIrr7ziGRsTE6NVq1bp4YcfltPp1BVXXKGJEyfqmWeeubyVAACAiHFZn4MSKnwOCgAA4afJPgcFAAAgGAgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYBwCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHEIFAAAYByfA2XTpk0aOXKk0tLSFBUVpZUrV3pdf//99ysqKsprGzZsmNeYY8eOafz48bLb7UpMTNSkSZN04sSJy1oIAACIHD4HysmTJ9WnTx8VFxdfcMywYcN05MgRz7ZkyRKv68ePH69du3Zp3bp1WrVqlTZt2qQpU6b4PnsAABCRYn29wfDhwzV8+PBfHGOz2eRwOBq9bvfu3Vq7dq2++OIL9e/fX5I0f/58jRgxQvPmzVNaWpqvUwIAABEmKOegbNiwQcnJyeratasefvhh/fzzz57rysrKlJiY6IkTScrMzFR0dLS2bt3a6P3V1NTI7XZ7bQAAIHIFPFCGDRumt956S6Wlpfrzn/+sjRs3avjw4Tpz5owkyeVyKTk52es2sbGxSkpKksvlavQ+i4qKlJCQ4NnS09MDPW0AAGAQn1/iuZixY8d6/tyrVy/17t1bnTt31oYNGzRkyBC/7jM/P195eXmey263m0gBACCCBf1txp06dVLbtm21Z88eSZLD4dDRo0e9xtTX1+vYsWMXPG/FZrPJbrd7bQAAIHIFPVC+//57/fzzz0pNTZUkOZ1OVVVVqby83DNm/fr1amhoUEZGRrCnAwAAwoDPL/GcOHHCczREkvbv36/t27crKSlJSUlJKiws1OjRo+VwOLR37149/vjjuvbaa5WVlSVJ6t69u4YNG6bJkydrwYIFqqurU25ursaOHcs7eAAAgCQ/jqB8+eWXuv7663X99ddLkvLy8nT99ddr9uzZiomJ0ddff60777xTXbp00aRJk9SvXz998sknstlsnvtYvHixunXrpiFDhmjEiBEaNGiQXnvttcCtCgAAhDWfj6AMHjxYlmVd8PoPPvjgoveRlJSkkpISXx8aAAA0E3wXDwAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIzjc6Bs2rRJI0eOVFpamqKiorRy5Uqv6y3L0uzZs5WamqpWrVopMzNT3333ndeYY8eOafz48bLb7UpMTNSkSZN04sSJy1oIAACIHD4HysmTJ9WnTx8VFxc3ev3cuXP18ssva8GCBdq6dauuuOIKZWVl6fTp054x48eP165du7Ru3TqtWrVKmzZt0pQpU/xfBQAAiCixvt5g+PDhGj58eKPXWZall156SbNmzdJdd90lSXrrrbeUkpKilStXauzYsdq9e7fWrl2rL774Qv3795ckzZ8/XyNGjNC8efOUlpZ2GcsBAACRIKDnoOzfv18ul0uZmZmefQkJCcrIyFBZWZkkqaysTImJiZ44kaTMzExFR0dr69atjd5vTU2N3G631wYAACJXQAPF5XJJklJSUrz2p6SkeK5zuVxKTk72uj42NlZJSUmeMecqKipSQkKCZ0tPTw/ktAEAgGHC4l08+fn5qq6u9myHDh0K9ZQAAEAQBTRQHA6HJKmystJrf2Vlpec6h8Oho0ePel1fX1+vY8eOecacy2azyW63e20AACByBTRQOnbsKIfDodLSUs8+t9utrVu3yul0SpKcTqeqqqpUXl7uGbN+/Xo1NDQoIyMjkNMBAABhyud38Zw4cUJ79uzxXN6/f7+2b9+upKQktW/fXlOnTtVzzz2nX//61+rYsaOeeuoppaWladSoUZKk7t27a9iwYZo8ebIWLFiguro65ebmauzYsbyDBwAASPIjUL788kvddtttnst5eXmSpIkTJ2rRokV6/PHHdfLkSU2ZMkVVVVUaNGiQ1q5dq5YtW3pus3jxYuXm5mrIkCGKjo7W6NGj9fLLLwdgOQAAIBJEWZZlhXoSvnK73UpISFB1dXVQzke5ZuZqr8sH5mQH/DEAAGhufPn9HRbv4gEAAM0LgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADBObKgnAES6a2auPm/fgTnZIZgJAIQPjqAAAADjECgAAMA4BAoAADAOgQIAAIzDSbIw2rknmHJyKQA0DxxBAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHD2qDz/jwNABAsHEEBQAAGIdAAQAAxiFQAACAcQgUAABgHE6SbcbOPdlV4oRXAIAZAn4E5emnn1ZUVJTX1q1bN8/1p0+fVk5Ojtq0aaMrr7xSo0ePVmVlZaCnAQAAwlhQXuK57rrrdOTIEc/26aefeq6bNm2a3n//fS1fvlwbN27U4cOHdffddwdjGgAAIEwF5SWe2NhYORyO8/ZXV1frjTfeUElJiW6//XZJ0sKFC9W9e3dt2bJFN954YzCmAwAAwkxQjqB89913SktLU6dOnTR+/HgdPHhQklReXq66ujplZmZ6xnbr1k3t27dXWVlZMKYCAADCUMCPoGRkZGjRokXq2rWrjhw5osLCQt18883auXOnXC6X4uLilJiY6HWblJQUuVyuC95nTU2NampqPJfdbnegpw0AAAwS8EAZPny458+9e/dWRkaGOnTooLffflutWrXy6z6LiopUWFgYqCkCAADDBf1zUBITE9WlSxft2bNHDodDtbW1qqqq8hpTWVnZ6DkrZ+Xn56u6utqzHTp0KMizBgAAoRT0QDlx4oT27t2r1NRU9evXTy1atFBpaann+oqKCh08eFBOp/OC92Gz2WS32702AAAQuQL+Es9jjz2mkSNHqkOHDjp8+LAKCgoUExOjcePGKSEhQZMmTVJeXp6SkpJkt9v16KOPyul08g4eAADgEfBA+f777zVu3Dj9/PPPateunQYNGqQtW7aoXbt2kqQXX3xR0dHRGj16tGpqapSVlaVXXnkl0NMAAABhLOCBsnTp0l+8vmXLliouLlZxcXGgHxoAAEQIvosnDPEdOgCASMe3GQMAAOMQKAAAwDgECgAAMA6BAgAAjMNJshGKE2kBAOGMIygAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjEOgAAAA4xAoAADAOAQKAAAwDoECAACMQ6AAAADjECgAAMA4BAoAADAOgQIAAIxDoAAAAOMQKAAAwDgECgAAMA6BAgAAjBMb6gng4q6ZuTrUUwAAoElxBAUAABiHIyhoFs49CnVgTnaIZhK+GjuSx98jgGDhCAoAADAOgQIAAIxDoAAAAONwDgoQ4Th3JPLwM0VzwBEUAABgHAIFAAAYh0ABAADG4RwUNAleMwcA+IJAAf4fH+YGAObgJR4AAGAcAgUAABiHl3gMwzcXAwDAERQAAGAgAgUAABiHQAEAAMbhHBREHM7jCX98bg4AjqAAAADjECgAAMA4BAoAADAO56DACx/3DgAwAUdQAACAcQgUAABgHF7iCbHm/JZYXk4CAFwIR1AAAIBxCBQAAGAcXuIBLoBPMw0/vGwIRA4CBQgTBBOA5oRAAcIYRwwARCrOQQEAAMbhCAqAoOIoDwB/hDRQiouL9Ze//EUul0t9+vTR/PnzNWDAgFBOKahC/ZknoX78SBApv2wjZR3+aM5rB8JJyF7iWbZsmfLy8lRQUKCvvvpKffr0UVZWlo4ePRqqKQEAAEOE7AjKCy+8oMmTJ+uBBx6QJC1YsECrV6/W3/72N82cOTNU00IY4sgQAESekARKbW2tysvLlZ+f79kXHR2tzMxMlZWVnTe+pqZGNTU1nsvV1dWSJLfbHZT5NdSc8rocqMc5934D6dw5BuqxGlu7P38/lzKfS3ms9tOWX/R+LkWg7sfftZ97u54FHwRtPpfyd9/Y38fOwqyL3o8/6/f37+zcOZ47v8Zudyk/52A9jwSTvz8L05z7776xnymaRmPPQcH4eZz9d2pZ1sUHWyHwww8/WJKszZs3e+2fPn26NWDAgPPGFxQUWJLY2NjY2NjYImA7dOjQRVshLN7Fk5+fr7y8PM/lhoYGHTt2TG3atFFUVFRAH8vtdis9PV2HDh2S3W4P6H2biPVGNtYb2VhvZIvE9VqWpePHjystLe2iY0MSKG3btlVMTIwqKyu99ldWVsrhcJw33mazyWazee1LTEwM5hRlt9sj5h/EpWC9kY31RjbWG9kibb0JCQmXNC4k7+KJi4tTv379VFpa6tnX0NCg0tJSOZ3OUEwJAAAYJGQv8eTl5WnixInq37+/BgwYoJdeekknT570vKsHAAA0XyELlDFjxujHH3/U7Nmz5XK51LdvX61du1YpKSmhmpKk/76cVFBQcN5LSpGK9UY21hvZWG9ka27rPVeUZV3Ke30AAACaDl8WCAAAjEOgAAAA4xAoAADAOAQKAAAwTrMMlOLiYl1zzTVq2bKlMjIy9Pnnn//i+OXLl6tbt25q2bKlevXqpTVr1jTRTAPDl/W+/vrruvnmm3XVVVfpqquuUmZm5kX/fkzj68/3rKVLlyoqKkqjRo0K7gQDzNf1VlVVKScnR6mpqbLZbOrSpUtY/Zv2db0vvfSSunbtqlatWik9PV3Tpk3T6dOnm2i2l2fTpk0aOXKk0tLSFBUVpZUrV170Nhs2bNBvfvMb2Ww2XXvttVq0aFHQ5xkovq733Xff1R133KF27drJbrfL6XTqgw8C871WTcGfn+9Zn332mWJjY9W3b9+gzS/Uml2gLFu2THl5eSooKNBXX32lPn36KCsrS0ePHm10/ObNmzVu3DhNmjRJ27Zt06hRozRq1Cjt3LmziWfuH1/Xu2HDBo0bN04ff/yxysrKlJ6erqFDh+qHH35o4pn7x9f1nnXgwAE99thjuvnmm5topoHh63pra2t1xx136MCBA3rnnXdUUVGh119/Xb/61a+aeOb+8XW9JSUlmjlzpgoKCrR792698cYbWrZsmZ544okmnrl/Tp48qT59+qi4uPiSxu/fv1/Z2dm67bbbtH37dk2dOlUPPvhg2PzS9nW9mzZt0h133KE1a9aovLxct912m0aOHKlt27YFeaaB4et6z6qqqtJ9992nIUOGBGlmhgjM1/+FjwEDBlg5OTmey2fOnLHS0tKsoqKiRsffc889VnZ2tte+jIwM6w9/+ENQ5xkovq73XPX19VZ8fLz15ptvBmuKAeXPeuvr662bbrrJ+utf/2pNnDjRuuuuu5pgpoHh63pfffVVq1OnTlZtbW1TTTGgfF1vTk6Odfvtt3vty8vLswYOHBjUeQaDJGvFihW/OObxxx+3rrvuOq99Y8aMsbKysoI4s+C4lPU2pkePHlZhYWHgJxRkvqx3zJgx1qxZs6yCggKrT58+QZ1XKDWrIyi1tbUqLy9XZmamZ190dLQyMzNVVlbW6G3Kysq8xktSVlbWBcebxJ/1nuvUqVOqq6tTUlJSsKYZMP6u95lnnlFycrImTZrUFNMMGH/W+/e//11Op1M5OTlKSUlRz5499fzzz+vMmTNNNW2/+bPem266SeXl5Z6Xgfbt26c1a9ZoxIgRTTLnphbOz1eB0NDQoOPHj4fF85W/Fi5cqH379qmgoCDUUwm6sPg240D56aefdObMmfM+rTYlJUXffvtto7dxuVyNjne5XEGbZ6D4s95zzZgxQ2lpaec96ZnIn/V++umneuONN7R9+/YmmGFg+bPeffv2af369Ro/frzWrFmjPXv26JFHHlFdXZ3xT3j+rPfee+/VTz/9pEGDBsmyLNXX1+uhhx4Km5d4fHWh5yu3263//Oc/atWqVYhm1jTmzZunEydO6J577gn1VILiu+++08yZM/XJJ58oNjbyf303qyMo8M2cOXO0dOlSrVixQi1btgz1dALu+PHjmjBhgl5//XW1bds21NNpEg0NDUpOTtZrr72mfv36acyYMXryySe1YMGCUE8tKDZs2KDnn39er7zyir766iu9++67Wr16tZ599tlQTw0BVlJSosLCQr399ttKTk4O9XQC7syZM7r33ntVWFioLl26hHo6TSLyE+x/tG3bVjExMaqsrPTaX1lZKYfD0ehtHA6HT+NN4s96z5o3b57mzJmjjz76SL179w7mNAPG1/Xu3btXBw4c0MiRIz37GhoaJEmxsbGqqKhQ586dgzvpy+DPzzc1NVUtWrRQTEyMZ1/37t3lcrlUW1uruLi4oM75cviz3qeeekoTJkzQgw8+KEnq1auXTp48qSlTpujJJ59UdHRk/T/ahZ6v7HZ7RB89Wbp0qR588EEtX748LI72+uP48eP68ssvtW3bNuXm5kr67/OVZVmKjY3Vhx9+qNtvvz3EswysyPqv8yLi4uLUr18/lZaWevY1NDSotLRUTqez0ds4nU6v8ZK0bt26C443iT/rlaS5c+fq2Wef1dq1a9W/f/+mmGpA+Lrebt26aceOHdq+fbtnu/POOz3vgEhPT2/K6fvMn5/vwIEDtWfPHk+ISdI///lPpaamGh0nkn/rPXXq1HkRcjbOrAj8GrJwfr7y15IlS/TAAw9oyZIlys7ODvV0gsZut5/3fPXQQw+pa9eu2r59uzIyMkI9xcAL8Um6TW7p0qWWzWazFi1aZH3zzTfWlClTrMTERMvlclmWZVkTJkywZs6c6Rn/2WefWbGxsda8efOs3bt3WwUFBVaLFi2sHTt2hGoJPvF1vXPmzLHi4uKsd955xzpy5IhnO378eKiW4BNf13uucHsXj6/rPXjwoBUfH2/l5uZaFRUV1qpVq6zk5GTrueeeC9USfOLregsKCqz4+HhryZIl1r59+6wPP/zQ6ty5s3XPPfeEagk+OX78uLVt2zZr27ZtliTrhRdesLZt22b961//sizLsmbOnGlNmDDBM37fvn1W69atrenTp1u7d++2iouLrZiYGGvt2rWhWoJPfF3v4sWLrdjYWKu4uNjr+aqqqipUS/CJr+s9V6S/i6fZBYplWdb8+fOt9u3bW3FxcdaAAQOsLVu2eK679dZbrYkTJ3qNf/vtt60uXbpYcXFx1nXXXWetXr26iWd8eXxZb4cOHSxJ520FBQVNP3E/+frz/V/hFiiW5ft6N2/ebGVkZFg2m83q1KmT9ac//cmqr69v4ln7z5f11tXVWU8//bTVuXNnq2XLllZ6err1yCOPWP/+97+bfuJ++Pjjjxv97/HsGidOnGjdeuut592mb9++VlxcnNWpUydr4cKFTT5vf/m63ltvvfUXx5vOn5/v/4r0QImyrAg8zgkAAMJaszoHBQAAhAcCBQAAGIdAAQAAxiFQAACAcQgUAABgHAIFAAAYh0ABAADGIVAAAIBxCBQAAGAcAgUAABiHQAEAAMYhUAAAgHH+D8o0VlBfUQS6AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.hist(results_df[\"WER_scores\"], bins=100)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 121,
   "id": "b9488e47",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(np.float64(0.2),\n",
       " np.float64(0.26279362488002195),\n",
       " np.float64(0.2546811649239338))"
      ]
     },
     "execution_count": 121,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.median(results_df[\"WER_scores\"]), np.mean(results_df[\"WER_scores\"]), np.std(results_df[\"WER_scores\"])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
