{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "fc5ed66f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.5.1+cu124\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "import os\n",
    "print(torch.__version__)\n",
    "import os\n",
    "\n",
    "import pickle\n",
    "import sys\n",
    "sys.path.append('../')\n",
    "import numpy as np\n",
    "import os \n",
    "import time\n",
    "import tqdm\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from dataset import PHONE_DEF_SIL\n",
    "from phonemeLM import PhonemeTokenizer, LightningGPT2PhonemeModel\n",
    "from dataset import PHONE_DEF_SIL, idsToPhonemes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fb03857e",
   "metadata": {},
   "outputs": [],
   "source": [
    "results_dir = \"../results/gru_ctc_mfcc/\"\n",
    "# results_dir = \"../results/sm_gru_ctc_diphones/\"\n",
    "pred_logits = pickle.load(open(os.path.join(results_dir, \"pred_logits.pkl\"), \"rb\"))\n",
    "\n",
    "df = pd.read_csv(os.path.join(results_dir, \"results.csv\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "a8fe68a3",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits_unfolded = [item for sublist in pred_logits for item in sublist]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9d3680a",
   "metadata": {},
   "source": [
    "## PhonemeLM loading"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "df73f7ab",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = \"cuda:0\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "dda18789",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Embedding size: torch.Size([44, 768])\n",
      "LM head size: 44\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_2801248/1286599691.py:29: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(\"phonemeLM/best_model-v1.ckpt\")[\"state_dict\"])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "LightningGPT2PhonemeModel(\n",
       "  (model): GPT2LMHeadModel(\n",
       "    (transformer): GPT2Model(\n",
       "      (wte): Embedding(44, 768)\n",
       "      (wpe): Embedding(1024, 768)\n",
       "      (drop): Dropout(p=0.1, inplace=False)\n",
       "      (h): ModuleList(\n",
       "        (0-11): 12 x GPT2Block(\n",
       "          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (attn): GPT2Attention(\n",
       "            (c_attn): Conv1D(nf=2304, nx=768)\n",
       "            (c_proj): Conv1D(nf=768, nx=768)\n",
       "            (attn_dropout): Dropout(p=0.1, inplace=False)\n",
       "            (resid_dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "          (mlp): GPT2MLP(\n",
       "            (c_fc): Conv1D(nf=3072, nx=768)\n",
       "            (c_proj): Conv1D(nf=768, nx=3072)\n",
       "            (act): NewGELUActivation()\n",
       "            (dropout): Dropout(p=0.1, inplace=False)\n",
       "          )\n",
       "        )\n",
       "      )\n",
       "      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (lm_head): Linear(in_features=768, out_features=44, bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import GPT2LMHeadModel, GPT2Config\n",
    "\n",
    "tokenizer = PhonemeTokenizer(PHONE_DEF_SIL)\n",
    "\n",
    "# 1. Load pretrained config and model\n",
    "config = GPT2Config.from_pretrained(\"gpt2\")\n",
    "gpt_model = GPT2LMHeadModel.from_pretrained(\"gpt2\", config=config)\n",
    "\n",
    "# 2. Adjust vocab size to your tokenizer\n",
    "phoneme_vocab_size = len(tokenizer)  # assumes tokenizer.__len__() works\n",
    "config.vocab_size = phoneme_vocab_size\n",
    "\n",
    "# 3. Replace embeddings and output head\n",
    "#    These are the only two layers tied to vocab size\n",
    "gpt_model.resize_token_embeddings(phoneme_vocab_size)\n",
    "\n",
    "print(\"Embedding size:\", gpt_model.transformer.wte.weight.shape)\n",
    "print(\"LM head size:\", gpt_model.lm_head.out_features)\n",
    "\n",
    "model = LightningGPT2PhonemeModel(\n",
    "    gpt2_model=gpt_model,\n",
    "    tokenizer=tokenizer,\n",
    "    learning_rate=5e-5,\n",
    "    weight_decay=1e-5,\n",
    "    max_length=64\n",
    ")\n",
    "\n",
    "# Load the model weights\n",
    "model.load_state_dict(torch.load(\"phonemeLM/best_model-v1.ckpt\")[\"state_dict\"])\n",
    "\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "53987350",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0    [TH, IY, AA, K, R, AH, S, IY, SIL, R, IY, K, A...\n",
       "1    [R, IH, CH, SIL, P, ER, CH, AH, S, T, SIL, S, ...\n",
       "2    [S, OW, SIL, R, UW, L, Z, SIL, W, IY, SIL, M, ...\n",
       "3    [L, AO, R, IY, Z, SIL, K, AA, S, T, UW, M, SIL...\n",
       "4    [DH, AH, SIL, T, UW, TH, SIL, F, EH, R, IY, SI...\n",
       "Name: True Phonemes, dtype: object"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#check if GPT2 is working\n",
    "import ast \n",
    "\n",
    "phoneme_batch = df[\"True Phonemes\"].apply(ast.literal_eval).iloc[:5]\n",
    "phoneme_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "b158d208",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[['TH', 'IY', 'AA', 'K', 'R', 'AH', 'S', 'IY', 'AH', 'L', 'SIL', 'W', 'AA', 'Z', 'SIL', 'DH', 'AH', 'SIL', 'OW', 'N', 'L', 'IY', 'SIL', 'W', 'AH', 'N', 'SIL', 'HH', 'UW', 'SIL', 'HH', 'AE', 'D', 'SIL', 'EH', 'V', 'ER']]\n"
     ]
    }
   ],
   "source": [
    "input_ids = model.tokenizer.batch_encode(phoneme_batch.tolist())\n",
    "\n",
    "#decode input_ids\n",
    "decoded_input = model.tokenizer.batch_decode(input_ids)\n",
    "# decoded_input[-1]\n",
    "\n",
    "#try generation from \n",
    "print(model.generate([phoneme_batch[0][:8]],max_length=30))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a0438c63",
   "metadata": {},
   "source": [
    "## get the log_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "787e600e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.5279407501220703"
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_phoneme_batch = df[\"Predicted Phonemes\"].apply(ast.literal_eval).iloc[:5]\n",
    "\n",
    "pred_input_ids = model.tokenizer.batch_encode(phoneme_batch.tolist())\n",
    "\n",
    "loss, _ = model(phoneme_batch)\n",
    "nll= loss.item()\n",
    "\n",
    "nll"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "id": "387d5e29",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0      [TH, IY, AA, K, R, AH, S, IY, SIL, R, IY, K, A...\n",
       "1      [R, IH, CH, SIL, P, ER, CH, AH, S, T, SIL, S, ...\n",
       "2      [S, OW, SIL, R, UW, L, Z, SIL, W, IY, SIL, M, ...\n",
       "3      [L, AO, R, IY, Z, SIL, K, AA, S, T, UW, M, SIL...\n",
       "4      [DH, AH, SIL, T, UW, TH, SIL, F, EH, R, IY, SI...\n",
       "                             ...                        \n",
       "875    [Y, AO, R, SIL, T, Y, UW, IH, SH, AH, N, SIL, ...\n",
       "876    [G, EH, T, IH, NG, SIL, P, L, EH, JH, SIL, SH,...\n",
       "877    [IH, F, SIL, Y, UW, SIL, HH, AE, V, SIL, EH, N...\n",
       "878     [M, IH, S, T, ER, IY, SIL, M, UW, V, IY, Z, SIL]\n",
       "879    [AY, SIL, K, AE, N, T, SIL, R, IH, M, EH, M, B...\n",
       "Name: True Phonemes, Length: 880, dtype: object"
      ]
     },
     "execution_count": 175,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df[\"True Phonemes\"].apply(ast.literal_eval)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 166,
   "id": "3e6db2fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Beam GPT‑2 + CTC Phoneme Decoder (v6 – phone‑aware debug)\n",
    "=========================================================\n",
    "\n",
    "This version separates **model token IDs** from **readable phoneme labels** so\n",
    "that debug prints (and optionally final outputs) use `idsToPhonemes` (or any\n",
    "custom mapping) rather than the GPT‑2 tokenizer, which may include special\n",
    "sub‑word symbols.\n",
    "\n",
    "What’s new\n",
    "----------\n",
    "* `phone_decoder` parameter – pass a function that converts `List[int] →\n",
    "  List[str]` (e.g. your `idsToPhonemes`).\n",
    "* All verbose prints (arg‑max, candidate list, best beam) now use\n",
    "  `phone_decoder` if supplied.\n",
    "* Optional `return_phonemes` flag: if `True`, final hypotheses are returned as\n",
    "  **phoneme lists**; else we fall back to `tokenizer.decode` (unchanged).\n",
    "\n",
    "Quick usage\n",
    "-----------\n",
    "```python\n",
    "best = decode_phonemes_beam_gpt2_ctc(\n",
    "    logits_sample[0], tokenizer, model,\n",
    "    phone_decoder=idsToPhonemes,  # <- new\n",
    "    beam_width=10, topk=3,\n",
    "    ctc_weight=1.0, lm_weight=0.0,\n",
    "    verbose=True,\n",
    "    return_phonemes=True,\n",
    ")\n",
    "print(best[0][0])   # list of phones\n",
    "```\n",
    "\"\"\"\n",
    "\n",
    "from __future__ import annotations\n",
    "\n",
    "import functools\n",
    "import math\n",
    "from typing import Callable, List, Optional, Sequence, Tuple\n",
    "\n",
    "import torch\n",
    "\n",
    "# -----------------------------------------------------------------------------\n",
    "# Segment utilities\n",
    "# -----------------------------------------------------------------------------\n",
    "\n",
    "def _segments(ids: torch.Tensor) -> List[Tuple[int, int]]:\n",
    "    \"\"\"Return list of (start, end) spans where *ids* is constant (end exclusive).\"\"\"\n",
    "    if ids.numel() == 0:\n",
    "        return []\n",
    "    segs: List[Tuple[int, int]] = []\n",
    "    start = 0\n",
    "    for t in range(1, ids.size(0)):\n",
    "        if ids[t].item() != ids[t - 1].item():\n",
    "            segs.append((start, t))\n",
    "            start = t\n",
    "    segs.append((start, ids.size(0)))\n",
    "    return segs\n",
    "\n",
    "\n",
    "# -----------------------------------------------------------------------------\n",
    "# GPT‑2 NLL (cached)\n",
    "# -----------------------------------------------------------------------------\n",
    "\n",
    "def _gpt2_nll(model: torch.nn.Module, ids: Tuple[int, ...]) -> float:\n",
    "    device = next(model.parameters()).device\n",
    "    tensor = torch.tensor(ids, device=device).unsqueeze(0)\n",
    "    with torch.no_grad():\n",
    "        out = model.model(input_ids=tensor, labels=tensor)\n",
    "        return out.loss.item() * (tensor.size(1) - 1)\n",
    "\n",
    "_gpt2_nll_cached = functools.lru_cache(maxsize=8192)(_gpt2_nll)\n",
    "\n",
    "\n",
    "# -----------------------------------------------------------------------------\n",
    "# Helper: ID → printable string\n",
    "# -----------------------------------------------------------------------------\n",
    "\n",
    "def _id2str(token_id: int, phone_decoder: Optional[Callable[[List[int]], List[str]]], tokenizer) -> str:\n",
    "    if phone_decoder is not None:\n",
    "        return phone_decoder([token_id])[0]\n",
    "    return tokenizer.decode([token_id])\n",
    "\n",
    "\n",
    "def _ids2str(token_ids: Sequence[int], phone_decoder: Optional[Callable[[List[int]], List[str]]], tokenizer) -> str:\n",
    "    if phone_decoder is not None:\n",
    "        return \" \".join(phone_decoder(list(token_ids)))\n",
    "    return tokenizer.decode(list(token_ids))\n",
    "\n",
    "\n",
    "# -----------------------------------------------------------------------------\n",
    "# Main decode\n",
    "# -----------------------------------------------------------------------------\n",
    "\n",
    "def decode_phonemes_beam_gpt2_ctc(\n",
    "    logits: torch.Tensor,\n",
    "    tokenizer,\n",
    "    gpt2_model: torch.nn.Module,\n",
    "    *,\n",
    "    beam_width: int = 10,\n",
    "    topk: int = 3,\n",
    "    n_best: int = 5,\n",
    "    blank_idx: int = 0,\n",
    "    ctc_weight: float = 1.0,\n",
    "    lm_weight: float = 1.0,\n",
    "    verbose: bool = False,\n",
    "    phone_decoder: Optional[Callable[[List[int]], List[str]]] = None,\n",
    "    return_phonemes: bool = False,\n",
    ") -> List[Tuple[str | List[str], float]]:\n",
    "    \"\"\"Beam search with combined CTC + GPT‑2 costs.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    phone_decoder : Callable or None\n",
    "        Function mapping `List[int] → List[str]` (e.g. `idsToPhonemes`). Used\n",
    "        **only for display/return**, not for GPT‑2 input.\n",
    "    return_phonemes : bool, default=False\n",
    "        If True, the first element of each result tuple is *List[str]* from\n",
    "        `phone_decoder`; otherwise a string from `tokenizer.decode`.\n",
    "    \"\"\"\n",
    "    log_probs = torch.log_softmax(logits, dim=-1)\n",
    "    argmax_ids = logits.argmax(dim=-1)\n",
    "    segs = _segments(argmax_ids)\n",
    "\n",
    "    # Remove blank segments\n",
    "    segs = [seg for seg in segs if argmax_ids[seg[0]].item() != blank_idx]\n",
    "    if not segs:\n",
    "        if verbose:\n",
    "            print(\"[Decoder] No non‑blank segments – returning empty.\")\n",
    "        return [([], 0.0)] if return_phonemes else [(\"\", 0.0)]\n",
    "\n",
    "    # Pre‑compute candidates per segment\n",
    "    seg_cands: List[List[Tuple[int, float]]] = []\n",
    "    for seg_idx, (s, e) in enumerate(segs):\n",
    "        lp_avg = torch.logsumexp(log_probs[s:e], dim=0) - math.log(e - s)\n",
    "        vals, idx = torch.topk(lp_avg, k=min(topk + 1, lp_avg.size(0)))\n",
    "        cands = [\n",
    "            (i.item(), v.item())\n",
    "            for i, v in zip(idx, vals)\n",
    "            if i.item() != blank_idx\n",
    "        ][:topk]\n",
    "        if not cands:\n",
    "            best_idx = lp_avg.argmax().item()\n",
    "            cands = [(best_idx, lp_avg[best_idx].item())]\n",
    "        seg_cands.append(cands)\n",
    "\n",
    "        if verbose:\n",
    "            argmax_str = _id2str(argmax_ids[s].item(), phone_decoder, tokenizer)\n",
    "            printable = \", \".join(\n",
    "                f\"{_id2str(tid, phone_decoder, tokenizer)}:{lp:.2f}\" for tid, lp in cands\n",
    "            )\n",
    "            print(f\"Step {seg_idx:02d} | argmax={argmax_str:3s} | candidates [{printable}]\")\n",
    "\n",
    "    # Beam = (ids_tuple, ctc_lp_sum, lm_cost)\n",
    "    beams: List[Tuple[Tuple[int, ...], float, float]] = [(tuple(), 0.0, 0.0)]\n",
    "\n",
    "    for seg_idx, cands in enumerate(seg_cands):\n",
    "        new_beams: List[Tuple[Tuple[int, ...], float, float]] = []\n",
    "        for ids, ctc_lp, _ in beams:\n",
    "            for token_id, seg_lp in cands:\n",
    "                ext_ids = ids + (token_id,)\n",
    "                ext_ctc = ctc_lp + seg_lp\n",
    "                ext_lm = _gpt2_nll_cached(gpt2_model, ext_ids) if lm_weight != 0 else 0.0\n",
    "                new_beams.append((ext_ids, ext_ctc, ext_lm))\n",
    "        new_beams.sort(key=lambda x: ctc_weight * (-x[1]) + lm_weight * x[2])\n",
    "        beams = new_beams[:beam_width]\n",
    "\n",
    "        if verbose:\n",
    "            top_ids, top_ctc, top_lm = beams[0]\n",
    "            hypo_str = _ids2str(top_ids, phone_decoder, tokenizer)\n",
    "            cost = ctc_weight * (-top_ctc) + lm_weight * top_lm\n",
    "            print(\n",
    "                f\"           best→ {hypo_str} | cost={cost:.2f} (CTC {top_ctc:.2f}, LM {top_lm:.2f})\"\n",
    "            )\n",
    "\n",
    "    # Final sorting and return\n",
    "    beams.sort(key=lambda x: ctc_weight * (-x[1]) + lm_weight * x[2])\n",
    "    results: List[Tuple[str | List[str], float]] = []\n",
    "    for ids, ctc_lp, lm in beams[:n_best]:\n",
    "        cost = ctc_weight * (-ctc_lp) + lm_weight * lm\n",
    "        if return_phonemes and phone_decoder is not None:\n",
    "            results.append((phone_decoder(list(ids)), cost))\n",
    "        else:\n",
    "            results.append((_ids2str(ids, phone_decoder, tokenizer), cost))\n",
    "\n",
    "    if verbose:\n",
    "        print(\"[Decoder] Finished. Top hypotheses:\")\n",
    "        for s, c in results:\n",
    "            print(f\"  {s} | {c:.2f}\")\n",
    "    return results\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f29a2721",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "It's a package deal.\n",
      "[LM guided] ['IH', 'T', 'S', 'SIL', 'AY', 'SIL', 'R', 'EY', 'G', 'IH', 'SIL', 'D', 'IY', 'L', 'SIL']\n",
      "[CTC Argmax] ['IH', 'T', 'S', 'SIL', 'AH', 'SIL', 'B', 'EY', 'K', 'AH', 'SIL', 'D', 'IY', 'L', 'SIL']\n"
     ]
    }
   ],
   "source": [
    "idx = 100\n",
    "print(df[\"True Sentence\"].iloc[idx])\n",
    "logits_sample = torch.stack(logits_unfolded[idx:idx+1])  # (T, 41)\n",
    "best = decode_phonemes_beam_gpt2_ctc(\n",
    "    logits_sample[0],\n",
    "    tokenizer,\n",
    "    model,\n",
    "    phone_decoder=idsToPhonemes,   # <- new\n",
    "    beam_width=10,\n",
    "    topk=5,\n",
    "    ctc_weight=0.6,\n",
    "    lm_weight=1.0,                 # CTC-only for sanity check\n",
    "    verbose=False,\n",
    "    return_phonemes=True,\n",
    ")\n",
    "print(\"[LM guided]\", best[0][0])   # ['IH', 'T', 'S', ...]\n",
    "\n",
    "# print(\"Decoded phonemes:\", phoneme_seq)\n",
    "\n",
    "#argmax decoding\n",
    "argmax_pred = torch.unique_consecutive(logits_sample[0].argmax(-1))\n",
    "print(\"[CTC Argmax]\", idsToPhonemes(argmax_pred.numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "id": "5047ec89",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/880 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 880/880 [29:18<00:00,  2.00s/it] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average CER (argmax): 0.16652904297839005\n",
      "Average CER (LM-guided): 0.36025332475662336\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "from edit_distance import SequenceMatcher\n",
    "from tqdm import tqdm\n",
    "import ast\n",
    "\n",
    "# Parse the true phoneme sequences\n",
    "true_phonemes_all = df[\"True Phonemes\"].apply(ast.literal_eval).tolist()\n",
    "\n",
    "# Output storage\n",
    "decoded_argmax_all = []\n",
    "decoded_lm_all = []\n",
    "cer_argmax_list = []\n",
    "cer_lm_list = []\n",
    "\n",
    "# Loop over all samples\n",
    "for idx in tqdm(range(len(df))):\n",
    "    logits_sample = torch.stack(logits_unfolded[idx:idx+1])[0]  # (T, 41)\n",
    "\n",
    "    # ----- Greedy Argmax CTC Decoding -----\n",
    "    argmax_pred = torch.unique_consecutive(logits_sample.argmax(-1))\n",
    "    argmax_pred = argmax_pred[argmax_pred != 0]\n",
    "    decoded_argmax = idsToPhonemes(argmax_pred.cpu().numpy())\n",
    "    decoded_argmax_all.append(decoded_argmax)\n",
    "\n",
    "    # ----- LM-guided Beam Search Decoding -----\n",
    "    best = decode_phonemes_beam_gpt2_ctc(\n",
    "        logits_sample,\n",
    "        tokenizer,\n",
    "        model,\n",
    "        phone_decoder=idsToPhonemes,\n",
    "        beam_width=5,\n",
    "        topk=3,\n",
    "        ctc_weight=0.6,\n",
    "        lm_weight=1.0,\n",
    "        verbose=False,\n",
    "        return_phonemes=True,\n",
    "    )\n",
    "    decoded_lm = best[0][0]\n",
    "    decoded_lm_all.append(decoded_lm)\n",
    "\n",
    "    # ----- Ground Truth -----\n",
    "    true_phonemes = true_phonemes_all[idx]\n",
    "\n",
    "    # ----- Compute CERs -----\n",
    "    matcher_argmax = SequenceMatcher(a=true_phonemes, b=decoded_argmax)\n",
    "    cer_argmax = matcher_argmax.distance() / len(true_phonemes) if len(true_phonemes) > 0 else 1.0\n",
    "    cer_argmax_list.append(cer_argmax)\n",
    "\n",
    "    matcher_lm = SequenceMatcher(a=true_phonemes, b=decoded_lm)\n",
    "    cer_lm = matcher_lm.distance() / len(true_phonemes) if len(true_phonemes) > 0 else 1.0\n",
    "    cer_lm_list.append(cer_lm)\n",
    "\n",
    "# ----- Summary -----\n",
    "print(\"Average CER (argmax):\", sum(cer_argmax_list) / len(cer_argmax_list))\n",
    "print(\"Average CER (LM-guided):\", sum(cer_lm_list) / len(cer_lm_list))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "747fa345",
   "metadata": {},
   "outputs": [],
   "source": [
    "decoded_lm_all"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
