{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "59debae9",
   "metadata": {},
   "source": [
    "## Here we target phoneme and diphoneme predictions as two distict tasks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "e2787057",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package cmudict to /home/XXXXXX/nltk_data...\n",
      "[nltk_data]   Package cmudict is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "from dataset import SpeechSentenceDataset, idsToPhonemes, getDatasetLoaders,getDatasetLoaders_V3, PHONE_DEF, PHONE_DEF_SIL\n",
    "import re \n",
    "from g2p_en import G2p\n",
    "import numpy as np\n",
    "from model.ctc_modelling import LightningGRUDecoder\n",
    "import time\n",
    "import numpy as np\n",
    "from edit_distance import SequenceMatcher\n",
    "import tqdm\n",
    "import pytorch_lightning as pl\n",
    "import jiwer\n",
    "import nltk\n",
    "from nltk.corpus import cmudict\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "import wandb\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import copy\n",
    "from difflib import get_close_matches\n",
    "from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer\n",
    "import pandas as pd\n",
    "from torchaudio.models.decoder import ctc_decoder\n",
    "import string\n",
    "from config import DATASET_SM_ROBUST, DATASET_SM_ZSCORE, DATASET_FULL_TRIALS_ZSCORE\n",
    "import torchaudio  \n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "# from model.ctc_modelling import Light\n",
    "\n",
    "# Download CMU Pronouncing Dictionary (First-time use)\n",
    "nltk.download(\"cmudict\")\n",
    "\n",
    "# Load CMUdict\n",
    "cmu_dict = cmudict.dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "b672ec04",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of trials:  10020\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  8800\n",
      "Number of trials:  880\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  880\n"
     ]
    }
   ],
   "source": [
    "train_loader, test_loader,_, loadedData = getDatasetLoaders_V3(DATASET_FULL_TRIALS_ZSCORE, 64, include_prego=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "99ce5c68",
   "metadata": {},
   "outputs": [],
   "source": [
    "from itertools import product\n",
    "from collections import defaultdict\n",
    "\n",
    "def build_diphone_vocab(phone_def_sil):\n",
    "    diphones = [f\"{a}→{b}\" for a, b in product(phone_def_sil, repeat=2)]\n",
    "    diphone_to_idx = {d: i+1 for i, d in enumerate(sorted(diphones))}  # +1 for CTC blank at idx 0\n",
    "    idx_to_diphone = {i: d for d, i in diphone_to_idx.items()}\n",
    "    return diphone_to_idx, idx_to_diphone\n",
    "\n",
    "diphone_to_idx, idx_to_diphone = build_diphone_vocab(PHONE_DEF_SIL)\n",
    "\n",
    "\n",
    "\n",
    "def build_phoneme_to_diphone_ids(PHONE_DEF_SIL, diphone_to_idx):\n",
    "    phoneme_to_diphone_ids = defaultdict(list)\n",
    "\n",
    "    for diphone_str, idx in diphone_to_idx.items():\n",
    "        if idx == 0:\n",
    "            continue  # skip CTC blank\n",
    "        left, right = diphone_str.split(\"→\")\n",
    "        phoneme_to_diphone_ids[left].append(idx)   # p is source: p→X\n",
    "        phoneme_to_diphone_ids[right].append(idx)  # p is target: X→p\n",
    "\n",
    "    return phoneme_to_diphone_ids\n",
    "\n",
    "\n",
    "def convert_phoneme_seq_to_diphone_ids(batch_phone_seq, idsToPhonemes, diphone_to_idx):\n",
    "    \"\"\"\n",
    "    Converts batch of phoneme index sequences to DCoND-style diphone ID sequences.\n",
    "    \n",
    "    Args:\n",
    "        batch_phone_seq: (B, T) tensor of phoneme indices\n",
    "        idsToPhonemes: callable to convert phoneme ids to phoneme strings\n",
    "        diphone_to_idx: dict mapping 'A→B' to int\n",
    "    \n",
    "    Returns:\n",
    "        diphone_seqs: list of 1D tensors with diphone IDs\n",
    "        diphone_lengths: list of ints (lengths per sequence)\n",
    "    \"\"\"\n",
    "    diphone_seqs = []\n",
    "    diphone_lengths = []\n",
    "\n",
    "    batch_size = batch_phone_seq.shape[0]\n",
    "\n",
    "    for i in range(batch_size):\n",
    "        phoneme_ids = batch_phone_seq[i].cpu().numpy()\n",
    "        phonemes = idsToPhonemes(phoneme_ids)\n",
    "        phonemes = [p for p in phonemes if p != \"<pad>\" and p != \"\"]\n",
    "\n",
    "        diphones = []\n",
    "        for j in range(len(phonemes)):\n",
    "            if j > 0:\n",
    "                diphones.append(f\"{phonemes[j-1]}→{phonemes[j]}\")\n",
    "            diphones.append(f\"{phonemes[j]}→{phonemes[j]}\")\n",
    "            if j < len(phonemes) - 1:\n",
    "                diphones.append(f\"{phonemes[j]}→{phonemes[j+1]}\")\n",
    "\n",
    "        # Convert to indices\n",
    "        diphone_ids = [diphone_to_idx[d] for d in diphones if d in diphone_to_idx]\n",
    "        diphone_ids_tensor = torch.tensor(diphone_ids, dtype=torch.int32)\n",
    "\n",
    "        # Deduplicate consecutive transitions\n",
    "        diphone_ids_tensor = torch.unique_consecutive(diphone_ids_tensor)\n",
    "\n",
    "        diphone_seqs.append(diphone_ids_tensor)\n",
    "        diphone_lengths.append(len(diphone_ids_tensor))\n",
    "\n",
    "    return diphone_seqs, diphone_lengths\n",
    "\n",
    "\n",
    "def idsToDiphonemes(id_sequence, idx_to_diphone):\n",
    "    \"\"\"\n",
    "    Converts a sequence of diphone IDs to a list of diphone strings.\n",
    "    \n",
    "    Args:\n",
    "        id_sequence (Union[Tensor, np.ndarray, List[int]]): 1D sequence of diphone IDs.\n",
    "        idx_to_diphone (dict): Mapping from index to diphone string.\n",
    "    \n",
    "    Returns:\n",
    "        List[str]: List of diphone strings like ['SIL→H', 'H→H', 'H→OW', ...]\n",
    "    \"\"\"\n",
    "    if isinstance(id_sequence, torch.Tensor):\n",
    "        id_sequence = id_sequence.cpu().numpy()\n",
    "    \n",
    "    diphones = []\n",
    "    for i in id_sequence:\n",
    "        if i == 0:\n",
    "            continue  # skip blank/padding\n",
    "        diphone = idx_to_diphone.get(i, \"<unk>\")\n",
    "        diphones.append(diphone)\n",
    "    \n",
    "    return diphones\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0cb4c24d",
   "metadata": {},
   "source": [
    "## Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "075b78a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "from augmentations import GaussianSmoothing\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import pytorch_lightning as pl\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from edit_distance import SequenceMatcher\n",
    "import math\n",
    "from transformers import AutoProcessor, ClapModel, AutoModel, AutoTokenizer\n",
    "import numpy as np\n",
    "from collections import Counter\n",
    "import torch.nn as nn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3863a145",
   "metadata": {},
   "outputs": [],
   "source": [
    "class LightningGRUDecoder_Diphonemes_DualHead(pl.LightningModule):\n",
    "    def __init__(\n",
    "        self,\n",
    "        neural_dim,\n",
    "        n_classes,\n",
    "        hidden_dim,\n",
    "        layer_dim,\n",
    "        nDays=24,\n",
    "        dropout=0.1,\n",
    "        strideLen=4,\n",
    "        kernelLen=14,\n",
    "        gaussianSmoothWidth=0,\n",
    "        bidirectional=False,\n",
    "        learning_rate=1e-3,\n",
    "        white_noise_SD=0.01,\n",
    "        constant_offset_SD=0.01,\n",
    "        weight_decay=1e-5,\n",
    "        mfcc_dim = 14,\n",
    "        mfcc_loss_weight = 1.,\n",
    "        channels = None,\n",
    "        phone_def_sil = PHONE_DEF_SIL,\n",
    "        total_steps = 100000,\n",
    "        alpha = 0.5,\n",
    "\n",
    "    ):\n",
    "        super().__init__()\n",
    "        self.total_steps = total_steps\n",
    "        self.layer_dim = layer_dim\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.neural_dim = neural_dim\n",
    "        self.n_classes = n_classes\n",
    "        self.nDays = nDays\n",
    "        self.strideLen = strideLen\n",
    "        self.kernelLen = kernelLen\n",
    "        self.gaussianSmoothWidth = gaussianSmoothWidth\n",
    "        self.bidirectional = bidirectional\n",
    "        self.learning_rate = learning_rate\n",
    "        self.white_noise_SD = white_noise_SD\n",
    "        self.constant_offset_SD = constant_offset_SD\n",
    "        self.weight_decay = weight_decay\n",
    "        self.mfcc_loss_weight = mfcc_loss_weight\n",
    "        self.channels = channels\n",
    "        self.PHONE_DEF_SIL = phone_def_sil\n",
    "        self.alpha = alpha\n",
    "\n",
    "        if channels is None:\n",
    "            self.channels = np.arange(0, neural_dim)\n",
    "\n",
    "        print(\"Resetting neural_dim based on channels\")\n",
    "        self.neural_dim = len(self.channels)\n",
    "        neural_dim = self.neural_dim\n",
    "        print(\"neural_dim\", neural_dim, self.neural_dim)\n",
    "\n",
    "        self.phoneme_to_diphone_ids = build_phoneme_to_diphone_ids(PHONE_DEF_SIL, diphone_to_idx)\n",
    "\n",
    "        for ph, ids in self.phoneme_to_diphone_ids.items():\n",
    "            self.phoneme_to_diphone_ids[ph] = sorted(set(ids))\n",
    "\n",
    "        self.inputLayerNonlinearity = nn.Softsign()\n",
    "        self.unfolder = nn.Unfold((self.kernelLen, 1), dilation=1, padding=0, stride=self.strideLen)\n",
    "        self.mfcc_unfolder = nn.Unfold((self.strideLen, 1), dilation=1, padding=0, stride=self.strideLen)\n",
    "        \n",
    "        self.gaussianSmoother = GaussianSmoothing(neural_dim, 20, self.gaussianSmoothWidth, dim=1)\n",
    "\n",
    "        # Per-day transformation weights\n",
    "        self.dayWeights = nn.Parameter(torch.randn(nDays, neural_dim, neural_dim))\n",
    "        self.dayBias = nn.Parameter(torch.zeros(nDays, 1, neural_dim))\n",
    "\n",
    "        for x in range(nDays):\n",
    "            self.dayWeights.data[x, :, :] = torch.eye(neural_dim)\n",
    "\n",
    "        # GRU layer\n",
    "        self.gru_decoder = nn.GRU(\n",
    "            (neural_dim) * self.kernelLen,\n",
    "            hidden_dim,\n",
    "            layer_dim,\n",
    "            batch_first=True,\n",
    "            dropout=dropout,\n",
    "            bidirectional=bidirectional,\n",
    "        )\n",
    "\n",
    "        for name, param in self.gru_decoder.named_parameters():\n",
    "            if \"weight_hh\" in name:\n",
    "                nn.init.orthogonal_(param)\n",
    "            if \"weight_ih\" in name:\n",
    "                nn.init.xavier_uniform_(param)\n",
    "\n",
    "        # Fully connected output layer\n",
    "        self.fc_decoder_diphone_out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, n_classes**2 + 1)  # +1 for CTC blank\n",
    "        \n",
    "        self.fc_decoder_out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, n_classes**2 + 1)  # +1 for CTC blank\n",
    "        \n",
    "        self.mfcc_decoder = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, mfcc_dim*self.strideLen) \n",
    "        \n",
    "        \n",
    "        # Loss function\n",
    "        self.ctc_loss = nn.CTCLoss(blank=0, reduction=\"mean\", zero_infinity=True)\n",
    "        self.diphone_ctc_loss = nn.CTCLoss(blank=0, reduction=\"mean\", zero_infinity=True)\n",
    "        self.l1oss = nn.L1Loss()\n",
    "\n",
    "    def get_neural_embedding(self, neuralInput, dayIdx):\n",
    "        \"\"\"\n",
    "        Forward pass of the model.\n",
    "        neuralInput: (batch, time, features)\n",
    "        dayIdx: Session index\n",
    "        \"\"\"\n",
    "\n",
    "        #channel selection\n",
    "        neuralInput = neuralInput[:, :, self.channels].contiguous()\n",
    "        neuralInput = torch.permute(neuralInput, (0, 2, 1))\n",
    "        neuralInput = self.gaussianSmoother(neuralInput)\n",
    "        neuralInput = torch.permute(neuralInput, (0, 2, 1))\n",
    "\n",
    "        # Apply day-specific transformations\n",
    "        dayWeights = torch.index_select(self.dayWeights, 0, dayIdx)\n",
    "        transformedNeural = torch.einsum(\"btd,bdk->btk\", neuralInput, dayWeights) + torch.index_select(self.dayBias, 0, dayIdx)\n",
    "        transformedNeural = self.inputLayerNonlinearity(transformedNeural)\n",
    "\n",
    "        # Apply unfolding (sliding window)\n",
    "        stridedInputs = torch.permute(\n",
    "            self.unfolder(torch.unsqueeze(torch.permute(transformedNeural, (0, 2, 1)), 3)), (0, 2, 1)\n",
    "        )\n",
    "\n",
    "        # Initialize GRU hidden state\n",
    "        h0 = torch.zeros(\n",
    "            self.layer_dim * (2 if self.bidirectional else 1),\n",
    "            transformedNeural.size(0),\n",
    "            self.hidden_dim,\n",
    "            device=self.device\n",
    "        ).requires_grad_()\n",
    "\n",
    "        # Apply GRU\n",
    "        hid, _ = self.gru_decoder(stridedInputs, h0.detach())\n",
    "        return hid\n",
    "    def forward(self, neuralInput, dayIdx):\n",
    "       \n",
    "        hid = self.get_neural_embedding(neuralInput, dayIdx)\n",
    "        # Final output layer\n",
    "        diphone_logits = self.fc_decoder_diphone_out(hid)\n",
    "        phoneme_logits = self.fc_decoder_out(hid)\n",
    "        mfcc_pred = self.mfcc_decoder(hid)\n",
    "        return  diphone_logits,phoneme_logits, mfcc_pred\n",
    "\n",
    "    def training_step(self, batch, batch_idx):\n",
    "        \"\"\"\n",
    "        Forward-pass one minibatch, compute diphone-CTC + phoneme-CTC\n",
    "        (optionally MFCC L1) and return the total loss for back-prop.\n",
    "        \"\"\"\n",
    "\n",
    "        # ─── unpack & move to device ──────────────────────────────────────────\n",
    "        X          = batch[\"neural_feats\"].to(self.device)          # (B,T,Fraw)\n",
    "        y          = batch[\"phone_seq\"].to(self.device)             # (B, Sph)\n",
    "        X_len      = batch[\"neural_time_bins\"].to(self.device)      # (B,)\n",
    "        y_len      = batch[\"phone_seq_len\"].to(self.device)         # (B,)\n",
    "        dayIdx     = batch[\"day\"].to(self.device)                   # (B,)\n",
    "        MFCC_list  = batch[\"mfcc\"]                                  # list(np.ndarray)\n",
    "\n",
    "        # ─── optional MFCC target preprocessing ──────────────────────────────\n",
    "        MFCC = pad_sequence([torch.tensor(m) for m in MFCC_list],\n",
    "                            batch_first=True).to(self.device)       # (B,Tm,Dim)\n",
    "        MFCC = torch.permute(self.mfcc_unfolder(\n",
    "                torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)\n",
    "            ), (0, 2, 1))                                        # (B,T',Dim·stride)\n",
    "\n",
    "        # ─── simple noise augmentation ───────────────────────────────────────\n",
    "        if self.white_noise_SD > 0:\n",
    "            X = X + torch.randn_like(X) * self.white_noise_SD\n",
    "        if self.constant_offset_SD > 0:\n",
    "            X = X + torch.randn(X.size(0), 1, X.size(2),\n",
    "                                device=self.device) * self.constant_offset_SD\n",
    "\n",
    "        # ─── forward through encoder/decoder ─────────────────────────────────\n",
    "        diphone_logits,phoneme_logits, mfcc_pred = self.forward(X, dayIdx)         # (B,T',D) & (B,T',Dim·stride)\n",
    "\n",
    "        # ─── build target diphone sequences ─────────────────────────────────\n",
    "        diphone_seqs, diphone_len = convert_phoneme_seq_to_diphone_ids(\n",
    "            y, idsToPhonemes, diphone_to_idx\n",
    "        )\n",
    "        y_diphone   = pad_sequence(diphone_seqs, batch_first=True, padding_value=0) \\\n",
    "                    .to(self.device)                              # (B,Sdi)\n",
    "        diphone_len = torch.tensor(diphone_len,\n",
    "                                dtype=torch.int32,\n",
    "                                device=self.device)              # (B,)\n",
    "\n",
    "        # ─── input lengths after unfolding (integer division!) ───────────────\n",
    "        input_lengths = ((X_len - self.kernelLen) // self.strideLen) \\\n",
    "                        .to(torch.int32)                            # (B,)\n",
    "\n",
    "        # ─── CTC loss on diphone logits ──────────────────────────────────────\n",
    "        diphone_logp = diphone_logits.log_softmax(-1)               # (B,T',D)\n",
    "        loss_diphone = self.diphone_ctc_loss(\n",
    "            diphone_logp.permute(1, 0, 2),  # (T',B,D)\n",
    "            y_diphone,\n",
    "            input_lengths,\n",
    "            diphone_len\n",
    "        )\n",
    "\n",
    "        # ─── diphone → phoneme marginalisation in log-space ─────────────────\n",
    "        B, T, _ = diphone_logp.shape\n",
    "        P = len(self.PHONE_DEF_SIL)\n",
    "\n",
    "\n",
    "\n",
    "        # phoneme_logp = diphone_logp.new_full((B, T, P), -float(\"inf\"))\n",
    "\n",
    "        # for p_idx, ph in enumerate(self.PHONE_DEF_SIL):\n",
    "        #     ids = self.phoneme_to_diphone_ids[ph]      # must be a **disjoint** list\n",
    "        #     assert len(set(ids)) == len(ids)           # sanity check\n",
    "        #     if ids:\n",
    "        #         logsumexp = torch.logsumexp(diphone_logp[:, :, ids], dim=-1)\n",
    "        #         phoneme_logp[:, :, p_idx] = logsumexp\n",
    "        #     else:\n",
    "        #         # optionally set to very low log-prob\n",
    "        #         phoneme_logp[:, :, p_idx] = -1e8\n",
    "\n",
    "        # # prepend the *same* blank that lives at diphone index 0\n",
    "        # blank_logp = diphone_logp[:, :, 0:1]\n",
    "        # phoneme_logp = torch.cat([blank_logp, phoneme_logp], dim=-1)\n",
    "        \n",
    "        # with torch.no_grad():\n",
    "        #     if torch.isnan(phoneme_logp).any():\n",
    "        #         print(\"NaN detected in phoneme_logp\")\n",
    "\n",
    "        # ─── CTC loss on phoneme distribution ───────────────────────────────\n",
    "\n",
    "        phoneme_logp =phoneme_logits.log_softmax(-1) \n",
    "        loss_phoneme = self.ctc_loss(\n",
    "            phoneme_logp.permute(1, 0, 2),  # (T',B,P+1)\n",
    "            y,\n",
    "            input_lengths,\n",
    "            y_len\n",
    "        )\n",
    "\n",
    "        # ─── optional MFCC reconstruction loss (commented out) ──────────────\n",
    "        # min_len = min(MFCC.size(1), mfcc_pred.size(1))\n",
    "        # l1_mfcc = self.l1loss(mfcc_pred[:, :min_len], MFCC[:, :min_len])\n",
    "        # mfcc_term = self.mfcc_loss_weight * l1_mfcc\n",
    "        # --------------------------------------------------------------------\n",
    "\n",
    "        # ─── dynamic α schedule (clamped to >=0) ────────────────────────────\n",
    "        # alpha = max(0.6, 1.0 - self.global_step / self.total_steps)\n",
    "\n",
    "        loss = self.alpha * loss_diphone + (1 - self.alpha) * loss_phoneme   # + mfcc_term\n",
    "\n",
    "        # ─── logging ────────────────────────────────────────────────────────\n",
    "        self.log_dict({\n",
    "            \"train_loss\":          loss,\n",
    "            \"train_loss_diphone\":  loss_diphone,\n",
    "            \"train_loss_phoneme\":  loss_phoneme,\n",
    "            # \"train_l1_mfcc\":     l1_mfcc if use_mfcc else torch.tensor(0.)\n",
    "        }, prog_bar=True, on_step=True, on_epoch=True)\n",
    "\n",
    "        return loss\n",
    "\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        \"\"\"\n",
    "        Run one minibatch in eval mode, compute the diphone-CTC +\n",
    "        phoneme-CTC losses and the phoneme error-rate (CER).\n",
    "        \"\"\"\n",
    "        # ─── unpack & move to device ─────────────────────────────────────────\n",
    "        X          = batch[\"neural_feats\"].to(self.device)          # (B,T,Fraw)\n",
    "        y          = batch[\"phone_seq\"].to(self.device)             # (B, Sph)\n",
    "        X_len      = batch[\"neural_time_bins\"].to(self.device)      # (B,)\n",
    "        y_len      = batch[\"phone_seq_len\"].to(self.device)         # (B,)\n",
    "        dayIdx     = batch[\"day\"].to(self.device)                   # (B,)\n",
    "        MFCC_list  = batch[\"mfcc\"]                                  # list(np.ndarray)\n",
    "\n",
    "        # ─── (optional) MFCC target pre-processing ──────────────────────────\n",
    "        MFCC = pad_sequence([torch.tensor(m) for m in MFCC_list],\n",
    "                            batch_first=True).to(self.device)       # (B,Tm,Dim)\n",
    "        MFCC = torch.permute(self.mfcc_unfolder(\n",
    "                torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)\n",
    "            ), (0, 2, 1))                                        # (B,T',Dim·stride)\n",
    "\n",
    "                # ─── forward through encoder/decoder ─────────────────────────────────\n",
    "        diphone_logits,phoneme_logits, mfcc_pred = self.forward(X, dayIdx)         # (B,T',D) & (B,T',Dim·stride)\n",
    "\n",
    "        # ─── build target diphone sequences ─────────────────────────────────\n",
    "        diphone_seqs, diphone_len = convert_phoneme_seq_to_diphone_ids(\n",
    "            y, idsToPhonemes, diphone_to_idx\n",
    "        )\n",
    "        y_diphone   = pad_sequence(diphone_seqs, batch_first=True, padding_value=0) \\\n",
    "                    .to(self.device)                              # (B,Sdi)\n",
    "        diphone_len = torch.tensor(diphone_len,\n",
    "                                dtype=torch.int32,\n",
    "                                device=self.device)              # (B,)\n",
    "\n",
    "        # ─── input lengths after unfolding (integer division!) ───────────────\n",
    "        input_lengths = ((X_len - self.kernelLen) // self.strideLen) \\\n",
    "                        .to(torch.int32)                            # (B,)\n",
    "\n",
    "        # ─── CTC loss on diphone logits ──────────────────────────────────────\n",
    "        diphone_logp = diphone_logits.log_softmax(-1)               # (B,T',D)\n",
    "        loss_diphone = self.diphone_ctc_loss(\n",
    "            diphone_logp.permute(1, 0, 2),  # (T',B,D)\n",
    "            y_diphone,\n",
    "            input_lengths,\n",
    "            diphone_len\n",
    "        )\n",
    "\n",
    "        # ─── diphone → phoneme marginalisation in log-space ─────────────────\n",
    "        B, T, _ = diphone_logp.shape\n",
    "        P = len(self.PHONE_DEF_SIL)\n",
    "\n",
    "\n",
    "\n",
    "        # phoneme_logp = diphone_logp.new_full((B, T, P), -float(\"inf\"))\n",
    "\n",
    "        # for p_idx, ph in enumerate(self.PHONE_DEF_SIL):\n",
    "        #     ids = self.phoneme_to_diphone_ids[ph]      # must be a **disjoint** list\n",
    "        #     assert len(set(ids)) == len(ids)           # sanity check\n",
    "        #     if ids:\n",
    "        #         logsumexp = torch.logsumexp(diphone_logp[:, :, ids], dim=-1)\n",
    "        #         phoneme_logp[:, :, p_idx] = logsumexp\n",
    "        #     else:\n",
    "        #         # optionally set to very low log-prob\n",
    "        #         phoneme_logp[:, :, p_idx] = -1e8\n",
    "\n",
    "        # # prepend the *same* blank that lives at diphone index 0\n",
    "        # blank_logp = diphone_logp[:, :, 0:1]\n",
    "        # phoneme_logp = torch.cat([blank_logp, phoneme_logp], dim=-1)\n",
    "        \n",
    "        # with torch.no_grad():\n",
    "        #     if torch.isnan(phoneme_logp).any():\n",
    "        #         print(\"NaN detected in phoneme_logp\")\n",
    "\n",
    "        # ─── CTC loss on phoneme distribution ───────────────────────────────\n",
    "\n",
    "        phoneme_logp =phoneme_logits.log_softmax(-1) \n",
    "        loss_phoneme = self.ctc_loss(\n",
    "            phoneme_logp.permute(1, 0, 2),  # (T',B,P+1)\n",
    "            y,\n",
    "            input_lengths,\n",
    "            y_len\n",
    "        )\n",
    "\n",
    "        # ─── optional MFCC reconstruction loss (commented out) ──────────────\n",
    "        # min_len = min(MFCC.size(1), mfcc_pred.size(1))\n",
    "        # l1_mfcc = self.l1loss(mfcc_pred[:, :min_len], MFCC[:, :min_len])\n",
    "        # mfcc_term = self.mfcc_loss_weight * l1_mfcc\n",
    "        # --------------------------------------------------------------------\n",
    "\n",
    "        # ─── dynamic α schedule (clamped to >=0) ────────────────────────────\n",
    "        # alpha = max(0.6, 1.0 - self.global_step / self.total_steps)\n",
    "\n",
    "        loss = self.alpha * loss_diphone + (1 - self.alpha) * loss_phoneme   # + mfcc_term\n",
    "\n",
    "\n",
    "        # ─── phoneme error-rate (CER) ──────────────────────────────────────\n",
    "        ### CHANGED: decode from log-probs (argmax over last dim)\n",
    "        decoded = phoneme_logp.argmax(-1)                           # (B,T')\n",
    "        total_edits, total_len = 0, 0\n",
    "\n",
    "        for i in range(decoded.size(0)):\n",
    "            seq = decoded[i, : input_lengths[i]].cpu()\n",
    "            seq = torch.unique_consecutive(seq)        # CTC collapse\n",
    "            seq = seq[seq != 0].tolist()               # remove blank\n",
    "\n",
    "            target = y[i, : y_len[i]].cpu().tolist()\n",
    "            total_edits += SequenceMatcher(a=target, b=seq).distance()\n",
    "            total_len   += len(target)\n",
    "\n",
    "        cer = total_edits / total_len if total_len else 1.0\n",
    "\n",
    "        # ─── logging ────────────────────────────────────────────────────────\n",
    "        self.log_dict({\n",
    "            \"val_loss\":          loss,\n",
    "            \"val_loss_diphone\":  loss_diphone,\n",
    "            \"val_loss_phoneme\":  loss_phoneme,\n",
    "            \"val_CER\":           cer,\n",
    "        }, prog_bar=True, on_step=False, on_epoch=True)\n",
    "\n",
    "        ### NEW: return loss so Lightning can aggregate (optional)\n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        \"\"\"\n",
    "        Configures the optimizer and learning rate scheduler.\n",
    "        \"\"\"\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, betas=(0.9, 0.999),\n",
    "                                      eps=1e-8)\n",
    "\n",
    "        return optimizer\n",
    "\n",
    "        # scheduler = ReduceLROnPlateau(optimizer, mode=\"min\", factor=0.5, patience=2)\n",
    "        # return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler, \"monitor\": \"val_loss\"}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a94f832d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4140"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nInputFeatures = 256 #channels \n",
    "nClasses = 40 \n",
    "dropout = 0.4 \n",
    "hidden_dim = 1024\n",
    "nlayers = 5\n",
    "stride_len = 4\n",
    "kernel_len = 14 #best 32\n",
    "gaussian_smooth_width = 2\n",
    "bidirectional = True\n",
    "\n",
    "white_noise_SD = 0.8\n",
    "constant_offset_SD = 0.2\n",
    "seq_len = 150\n",
    "max_time_series_len = 12000\n",
    "\n",
    "lr_start = 5e-5\n",
    "lr_end = 5e-5\n",
    "l2_decay = 1e-5\n",
    "\n",
    "\n",
    "warmup_epoch = 5\n",
    "steps_per_epoch = len(train_loader)\n",
    "warmup_steps = warmup_epoch * steps_per_epoch\n",
    "\n",
    "target_epoch = 30\n",
    "total_steps = target_epoch * steps_per_epoch\n",
    "total_steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "42eece3f",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_name = \"gru_ctc_diphones_dualhead\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2d072a00",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Resetting neural_dim based on channels\n",
      "neural_dim 256 256\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_412441/938539527.py:55: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  model.load_state_dict(torch.load(f\".checkpoints/{output_name}/best_model.ckpt\")[\"state_dict\"])\n"
     ]
    }
   ],
   "source": [
    "TRAIN = False\n",
    "\n",
    "\n",
    "model = LightningGRUDecoder_Diphonemes_DualHead(\n",
    "            neural_dim=nInputFeatures,\n",
    "            n_classes=nClasses,\n",
    "            hidden_dim=hidden_dim,\n",
    "            layer_dim=nlayers,\n",
    "            strideLen=stride_len,\n",
    "            kernelLen=kernel_len,\n",
    "            gaussianSmoothWidth=gaussian_smooth_width,\n",
    "            bidirectional=bidirectional,\n",
    "            dropout=dropout,\n",
    "            white_noise_SD=white_noise_SD,\n",
    "            constant_offset_SD=constant_offset_SD,\n",
    "            weight_decay=l2_decay,\n",
    "            learning_rate=lr_start,\n",
    "            total_steps=total_steps,)\n",
    "if TRAIN:\n",
    "\n",
    "    wandb_logger = WandbLogger(project=\"ECOG_Sentence_dataset\", name=f\"{output_name}\",\n",
    "                                reinit=True)\n",
    "\n",
    "    # Define ModelCheckpoint to save the best model based on validation loss\n",
    "    checkpoint_callback = ModelCheckpoint(\n",
    "        monitor=\"val_loss\",  # Ensure your validation step logs \"val_loss\"\n",
    "        mode=\"min\",          # Save the model with the lowest validation loss\n",
    "        save_top_k=1,        # Keep only the best model\n",
    "        dirpath=f\".checkpoints/{output_name}/\",  # Directory to save checkpoints\n",
    "        filename=f\"best_model\",  # Model filename\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "    # Define EarlyStopping callback with patience of 3 epochs\n",
    "    early_stopping_callback = EarlyStopping(\n",
    "        monitor=\"val_CER\",\n",
    "        patience=8,   # Stop training if no improvement in 3 epochs\n",
    "        mode=\"min\",\n",
    "        verbose=True\n",
    "    )\n",
    "\n",
    "\n",
    "    # Train model\n",
    "    trainer = pl.Trainer(max_epochs=100,devices =[1], callbacks=[checkpoint_callback, early_stopping_callback], logger=wandb_logger)\n",
    "\n",
    "    trainer.fit(model, train_loader, test_loader)\n",
    "\n",
    "\n",
    "\n",
    "    # close wandb logger\n",
    "    wandb.finish()\n",
    "\n",
    "else:\n",
    "    #reload state_dict of best model\n",
    "    model.load_state_dict(torch.load(f\".checkpoints/{output_name}/best_model.ckpt\")[\"state_dict\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "21a234cd",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LightningGRUDecoder_Diphonemes_DualHead(\n",
       "  (inputLayerNonlinearity): Softsign()\n",
       "  (unfolder): Unfold(kernel_size=(14, 1), dilation=1, padding=0, stride=4)\n",
       "  (mfcc_unfolder): Unfold(kernel_size=(4, 1), dilation=1, padding=0, stride=4)\n",
       "  (gaussianSmoother): GaussianSmoothing()\n",
       "  (gru_decoder): GRU(3584, 1024, num_layers=5, batch_first=True, dropout=0.4, bidirectional=True)\n",
       "  (fc_decoder_diphone_out): Linear(in_features=2048, out_features=1601, bias=True)\n",
       "  (fc_decoder_out): Linear(in_features=2048, out_features=1601, bias=True)\n",
       "  (mfcc_decoder): Linear(in_features=2048, out_features=56, bias=True)\n",
       "  (ctc_loss): CTCLoss()\n",
       "  (diphone_ctc_loss): CTCLoss()\n",
       "  (l1oss): L1Loss()\n",
       ")"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = \"cuda:1\"\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "73ab00d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_ctc_output(logits):\n",
    "    \"\"\"\n",
    "    Converts model logits to predicted phoneme sequences.\n",
    "    - Removes repeated phonemes.\n",
    "    - Removes blank tokens (0).\n",
    "    \"\"\"\n",
    "\n",
    "    predictions = torch.argmax(logits, dim=-1)  # Get most probable phoneme indices\n",
    "    predictions = [torch.unique_consecutive(seq[seq != 0]).cpu().numpy() for seq in predictions]  # Remove blanks\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bc5d35c3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 0/14 [00:00<?, ?it/s]/data/XXXXXX/speech_decoding_BCI/augmentations.py:91: UserWarning: Using padding='same' with even kernel lengths and odd dilation may require a zero-padded copy of the input be created (Triggered internally at ../aten/src/ATen/native/Convolution.cpp:1036.)\n",
      "  return self.conv(input, weight=self.weight, groups=self.groups, padding=\"same\")\n",
      "100%|██████████| 14/14 [00:22<00:00,  1.63s/it]\n"
     ]
    }
   ],
   "source": [
    "## predit all teh test set \n",
    "pred_phonemes = []\n",
    "pred_logits = []\n",
    "true_phonemes = []\n",
    "true_sentences = []\n",
    "day_indices = []\n",
    "cer_list = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm.tqdm(test_loader):\n",
    "        X = batch[\"neural_feats\"]\n",
    "        y = batch[\"phone_seq\"]\n",
    "        X_len = batch[\"neural_time_bins\"]\n",
    "        y_len = batch[\"phone_seq_len\"]\n",
    "        days = batch[\"day\"]\n",
    "        transcriptions = batch[\"sentence\"]\n",
    "        \n",
    "        # Move data to device\n",
    "        X = X.to(device)\n",
    "        y = y.to(device)\n",
    "\n",
    "        days = days.to(device)\n",
    "        X_len = X_len.to(device)\n",
    "        y_len = y_len.to(device)\n",
    "\n",
    "        diphone_logits,pred, mfcc_pred  = model(X,days)\n",
    "        # pred = torch.nn.functional.log_softmax(logits, dim=-1).cpu()\n",
    "        # print(logits.shape)\n",
    "        # decoded = decoder(pred)\n",
    "        pred_logits.append(pred)\n",
    "\n",
    "        total_edit_distance, total_seq_length = 0, 0\n",
    "\n",
    "        for i in range(pred.shape[0]):\n",
    "            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / model.strideLen), :], dim=-1)\n",
    "            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "\n",
    "            trueSeq = y[i][:y_len[i]].cpu().numpy()\n",
    "            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "            total_edit_distance += matcher.distance()\n",
    "            total_seq_length += len(trueSeq)\n",
    "\n",
    "            cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0\n",
    "            cer_list.append(cer)\n",
    "            \n",
    "        pp = decode_ctc_output(pred)\n",
    "\n",
    "        pred_phonemes.extend(pp)\n",
    "        true_phonemes.extend([y[i][:y_len[i]].cpu().numpy() for i in range(len(y))])\n",
    "        # true_phonemes.extend(y.cpu().numpy())\n",
    "        true_sentences.extend(transcriptions)\n",
    "        day_indices.extend(days.cpu().numpy())\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "d56e425e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(0.21076115657868477)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(cer_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "93227fc1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted Phonemes: ['R', 'EY', 'S', 'SIL', 'P', 'OY', 'CH', 'S', 'SIL', 'EH', 'V', 'ER', 'SIL', 'S', 'AE', 'N', 'D', 'SIL', 'D', 'IH', 'S', 'AH', 'AY', 'F', 'SIL']\n",
      "True Phonemes: ['R', 'IH', 'CH', 'SIL', 'P', 'ER', 'CH', 'AH', 'S', 'T', 'SIL', 'S', 'EH', 'V', 'R', 'AH', 'L', 'SIL', 'S', 'AY', 'N', 'D', 'SIL', 'L', 'IH', 'TH', 'AH', 'G', 'R', 'AE', 'F', 'S', 'SIL']\n",
      "True Sentence: Rich purchased several signed lithographs.\n"
     ]
    }
   ],
   "source": [
    "idx = 1\n",
    "print(\"Predicted Phonemes:\", idsToPhonemes(pred_phonemes[idx]))\n",
    "print(\"True Phonemes:\", idsToPhonemes(true_phonemes[idx]))\n",
    "print(\"True Sentence:\", true_sentences[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "eb7c4eae",
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_accuracy(preds, targets):\n",
    "    \n",
    "\n",
    "    accs= []\n",
    "    for pred, target in zip(preds, targets):\n",
    "        \n",
    "        #truncate to the length of the shortest sequence\n",
    "        min_len = min(len(pred), len(target))\n",
    "\n",
    "\n",
    "        pred = pred[:min_len]\n",
    "        target = target[:min_len]\n",
    "\n",
    "        equal_inference = (pred == target)\n",
    "        acc = np.sum(equal_inference)/ len(pred)\n",
    "        accs.append(acc)\n",
    "\n",
    "    return np.mean(accs)\n",
    "   \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "7c0fa6d4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "overall_acc 0.4814017630628827\n",
      "Range of Accuracy per day 0.238823447647527 0.6495838938091436\n",
      "Range of CER per day 0.1527034001120558 0.3731003181831515\n",
      "Average lenght diff: 0.6795454545454546 +- 2.303069053701105\n",
      "Range of diff lenghts per day: -0.275 - 2.65\n"
     ]
    }
   ],
   "source": [
    "overall_acc = compute_accuracy(pred_phonemes, true_phonemes)\n",
    "print(\"overall_acc\", overall_acc)\n",
    "# day_indices\n",
    "day_indices_flat = day_indices\n",
    "\n",
    "#compute accuracy per day by selecting indices of the same day\n",
    "day_accs = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    acc = compute_accuracy([pred_phonemes[idx] for idx in indices], [true_phonemes[idx] for idx in indices])\n",
    "    day_accs.append(acc)\n",
    "\n",
    "day_accs\n",
    "print(\"Range of Accuracy per day\", min(day_accs), max(day_accs))\n",
    "cer_list_per_day = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    cer_list_per_day.append(np.array(cer_list)[indices].mean())\n",
    "\n",
    "cer_list_per_day\n",
    "\n",
    "print(\"Range of CER per day\", min(cer_list_per_day), max(cer_list_per_day))\n",
    "diffs = []\n",
    "for i in range(len(true_phonemes)):\n",
    "    true = true_phonemes[i]\n",
    "    pred = pred_phonemes[i]\n",
    "    diffs.append(np.array(len(true)) - np.array(len(pred)))\n",
    "\n",
    "print(f\"Average lenght diff: {np.mean(diffs)} +- {np.std(diffs)}\")\n",
    "# true_phonemes[0]\n",
    "## compute diff lenghts per day\n",
    "diffs_per_day = []\n",
    "for day_index in set(day_indices_flat):\n",
    "    indices = [idx for idx, day in enumerate(day_indices_flat) if day == day_index]\n",
    "    diffs_per_day.append(np.array(diffs)[indices].mean())\n",
    "\n",
    "diffs_per_day\n",
    "print(f\"Range of diff lenghts per day: {min(diffs_per_day)} - {max(diffs_per_day)}\")\n",
    "import os\n",
    "\n",
    "results_dir = f\"results/{output_name}/\"\n",
    "os.makedirs(results_dir, exist_ok=True)\n",
    "\n",
    "#create a dataframe with the results\n",
    "df = pd.DataFrame({\n",
    "    'True Phonemes': [idsToPhonemes(p) for p in true_phonemes],\n",
    "    'Predicted Phonemes': [idsToPhonemes(p) for p in pred_phonemes],\n",
    "    'True Sentence': true_sentences,\n",
    "    'Day Index': day_indices_flat,\n",
    "    'CER': cer_list\n",
    "})\n",
    "\n",
    "#save it \n",
    "df.to_csv(os.path.join(results_dir, \"results.csv\"), index=False)\n",
    "df.sort_values(by=[\"CER\"], ascending=True).iloc[-5:]\n",
    "#create a dataframe with all the metrics\n",
    "df_metrics = pd.DataFrame({\n",
    "    'Overall Accuracy': [overall_acc],\n",
    "    'Range of Accuracy per day': f\"{[min(day_accs), max(day_accs)]}\",\n",
    "    'Range of CER per day': f\"{[min(cer_list_per_day), max(cer_list_per_day)]}\",\n",
    "    'Average length diff': f\"{[np.mean(diffs), np.std(diffs)]}\",\n",
    "    'Range of length diff per day': f\"{[min(diffs_per_day), max(diffs_per_day)]}\"\n",
    "})\n",
    "df_metrics.to_csv(os.path.join(results_dir,\"metrics.csv\"), index=False)\n",
    "## save pred_logits\n",
    "with open(os.path.join(results_dir, \"pred_logits.pkl\"), \"wb\") as f:\n",
    "    pickle.dump(pred_logits, f)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
