{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f6bb086f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package cmudict to /home/XXXXXX/nltk_data...\n",
      "[nltk_data]   Package cmudict is already up-to-date!\n"
     ]
    }
   ],
   "source": [
    "import pickle\n",
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.utils.data import DataLoader\n",
    "import torch\n",
    "from dataset import SpeechSentenceDataset, idsToPhonemes, getDatasetLoaders,getDatasetLoaders_V2, getDatasetLoaders_V3, PHONE_DEF, PHONE_DEF_SIL\n",
    "import re \n",
    "from g2p_en import G2p\n",
    "import numpy as np\n",
    "from model.ctc_modelling import LightningGRUDecoder, LightningGRUDecoder_MFCC_v3\n",
    "import time\n",
    "import numpy as np\n",
    "from edit_distance import SequenceMatcher\n",
    "import tqdm\n",
    "import pytorch_lightning as pl\n",
    "import jiwer\n",
    "import nltk\n",
    "from nltk.corpus import cmudict\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "import wandb\n",
    "from model.ctc_modelling import LightningGRUDecoder\n",
    "\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import copy\n",
    "from difflib import get_close_matches\n",
    "from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer\n",
    "import pandas as pd\n",
    "from torchaudio.models.decoder import ctc_decoder\n",
    "import string\n",
    "from config import DATASET_SM_ROBUST, DATASET_SM_ZSCORE, DATASET_FULL_TRIALS_ZSCORE, DATASET_AFTERGO_TRIALS_ZSCORE\n",
    "import torchaudio  \n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "# from model.ctc_modelling import Light\n",
    "from dataclasses import dataclass\n",
    "from augmentations import GaussianSmoothing\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig\n",
    "from transformers.modeling_outputs import BaseModelOutput, ModelOutput\n",
    "from typing import Optional, Tuple\n",
    "from torch import nn\n",
    "from peft import LoraConfig, get_peft_model\n",
    "from jiwer import wer as jiwer_wer\n",
    "\n",
    "\n",
    "# Download CMU Pronouncing Dictionary (First-time use)\n",
    "nltk.download(\"cmudict\")\n",
    "\n",
    "# Load CMUdict\n",
    "cmu_dict = cmudict.dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "eccb6984",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WER: 0.167\n"
     ]
    }
   ],
   "source": [
    "from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSpaces, Strip\n",
    "\n",
    "# Custom transformation\n",
    "normalize_text = Compose([\n",
    "    ToLowerCase(),\n",
    "    RemovePunctuation(),\n",
    "    RemoveMultipleSpaces(),\n",
    "    Strip()\n",
    "])\n",
    "\n",
    "# Example usage\n",
    "ground_truth = [\"It's a beautiful day, isn't it?\"]\n",
    "predictions = [\"its a beautiful day isnt it right\"]\n",
    "# Normalized input (must be strings)\n",
    "ref_texts = [normalize_text(t) for t in ground_truth]\n",
    "pred_texts = [normalize_text(t) for t in predictions]\n",
    "\n",
    "# Now compute WER normally (jiwer splits into tokens internally)\n",
    "wer_score = wer(ref_texts, pred_texts)\n",
    "print(f\"WER: {wer_score:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "d96def02",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[\"It's a beautiful day, isn't it?\"]"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ground_truth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d30d2d22",
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class HybridCausalLMOutput(ModelOutput):\n",
    "    \"\"\"\n",
    "    Base class for causal language model (or autoregressive) outputs.\n",
    "\n",
    "    Args:\n",
    "        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):\n",
    "            Language modeling loss (for next-token prediction).\n",
    "        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):\n",
    "            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).\n",
    "        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):\n",
    "            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +\n",
    "            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.\n",
    "\n",
    "            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.\n",
    "        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):\n",
    "            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,\n",
    "            sequence_length)`.\n",
    "\n",
    "            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention\n",
    "            heads.\n",
    "    \"\"\"\n",
    "    ce_loss : Optional[torch.FloatTensor] = None\n",
    "    ctc_loss : Optional[torch.FloatTensor] = None\n",
    "    loss: Optional[torch.FloatTensor] = None\n",
    "    logits: torch.FloatTensor = None\n",
    "    hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None\n",
    "    attentions: Optional[Tuple[torch.FloatTensor, ...]] = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "1f78346e",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSpaces, Strip\n",
    "\n",
    "normalize_text = Compose([\n",
    "    ToLowerCase(),\n",
    "    RemovePunctuation(),\n",
    "    RemoveMultipleSpaces(),\n",
    "    Strip()\n",
    "])\n",
    "\n",
    "class HybridGRUDecoder(pl.LightningModule):\n",
    "    def __init__(self, neural_encoder, lm_model_dim, learning_rate=1e-4,weight_decay=1e-5,\n",
    "                 ce_loss_weight= 1.,\n",
    "            ctc_loss_weight = 0.,\n",
    "            freeze_lm= True,\n",
    "            freeze_encoder = False, \n",
    "            use_lora=True,\n",
    "            lora_r=32,\n",
    "            lora_alpha=64,\n",
    "            lora_dropout=0.2,\n",
    "            lora_target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"out_proj\"]\n",
    "            ):\n",
    "            \n",
    "        \n",
    "        super().__init__()\n",
    "\n",
    "        self.encoder = neural_encoder\n",
    "        self.lm_model_dim = lm_model_dim\n",
    "        self.learning_rate = learning_rate  \n",
    "        self.weight_decay = weight_decay\n",
    "        self.ce_loss_weight = ce_loss_weight\n",
    "        self.ctc_loss_weight = ctc_loss_weight\n",
    "\n",
    "        self.freeze_lm = freeze_lm\n",
    "        self.freeze_encoder = freeze_encoder\n",
    "        self.use_lora = use_lora\n",
    "        self.lora_r = lora_r\n",
    "        self.lora_alpha = lora_alpha\n",
    "        self.lora_dropout = lora_dropout\n",
    "        self.lora_target_modules = lora_target_modules\n",
    "\n",
    "        \n",
    "        self.project = nn.Sequential(nn.LayerNorm(self.encoder.hidden_dim*2 if self.encoder.bidirectional else self.encoder.hidden_dim),\n",
    "                                     (nn.Linear(self.encoder.hidden_dim*2 if self.encoder.bidirectional else self.encoder.hidden_dim, lm_model_dim)), \n",
    "                                     nn.ReLU(),\n",
    "                                    nn.LayerNorm(lm_model_dim))\n",
    "\n",
    "        ## LANGUAGE HEAD\n",
    "        self.language_model =BartForConditionalGeneration.from_pretrained(\"facebook/bart-base\")\n",
    "        self.tokenizer = BartTokenizer.from_pretrained(\"facebook/bart-base\")  \n",
    "\n",
    "        if self.freeze_lm:\n",
    "            for name, param in self.language_model.named_parameters():\n",
    "                if \"crossattention\" in name or \"ln_cross_attn\" in name:\n",
    "                    param.requires_grad = True\n",
    "                else:\n",
    "                    param.requires_grad = False\n",
    "            \n",
    "            print(\"Freezing all parameters except cross attention in the language model.\")\n",
    "\n",
    "        if self.freeze_encoder:\n",
    "            for name, param in self.encoder.named_parameters():\n",
    "                param.requires_grad = False\n",
    "            print(\"Freezing all parameters in the encoder.\")\n",
    "\n",
    "        if self.use_lora:\n",
    "            \n",
    "            lora_cfg = LoraConfig(\n",
    "                r=self.lora_r,\n",
    "                lora_alpha=self.lora_alpha,\n",
    "                lora_dropout=self.lora_dropout,\n",
    "                target_modules=lora_target_modules,\n",
    "            )\n",
    "            self.language_model = get_peft_model(self.language_model, lora_cfg)\n",
    "            print(f\"Using LoRA with r={self.lora_r}, alpha={self.lora_alpha}, dropout={self.lora_dropout} on target modules {self.lora_target_modules}.\")\n",
    "\n",
    "    def get_neural_embedding(self, neuralInput, dayIdx):\n",
    "        \"\"\"\n",
    "        Get the neural embedding from the encoder.\n",
    "        \"\"\"\n",
    "        # neuralInput: (batch_size, seq_len, input_dim)\n",
    "        # dayIdx: (batch_size,)\n",
    "        neural_embedding = self.encoder.get_neural_embedding(neuralInput, dayIdx)\n",
    "        return neural_embedding\n",
    "\n",
    "    def forward(self, neuralInput, dayIdx, sentence = None):\n",
    "        \"\"\"\n",
    "        Forward pass of the model.\n",
    "        neuralInput: (batch, time, features)\n",
    "        dayIdx: Session index\n",
    "        \"\"\"\n",
    "\n",
    "        hid = self.get_neural_embedding(neuralInput, dayIdx)\n",
    "\n",
    "        #1. predict CTC logits\n",
    "        pred = self.encoder.fc_decoder_out(hid)\n",
    "\n",
    "        #2. Predict the MFCC\n",
    "        mfcc_pred = self.encoder.mfcc_decoder(hid)\n",
    "\n",
    "        #3. Project hid to the language model dimension\n",
    "        encoder_outputs = self.project(hid)     \n",
    "\n",
    "        #if sentence is available, proceed with LM part\n",
    "        if sentence is not None:\n",
    "            decoder_input_ids = self.tokenizer(sentence, return_tensors=\"pt\", padding=True, truncation=True).input_ids\n",
    "            decoder_input_ids = decoder_input_ids.to(self.device)\n",
    "            \n",
    "            lm_outputs = self.language_model(encoder_outputs=(encoder_outputs, encoder_outputs), labels = decoder_input_ids)\n",
    "        else:\n",
    "            lm_outputs = None\n",
    "\n",
    "        return pred, mfcc_pred, lm_outputs\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        \"\"\"\n",
    "        Training step - Runs forward pass, computes loss, and returns it for backprop.\n",
    "        \"\"\"\n",
    "\n",
    "        self.encoder.train()\n",
    "\n",
    "        X = batch[\"neural_feats\"]\n",
    "        y = batch[\"phone_seq\"]\n",
    "        X_len = batch[\"neural_time_bins\"]\n",
    "        y_len = batch[\"phone_seq_len\"]\n",
    "        dayIdx = batch[\"day\"]\n",
    "        sentence = batch[\"sentence\"]\n",
    "        MFCC = batch[\"mfcc\"]\n",
    "\n",
    "        #unfold MFCC\n",
    "\n",
    "        # MFCC = torch.permute(MFCC, (0, 2, 1))\n",
    "\n",
    "        mfcc_list = []\n",
    "        for idx in range(len(batch[\"mfcc\"])):\n",
    "            if self.encoder.clip_mfcc_afer_go:\n",
    "                mfcc_list.append(torch.tensor(batch[\"mfcc\"][idx][batch[\"go_onset\"][idx]:]))\n",
    "            else:\n",
    "                mfcc_list.append(torch.tensor(batch[\"mfcc\"][idx]))\n",
    "\n",
    "\n",
    "\n",
    "        MFCC = pad_sequence(mfcc_list, batch_first=True)\n",
    "        MFCC = MFCC.to(self.device)\n",
    "\n",
    "        MFCC = torch.permute(self.encoder.mfcc_unfolder(torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)), (0, 2, 1))\n",
    "\n",
    "\n",
    "        # X, y, X_len, y_len, dayIdx, sentence = batch\n",
    "        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)\n",
    "\n",
    "        # Noise augmentation\n",
    "        if self.encoder.white_noise_SD > 0:\n",
    "            X += torch.randn(X.shape, device=self.device) * self.encoder.white_noise_SD\n",
    "        if self.encoder.constant_offset_SD > 0:\n",
    "            X += torch.randn([X.shape[0], 1, X.shape[2]], device=self.device) * self.encoder.constant_offset_SD\n",
    "\n",
    "        # Forward pass\n",
    "        pred, mfcc_pred, lm_outputs = self.forward(X, dayIdx, sentence = sentence)\n",
    "        ctc_loss = self.encoder.ctc_loss(\n",
    "            torch.permute(pred.log_softmax(2), [1, 0, 2]),\n",
    "            y,\n",
    "            ((X_len - self.encoder.kernelLen) / self.encoder.strideLen).to(torch.int32),\n",
    "            y_len,\n",
    "        )\n",
    "\n",
    "\n",
    "\n",
    "        min_seq_len = min(MFCC.shape[1],mfcc_pred.shape[1])\n",
    "        mfcc_pred = mfcc_pred[:, :min_seq_len, :]\n",
    "        MFCC = MFCC[:, :min_seq_len, :]\n",
    "        \n",
    "        l1_loss = self.encoder.l1oss(\n",
    "            mfcc_pred,\n",
    "            MFCC,\n",
    "        )\n",
    "\n",
    "\n",
    "        lm_loss = lm_outputs.loss \n",
    "\n",
    "        loss = self.ce_loss_weight * lm_loss + self.ctc_loss_weight * ctc_loss + l1_loss\n",
    "        self.log(\"train_loss\", loss, prog_bar=True, on_step=True, on_epoch=True)\n",
    "        self.log(\"train_ce_loss\", lm_loss, prog_bar=True, on_step=True, on_epoch=True)\n",
    "        self.log(\"train_ctc_loss\", ctc_loss, prog_bar=True, on_step=True, on_epoch=True)\n",
    "\n",
    "        return loss\n",
    "    \n",
    "\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        \"\"\"\n",
    "        validation step - Runs forward pass, computes loss, and returns it for backprop.\n",
    "        \"\"\"\n",
    "        X = batch[\"neural_feats\"]\n",
    "        y = batch[\"phone_seq\"]\n",
    "        X_len = batch[\"neural_time_bins\"]\n",
    "        y_len = batch[\"phone_seq_len\"]\n",
    "        dayIdx = batch[\"day\"]\n",
    "        sentence = batch[\"sentence\"]\n",
    "        MFCC = batch[\"mfcc\"]\n",
    "\n",
    "        #unfold MFCC\n",
    "\n",
    "        # MFCC = torch.permute(MFCC, (0, 2, 1))\n",
    "\n",
    "        mfcc_list = []\n",
    "        for idx in range(len(batch[\"mfcc\"])):\n",
    "            if self.encoder.clip_mfcc_afer_go:\n",
    "                mfcc_list.append(torch.tensor(batch[\"mfcc\"][idx][batch[\"go_onset\"][idx]:]))\n",
    "            else:\n",
    "                mfcc_list.append(torch.tensor(batch[\"mfcc\"][idx]))\n",
    "\n",
    "\n",
    "\n",
    "        MFCC = pad_sequence(mfcc_list, batch_first=True)\n",
    "        MFCC = MFCC.to(self.device)\n",
    "\n",
    "        MFCC = torch.permute(self.encoder.mfcc_unfolder(torch.unsqueeze(torch.permute(MFCC, (0, 2, 1)), 3)), (0, 2, 1))\n",
    "\n",
    "\n",
    "        # X, y, X_len, y_len, dayIdx, sentence = batch\n",
    "        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)\n",
    "\n",
    "        \n",
    "        # Forward pass\n",
    "        pred, mfcc_pred, lm_outputs = self.forward(X, dayIdx, sentence = sentence)\n",
    "        ctc_loss = self.encoder.ctc_loss(\n",
    "            torch.permute(pred.log_softmax(2), [1, 0, 2]),\n",
    "            y,\n",
    "            ((X_len - self.encoder.kernelLen) / self.encoder.strideLen).to(torch.int32),\n",
    "            y_len,\n",
    "        )\n",
    "\n",
    "\n",
    "\n",
    "        min_seq_len = min(MFCC.shape[1],mfcc_pred.shape[1])\n",
    "        mfcc_pred = mfcc_pred[:, :min_seq_len, :]\n",
    "        MFCC = MFCC[:, :min_seq_len, :]\n",
    "        \n",
    "        l1_loss = self.encoder.l1oss(\n",
    "            mfcc_pred,\n",
    "            MFCC,\n",
    "        )\n",
    "\n",
    "\n",
    "        lm_loss = lm_outputs.loss \n",
    "\n",
    "        loss = self.ce_loss_weight * lm_loss + self.ctc_loss_weight * ctc_loss + l1_loss\n",
    "\n",
    "        # Compute CER (Phoneme Error Rate)\n",
    "        total_edit_distance, total_seq_length = 0, 0\n",
    "        for i in range(pred.shape[0]):\n",
    "            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / self.encoder.strideLen), :], dim=-1)\n",
    "            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "\n",
    "            trueSeq = y[i][:y_len[i]].cpu().numpy()\n",
    "            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "            total_edit_distance += matcher.distance()\n",
    "            total_seq_length += len(trueSeq)\n",
    "\n",
    "        cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0\n",
    "\n",
    "\n",
    "        generated_text = self.generate(X, dayIdx, max_length=60, num_beams=1)\n",
    "\n",
    "        if batch_idx==0:\n",
    "            # neural_embeddings =self.get_neural_embedding(X, dayIdx)\n",
    "            # encoder_outputs = self.project(neural_embeddings)\n",
    "            # encoder_outputs = BaseModelOutput(last_hidden_state=encoder_outputs)\n",
    "\n",
    "            # generated_ids = self.language_model.generate(encoder_outputs=encoder_outputs, num_beams=5, max_length=60, num_return_sequences=1)\n",
    "            # generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
    "            print(\"Generated Text: \", generated_text[:10])\n",
    "            print(\"True Text: \", sentence[:10])\n",
    "\n",
    "        ref_texts = [normalize_text(s) for s in sentence]\n",
    "        pred_texts = [normalize_text(s) for s in generated_text]\n",
    "\n",
    "        # Compute WER\n",
    "        wer_score = wer(ref_texts, pred_texts)\n",
    "\n",
    "        # Log WER (averaged across all batches by Lightning)\n",
    "        self.log(\"val_WER\", wer_score, prog_bar=True, on_epoch=True)\n",
    "\n",
    "        self.log(\"val_loss\", loss, prog_bar=True, on_epoch=True)\n",
    "        self.log(\"val_ce_loss\", lm_loss, prog_bar=True, on_epoch=True)\n",
    "        self.log(\"val_ctc_loss\",ctc_loss, prog_bar=True, on_epoch=True)\n",
    "        self.log(\"val_CER\", cer, prog_bar=True, on_epoch=True)\n",
    "        return loss\n",
    "\n",
    "    def generate(self, neuralInput, dayIdx, max_length=60, num_beams=1):\n",
    "        \"\"\"\n",
    "        Generate text from the model.\n",
    "        neuralInput: (batch_size, seq_len, input_dim)\n",
    "        dayIdx: Session index\n",
    "        \"\"\"\n",
    "        self.eval()\n",
    "        with torch.no_grad():\n",
    "            hid = self.get_neural_embedding(neuralInput, dayIdx)\n",
    "            encoder_outputs = self.project(hid)\n",
    "            encoder_outputs = BaseModelOutput(last_hidden_state=encoder_outputs)\n",
    "\n",
    "            generated_ids = self.language_model.generate(encoder_outputs=encoder_outputs, max_length=max_length,num_beams=num_beams, num_return_sequences=1)\n",
    "            generated_text = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
    "\n",
    "        return generated_text\n",
    "\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        \"\"\"\n",
    "        Configures the optimizer and learning rate scheduler.\n",
    "        \"\"\"\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, betas=(0.9, 0.99),\n",
    "                                      eps=1e-8,)\n",
    "\n",
    "        return optimizer\n",
    "    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d58b80b",
   "metadata": {},
   "source": [
    "## Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e1804613",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of trials:  10020\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  8800\n",
      "Number of trials:  880\n",
      "Number of days:  24\n",
      "Number of trials after filtering by indices:  880\n"
     ]
    }
   ],
   "source": [
    "train_loader, test_loader,_, loadedData = getDatasetLoaders_V3(DATASET_FULL_TRIALS_ZSCORE, 64, include_prego=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "35a25b88",
   "metadata": {},
   "outputs": [],
   "source": [
    "neural_encoder_model_weights_path = \".checkpoints/mfcc_sm_gru_ctc_LONGRUN/best_model.ckpt\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "d2703528",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Resetting neural_dim based on channels\n",
      "neural_dim 256 256\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n",
      "  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_3187552/3666902854.py:36: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  neural_encoder.load_state_dict(torch.load(neural_encoder_model_weights_path)[\"state_dict\"])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "\n",
    "nInputFeatures = 256 #channels \n",
    "nClasses = 40 \n",
    "dropout = 0.4 \n",
    "hidden_dim = 1024\n",
    "nlayers = 5\n",
    "stride_len = 4\n",
    "kernel_len =32\n",
    "gaussian_smooth_width = 2\n",
    "bidirectional = True\n",
    "\n",
    "white_noise_SD = 0.8\n",
    "constant_offset_SD = 0.2\n",
    "seq_len = 150\n",
    "max_time_series_len = 12000\n",
    "\n",
    "lr_start = 3e-4\n",
    "lr_end = 0.02\n",
    "l2_decay = 1e-5\n",
    "\n",
    "\n",
    "neural_encoder = LightningGRUDecoder_MFCC_v3(\n",
    "            neural_dim=nInputFeatures,\n",
    "            n_classes=nClasses,\n",
    "            hidden_dim=hidden_dim,\n",
    "            layer_dim=nlayers,\n",
    "            strideLen=stride_len,\n",
    "            kernelLen=kernel_len,\n",
    "            gaussianSmoothWidth=gaussian_smooth_width,\n",
    "            bidirectional=bidirectional,\n",
    "            dropout=dropout,\n",
    "            white_noise_SD=white_noise_SD,\n",
    "            constant_offset_SD=constant_offset_SD,\n",
    "            weight_decay=l2_decay,\n",
    "            learning_rate=lr_start)\n",
    "\n",
    "neural_encoder.load_state_dict(torch.load(neural_encoder_model_weights_path)[\"state_dict\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "96d03f9d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# # neural_encoder.fc_decoder_out.__dict__\n",
    "# # break\n",
    "\n",
    "# neural_encoder.to(\"cuda:0\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e8804fc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# neural_encoder.eval()\n",
    "# for batch_idx, batch in enumerate(test_loader):\n",
    "#     loss = neural_encoder.validation_step(batch, batch_idx)\n",
    "#     print(loss)\n",
    "\n",
    "#     if batch_idx==0:\n",
    "#         pred, mfcc_pred = neural_encoder.forward(batch[\"neural_feats\"].to(\"cuda:0\"), batch[\"day\"].to(\"cuda:0\"))\n",
    "\n",
    "#     break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "98a25df1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# total_edit_distance, total_seq_length = 0, 0\n",
    "# outs= []\n",
    "# for i in range(pred.shape[0]):\n",
    "#     decodedSeq = torch.argmax(pred[i, : int(batch[\"neural_time_bins\"][i] / neural_encoder.strideLen), :], dim=-1)\n",
    "#     decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "#     decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "#     outs.append(decodedSeq)\n",
    "#     trueSeq = batch[\"phone_seq\"][i][:batch[\"phone_seq_len\"][i]].cpu().numpy()\n",
    "#     matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "#     total_edit_distance += matcher.distance()\n",
    "#     total_seq_length += len(trueSeq)\n",
    "\n",
    "# cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0\n",
    "# cer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "cba5bff1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pred"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "b3d1ca42",
   "metadata": {},
   "outputs": [],
   "source": [
    "# print(idsToPhonemes(outs[10]))\n",
    "# print(idsToPhonemes(batch[\"phone_seq\"][10].numpy()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c167dc48",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Freezing all parameters in the encoder.\n",
      "Using LoRA with r=64, alpha=128, dropout=0.2 on target modules ['q_proj', 'k_proj', 'v_proj', 'out_proj'].\n"
     ]
    }
   ],
   "source": [
    "\n",
    "model = HybridGRUDecoder(\n",
    "    neural_encoder=neural_encoder,\n",
    "    learning_rate=lr_start,\n",
    "    weight_decay=l2_decay,\n",
    "    lm_model_dim = 768, \n",
    "    freeze_lm=False,\n",
    "    freeze_encoder=True,\n",
    "    use_lora=True,\n",
    "    lora_r=64,\n",
    "    lora_alpha=128,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "d20c56e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_name = \"gru_ctc_mfcc_bart\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "567a2fb3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.19.4"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20250527_201338-kttzkgtk</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/kttzkgtk' target=\"_blank\">gru_ctc_mfcc_bart</a></strong> to <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset' target=\"_blank\">https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/kttzkgtk' target=\"_blank\">https://wandb.ai/XXXXXXXXXXXXXX/ECOG_Sentence_dataset/runs/kttzkgtk</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_mfcc_bart exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2]\n",
      "\n",
      "  | Name           | Type                        | Params | Mode \n",
      "-----------------------------------------------------------------------\n",
      "0 | encoder        | LightningGRUDecoder_MFCC_v3 | 133 M  | train\n",
      "1 | project        | Sequential                  | 1.6 M  | train\n",
      "2 | language_model | PeftModel                   | 146 M  | train\n",
      "-----------------------------------------------------------------------\n",
      "8.7 M     Trainable params\n",
      "273 M     Non-trainable params\n",
      "282 M     Total params\n",
      "1,128.149 Total estimated model params size (MB)\n",
      "737       Modules in train mode\n",
      "182       Modules in eval mode\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "58376a0286524767bfa43ec81c01f727",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n",
      "Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n",
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/transformers/generation/configuration_utils.py:677: UserWarning: `num_beams` is set to 1. However, `early_stopping` is set to `True` -- this flag is only used in beam-based generation modes. You should set `num_beams>1` or unset `early_stopping`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated Text:  ['_ Era-���7 FOX��5��4��1 D��atom��( K��www��.��\"-�atomatom�atomwww�atom(��×��NYSE��and��*�', '\\' - - -,���\"�,,�, e\"�t\" -,\", familiar -\"\",� e�\"\"� -\"� e, e($,\\'($\" e-($-, -� M--,�', ' .�...��.�s..www.�www..~. Wash..×.�~.�(.�8. Wash�.—..(. Wash Wash���www�. Wash Flickr�.www Wash', ' to - - -,: (\" awoken.���:�� (�:-��.� (:�- (�� to: (� to�� for::�: (-�:: to�: for:� Dr�:', '�G, -\"���\"\"\"~\" *\"\"�\"~~~�~~\":�~��~\"~�4~~1~~4,~�️�� (�� -:��,,�', ' ( ( (\",s (ssswwwgmailruggedrelyss (gmail-ssgmail Patreon\".\",.\", (\",* Rear -s DUP ShinjisCopyrightrugged ( (twitter (~*As DUPs\",©-sThankss©-A', ',, # ( ( (, #:ob���.��($��www�� (~��~��A��6� (��B ( (�\\xa0�($,($ (($ (� /($($/ #\\xa0 (,', '.. … … …..�...www … …� …�\". ….:.. purchased.. rhetantam.. contacted�. …, …�.�\\'..::. …� Liberal�.\\'. …\\'.:', ',,,...~.,,s --, --\". -- -- --\"\\' -- --. --..\\'..( \" & \"s.. --www -- -- $..?.. &.. (.. and..www', \" …...,,.,.. gift.. and,,, -,,'',,-,, or..5..@, gift19.@..:. j,. j j response., j j …. j …\"]\n",
      "True Text:  ['Theocracy reconsidered.', 'Rich purchased several signed lithographs.', 'So rules we made, in unabashed collusion.', \"Lori's costume needed black gloves to be completely elegant.\", \"The tooth fairy forgot to come when Roger's tooth fell out.\", 'That stinging vapor was caused by chloride vaporization.', \"Before Thursday's exam, review every formula.\", 'Wildfire near Sunshine forces park closures.', \"The word means it won't boil away easily, nothing else.\", \"Would a blue feather in a man's hat make him happy all day?\"]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/XXXXXX/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c4920046c1fe4484a88030f7bbc660bd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1a19dbe5e8f645c7bcd15d93c43394fa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generated Text:  ['I', \"I't\", 'I...', \"I't...\", \"I't...\", \"I't...\", \"I't.\", 'I...', \"I't...\", \"I't...\"]\n",
      "True Text:  ['Theocracy reconsidered.', 'Rich purchased several signed lithographs.', 'So rules we made, in unabashed collusion.', \"Lori's costume needed black gloves to be completely elegant.\", \"The tooth fairy forgot to come when Roger's tooth fell out.\", 'That stinging vapor was caused by chloride vaporization.', \"Before Thursday's exam, review every formula.\", 'Wildfire near Sunshine forces park closures.', \"The word means it won't boil away easily, nothing else.\", \"Would a blue feather in a man's hat make him happy all day?\"]\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0, global step 138: 'val_CER' reached 0.15915 (best 0.15915), saving model to '/data/XXXXXX/speech_decoding_BCI/.checkpoints/gru_ctc_mfcc_bart/best_model-v9.ckpt' as top 5\n",
      "\n",
      "Detected KeyboardInterrupt, attempting graceful shutdown ...\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'exit' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:47\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m     46\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mlauncher\u001b[38;5;241m.\u001b[39mlaunch(trainer_fn, \u001b[38;5;241m*\u001b[39margs, trainer\u001b[38;5;241m=\u001b[39mtrainer, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m---> 47\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtrainer_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     49\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m _TunerExitException:\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:575\u001b[0m, in \u001b[0;36mTrainer._fit_impl\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m    569\u001b[0m ckpt_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_checkpoint_connector\u001b[38;5;241m.\u001b[39m_select_ckpt_path(\n\u001b[1;32m    570\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mfn,\n\u001b[1;32m    571\u001b[0m     ckpt_path,\n\u001b[1;32m    572\u001b[0m     model_provided\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m    573\u001b[0m     model_connected\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m    574\u001b[0m )\n\u001b[0;32m--> 575\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mckpt_path\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    577\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstopped\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:982\u001b[0m, in \u001b[0;36mTrainer._run\u001b[0;34m(self, model, ckpt_path)\u001b[0m\n\u001b[1;32m    979\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m    980\u001b[0m \u001b[38;5;66;03m# RUN THE TRAINER\u001b[39;00m\n\u001b[1;32m    981\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[0;32m--> 982\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_run_stage\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    984\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n\u001b[1;32m    985\u001b[0m \u001b[38;5;66;03m# POST-Training CLEAN UP\u001b[39;00m\n\u001b[1;32m    986\u001b[0m \u001b[38;5;66;03m# ----------------------------\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:1026\u001b[0m, in \u001b[0;36mTrainer._run_stage\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1025\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mautograd\u001b[38;5;241m.\u001b[39mset_detect_anomaly(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_detect_anomaly):\n\u001b[0;32m-> 1026\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1027\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:216\u001b[0m, in \u001b[0;36m_FitLoop.run\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    215\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_start()\n\u001b[0;32m--> 216\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    217\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end()\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/loops/fit_loop.py:455\u001b[0m, in \u001b[0;36m_FitLoop.advance\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    454\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_data_fetcher \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 455\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mepoch_loop\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data_fetcher\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py:150\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.run\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m    149\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 150\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43madvance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata_fetcher\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    151\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mon_advance_end(data_fetcher)\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/loops/training_epoch_loop.py:320\u001b[0m, in \u001b[0;36m_TrainingEpochLoop.advance\u001b[0;34m(self, data_fetcher)\u001b[0m\n\u001b[1;32m    318\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mlightning_module\u001b[38;5;241m.\u001b[39mautomatic_optimization:\n\u001b[1;32m    319\u001b[0m     \u001b[38;5;66;03m# in automatic optimization, there can only be one optimizer\u001b[39;00m\n\u001b[0;32m--> 320\u001b[0m     batch_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautomatic_optimization\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizers\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    321\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:192\u001b[0m, in \u001b[0;36m_AutomaticOptimization.run\u001b[0;34m(self, optimizer, batch_idx, kwargs)\u001b[0m\n\u001b[1;32m    187\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m    188\u001b[0m \u001b[38;5;66;03m# BACKWARD PASS\u001b[39;00m\n\u001b[1;32m    189\u001b[0m \u001b[38;5;66;03m# ------------------------------\u001b[39;00m\n\u001b[1;32m    190\u001b[0m \u001b[38;5;66;03m# gradient update with accumulated gradients\u001b[39;00m\n\u001b[1;32m    191\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 192\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    194\u001b[0m result \u001b[38;5;241m=\u001b[39m closure\u001b[38;5;241m.\u001b[39mconsume_result()\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:270\u001b[0m, in \u001b[0;36m_AutomaticOptimization._optimizer_step\u001b[0;34m(self, batch_idx, train_step_and_backward_closure)\u001b[0m\n\u001b[1;32m    269\u001b[0m \u001b[38;5;66;03m# model hook\u001b[39;00m\n\u001b[0;32m--> 270\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_lightning_module_hook\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    271\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    272\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43moptimizer_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    273\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcurrent_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    274\u001b[0m \u001b[43m    \u001b[49m\u001b[43mbatch_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    275\u001b[0m \u001b[43m    \u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    276\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtrain_step_and_backward_closure\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    277\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    279\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m should_accumulate:\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:171\u001b[0m, in \u001b[0;36m_call_lightning_module_hook\u001b[0;34m(trainer, hook_name, pl_module, *args, **kwargs)\u001b[0m\n\u001b[1;32m    170\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[LightningModule]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpl_module\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 171\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    173\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/core/module.py:1302\u001b[0m, in \u001b[0;36mLightningModule.optimizer_step\u001b[0;34m(self, epoch, batch_idx, optimizer, optimizer_closure)\u001b[0m\n\u001b[1;32m   1278\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124mr\u001b[39m\u001b[38;5;124;03m\"\"\"Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls\u001b[39;00m\n\u001b[1;32m   1279\u001b[0m \u001b[38;5;124;03mthe optimizer.\u001b[39;00m\n\u001b[1;32m   1280\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1300\u001b[0m \n\u001b[1;32m   1301\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m-> 1302\u001b[0m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptimizer_closure\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py:154\u001b[0m, in \u001b[0;36mLightningOptimizer.step\u001b[0;34m(self, closure, **kwargs)\u001b[0m\n\u001b[1;32m    153\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_strategy \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 154\u001b[0m step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_strategy\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_optimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    156\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_on_after_step()\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:239\u001b[0m, in \u001b[0;36mStrategy.optimizer_step\u001b[0;34m(self, optimizer, closure, model, **kwargs)\u001b[0m\n\u001b[1;32m    238\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(model, pl\u001b[38;5;241m.\u001b[39mLightningModule)\n\u001b[0;32m--> 239\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprecision_plugin\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43moptimizer_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43moptimizer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision.py:123\u001b[0m, in \u001b[0;36mPrecision.optimizer_step\u001b[0;34m(self, optimizer, model, closure, **kwargs)\u001b[0m\n\u001b[1;32m    122\u001b[0m closure \u001b[38;5;241m=\u001b[39m partial(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_wrap_closure, model, optimizer, closure)\n\u001b[0;32m--> 123\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43moptimizer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m(\u001b[49m\u001b[43mclosure\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/torch/optim/optimizer.py:487\u001b[0m, in \u001b[0;36mOptimizer.profile_hook_step.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    483\u001b[0m             \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(\n\u001b[1;32m    484\u001b[0m                 \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m must return None or a tuple of (new_args, new_kwargs), but got \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mresult\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    485\u001b[0m             )\n\u001b[0;32m--> 487\u001b[0m out \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    488\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_optimizer_step_code()\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/torch/optim/optimizer.py:91\u001b[0m, in \u001b[0;36m_use_grad_for_differentiable.<locals>._use_grad\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m     90\u001b[0m     torch\u001b[38;5;241m.\u001b[39m_dynamo\u001b[38;5;241m.\u001b[39mgraph_break()\n\u001b[0;32m---> 91\u001b[0m     ret \u001b[38;5;241m=\u001b[39m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     92\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/torch/optim/adam.py:202\u001b[0m, in \u001b[0;36mAdam.step\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m    201\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39menable_grad():\n\u001b[0;32m--> 202\u001b[0m         loss \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    204\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m group \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mparam_groups:\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/plugins/precision/precision.py:109\u001b[0m, in \u001b[0;36mPrecision._wrap_closure\u001b[0;34m(self, model, optimizer, closure)\u001b[0m\n\u001b[1;32m    102\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step``\u001b[39;00m\n\u001b[1;32m    103\u001b[0m \u001b[38;5;124;03mhook is called.\u001b[39;00m\n\u001b[1;32m    104\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    107\u001b[0m \n\u001b[1;32m    108\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m--> 109\u001b[0m closure_result \u001b[38;5;241m=\u001b[39m \u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    110\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_after_closure(model, optimizer)\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:146\u001b[0m, in \u001b[0;36mClosure.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    144\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m    145\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Optional[Tensor]:\n\u001b[0;32m--> 146\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mclosure\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    147\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_result\u001b[38;5;241m.\u001b[39mloss\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/torch/utils/_contextlib.py:116\u001b[0m, in \u001b[0;36mcontext_decorator.<locals>.decorate_context\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    115\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[0;32m--> 116\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:131\u001b[0m, in \u001b[0;36mClosure.closure\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    128\u001b[0m \u001b[38;5;129m@override\u001b[39m\n\u001b[1;32m    129\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39menable_grad()\n\u001b[1;32m    130\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mclosure\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m ClosureResult:\n\u001b[0;32m--> 131\u001b[0m     step_output \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_step_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    133\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m step_output\u001b[38;5;241m.\u001b[39mclosure_loss \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/loops/optimization/automatic.py:319\u001b[0m, in \u001b[0;36m_AutomaticOptimization._training_step\u001b[0;34m(self, kwargs)\u001b[0m\n\u001b[1;32m    317\u001b[0m trainer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\n\u001b[0;32m--> 319\u001b[0m training_step_output \u001b[38;5;241m=\u001b[39m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_strategy_hook\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrainer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtraining_step\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    320\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39mpost_training_step()  \u001b[38;5;66;03m# unused hook - call anyway for backward compatibility\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:323\u001b[0m, in \u001b[0;36m_call_strategy_hook\u001b[0;34m(trainer, hook_name, *args, **kwargs)\u001b[0m\n\u001b[1;32m    322\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m trainer\u001b[38;5;241m.\u001b[39mprofiler\u001b[38;5;241m.\u001b[39mprofile(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m[Strategy]\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mtrainer\u001b[38;5;241m.\u001b[39mstrategy\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__class__\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mhook_name\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 323\u001b[0m     output \u001b[38;5;241m=\u001b[39m \u001b[43mfn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    325\u001b[0m \u001b[38;5;66;03m# restore current_fx when nested context\u001b[39;00m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/strategies/strategy.py:391\u001b[0m, in \u001b[0;36mStrategy.training_step\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    390\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_redirection(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodel, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlightning_module, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtraining_step\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m--> 391\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlightning_module\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "Cell \u001b[0;32mIn[24], line 158\u001b[0m, in \u001b[0;36mHybridGRUDecoder.training_step\u001b[0;34m(self, batch, batch_idx)\u001b[0m\n\u001b[1;32m    157\u001b[0m \u001b[38;5;66;03m# Forward pass\u001b[39;00m\n\u001b[0;32m--> 158\u001b[0m pred, mfcc_pred, lm_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mforward\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdayIdx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msentence\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43msentence\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    159\u001b[0m ctc_loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencoder\u001b[38;5;241m.\u001b[39mctc_loss(\n\u001b[1;32m    160\u001b[0m     torch\u001b[38;5;241m.\u001b[39mpermute(pred\u001b[38;5;241m.\u001b[39mlog_softmax(\u001b[38;5;241m2\u001b[39m), [\u001b[38;5;241m1\u001b[39m, \u001b[38;5;241m0\u001b[39m, \u001b[38;5;241m2\u001b[39m]),\n\u001b[1;32m    161\u001b[0m     y,\n\u001b[1;32m    162\u001b[0m     ((X_len \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencoder\u001b[38;5;241m.\u001b[39mkernelLen) \u001b[38;5;241m/\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mencoder\u001b[38;5;241m.\u001b[39mstrideLen)\u001b[38;5;241m.\u001b[39mto(torch\u001b[38;5;241m.\u001b[39mint32),\n\u001b[1;32m    163\u001b[0m     y_len,\n\u001b[1;32m    164\u001b[0m )\n",
      "Cell \u001b[0;32mIn[24], line 106\u001b[0m, in \u001b[0;36mHybridGRUDecoder.forward\u001b[0;34m(self, neuralInput, dayIdx, sentence)\u001b[0m\n\u001b[1;32m    105\u001b[0m decoder_input_ids \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtokenizer(sentence, return_tensors\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpt\u001b[39m\u001b[38;5;124m\"\u001b[39m, padding\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\u001b[38;5;241m.\u001b[39minput_ids\n\u001b[0;32m--> 106\u001b[0m decoder_input_ids \u001b[38;5;241m=\u001b[39m \u001b[43mdecoder_input_ids\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    108\u001b[0m lm_outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mlanguage_model(encoder_outputs\u001b[38;5;241m=\u001b[39m(encoder_outputs, encoder_outputs), labels \u001b[38;5;241m=\u001b[39m decoder_input_ids)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: ",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[26], line 26\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[38;5;66;03m# Train model\u001b[39;00m\n\u001b[1;32m     24\u001b[0m trainer \u001b[38;5;241m=\u001b[39m pl\u001b[38;5;241m.\u001b[39mTrainer(max_epochs\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m100\u001b[39m,devices \u001b[38;5;241m=\u001b[39m[\u001b[38;5;241m0\u001b[39m], callbacks\u001b[38;5;241m=\u001b[39m[checkpoint_callback], logger\u001b[38;5;241m=\u001b[39mwandb_logger)\n\u001b[0;32m---> 26\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtest_loader\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     28\u001b[0m \u001b[38;5;66;03m#reload state_dict of best model\u001b[39;00m\n\u001b[1;32m     29\u001b[0m \u001b[38;5;66;03m# model.load_state_dict(torch.load(f\".checkpoints/mfcc_sm_gru_ctc/best_model.ckpt\")[\"state_dict\"])\u001b[39;00m\n\u001b[1;32m     30\u001b[0m \n\u001b[1;32m     31\u001b[0m \n\u001b[1;32m     32\u001b[0m \u001b[38;5;66;03m# close wandb logger\u001b[39;00m\n\u001b[1;32m     33\u001b[0m wandb\u001b[38;5;241m.\u001b[39mfinish()\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py:539\u001b[0m, in \u001b[0;36mTrainer.fit\u001b[0;34m(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)\u001b[0m\n\u001b[1;32m    537\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mstatus \u001b[38;5;241m=\u001b[39m TrainerStatus\u001b[38;5;241m.\u001b[39mRUNNING\n\u001b[1;32m    538\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtraining \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 539\u001b[0m \u001b[43mcall\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_and_handle_interrupt\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    540\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_fit_impl\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtrain_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval_dataloaders\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdatamodule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mckpt_path\u001b[49m\n\u001b[1;32m    541\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/pytorch_lightning/trainer/call.py:64\u001b[0m, in \u001b[0;36m_call_and_handle_interrupt\u001b[0;34m(trainer, trainer_fn, *args, **kwargs)\u001b[0m\n\u001b[1;32m     62\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(launcher, _SubprocessScriptLauncher):\n\u001b[1;32m     63\u001b[0m         launcher\u001b[38;5;241m.\u001b[39mkill(_get_sigkill_signal())\n\u001b[0;32m---> 64\u001b[0m     \u001b[43mexit\u001b[49m(\u001b[38;5;241m1\u001b[39m)\n\u001b[1;32m     66\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mBaseException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m exception:\n\u001b[1;32m     67\u001b[0m     _interrupt(trainer, exception)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'exit' is not defined"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7fbba6f35550>> (for post_run_cell), with arguments args (<ExecutionResult object at 7fbb890703d0, execution_count=26 error_before_exec=None error_in_exec=name 'exit' is not defined info=<ExecutionInfo object at 7fbb89070130, raw_cell=\"\n",
      "wandb_logger = WandbLogger(project=\"ECOG_Sentence..\" store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2BXXXXXX_bci/data/XXXXXX/speech_decoding_BCI/GRU_CTC_BART.ipynb#X11sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:\n"
     ]
    },
    {
     "ename": "BrokenPipeError",
     "evalue": "[Errno 32] Broken pipe",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mBrokenPipeError\u001b[0m                           Traceback (most recent call last)",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/wandb_init.py:538\u001b[0m, in \u001b[0;36m_WandbInit._pause_backend\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    536\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mnotebook\u001b[38;5;241m.\u001b[39msave_ipynb():  \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m    537\u001b[0m     \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrun \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 538\u001b[0m     res \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlog_code\u001b[49m\u001b[43m(\u001b[49m\u001b[43mroot\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m    539\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_logger\u001b[38;5;241m.\u001b[39minfo(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msaved code: \u001b[39m\u001b[38;5;132;01m%s\u001b[39;00m\u001b[38;5;124m\"\u001b[39m, res)  \u001b[38;5;66;03m# type: ignore\u001b[39;00m\n\u001b[1;32m    540\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbackend\u001b[38;5;241m.\u001b[39minterface \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/wandb_run.py:393\u001b[0m, in \u001b[0;36m_run_decorator._noop_on_finish.<locals>.decorator_fn.<locals>.wrapper_fn\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    390\u001b[0m \u001b[38;5;129m@functools\u001b[39m\u001b[38;5;241m.\u001b[39mwraps(func)\n\u001b[1;32m    391\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mwrapper_fn\u001b[39m(\u001b[38;5;28mself\u001b[39m: \u001b[38;5;28mtype\u001b[39m[Run], \u001b[38;5;241m*\u001b[39margs: Any, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m    392\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mgetattr\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m_is_finished\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[0;32m--> 393\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    395\u001b[0m     default_message \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m    396\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mRun (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mid\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m) is finished. The call to `\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfunc\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m` will be ignored. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    397\u001b[0m         \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPlease make sure that you are using an active run.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m    398\u001b[0m     )\n\u001b[1;32m    399\u001b[0m     resolved_message \u001b[38;5;241m=\u001b[39m message \u001b[38;5;129;01mor\u001b[39;00m default_message\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/wandb_run.py:383\u001b[0m, in \u001b[0;36m_run_decorator._attach.<locals>.wrapper\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    381\u001b[0m         \u001b[38;5;28;01mraise\u001b[39;00m e\n\u001b[1;32m    382\u001b[0m     \u001b[38;5;28mcls\u001b[39m\u001b[38;5;241m.\u001b[39m_is_attaching \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[0;32m--> 383\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/wandb_run.py:1098\u001b[0m, in \u001b[0;36mRun.log_code\u001b[0;34m(self, root, name, include_fn, exclude_fn)\u001b[0m\n\u001b[1;32m   1093\u001b[0m     wandb\u001b[38;5;241m.\u001b[39mtermwarn(\n\u001b[1;32m   1094\u001b[0m         \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mNo relevant files were detected in the specified directory. No code will be logged to your run.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m   1095\u001b[0m     )\n\u001b[1;32m   1096\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m-> 1098\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_log_artifact\u001b[49m\u001b[43m(\u001b[49m\u001b[43mart\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/wandb_run.py:3285\u001b[0m, in \u001b[0;36mRun._log_artifact\u001b[0;34m(self, artifact_or_path, name, type, aliases, tags, distributed_id, finalize, is_user_created, use_after_commit)\u001b[0m\n\u001b[1;32m   3283\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backend \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backend\u001b[38;5;241m.\u001b[39minterface:\n\u001b[1;32m   3284\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_settings\u001b[38;5;241m.\u001b[39m_offline:\n\u001b[0;32m-> 3285\u001b[0m         future \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_backend\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43minterface\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcommunicate_artifact\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   3286\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3287\u001b[0m \u001b[43m            \u001b[49m\u001b[43martifact\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3288\u001b[0m \u001b[43m            \u001b[49m\u001b[43maliases\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3289\u001b[0m \u001b[43m            \u001b[49m\u001b[43mtags\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3290\u001b[0m \u001b[43m            \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3291\u001b[0m \u001b[43m            \u001b[49m\u001b[43mfinalize\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfinalize\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3292\u001b[0m \u001b[43m            \u001b[49m\u001b[43mis_user_created\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_user_created\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3293\u001b[0m \u001b[43m            \u001b[49m\u001b[43muse_after_commit\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_after_commit\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   3294\u001b[0m \u001b[43m        \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3295\u001b[0m         artifact\u001b[38;5;241m.\u001b[39m_set_save_future(future, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_public_api()\u001b[38;5;241m.\u001b[39mclient)\n\u001b[1;32m   3296\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/interface/interface.py:592\u001b[0m, in \u001b[0;36mInterfaceBase.communicate_artifact\u001b[0;34m(self, run, artifact, aliases, tags, history_step, is_user_created, use_after_commit, finalize)\u001b[0m\n\u001b[1;32m    590\u001b[0m     log_artifact\u001b[38;5;241m.\u001b[39mhistory_step \u001b[38;5;241m=\u001b[39m history_step\n\u001b[1;32m    591\u001b[0m log_artifact\u001b[38;5;241m.\u001b[39mstaging_dir \u001b[38;5;241m=\u001b[39m get_staging_dir()\n\u001b[0;32m--> 592\u001b[0m resp \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_communicate_artifact\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlog_artifact\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    593\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m resp\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/interface/interface_shared.py:415\u001b[0m, in \u001b[0;36mInterfaceShared._communicate_artifact\u001b[0;34m(self, log_artifact)\u001b[0m\n\u001b[1;32m    413\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_communicate_artifact\u001b[39m(\u001b[38;5;28mself\u001b[39m, log_artifact: pb\u001b[38;5;241m.\u001b[39mLogArtifactRequest) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Any:\n\u001b[1;32m    414\u001b[0m     rec \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_request(log_artifact\u001b[38;5;241m=\u001b[39mlog_artifact)\n\u001b[0;32m--> 415\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_communicate_async\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrec\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/interface/interface_sock.py:56\u001b[0m, in \u001b[0;36mInterfaceSock._communicate_async\u001b[0;34m(self, rec, local)\u001b[0m\n\u001b[1;32m     54\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process_check \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_process\u001b[38;5;241m.\u001b[39mis_alive():\n\u001b[1;32m     55\u001b[0m     \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mThe wandb backend process has shutdown\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m---> 56\u001b[0m future \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_router\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend_and_receive\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrec\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlocal\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlocal\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     57\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m future\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/interface/router.py:92\u001b[0m, in \u001b[0;36mMessageRouter.send_and_receive\u001b[0;34m(self, rec, local)\u001b[0m\n\u001b[1;32m     89\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock:\n\u001b[1;32m     90\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pending_reqs[rec\u001b[38;5;241m.\u001b[39muuid] \u001b[38;5;241m=\u001b[39m future\n\u001b[0;32m---> 92\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_message\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrec\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     94\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m future\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/interface/router_sock.py:36\u001b[0m, in \u001b[0;36mMessageSockRouter._send_message\u001b[0;34m(self, record)\u001b[0m\n\u001b[1;32m     35\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_send_message\u001b[39m(\u001b[38;5;28mself\u001b[39m, record: \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mpb.Record\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m---> 36\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sock_client\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend_record_communicate\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrecord\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/lib/sock_client.py:217\u001b[0m, in \u001b[0;36mSockClient.send_record_communicate\u001b[0;34m(self, record)\u001b[0m\n\u001b[1;32m    215\u001b[0m server_req \u001b[38;5;241m=\u001b[39m spb\u001b[38;5;241m.\u001b[39mServerRequest()\n\u001b[1;32m    216\u001b[0m server_req\u001b[38;5;241m.\u001b[39mrecord_communicate\u001b[38;5;241m.\u001b[39mCopyFrom(record)\n\u001b[0;32m--> 217\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend_server_request\u001b[49m\u001b[43m(\u001b[49m\u001b[43mserver_req\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/lib/sock_client.py:154\u001b[0m, in \u001b[0;36mSockClient.send_server_request\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m    153\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21msend_server_request\u001b[39m(\u001b[38;5;28mself\u001b[39m, msg: Any) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 154\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_send_message\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmsg\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/lib/sock_client.py:151\u001b[0m, in \u001b[0;36mSockClient._send_message\u001b[0;34m(self, msg)\u001b[0m\n\u001b[1;32m    149\u001b[0m header \u001b[38;5;241m=\u001b[39m struct\u001b[38;5;241m.\u001b[39mpack(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m<BI\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mord\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mW\u001b[39m\u001b[38;5;124m\"\u001b[39m), raw_size)\n\u001b[1;32m    150\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_lock:\n\u001b[0;32m--> 151\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sendall_with_error_handle\u001b[49m\u001b[43m(\u001b[49m\u001b[43mheader\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m+\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/anaconda3/envs/evo/lib/python3.9/site-packages/wandb/sdk/lib/sock_client.py:130\u001b[0m, in \u001b[0;36mSockClient._sendall_with_error_handle\u001b[0;34m(self, data)\u001b[0m\n\u001b[1;32m    128\u001b[0m start_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mmonotonic()\n\u001b[1;32m    129\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 130\u001b[0m     sent \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_sock\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msend\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    131\u001b[0m     \u001b[38;5;66;03m# sent equal to 0 indicates a closed socket\u001b[39;00m\n\u001b[1;32m    132\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m sent \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n",
      "\u001b[0;31mBrokenPipeError\u001b[0m: [Errno 32] Broken pipe"
     ]
    }
   ],
   "source": [
    "\n",
    "wandb_logger = WandbLogger(project=\"ECOG_Sentence_dataset\", name=f\"{output_name}\",\n",
    "                            reinit=True)\n",
    "\n",
    "# Define ModelCheckpoint to save the best model based on validation loss\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    monitor=\"val_CER\",  # Ensure your validation step logs \"val_loss\"\n",
    "    mode=\"min\",          # Save the model with the lowest validation loss\n",
    "    save_top_k=5,        # Keep only the best model\n",
    "    dirpath=f\".checkpoints/{output_name}/\",  # Directory to save checkpoints\n",
    "    filename=f\"best_model\",  # Model filename\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "# Define EarlyStopping callback with patience of 3 epochs\n",
    "early_stopping_callback = EarlyStopping(\n",
    "    monitor=\"val_CER\",\n",
    "    patience=10,   # Stop training if no improvement in 3 epochs\n",
    "    mode=\"min\",\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "\n",
    "# Train model\n",
    "trainer = pl.Trainer(max_epochs=100,devices =[0], callbacks=[checkpoint_callback], logger=wandb_logger)\n",
    "\n",
    "trainer.fit(model, train_loader, test_loader)\n",
    "\n",
    "#reload state_dict of best model\n",
    "# model.load_state_dict(torch.load(f\".checkpoints/mfcc_sm_gru_ctc/best_model.ckpt\")[\"state_dict\"])\n",
    "\n",
    "\n",
    "# close wandb logger\n",
    "wandb.finish()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "90ac5b5a",
   "metadata": {},
   "outputs": [],
   "source": [
    "#save model weights\n",
    "torch.save(model.state_dict(), f\".checkpoints/{output_name}/final_model_weights.pth\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "evo",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
