{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "05d88eb0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rnn_crossdataset import CrossGRUDecoder, CrossGRUDecoderMLPProjection\n",
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "sys.path.append(\"../../\")\n",
    "from model_training.augmented_dataset import AugmentedNeuralTextDataset, BaseNeuralTextDataset, idsToPhonemes, collate_fn_flexible, PHONE_DEF\n",
    "import pickle\n",
    "from config import TRAIN_CARD_DATASET, TRAIN_WILLET_DATASET, TEST_CARD_DATASET, TEST_WILLET_DATASET, VAL_CARD_DATASET, COMPETITION_WILLET_DATASET\n",
    "import torch\n",
    "from torch.utils.data import ConcatDataset\n",
    "import os\n",
    "from data_augmentations import gauss_smooth\n",
    "\n",
    "import pytorch_lightning as pl\n",
    "from pytorch_lightning.loggers import WandbLogger\n",
    "import wandb\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping\n",
    "import tqdm\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "570b8793",
   "metadata": {},
   "outputs": [],
   "source": [
    "#load datasets with pickle\n",
    "with open(TRAIN_CARD_DATASET, \"rb\") as f:\n",
    "    card_train_data = pickle.load(f)\n",
    "with open(TEST_CARD_DATASET, \"rb\") as f:\n",
    "    card_test_data = pickle.load(f)\n",
    "with open(VAL_CARD_DATASET, \"rb\") as f:\n",
    "    card_val_data = pickle.load(f)\n",
    "\n",
    "#willet\n",
    "with open(TRAIN_WILLET_DATASET, \"rb\") as f:\n",
    "    willet_train_data = pickle.load(f)\n",
    "with open(TEST_WILLET_DATASET, \"rb\") as f:\n",
    "    willet_test_data = pickle.load(f)\n",
    "with open(COMPETITION_WILLET_DATASET, \"rb\") as f:\n",
    "    willet_competition_data = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b43a0534",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_card_dataset = BaseNeuralTextDataset(card_train_data,  source_dataset=\"card\")\n",
    "test_card_dataset = BaseNeuralTextDataset(card_test_data, eval_type=\"test\", source_dataset=\"card\")\n",
    "val_card_dataset = BaseNeuralTextDataset(card_val_data, source_dataset=\"card\")\n",
    "\n",
    "train_willet_dataset = BaseNeuralTextDataset(willet_train_data,source_dataset=\"willet\") \n",
    "test_willet_dataset = BaseNeuralTextDataset(willet_test_data,source_dataset=\"willet\")\n",
    "competition_willet_dataset = BaseNeuralTextDataset(willet_competition_data, eval_type=\"test\",source_dataset=\"willet\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7a71b29a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# concat both train datasets using concatenate function from torch.utils.data\n",
    "combined_train_dataset = ConcatDataset([train_card_dataset, train_willet_dataset])\n",
    "combined_val_dataset = ConcatDataset([val_card_dataset, test_willet_dataset])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "baa88825",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "DataLoaders created.\n"
     ]
    }
   ],
   "source": [
    "train_loader = torch.utils.data.DataLoader(combined_train_dataset, batch_size=64, shuffle=True, collate_fn=collate_fn_flexible)\n",
    "val_card_loader = torch.utils.data.DataLoader(val_card_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_flexible)\n",
    "test_card_loader = torch.utils.data.DataLoader(test_card_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_flexible)\n",
    "test_willet_loader = torch.utils.data.DataLoader(test_willet_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_flexible)\n",
    "competition_willet_loader = torch.utils.data.DataLoader(competition_willet_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_flexible)\n",
    "\n",
    "val_loader = torch.utils.data.DataLoader(combined_val_dataset, batch_size=64, shuffle=False, collate_fn=collate_fn_flexible)\n",
    "print(\"\\nDataLoaders created.\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3f58b841",
   "metadata": {},
   "source": [
    "## load pretrained model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7c1f21e7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "_IncompatibleKeys(missing_keys=['willet_day_weights.0', 'willet_day_weights.1', 'willet_day_weights.2', 'willet_day_weights.3', 'willet_day_weights.4', 'willet_day_weights.5', 'willet_day_weights.6', 'willet_day_weights.7', 'willet_day_weights.8', 'willet_day_weights.9', 'willet_day_weights.10', 'willet_day_weights.11', 'willet_day_weights.12', 'willet_day_weights.13', 'willet_day_weights.14', 'willet_day_weights.15', 'willet_day_weights.16', 'willet_day_weights.17', 'willet_day_weights.18', 'willet_day_weights.19', 'willet_day_weights.20', 'willet_day_weights.21', 'willet_day_weights.22', 'willet_day_weights.23', 'willet_day_biases.0', 'willet_day_biases.1', 'willet_day_biases.2', 'willet_day_biases.3', 'willet_day_biases.4', 'willet_day_biases.5', 'willet_day_biases.6', 'willet_day_biases.7', 'willet_day_biases.8', 'willet_day_biases.9', 'willet_day_biases.10', 'willet_day_biases.11', 'willet_day_biases.12', 'willet_day_biases.13', 'willet_day_biases.14', 'willet_day_biases.15', 'willet_day_biases.16', 'willet_day_biases.17', 'willet_day_biases.18', 'willet_day_biases.19', 'willet_day_biases.20', 'willet_day_biases.21', 'willet_day_biases.22', 'willet_day_biases.23'], unexpected_keys=[])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_path = \"/data/data/matteo/nejm-brain-to-text/data/t15_pretrained_rnn_baseline\"\n",
    "\n",
    "# define model\n",
    "model = CrossGRUDecoder(\n",
    "    neural_dim = 512,\n",
    "    n_units = 768, \n",
    "    n_days = 45,\n",
    "    n_classes = 41,\n",
    "    rnn_dropout = 0.4,\n",
    "    input_dropout = 0.2,\n",
    "    n_layers = 5,\n",
    "    patch_size =14,\n",
    "    patch_stride = 4,\n",
    "    n_willet_days = 24\n",
    ")\n",
    "\n",
    "# load model weights\n",
    "checkpoint = torch.load(os.path.join(model_path, 'checkpoint/best_checkpoint'), weights_only=False,map_location='cpu')\n",
    "# rename keys to not start with \"module.\" (happens if model was saved with DataParallel)\n",
    "for key in list(checkpoint['model_state_dict'].keys()):\n",
    "    checkpoint['model_state_dict'][key.replace(\"module.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
    "    checkpoint['model_state_dict'][key.replace(\"_orig_mod.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
    "model.load_state_dict(checkpoint['model_state_dict'],strict=False, )  \n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ee7382b2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CrossGRUDecoder(\n",
       "  (day_layer_activation): Softsign()\n",
       "  (day_weights): ParameterList(\n",
       "      (0): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (1): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (2): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (3): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (4): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (5): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (6): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (7): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (8): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (9): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (10): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (11): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (12): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (13): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (14): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (15): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (16): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (17): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (18): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (19): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (20): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (21): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (22): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (23): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (24): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (25): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (26): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (27): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (28): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (29): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (30): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (31): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (32): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (33): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (34): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (35): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (36): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (37): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (38): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (39): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (40): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (41): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (42): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (43): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (44): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "  )\n",
       "  (day_biases): ParameterList(\n",
       "      (0): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (1): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (2): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (3): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (4): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (5): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (6): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (7): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (8): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (9): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (10): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (11): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (12): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (13): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (14): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (15): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (16): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (17): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (18): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (19): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (20): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (21): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (22): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (23): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (24): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (25): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (26): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (27): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (28): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (29): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (30): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (31): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (32): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (33): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (34): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (35): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (36): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (37): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (38): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (39): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (40): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (41): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (42): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (43): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (44): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "  )\n",
       "  (willet_day_weights): ParameterList(\n",
       "      (0): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (1): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (2): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (3): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (4): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (5): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (6): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (7): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (8): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (9): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (10): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (11): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (12): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (13): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (14): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (15): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (16): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (17): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (18): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (19): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (20): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (21): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (22): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "      (23): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "  )\n",
       "  (willet_day_biases): ParameterList(\n",
       "      (0): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (1): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (2): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (3): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (4): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (5): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (6): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (7): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (8): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (9): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (10): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (11): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (12): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (13): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (14): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (15): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (16): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (17): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (18): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (19): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (20): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (21): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (22): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "      (23): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "  )\n",
       "  (day_layer_dropout): Dropout(p=0.2, inplace=False)\n",
       "  (gru): GRU(7168, 768, num_layers=5, batch_first=True, dropout=0.4)\n",
       "  (out): Linear(in_features=768, out_features=41, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = \"cuda:0\"\n",
    "\n",
    "model.to(device)\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "83841954",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = next(iter(train_loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "08dfbe63",
   "metadata": {},
   "outputs": [],
   "source": [
    "## get all items in batch with source_dataset equal to \"card\"\n",
    "card_items = [i for i, src in enumerate(batch[\"source_dataset\"]) if src == \"card\"]\n",
    "willet_items = [i for i, src in enumerate(batch[\"source_dataset\"]) if src == \"willet\"]\n",
    "\n",
    "\n",
    "neural_features_card = batch[\"neural_features\"][card_items]\n",
    "neural_features_willet = batch[\"neural_features\"][willet_items]\n",
    "\n",
    "day_card = batch[\"day\"][card_items]\n",
    "day_willet = batch[\"day\"][willet_items]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d7e16435",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "30"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(card_items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "caee43ac",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([30, 1375, 512]), torch.Size([30]))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "neural_features_card.shape, day_card.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "88eb2107",
   "metadata": {},
   "outputs": [],
   "source": [
    "out = model(batch[\"neural_features\"].to(device), batch[\"day\"].to(device), card_indices=card_items, willet_indices=willet_items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8181ec94",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(40, device='cuda:0')"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out.argmax(dim=-1).max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e40e9472",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['AY', 'SIL', 'W', 'UH', 'D', 'SIL', 'R', 'AE', 'DH', 'ER', 'SIL', 'W', 'EY', 'T', 'SIL', 'F', 'AO', 'R', 'SIL', 'AH', 'SIL', 'M', 'AH', 'N', 'TH', 'SIL']\n",
      "I would rather wait for a month.\n",
      "card\n"
     ]
    }
   ],
   "source": [
    "idx=5\n",
    "print(idsToPhonemes(torch.unique_consecutive(out[idx].argmax(dim=-1)).cpu().numpy()))\n",
    "print(batch[\"sentence_label\"][idx],)\n",
    "print(batch[\"source_dataset\"][idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "69c9b95f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['neural_features', 'neural_lengths', 'n_time_steps', 'neural_mask', 'seq_class_ids', 'seq_lengths', 'seq_mask', 'day', 'sentence_label', 'transcriptions', 'source_dataset'])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "8ddd9f5c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 379,  257,  956,  489,  173,  853,  385,  723,  887, 1308,  247,  223,\n",
       "        1347,  529,  264,  360,  322, 1333,  449,  189,  710,  484,  565,  226,\n",
       "        1358,  399,  230, 1375,  706,  180,  311,  218,  367,  848,  749, 1336,\n",
       "         257,  236,  250,  360, 1187,  717,  706,  707,  300, 1098,  396,  566,\n",
       "         239,  249, 1305, 1121,  866,  178,  262,  225,  238,  745,  259,  392,\n",
       "         717, 1072,  252,  701])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch[\"neural_lengths\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "f20446ae",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 379,  257,  956,  489,  173,  853,  385,  723,  887, 1308,  247,  223,\n",
       "        1347,  529,  264,  360,  322, 1333,  449,  189,  710,  484,  565,  226,\n",
       "        1358,  399,  230, 1375,  706,  180,  311,  218,  367,  848,  749, 1336,\n",
       "         257,  236,  250,  360, 1187,  717,  706,  707,  300, 1098,  396,  566,\n",
       "         239,  249, 1305, 1121,  866,  178,  262,  225,  238,  745,  259,  392,\n",
       "         717, 1072,  252,  701])"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch[\"n_time_steps\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "88f3e9a5",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.nn.utils.rnn import pad_sequence\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from edit_distance import SequenceMatcher\n",
    "\n",
    "class LightningCrossGRUDecoder_V2(pl.LightningModule):\n",
    "    def __init__(\n",
    "        self,\n",
    "        neural_dim,\n",
    "        n_units=768,\n",
    "        n_classes=41,\n",
    "        n_days=45,\n",
    "        rnn_dropout=0.4,\n",
    "        input_dropout=0.2,\n",
    "        n_layers = 5,\n",
    "        patch_stride =4,\n",
    "        patch_size = 14,\n",
    "        n_willet_days = 24,\n",
    "        pretrained_card_path=None,\n",
    "        learning_rate=3e-4,\n",
    "        white_noise_SD=0.8,\n",
    "        constant_offset_SD=0.2,\n",
    "        weight_decay=1e-5,\n",
    "        smoothing =True,\n",
    "        smoothing_kernel = 100,\n",
    "        smoothing_std = 2,\n",
    "    ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.save_hyperparameters()\n",
    "        self.model = CrossGRUDecoderMLPProjection(\n",
    "            neural_dim=neural_dim, \n",
    "            n_units=n_units,\n",
    "            n_classes=n_classes,\n",
    "            n_days=n_days,\n",
    "            n_willet_days=n_willet_days,\n",
    "            rnn_dropout=rnn_dropout,\n",
    "            input_dropout=input_dropout,\n",
    "            n_layers=n_layers,\n",
    "            patch_size=patch_size,\n",
    "            patch_stride=patch_stride\n",
    "        )\n",
    "\n",
    "\n",
    "        if pretrained_card_path is not None:\n",
    "            # Load pretrained weights for the card-specific parameters\n",
    "            checkpoint = torch.load(os.path.join(pretrained_card_path, 'checkpoint/best_checkpoint'), weights_only=False,map_location='cpu')\n",
    "            # rename keys to not start with \"module.\" (happens if model was saved with DataParallel)\n",
    "            for key in list(checkpoint['model_state_dict'].keys()):\n",
    "                checkpoint['model_state_dict'][key.replace(\"module.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
    "                checkpoint['model_state_dict'][key.replace(\"_orig_mod.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
    "            self.model.load_state_dict(checkpoint['model_state_dict'],strict=False, )  \n",
    "            print(f\"Loaded pretrained weights from {pretrained_card_path}\")\n",
    "        \n",
    "        self.learning_rate = learning_rate\n",
    "        self.white_noise_SD = white_noise_SD\n",
    "        self.constant_offset_SD = constant_offset_SD\n",
    "        self.weight_decay = weight_decay\n",
    "        self.smoothing = smoothing\n",
    "        self.smoothing_kernel = smoothing_kernel\n",
    "        self.smoothing_std = smoothing_std\n",
    "\n",
    "        self.patch_size = patch_size\n",
    "        self.patch_stride = patch_stride \n",
    "\n",
    "        # Loss function\n",
    "        self.ctc_loss = torch.nn.CTCLoss(blank=0, reduction=\"mean\", zero_infinity=True)\n",
    "\n",
    "    def freeze_card_parameters(self):\n",
    "        for name, param in self.model.named_parameters():\n",
    "            if \"willet\" not in name:\n",
    "                param.requires_grad = False\n",
    "        print(\"Froze card parameters.\")\n",
    "\n",
    "    def unfreeze_all_parameters(self):\n",
    "        for param in self.model.parameters():\n",
    "            param.requires_grad = True\n",
    "        print(\"Unfroze all parameters.\")\n",
    "\n",
    "\n",
    "\n",
    "    def get_neural_embeddings(self, x: torch.Tensor,\n",
    "        day_idx: torch.Tensor,\n",
    "        *,\n",
    "        card_indices=None,\n",
    "        willet_indices=None,\n",
    "        states: torch.Tensor = None,\n",
    "        return_state: bool = False,\n",
    "        ) -> torch.Tensor:\n",
    "        \"\"\"\n",
    "        Forward pass of the model.\n",
    "        neuralInput: (batch, time, features)\n",
    "        dayIdx: Session index\n",
    "        \"\"\"\n",
    "\n",
    "        \n",
    "        logits, hidden_states = self.model(\n",
    "            x,\n",
    "            day_idx,\n",
    "            card_indices=card_indices,\n",
    "            willet_indices=willet_indices,\n",
    "            states=states,\n",
    "            return_state=True,\n",
    "        )\n",
    "\n",
    "        return hidden_states\n",
    "\n",
    "    def forward(self, x: torch.Tensor,\n",
    "        day_idx: torch.Tensor,\n",
    "        *,\n",
    "        card_indices=None,\n",
    "        willet_indices=None,\n",
    "        states: torch.Tensor = None,\n",
    "        ):\n",
    "        \"\"\"\n",
    "        Forward pass of the model.\n",
    "        neuralInput: (batch, time, features)\n",
    "        dayIdx: Session index\n",
    "        \"\"\"\n",
    "        logits = self.model(\n",
    "            x,\n",
    "            day_idx,\n",
    "            card_indices=card_indices,\n",
    "            willet_indices=willet_indices,\n",
    "            states=states,\n",
    "            return_state=False,\n",
    "        )\n",
    "\n",
    "        return logits\n",
    "    \n",
    "    def training_step(self, batch, batch_idx):\n",
    "        \"\"\"\n",
    "        Training step - Runs forward pass, computes loss, and returns it for backprop.\n",
    "        \"\"\"\n",
    "\n",
    "        X = batch[\"neural_features\"]\n",
    "        y = batch[\"seq_class_ids\"]\n",
    "        X_len = batch[\"n_time_steps\"]\n",
    "        y_len = batch[\"seq_lengths\"]\n",
    "        sentence = batch[\"sentence_label\"]\n",
    "        dayIdx = batch[\"day\"]\n",
    "\n",
    "        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)\n",
    "\n",
    "        # Noise augmentation\n",
    "        if self.white_noise_SD > 0:\n",
    "            X += torch.randn(X.shape, device=self.device) * self.white_noise_SD\n",
    "        if self.constant_offset_SD > 0:\n",
    "            X += torch.randn([X.shape[0], 1, X.shape[2]], device=self.device) * self.constant_offset_SD\n",
    "\n",
    "        if self.smoothing:\n",
    "            X = gauss_smooth(\n",
    "                inputs = X, \n",
    "                device = self.device,\n",
    "                smooth_kernel_std = self.smoothing_std,\n",
    "                smooth_kernel_size= self.smoothing_kernel,\n",
    "            )\n",
    "\n",
    "        ##compute card and willet indices\n",
    "        card_items = [i for i, src in enumerate(batch[\"source_dataset\"]) if src == \"card\"]  \n",
    "        willet_items = [i for i, src in enumerate(batch[\"source_dataset\"]) if src == \"willet\"]\n",
    "        # Forward pass\n",
    "        pred = self.forward(X, dayIdx, card_indices=card_items, willet_indices=willet_items)\n",
    "\n",
    "        # Compute CTC Loss\n",
    "        loss = self.ctc_loss(\n",
    "            torch.permute(pred.log_softmax(2), [1, 0, 2]),\n",
    "            y,\n",
    "            ((X_len - self.patch_size) / self.patch_stride).to(torch.int32)+1,\n",
    "            y_len,\n",
    "        )\n",
    "\n",
    "        self.log(\"train_loss\", loss, prog_bar=True, on_step=True, on_epoch=True)\n",
    "        return loss\n",
    "\n",
    "    def validation_step(self, batch, batch_idx):\n",
    "        \"\"\"\n",
    "        Validation step - Computes loss and CER.\n",
    "        \"\"\"\n",
    "        X = batch[\"neural_features\"]\n",
    "        y = batch[\"seq_class_ids\"]\n",
    "        X_len = batch[\"n_time_steps\"]\n",
    "        y_len = batch[\"seq_lengths\"]\n",
    "        sentence = batch[\"sentence_label\"]\n",
    "        dayIdx = batch[\"day\"]\n",
    "\n",
    "\n",
    "        X, y, X_len, y_len, dayIdx = X.to(self.device), y.to(self.device), X_len.to(self.device), y_len.to(self.device), dayIdx.to(self.device)\n",
    "\n",
    "        ##compute card and willet indices\n",
    "        card_items = [i for i, src in enumerate(batch[\"source_dataset\"]) if src == \"card\"]  \n",
    "        willet_items = [i for i, src in enumerate(batch[\"source_dataset\"]) if src == \"willet\"]\n",
    "        # Forward pass\n",
    "        pred = self.forward(X, dayIdx, card_indices=card_items, willet_indices=willet_items)\n",
    "\n",
    "\n",
    "        # Compute CTC Loss\n",
    "        loss = self.ctc_loss(\n",
    "            torch.permute(pred.log_softmax(2), [1, 0, 2]),\n",
    "            y,\n",
    "            ((X_len - self.patch_size) / self.patch_stride).to(torch.int32)+1,\n",
    "            y_len,\n",
    "        )\n",
    "\n",
    "        pred_card = pred[card_items]\n",
    "        y_card = y[card_items]\n",
    "        y_len_card = y_len[card_items]\n",
    "        X_len_card = X_len[card_items]\n",
    "\n",
    "        pred_willet = pred[willet_items]\n",
    "        y_willet = y[willet_items]\n",
    "        y_len_willet = y_len[willet_items]\n",
    "        X_len_willet = X_len[willet_items]\n",
    " \n",
    "        if X_len_card.shape[0]>0:\n",
    "\n",
    "            # Compute CER (Phoneme Error Rate)\n",
    "            total_edit_distance, total_seq_length = 0, 0\n",
    "            for i in range(pred_card.shape[0]):\n",
    "                decodedSeq = torch.argmax(pred_card[i, : int(X_len_card[i] / self.patch_stride), :], dim=-1)\n",
    "                decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "                decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "\n",
    "                trueSeq = y_card[i][:y_len_card[i]].cpu().numpy()\n",
    "                matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "                total_edit_distance += matcher.distance()\n",
    "                total_seq_length += len(trueSeq)\n",
    "\n",
    "            cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0\n",
    "            self.log(\"val_CER_card\", cer, prog_bar=True, on_epoch=True)\n",
    "\n",
    "        \n",
    "        if X_len_willet.shape[0]>0:\n",
    "            total_edit_distance, total_seq_length = 0, 0\n",
    "            for i in range(pred_willet.shape[0]):\n",
    "                decodedSeq = torch.argmax(pred_willet[i, : int(X_len_willet[i] / self.patch_stride), :], dim=-1)\n",
    "                decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "                decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "\n",
    "                trueSeq = y_willet[i][:y_len_willet[i]].cpu().numpy()\n",
    "                matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "                total_edit_distance += matcher.distance()\n",
    "                total_seq_length += len(trueSeq)\n",
    "\n",
    "            cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0\n",
    "            self.log(\"val_CER_willet\", cer, prog_bar=True, on_epoch=True)\n",
    "\n",
    "        self.log(\"val_loss\", loss, prog_bar=True, on_epoch=True)\n",
    "        return loss\n",
    "    \n",
    "    def on_after_backward(self):\n",
    "        # This is BEFORE Lightning’s gradient clipping\n",
    "        grad_norm = torch.norm(\n",
    "            torch.stack([\n",
    "                p.grad.detach().data.norm(2)\n",
    "                for p in self.parameters()\n",
    "                if p.grad is not None\n",
    "            ]),\n",
    "            2\n",
    "        )\n",
    "        self.log(\"train_grad_norm\", grad_norm.item(), on_step=True, on_epoch=False, prog_bar=False)\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        \"\"\"\n",
    "        Configures the optimizer and learning rate scheduler.\n",
    "        \"\"\"\n",
    "        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay, betas=(0.9, 0.99),\n",
    "                                      eps=1e-8,) #eps was 0.1\n",
    "\n",
    "        scheduler = ReduceLROnPlateau(optimizer, mode=\"min\", factor=0.5, patience=3, cooldown=2, min_lr=1e-6)\n",
    "        # scheduler = WarmupCosineAnnealingLR(\n",
    "        #     optimizer,\n",
    "        #     warmup_steps=self.warmup_steps,\n",
    "        #     total_steps=self.total_steps,\n",
    "        #     eta_min=1e-6  # Smallest LR\n",
    "        # )\n",
    "        return {\"optimizer\": optimizer, \"lr_scheduler\": scheduler, \"monitor\": \"val_loss\"}\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "4e30694b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded pretrained weights from /data/data/matteo/nejm-brain-to-text/data/t15_pretrained_rnn_baseline\n",
      "Froze card parameters.\n"
     ]
    }
   ],
   "source": [
    "model = LightningCrossGRUDecoder_V2(\n",
    "    neural_dim = 512,\n",
    "    n_units = 768, \n",
    "    n_days = 45,\n",
    "    n_classes = 41,\n",
    "    rnn_dropout = 0.4,\n",
    "    input_dropout = 0.2,\n",
    "    n_layers = 5,\n",
    "    patch_size =14,\n",
    "    patch_stride = 4,\n",
    "    n_willet_days = 24,\n",
    "    pretrained_card_path = \"/data/data/matteo/nejm-brain-to-text/data/t15_pretrained_rnn_baseline\",\n",
    ")\n",
    "\n",
    "model.freeze_card_parameters()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "455c3e70",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_name = \"CrossDatasetGRU_warmup\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4c0323aa",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n",
      "You are using a CUDA device ('NVIDIA H100 80GB HBM3') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mmatteoferrante\u001b[0m to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Using a boolean value for 'reinit' is deprecated. Use 'return_previous' or 'finish_previous' instead.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.21.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20250902_172407-pnl3af5m</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/matteoferrante/B2TXT25/runs/pnl3af5m' target=\"_blank\">CrossDatasetGRU_warmup</a></strong> to <a href='https://wandb.ai/matteoferrante/B2TXT25' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/matteoferrante/B2TXT25' target=\"_blank\">https://wandb.ai/matteoferrante/B2TXT25</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/matteoferrante/B2TXT25/runs/pnl3af5m' target=\"_blank\">https://wandb.ai/matteoferrante/B2TXT25/runs/pnl3af5m</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matteo/anaconda3/envs/b2txt25/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:701: Checkpoint directory /data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n",
      "\n",
      "  | Name     | Type                         | Params | Mode \n",
      "------------------------------------------------------------------\n",
      "0 | model    | CrossGRUDecoderMLPProjection | 51.0 M | train\n",
      "1 | ctc_loss | CTCLoss                      | 0      | train\n",
      "------------------------------------------------------------------\n",
      "6.7 M     Trainable params\n",
      "44.3 M    Non-trainable params\n",
      "51.0 M    Total params\n",
      "204.057   Total estimated model params size (MB)\n",
      "15        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9e2e6a3c963e4491bb21ddd353ad1887",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matteo/anaconda3/envs/b2txt25/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=223` in the `DataLoader` to improve performance.\n",
      "/home/matteo/anaconda3/envs/b2txt25/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 64. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
      "/home/matteo/anaconda3/envs/b2txt25/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=223` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b79b819773ee4f5ea6538e7949599667",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matteo/anaconda3/envs/b2txt25/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 37. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2471066548574c3db1fd7f959bbbf3dd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matteo/anaconda3/envs/b2txt25/lib/python3.10/site-packages/pytorch_lightning/utilities/data.py:79: Trying to infer the `batch_size` from an ambiguous collection. The batch size we found is 2. To avoid any miscalculations, use `self.log(..., batch_size=batch_size)`.\n",
      "Epoch 0, global step 264: 'val_CER_card' reached 0.11933 (best 0.11933), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-v2.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9ecca7007c684310881b188aaa094b46",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1, global step 528: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "52825e158b1541c888b789857b32c45e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2, global step 792: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a5591688a3184d4e8b8a92468653100b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3, global step 1056: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3f40f38dd2ae4e45a8be48f109fb65dc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4, global step 1320: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "155f7f0c57bc4a6cbeae1cd2380e5514",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5, global step 1584: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "edaea367a5c94462b448380aab112669",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6, global step 1848: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3509b20afd2d4640933fa53c5423a413",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7, global step 2112: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e34924120ca54dd6800a32aa29991ea3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8, global step 2376: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b7af8cd4e7104b1ba14a0a241a1256b4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9, global step 2640: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7df32cb39bd4f34b2f538c680b0de23",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10, global step 2904: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e43e98fb78bb41e8875ea068f05c463a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11, global step 3168: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab867f28709746f9bb19a1bdef2860e9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12, global step 3432: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f03c866b56744ce0b348d7a182ca2ce6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13, global step 3696: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6442dabba36e4f25ab87c8672ccc55ef",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14, global step 3960: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a575fd87daab4ea28c4368c2b48f3a33",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15, global step 4224: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "16833c0b0936433f89fe7d5ad9f5f6fc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16, global step 4488: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "593aa0136e534fd0aeec0e2c18c7ec52",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17, global step 4752: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a18b17e2d8a9444e841dd892397d74e3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18, global step 5016: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "718e754a8aed4c27bdec9a928d579646",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19, global step 5280: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0845277ea38b494991e5ff1328997f9c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20, global step 5544: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00914d2240c745d19619af9eeb793099",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21, global step 5808: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a6d98320c5004a2a9b003ea08c1f3f8e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22, global step 6072: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3de31bca39494354b7592348c7853f78",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23, global step 6336: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "626c79353ec54ca7ae692dc2f7e1b3fe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24, global step 6600: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ac8a6bae37e047879e46d7b0e2897e8b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25, global step 6864: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "65b0df14a7b14616b972756bc405221e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26, global step 7128: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0824d9a4db2948acb7e551e08f1194d5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27, global step 7392: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2fa3185a53ee403c95211fef38a0ca20",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28, global step 7656: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6108bf523220431ebd132aa217b9032f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29, global step 7920: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f860a51e53ae4e849b96e45af99d6e32",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30, global step 8184: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d6249d513e04277852c81be35213d8e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31, global step 8448: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "95aaa54c30224169b7c841d32e815f63",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32, global step 8712: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5028f8bc528b4951b6f1096774c66899",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33, global step 8976: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d2f7905427f4e4fb30c64ce642c1780",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34, global step 9240: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "79b184ab53ce442da2e00a7c3c57a6db",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35, global step 9504: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1f34dd95bff240a889881201779a1b4d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36, global step 9768: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ca85672637134a23844ebd1db7e724de",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37, global step 10032: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bd021fcf1c384d67a17ff4ba113a18d3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38, global step 10296: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "094b355e59964ec68e9d6a3eb5849b20",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39, global step 10560: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0d57834d0a4e4c15b3ae49d9bbe803ed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40, global step 10824: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cf7d028ad65540ecbd1f60fc6b78600d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41, global step 11088: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62ebd479a28848e69acca37852b0fb52",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42, global step 11352: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e1bce4a5b4d54ef1930d583dcb8c4d84",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43, global step 11616: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "22e7882f0adc4192a1615fab77047515",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44, global step 11880: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "84338c42bae14e3b974d2b5156c7a928",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45, global step 12144: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "14edede767174e2d9630c5f23e00221b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46, global step 12408: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ab792dbe077c44e0b0c2e97fb55e674d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47, global step 12672: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c253c25d51ff4522a7ef6fa6bb03353f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48, global step 12936: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8806cc2f48a243bea034a4ff26678e3f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49, global step 13200: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3c0149da0cbf46b084f9f67f4781b016",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50, global step 13464: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cec3c1c8b4a14e06bc9d12975919c02e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 51, global step 13728: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4af8227ea07e4c7e8e58a6119bb4176a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 52, global step 13992: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4dd089eb39e9426e9d85d6a8c3ed4cba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 53, global step 14256: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4671e53f38b846bf8bf9af77cb33b138",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 54, global step 14520: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4f0a4eda161b4f18a0c9bd1573a07c7f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 55, global step 14784: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c0559843ebde4d61b0baab478c587984",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 56, global step 15048: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1249f761868f4660be0eb2d2388caeb9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 57, global step 15312: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d86a3ad6f1fd47889774ef47d54ea473",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 58, global step 15576: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aebc087b6f9a47fd865a761d85b965d0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 59, global step 15840: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bf3aad2f3a0d43ba94975bbf32df5d03",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 60, global step 16104: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2358bbe0bde247ddad76e6a66b73076e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 61, global step 16368: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3323fc3eb750461194a737dd4d8f2dda",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 62, global step 16632: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "03bf2f79ca4649709abc5012326de041",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 63, global step 16896: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "17c2e4bf36914df99b0bbeefd43a7005",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 64, global step 17160: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "37a63ece2ede47b9927ca38c9e0e1c12",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 65, global step 17424: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0fe7d25951e14cd7ac6a501d8901426d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 66, global step 17688: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "94fb93f647734d91afb1cb5fe2f35d62",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 67, global step 17952: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "23b4936ea1384e598c82c8974b6bd96e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 68, global step 18216: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "336a02314b5245509f226a00af28bdc5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 69, global step 18480: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "44406368998c4525955d7a836ae52169",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 70, global step 18744: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d2574a902d6b4e83876bf6e03ddd2192",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 71, global step 19008: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "02162a4c90de404482429b44a7dd8b1c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 72, global step 19272: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "126e4f4c6a154bffbb4e3bbab5e47056",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 73, global step 19536: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86980152afe7458b9e0185b29158b8a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 74, global step 19800: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4c2e08798907422f853218a24a4be013",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 75, global step 20064: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e7ed18b2c3994551b19472c0197d4a44",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 76, global step 20328: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9a24e0ca38564209a550035e8eec04bb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 77, global step 20592: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d2adbcecb02d40feb2a4448c8d688b6b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 78, global step 20856: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c70a51db6977473cb3dafa6452ea6a91",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 79, global step 21120: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6547add8c9f74914803b3eadce52ed20",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 80, global step 21384: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ded05d490e304b0d9aede43f0a28bab8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 81, global step 21648: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3424b76d781d49868c93867db9d823c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 82, global step 21912: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e128ef0562ba46678da901b65fcf1132",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 83, global step 22176: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a00fb00ca7e54547937a30572bdaab95",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 84, global step 22440: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e625e0f51ac847d78f14b74f3e67a1ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 85, global step 22704: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "784a543df3f048249f1c7318d1764d8a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 86, global step 22968: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e1b5f66e5f13406e9fb76cec68d561e4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 87, global step 23232: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f8495279612246a3afc37f2edb393aff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 88, global step 23496: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2f986b36796b4a1492069b5bef55dea7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 89, global step 23760: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "57f0b08c142a4f459a45c8036064f0f5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 90, global step 24024: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9fee4ac1704f4e32b19c799ebc17ce55",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 91, global step 24288: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fca21939a7e6447096959eddc2985fb1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 92, global step 24552: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "467e4d14d276434081d17d5c217c4d07",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 93, global step 24816: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0eca6450bb064430874817b6acdd340d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 94, global step 25080: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fcf2cd0f2f0349c49084a68d7133b301",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 95, global step 25344: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "77ae800603be450dbad2f3c823ca1999",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 96, global step 25608: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1f215867d8074b848d969f8d16409679",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 97, global step 25872: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5c2e4b8796ef44c8a620770534dc8bb3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 98, global step 26136: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "17aa8452b29a436683377e14d9f3b8d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 99, global step 26400: 'val_CER_card' was not in top 1\n",
      "`Trainer.fit` stopped: `max_epochs=100` reached.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n",
      "\n",
      "  | Name     | Type                         | Params | Mode \n",
      "------------------------------------------------------------------\n",
      "0 | model    | CrossGRUDecoderMLPProjection | 51.0 M | train\n",
      "1 | ctc_loss | CTCLoss                      | 0      | train\n",
      "------------------------------------------------------------------\n",
      "51.0 M    Trainable params\n",
      "0         Non-trainable params\n",
      "51.0 M    Total params\n",
      "204.057   Total estimated model params size (MB)\n",
      "15        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warmup training complete.\n",
      "Unfroze all parameters.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d60b6d04ecce4c8695a3549d8ce5b7ad",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=100` reached.\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<br>    <style><br>        .wandb-row {<br>            display: flex;<br>            flex-direction: row;<br>            flex-wrap: wrap;<br>            justify-content: flex-start;<br>            width: 100%;<br>        }<br>        .wandb-col {<br>            display: flex;<br>            flex-direction: column;<br>            flex-basis: 100%;<br>            flex: 1;<br>            padding: 10px;<br>        }<br>    </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▁▁▁▂▂▂▂▂▃▄▄▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇████</td></tr><tr><td>train_grad_norm</td><td>█▇▄▅▂▁▁▂▁▆▃▅▂▃▃▅▄▂▃▄▂▄▄▅▄▄▇▄▃▆▃▂▅▅▃▄▃▄▃▆</td></tr><tr><td>train_loss_epoch</td><td>█▇▇▇▆▅▅▄▄▄▄▃▃▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>train_loss_step</td><td>▇▇█▅▅▆▇▅▅▂▆▃▄▄▃▃▂▂▄▃▃▃▂▁▂▄▃▃▃▅▃▃▄▂▂▃▂▄▂▂</td></tr><tr><td>trainer/global_step</td><td>▁▁▁▁▁▂▂▃▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇████</td></tr><tr><td>val_CER_card</td><td>▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_CER_willet</td><td>████▇▇▆▆▅▄▄▃▃▃▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_loss</td><td>█▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>99</td></tr><tr><td>train_grad_norm</td><td>14.94493</td></tr><tr><td>train_loss_epoch</td><td>1.31673</td></tr><tr><td>train_loss_step</td><td>1.08189</td></tr><tr><td>trainer/global_step</td><td>26399</td></tr><tr><td>val_CER_card</td><td>0.11933</td></tr><tr><td>val_CER_willet</td><td>0.57377</td></tr><tr><td>val_loss</td><td>1.68425</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">CrossDatasetGRU_warmup</strong> at: <a href='https://wandb.ai/matteoferrante/B2TXT25/runs/pnl3af5m' target=\"_blank\">https://wandb.ai/matteoferrante/B2TXT25/runs/pnl3af5m</a><br> View project at: <a href='https://wandb.ai/matteoferrante/B2TXT25' target=\"_blank\">https://wandb.ai/matteoferrante/B2TXT25</a><br>Synced 6 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20250902_172407-pnl3af5m/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "wandb_logger = WandbLogger(project=\"B2TXT25\", name=f\"{output_name}\",\n",
    "                            reinit=True)\n",
    "\n",
    "# Define ModelCheckpoint to save the best model based on validation loss\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    monitor=\"val_CER_card\",  # Ensure your validation step logs \"val_loss\"\n",
    "    mode=\"min\",          # Save the model with the lowest validation loss\n",
    "    save_top_k=1,        # Keep only the best model\n",
    "    dirpath=f\".checkpoints/{output_name}/\",  # Directory to save checkpoints\n",
    "    filename=f\"best_model\",  # Model filename\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "# Define EarlyStopping callback with patience of 3 epochs\n",
    "early_stopping_callback = EarlyStopping(\n",
    "    monitor=\"val_CER_card\",\n",
    "    patience=5,   # Stop training if no improvement in 3 epochs\n",
    "    mode=\"min\",\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "\n",
    "# Train model\n",
    "trainer = pl.Trainer(max_epochs=100,devices =[0], callbacks=[checkpoint_callback], logger=wandb_logger)\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)\n",
    "\n",
    "print(\"Warmup training complete.\")\n",
    "\n",
    "\n",
    "#reload state_dict of best model\n",
    "# model.load_state_dict(torch.load(f\".checkpoints/mfcc_sm_gru_ctc/best_model.ckpt\")[\"state_dict\"])\n",
    "\n",
    "\n",
    "# close wandb logger\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "bacc0ae3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Unfroze all parameters.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.21.1"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>./wandb/run-20250902_184958-rw8uy7vc</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href='https://wandb.ai/matteoferrante/B2TXT25/runs/rw8uy7vc' target=\"_blank\">CrossDatasetGRU_warmup-finetune</a></strong> to <a href='https://wandb.ai/matteoferrante/B2TXT25' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/developer-guide' target=\"_blank\">docs</a>)<br>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View project at <a href='https://wandb.ai/matteoferrante/B2TXT25' target=\"_blank\">https://wandb.ai/matteoferrante/B2TXT25</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run at <a href='https://wandb.ai/matteoferrante/B2TXT25/runs/rw8uy7vc' target=\"_blank\">https://wandb.ai/matteoferrante/B2TXT25/runs/rw8uy7vc</a>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matteo/anaconda3/envs/b2txt25/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:701: Checkpoint directory /data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup exists and is not empty.\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n",
      "\n",
      "  | Name     | Type                         | Params | Mode \n",
      "------------------------------------------------------------------\n",
      "0 | model    | CrossGRUDecoderMLPProjection | 51.0 M | train\n",
      "1 | ctc_loss | CTCLoss                      | 0      | train\n",
      "------------------------------------------------------------------\n",
      "51.0 M    Trainable params\n",
      "0         Non-trainable params\n",
      "51.0 M    Total params\n",
      "204.057   Total estimated model params size (MB)\n",
      "15        Modules in train mode\n",
      "0         Modules in eval mode\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "16af7c9676c74e50be9c62462d54a077",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Sanity Checking: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/matteo/anaconda3/envs/b2txt25/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=223` in the `DataLoader` to improve performance.\n",
      "/home/matteo/anaconda3/envs/b2txt25/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=223` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a51bf2803503483c89fc7477190f8eed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Training: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "32cd9ab49a8f40018a35bc069aba2efd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 0, global step 264: 'val_CER_card' reached 0.16727 (best 0.16727), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "960dbff41bc4489d93e8bf9ad1a7936a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 1, global step 528: 'val_CER_card' reached 0.16217 (best 0.16217), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "986d9890f4c24653bef975e50d78c9a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 2, global step 792: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "16c4da92bba640ba9ec4f9e686254c64",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 3, global step 1056: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8c0403277b774db09c1d7b974ec55a93",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 4, global step 1320: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "559e3417fdc844a8b810b1ceab610caa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 5, global step 1584: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d12b7a94f99d40ec9198a9d791075174",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 6, global step 1848: 'val_CER_card' reached 0.15642 (best 0.15642), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ea4707b9090647f48df0f5ce7f1e4c63",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 7, global step 2112: 'val_CER_card' reached 0.15231 (best 0.15231), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1fb2e4021f094abaaff061c9a74a3b3b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 8, global step 2376: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1695c48601ce4e56a0574d6355179c8a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 9, global step 2640: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8277c227be79485cb22b7642295eaae6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 10, global step 2904: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "963c296895eb4bd9ba043e3b27d949a5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 11, global step 3168: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a00fc6e7792844238cbe070ac9f44a52",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 12, global step 3432: 'val_CER_card' reached 0.15201 (best 0.15201), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4764bd9a2c7c4e818a0159d07f4b9aa2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 13, global step 3696: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "87beb39ece624799acad5b81e65231d6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 14, global step 3960: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59cfdb7c3a3247b9b9053c1747d36cab",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 15, global step 4224: 'val_CER_card' reached 0.15095 (best 0.15095), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1a824c01f3af48279361690b564a3338",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 16, global step 4488: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "33840f7e2d054d08bf92e90346308955",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 17, global step 4752: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5c871b9093ae4f3fa98ed30bd9d24d33",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 18, global step 5016: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5c94f6f463474410b90ed44606542b54",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 19, global step 5280: 'val_CER_card' reached 0.14955 (best 0.14955), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c82f079bf7784a1a89eef50ab3a22e6a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 20, global step 5544: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62aed6bcb57f4bc0a377e99940bcbe24",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 21, global step 5808: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "90ae49de28714f4cbe0aa5cf047870cf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 22, global step 6072: 'val_CER_card' reached 0.14133 (best 0.14133), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cd51894dd158431196b0c11bc181820e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 23, global step 6336: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "63f26511bb594814a2a1328bcc6f0c96",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 24, global step 6600: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a368b402e71c4fd584c663d489e026fb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 25, global step 6864: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb8a4a669da547d085213537ad032990",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 26, global step 7128: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c1fb4f1cd4ce4169a8c7867055457f16",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 27, global step 7392: 'val_CER_card' reached 0.14017 (best 0.14017), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1270e8918bd94390b93f4524a96be3e2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 28, global step 7656: 'val_CER_card' reached 0.13606 (best 0.13606), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6df15cdde27746cba1d4c85090cdacce",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 29, global step 7920: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9344d42ea74b4ee69c40d10a35fdb63a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 30, global step 8184: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b0f8c35a9d224e809862e744a2d1c022",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 31, global step 8448: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "88b8e2ecfe5448a3ba08b9f7fb3df74a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 32, global step 8712: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b8e203992a2c4a8580a4781c12a305b5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 33, global step 8976: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fc00b8e5eb3a43a8b8da967815ab2235",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 34, global step 9240: 'val_CER_card' reached 0.13558 (best 0.13558), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d225325bd0a44c6d843706af7908525b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 35, global step 9504: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "59b4677e07cd4554ab887864d2e90ca4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 36, global step 9768: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "954a84d0d4464a8da1edb574f52ef796",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 37, global step 10032: 'val_CER_card' reached 0.13512 (best 0.13512), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "68539b4f46f04a5ba4a06c7553035c11",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 38, global step 10296: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c38f7c228bf2490187d555d2e15cf2e6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 39, global step 10560: 'val_CER_card' reached 0.13412 (best 0.13412), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0dd3ee7ecebe49bab7d752b3a3430dfb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 40, global step 10824: 'val_CER_card' reached 0.13215 (best 0.13215), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "99fe20db7d414604a02538bb87468a93",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 41, global step 11088: 'val_CER_card' reached 0.13160 (best 0.13160), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "17aecfa62f134d759daf98f3cd0602a6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 42, global step 11352: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6747cfc41c2848fbb6a7bd2ac54c7c58",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 43, global step 11616: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "92f4069c55cf4b17a63917991419e155",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 44, global step 11880: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "36064f3642224b58b6621406bfe9fa67",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 45, global step 12144: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a31e5b4c4be6411e9a1fb27f423504b8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 46, global step 12408: 'val_CER_card' reached 0.13153 (best 0.13153), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2a915372a9f547238e8b26ea0ce380ea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 47, global step 12672: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "70fbf8921bc5482ab0e1115caa19f449",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 48, global step 12936: 'val_CER_card' reached 0.13035 (best 0.13035), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ca59e1158c0a483c936a7e9adb756501",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 49, global step 13200: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "02e0e8df78914ae49d82744d9e983202",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 50, global step 13464: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aab30e6753cc43da977fe723b57dfbb0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 51, global step 13728: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4ae18c09cdbb436eb949ffa0d4228011",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 52, global step 13992: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5811d93a39e94d07b49792538f259079",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 53, global step 14256: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "223889eb49a44efb9016cb2b6a7e3094",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 54, global step 14520: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "118af7adf00b4661a05c892cbaa1c542",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 55, global step 14784: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bd7349ed45e64c22a21ec3bbb6b6819c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 56, global step 15048: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2fd4dff4bed848b085c11d1ac4a234c3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 57, global step 15312: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c9a8db96a6414c4eb7b36949f0503116",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 58, global step 15576: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7f71a83232f949d8a988db6ee440aada",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 59, global step 15840: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "60a566d48b2c4d30af4caa377294a265",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 60, global step 16104: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2b8e17971cd04d4aaf94b76c4ff6cd37",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 61, global step 16368: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "35ab61bff6f841baade9dee796b1cc35",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 62, global step 16632: 'val_CER_card' reached 0.13033 (best 0.13033), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "499407b13e5a4e4883d21f4b000711bc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 63, global step 16896: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9805f11ca346451eb77d435bee7158cf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 64, global step 17160: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c498b653ad21497faa6332872ffee7d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 65, global step 17424: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4c730d86832746eca8b3c366f972dfdc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 66, global step 17688: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7a8f9ddc2808454a97c91a57dcafd7ba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 67, global step 17952: 'val_CER_card' reached 0.13025 (best 0.13025), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5d3a141de6874c03a4514c8157d5b433",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 68, global step 18216: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6121b47c9e6044858cf87540d919afed",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 69, global step 18480: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "702b64eca8764c03a0c262af716577ba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 70, global step 18744: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e4efed2f2e874e20abeb2f2b4e7920d3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 71, global step 19008: 'val_CER_card' reached 0.13020 (best 0.13020), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3e8efa6428124cae9982fb0a0f1df6fe",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 72, global step 19272: 'val_CER_card' reached 0.13018 (best 0.13018), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e1d09e8452c549bd9d4a6484ff06bdee",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 73, global step 19536: 'val_CER_card' reached 0.13015 (best 0.13015), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2d15c0684fda43dda5cc20cda87a4fd5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 74, global step 19800: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6bdca13d9d304c308d3ca4201c430228",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 75, global step 20064: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f572f29ef966436d9a94471b38398004",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 76, global step 20328: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2b8c90a1c48e46d492a912361b71f59d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 77, global step 20592: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a88af98b06334060ac72f8822d028705",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 78, global step 20856: 'val_CER_card' reached 0.13015 (best 0.13015), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9a4d94507b2547cdb1e1810a1d956658",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 79, global step 21120: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c7ac3854c82d40209b2bbd9ce79b5a61",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 80, global step 21384: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5971964a5a1e49528a3ab8501252017b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 81, global step 21648: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "76cb86ac7fc949d88051a3d90c780995",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 82, global step 21912: 'val_CER_card' reached 0.13010 (best 0.13010), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a0fd8ddf6ef54a98861ee8dfcc880ad6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 83, global step 22176: 'val_CER_card' reached 0.12986 (best 0.12986), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c969e53142d542bdb010ffd7c065b730",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 84, global step 22440: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2d292fa4fdb24d89b3b0f0b517694055",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 85, global step 22704: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0bc4b7112bba4b2d9a0ea7bdd8da019a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 86, global step 22968: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2223664dbd8f40b18a32c7f4f80e2f91",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 87, global step 23232: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d388cb95dc44c8c9323b0f8db8ad13e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 88, global step 23496: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f6cca05659b948dfbe6c451a1d3e108c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 89, global step 23760: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6edaa225e6e04217adf49846f05d754f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 90, global step 24024: 'val_CER_card' reached 0.12985 (best 0.12985), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a19fac3af1ca4b57af370cfc4249d6d6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 91, global step 24288: 'val_CER_card' reached 0.12984 (best 0.12984), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d5d9c576c9db47eb97697c73e3c3a0a2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 92, global step 24552: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "496dcf4e2dc74c3eb12869d5eeb2d326",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 93, global step 24816: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0fbfed10f4bf467f8c0271f62627ece9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 94, global step 25080: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "96c71cac54ab4fa7a950aa8ef5f4938d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 95, global step 25344: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0ba913cbd1aa450897095797dbe3c895",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 96, global step 25608: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7e2808b1f5844404a968eed5d63a8a6d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 97, global step 25872: 'val_CER_card' reached 0.12980 (best 0.12980), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "33d0412830334e28a3d99aad6e372fe2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 98, global step 26136: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "484257deecc94e46b51b04597a2df2ff",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 99, global step 26400: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b2655949fb26467b92074c0d3699a4d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 100, global step 26664: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "182b2a7ea3374dc38bed26c44459ffec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 101, global step 26928: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "73de428b4c264d309087e1c90d099a17",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 102, global step 27192: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "47f9e57eab7c446fa145a7d43830a4eb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 103, global step 27456: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e967d67e7314a17847802446a9faf68",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 104, global step 27720: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "83a5d8ec63f34073b60e3fb384e13f9d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 105, global step 27984: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "686b300165094445aa1a0a6b33316da6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 106, global step 28248: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7192026ef55641c2836f87de54938292",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 107, global step 28512: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "36c333b346bd4476a0b6b7bc220e29a4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 108, global step 28776: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "904c9ae645864736ba9fdb2d5e18cf67",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 109, global step 29040: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d382420f67e74aa38b32c52649213af2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 110, global step 29304: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9d6898ce5f68483f990f2da5c7bd71a8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 111, global step 29568: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "04dd9fc740744952a11260aac1201ac7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 112, global step 29832: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1e83816289614a76b67628a7bf4ebe76",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 113, global step 30096: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "95c82b1adb7b43c6a33c9d2bbfdf884a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 114, global step 30360: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1c2581966d1d46b0a57ac910a928864a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 115, global step 30624: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a3391721ca184f2c82e35fd84cf098db",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 116, global step 30888: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "471239242787442db2cb28f1c3963c90",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 117, global step 31152: 'val_CER_card' reached 0.12974 (best 0.12974), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "4e0733e4af1441c2a17d5c06cdd6c4b2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 118, global step 31416: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3b498408cb8749b79c4d161f8deae80c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 119, global step 31680: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c32cb6e7f02f4afa8fe7439df800ef10",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 120, global step 31944: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9aaa1d9f5b1843e2a82f57a522105b15",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 121, global step 32208: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7212717942f84935826f0ba6d8810e5e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 122, global step 32472: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "745c920729f54671bb8a1d8f3ad997ac",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 123, global step 32736: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2e1dba12c6314717aef46d8a1105a3a5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 124, global step 33000: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "87c7429d79404f5e8282450d2f152dbc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 125, global step 33264: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "26504b59bb09489fb3f1b20989626253",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 126, global step 33528: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f2910461e0b249bb9bebbe407ed7624f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 127, global step 33792: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "09943aaabb964ef7b8b37dfc17416892",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 128, global step 34056: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00ed3cfc387a45ad96e8e08fd35a8d54",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 129, global step 34320: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8183c904b4c94b04b45762dd25ece762",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 130, global step 34584: 'val_CER_card' reached 0.12969 (best 0.12969), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "95d227be78c149bd9c56f1833cede2c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 131, global step 34848: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cb35223cfac645cfb1f57611da036cda",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 132, global step 35112: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b42980db5c4349de94270ca3784b9d83",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 133, global step 35376: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b7f99f5803e14b87bf6b98ea4759f2c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 134, global step 35640: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a330be18db254af6b8270daf267a4840",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 135, global step 35904: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b0d3869508ae4f87959490feb2c0ea2a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 136, global step 36168: 'val_CER_card' reached 0.12957 (best 0.12957), saving model to '/data/data/matteo/nejm-brain-to-text/model_training/cross_dataset_gru/.checkpoints/CrossDatasetGRU_warmup/best_model-finetuned.ckpt' as top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "9ae1f8e7b63b4187be763db38d2e65ba",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 137, global step 36432: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "82f5cccbe80d4655880b457e2de70084",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 138, global step 36696: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "df38ee69d99549f8a25cd8f6be18effd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 139, global step 36960: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c11d4b58af9d4bfbbb083f32b6b52741",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 140, global step 37224: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fc6ae82d3d554963b1a2302e1d38a039",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 141, global step 37488: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62bf05963bc34be2969289f85d1fe9eb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 142, global step 37752: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6a96d285a02f40e1b701cb6c537fcd3f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 143, global step 38016: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "50e718e84e31451d903981f73bd47423",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 144, global step 38280: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aaeb3e05b1174a0f866f0b7d302b6afd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 145, global step 38544: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ffea0938c90c4b87a9517423058707e8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 146, global step 38808: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "429a9d0e174d4affb49672576f794501",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 147, global step 39072: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "24b79da73b2b4135817b39d0dc2ecf3a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 148, global step 39336: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c17f674535e14f7eb84506ff86e3041a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 149, global step 39600: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e44334ed359b4678a31304e46464bc02",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 150, global step 39864: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ec0e3e6566e4999b5dc98086ac28556",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 151, global step 40128: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "da8ed2a127224f0ab13e9ba87a41728b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 152, global step 40392: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e8dde9eca5764d2a982750eb90aba0c9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 153, global step 40656: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "13b44a10060b482fa7922f273ee797e0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 154, global step 40920: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ebe394c6dcfc4fae868319f4f710da68",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 155, global step 41184: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "931647063ba441fabdd75c2939e3e7fa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 156, global step 41448: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "86cb1e2453314d7fbcd3b8ff6ccb515a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 157, global step 41712: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b206243005404edbb9ef10d06d2de79c",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 158, global step 41976: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "43021c227b114f7bbbec784fa60b5def",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 159, global step 42240: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "df664c1ed18649e28d1e915d5ae75909",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 160, global step 42504: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e593f726ef5440a595d6eb4e9abf7869",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 161, global step 42768: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ccf12d5762dd409f8c4e84aea050266b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 162, global step 43032: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bb51872860504d2681a012cb403ce7c1",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 163, global step 43296: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "849d3b5eab924cb2841f916396ff05a9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 164, global step 43560: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "37b520b49e7748fbaa6677b9b737f9b2",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 165, global step 43824: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a05324dab1b34eb8a6ac90a9f72e0d29",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 166, global step 44088: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7e2d54da9b1448b89eb0e93ecbff6b50",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 167, global step 44352: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f2da53e2384d466c924a09578f432d5b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 168, global step 44616: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cbbc4c1547be4d7c9d3cf20f1918d642",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 169, global step 44880: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "47715e178d634190bee037c934cee648",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 170, global step 45144: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8d3fc5f5dc3144a2b60e1f7a0e1eb58b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 171, global step 45408: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5108b3be1e7141e39e0328faea81fcb5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 172, global step 45672: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "04f5194a003c45dba33a9fa396920e34",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 173, global step 45936: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d82828837f9140d89156b19625d14fdd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 174, global step 46200: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a39a49a5d1fc41aba4ece7f1da80103e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 175, global step 46464: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "62b4de753647481e8549466bec36c118",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 176, global step 46728: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f6939eae058346c08bd1e230003be7f8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 177, global step 46992: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "49379ef477c84b439366bfea458a938b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 178, global step 47256: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "964a68556b6c44138ab0e835e53c4874",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 179, global step 47520: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "caf728dcc2344e03855e6e8307f0e2a0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 180, global step 47784: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f899440b39c24b36922ba80f9d01af2e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 181, global step 48048: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dd187e6378dd4b19aa6673dd4336e47a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 182, global step 48312: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f4b5e688fb3d47ed8e306ff0ea742890",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 183, global step 48576: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "fd6b26b4d7344d93b5685a3652c06129",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 184, global step 48840: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a21c4796e2c84e869b830efa22469d3d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 185, global step 49104: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "26198e349def40d59569d54a4148046b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 186, global step 49368: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d66c9840e2774a96bec817c5cb0f632a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 187, global step 49632: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e297de561e614bba83a2acdb56562723",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 188, global step 49896: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "00736c5a72354fe4a598e28d283722a3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 189, global step 50160: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "222b7b2406ca41e5a1ac267031532186",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 190, global step 50424: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "33916a0a521644349ce47c06f7e20b92",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 191, global step 50688: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bdff755dd41e46a483d9a806eced5017",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 192, global step 50952: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "de5dc27d0d544e5ba08c3cda3753f422",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 193, global step 51216: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a4946e45b9bf4409bc772c4c4a13c88e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 194, global step 51480: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1d2692148cdc4664a98ff404b6a57567",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 195, global step 51744: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c84170c14f43430398fa24a5437f4295",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 196, global step 52008: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e2e86ea1986049b791d3730868e46bc3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 197, global step 52272: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d1675a7c8cbe4906afe7d3f96b755883",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 198, global step 52536: 'val_CER_card' was not in top 1\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bfc52a83b428499e97ba9096c2ab7e1a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Validation: |          | 0/? [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch 199, global step 52800: 'val_CER_card' was not in top 1\n",
      "`Trainer.fit` stopped: `max_epochs=200` reached.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Actual training complete.\n"
     ]
    },
    {
     "data": {
      "text/html": [],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<br>    <style><br>        .wandb-row {<br>            display: flex;<br>            flex-direction: row;<br>            flex-wrap: wrap;<br>            justify-content: flex-start;<br>            width: 100%;<br>        }<br>        .wandb-col {<br>            display: flex;<br>            flex-direction: column;<br>            flex-basis: 100%;<br>            flex: 1;<br>            padding: 10px;<br>        }<br>    </style><br><div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>▁▂▂▂▂▂▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇███</td></tr><tr><td>train_grad_norm</td><td>▂▂▂▂▂▄▂▁▁▂▃▂▂▂▁▁▂▁▂▂▁▁▁▂▁▁▂▁▂█▇▂▃▆▁▂▁▁▁▂</td></tr><tr><td>train_loss_epoch</td><td>█▇▆▅▅▅▅▄▄▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>train_loss_step</td><td>█▃▃▂▃▂▂▂▂▂▂▁▁▂▁▁▂▂▂▂▁▂▂▂▁▂▂▂▂▂▂▂▁▂▁▂▁▁▂▂</td></tr><tr><td>trainer/global_step</td><td>▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▅▅▅▆▆▆▆▆▆▆▆▆▆▇▇▇▇▇█████</td></tr><tr><td>val_CER_card</td><td>██▅▇▇▅▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_CER_willet</td><td>█▆▅▅▅▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr><tr><td>val_loss</td><td>█▆▅▄▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>epoch</td><td>199</td></tr><tr><td>train_grad_norm</td><td>1.35964</td></tr><tr><td>train_loss_epoch</td><td>0.2098</td></tr><tr><td>train_loss_step</td><td>0.22261</td></tr><tr><td>trainer/global_step</td><td>52799</td></tr><tr><td>val_CER_card</td><td>0.13016</td></tr><tr><td>val_CER_willet</td><td>0.23028</td></tr><tr><td>val_loss</td><td>0.76887</td></tr></table><br/></div></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       " View run <strong style=\"color:#cdcd00\">CrossDatasetGRU_warmup-finetune</strong> at: <a href='https://wandb.ai/matteoferrante/B2TXT25/runs/rw8uy7vc' target=\"_blank\">https://wandb.ai/matteoferrante/B2TXT25/runs/rw8uy7vc</a><br> View project at: <a href='https://wandb.ai/matteoferrante/B2TXT25' target=\"_blank\">https://wandb.ai/matteoferrante/B2TXT25</a><br>Synced 6 W&B file(s), 0 media file(s), 3 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20250902_184958-rw8uy7vc/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "wandb_logger = WandbLogger(project=\"B2TXT25\", name=f\"{output_name}-finetune\",\n",
    "                            reinit=True)\n",
    "# Train model\n",
    "model.unfreeze_all_parameters()\n",
    "\n",
    "# Define ModelCheckpoint to save the best model based on validation loss\n",
    "checkpoint_callback = ModelCheckpoint(\n",
    "    monitor=\"val_CER_card\",  # Ensure your validation step logs \"val_loss\"\n",
    "    mode=\"min\",          # Save the model with the lowest validation loss\n",
    "    save_top_k=1,        # Keep only the best model\n",
    "    dirpath=f\".checkpoints/{output_name}/\",  # Directory to save checkpoints\n",
    "    filename=f\"best_model-finetuned\",  # Model filename\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "# Define EarlyStopping callback with patience of 3 epochs\n",
    "early_stopping_callback = EarlyStopping(\n",
    "    monitor=\"val_CER_card\",\n",
    "    patience=5,   # Stop training if no improvement in 3 epochs\n",
    "    mode=\"min\",\n",
    "    verbose=True\n",
    ")\n",
    "\n",
    "trainer = pl.Trainer(max_epochs=200,devices =[0], callbacks=[checkpoint_callback], logger=wandb_logger,    gradient_clip_val=10.0,   # clip global grad norm\n",
    ")\n",
    "\n",
    "trainer.fit(model, train_loader, val_loader)\n",
    "\n",
    "print(\"Actual training complete.\")\n",
    "# close wandb logger\n",
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4f97fb34",
   "metadata": {},
   "source": [
    "## TODO: Fix warmup + finetuning, change LR and possibly use scheduler"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "f5a52ef5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LightningCrossGRUDecoder_V2(\n",
       "  (model): CrossGRUDecoderMLPProjection(\n",
       "    (day_layer_activation): Softsign()\n",
       "    (day_weights): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (1): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (2): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (3): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (4): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (5): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (6): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (7): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (8): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (9): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (10): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (11): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (12): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (13): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (14): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (15): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (16): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (17): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (18): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (19): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (20): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (21): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (22): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (23): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (24): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (25): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (26): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (27): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (28): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (29): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (30): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (31): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (32): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (33): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (34): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (35): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (36): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (37): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (38): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (39): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (40): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (41): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (42): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (43): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (44): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "    )\n",
       "    (day_biases): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (1): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (2): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (3): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (4): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (5): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (6): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (7): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (8): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (9): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (10): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (11): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (12): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (13): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (14): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (15): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (16): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (17): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (18): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (19): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (20): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (21): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (22): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (23): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (24): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (25): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (26): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (27): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (28): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (29): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (30): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (31): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (32): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (33): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (34): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (35): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (36): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (37): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (38): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (39): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (40): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (41): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (42): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (43): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (44): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "    )\n",
       "    (willet_day_weights): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (1): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (2): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (3): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (4): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (5): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (6): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (7): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (8): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (9): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (10): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (11): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (12): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (13): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (14): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (15): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (16): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (17): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (18): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (19): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (20): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (21): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (22): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (23): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "    )\n",
       "    (willet_day_biases): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (1): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (2): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (3): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (4): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (5): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (6): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (7): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (8): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (9): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (10): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (11): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (12): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (13): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (14): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (15): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (16): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (17): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (18): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (19): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (20): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (21): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (22): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (23): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "    )\n",
       "    (day_layer_dropout): Dropout(p=0.2, inplace=False)\n",
       "    (willet_projection): Sequential(\n",
       "      (0): Linear(in_features=256, out_features=512, bias=True)\n",
       "      (1): ReLU()\n",
       "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
       "      (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (gru): GRU(7168, 768, num_layers=5, batch_first=True, dropout=0.4)\n",
       "    (out): Linear(in_features=768, out_features=41, bias=True)\n",
       "  )\n",
       "  (ctc_loss): CTCLoss()\n",
       ")"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "6b7ec5e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warmup trained\n"
     ]
    },
    {
     "ename": "SyntaxError",
     "evalue": "'break' outside loop (125874167.py, line 2)",
     "output_type": "error",
     "traceback": [
      "\u001b[0;36m  Cell \u001b[0;32mIn[26], line 2\u001b[0;36m\u001b[0m\n\u001b[0;31m    break\u001b[0m\n\u001b[0m    ^\u001b[0m\n\u001b[0;31mSyntaxError\u001b[0m\u001b[0;31m:\u001b[0m 'break' outside loop\n"
     ]
    }
   ],
   "source": [
    "print(\"Warmup trained\")\n",
    "break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "3a67240c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LightningCrossGRUDecoder_V2(\n",
       "  (model): CrossGRUDecoderMLPProjection(\n",
       "    (day_layer_activation): Softsign()\n",
       "    (day_weights): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (1): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (2): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (3): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (4): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (5): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (6): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (7): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (8): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (9): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (10): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (11): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (12): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (13): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (14): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (15): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (16): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (17): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (18): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (19): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (20): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (21): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (22): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (23): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (24): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (25): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (26): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (27): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (28): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (29): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (30): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (31): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (32): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (33): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (34): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (35): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (36): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (37): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (38): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (39): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (40): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (41): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (42): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (43): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (44): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "    )\n",
       "    (day_biases): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (1): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (2): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (3): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (4): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (5): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (6): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (7): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (8): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (9): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (10): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (11): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (12): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (13): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (14): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (15): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (16): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (17): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (18): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (19): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (20): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (21): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (22): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (23): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (24): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (25): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (26): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (27): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (28): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (29): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (30): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (31): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (32): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (33): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (34): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (35): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (36): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (37): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (38): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (39): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (40): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (41): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (42): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (43): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (44): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "    )\n",
       "    (willet_day_weights): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (1): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (2): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (3): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (4): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (5): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (6): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (7): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (8): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (9): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (10): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (11): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (12): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (13): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (14): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (15): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (16): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (17): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (18): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (19): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (20): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (21): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (22): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (23): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "    )\n",
       "    (willet_day_biases): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (1): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (2): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (3): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (4): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (5): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (6): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (7): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (8): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (9): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (10): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (11): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (12): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (13): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (14): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (15): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (16): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (17): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (18): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (19): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (20): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (21): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (22): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (23): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "    )\n",
       "    (day_layer_dropout): Dropout(p=0.2, inplace=False)\n",
       "    (willet_projection): Sequential(\n",
       "      (0): Linear(in_features=256, out_features=512, bias=True)\n",
       "      (1): ReLU()\n",
       "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
       "      (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (gru): GRU(7168, 768, num_layers=5, batch_first=True, dropout=0.4)\n",
       "    (out): Linear(in_features=768, out_features=41, bias=True)\n",
       "  )\n",
       "  (ctc_loss): CTCLoss()\n",
       ")"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#load finetuned model\n",
    "checkpoint = torch.load(os.path.join(f\".checkpoints/{output_name}/best_model-finetuned.ckpt\"), weights_only=False,map_location='cpu')\n",
    "# rename keys to not start with \"module.\"(happens if model was saved with DataParallel)\n",
    "for key in list(checkpoint['state_dict'].keys()):\n",
    "    checkpoint['state_dict'][key.replace(\"module.\", \"\")] = checkpoint['state_dict'].pop(key)\n",
    "    checkpoint['state_dict'][key.replace(\"_orig_mod.\", \"\")] = checkpoint['state_dict'].pop(key)\n",
    "model.load_state_dict(checkpoint['state_dict'],strict=False, )  \n",
    "model.to(\"cuda:0\")\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "6335163e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def decode_ctc_output(logits):\n",
    "    \"\"\"\n",
    "    Converts model logits to predicted phoneme sequences.\n",
    "    - Removes repeated phonemes.\n",
    "    - Removes blank tokens (0).\n",
    "    \"\"\"\n",
    "\n",
    "    predictions = torch.argmax(logits, dim=-1)  # Get most probable phoneme indices\n",
    "    predictions = [torch.unique_consecutive(seq[seq != 0]).cpu().numpy() for seq in predictions]  # Remove blanks\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "51eb3db2",
   "metadata": {},
   "source": [
    "## Card"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fa0481b5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "LightningCrossGRUDecoder_V2(\n",
       "  (model): CrossGRUDecoderMLPProjection(\n",
       "    (day_layer_activation): Softsign()\n",
       "    (day_weights): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (1): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (2): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (3): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (4): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (5): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (6): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (7): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (8): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (9): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (10): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (11): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (12): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (13): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (14): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (15): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (16): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (17): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (18): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (19): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (20): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (21): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (22): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (23): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (24): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (25): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (26): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (27): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (28): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (29): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (30): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (31): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (32): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (33): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (34): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (35): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (36): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (37): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (38): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (39): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (40): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (41): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (42): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (43): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (44): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "    )\n",
       "    (day_biases): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (1): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (2): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (3): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (4): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (5): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (6): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (7): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (8): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (9): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (10): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (11): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (12): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (13): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (14): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (15): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (16): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (17): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (18): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (19): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (20): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (21): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (22): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (23): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (24): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (25): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (26): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (27): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (28): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (29): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (30): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (31): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (32): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (33): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (34): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (35): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (36): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (37): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (38): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (39): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (40): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (41): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (42): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (43): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (44): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "    )\n",
       "    (willet_day_weights): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (1): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (2): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (3): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (4): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (5): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (6): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (7): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (8): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (9): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (10): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (11): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (12): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (13): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (14): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (15): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (16): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (17): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (18): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (19): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (20): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (21): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (22): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "        (23): Parameter containing: [torch.float32 of size 512x512 (cuda:0)]\n",
       "    )\n",
       "    (willet_day_biases): ParameterList(\n",
       "        (0): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (1): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (2): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (3): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (4): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (5): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (6): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (7): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (8): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (9): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (10): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (11): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (12): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (13): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (14): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (15): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (16): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (17): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (18): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (19): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (20): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (21): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (22): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "        (23): Parameter containing: [torch.float32 of size 1x512 (cuda:0)]\n",
       "    )\n",
       "    (day_layer_dropout): Dropout(p=0.2, inplace=False)\n",
       "    (willet_projection): Sequential(\n",
       "      (0): Linear(in_features=256, out_features=512, bias=True)\n",
       "      (1): ReLU()\n",
       "      (2): Linear(in_features=512, out_features=512, bias=True)\n",
       "      (3): LayerNorm((512,), eps=1e-05, elementwise_affine=True)\n",
       "    )\n",
       "    (gru): GRU(7168, 768, num_layers=5, batch_first=True, dropout=0.4)\n",
       "    (out): Linear(in_features=768, out_features=41, bias=True)\n",
       "  )\n",
       "  (ctc_loss): CTCLoss()\n",
       ")"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# model.to(\"cuda:0\")\n",
    "# model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "5867a473",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 23/23 [00:03<00:00,  6.87it/s]\n"
     ]
    }
   ],
   "source": [
    "pred_phonemes = []\n",
    "pred_logits = []\n",
    "true_phonemes = []\n",
    "true_sentences = []\n",
    "day_indices = []\n",
    "cer_list = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm.tqdm(val_card_loader):\n",
    "        X = batch[\"neural_features\"]\n",
    "        y = batch[\"seq_class_ids\"]\n",
    "        X_len = batch[\"n_time_steps\"]\n",
    "        y_len = batch[\"seq_lengths\"]\n",
    "        days = batch[\"day\"]\n",
    "        transcriptions = batch[\"sentence_label\"]\n",
    "\n",
    "        # Move data to device\n",
    "        X = X.to(device)\n",
    "        y = y.to(device)\n",
    "\n",
    "        days = days.to(device)\n",
    "        X_len = X_len.to(device)\n",
    "        y_len = y_len.to(device)\n",
    "\n",
    "        logits = model(X,days, card_indices=np.arange(X.shape[0]))\n",
    "        pred = torch.nn.functional.log_softmax(logits, dim=-1).cpu()\n",
    "        pred_logits.append(pred)\n",
    "\n",
    "        total_edit_distance, total_seq_length = 0, 0\n",
    "\n",
    "        for i in range(pred.shape[0]):\n",
    "            decodedSeq = torch.argmax(pred[i, : int(X_len[i] / model.patch_stride), :], dim=-1)\n",
    "            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "\n",
    "            trueSeq = y[i][:y_len[i]].cpu().numpy()\n",
    "            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "            total_edit_distance += matcher.distance()\n",
    "            total_seq_length += len(trueSeq)\n",
    "\n",
    "            cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0\n",
    "            cer_list.append(cer)\n",
    "            \n",
    "        pp = decode_ctc_output(pred)\n",
    "\n",
    "        pred_phonemes.extend(pp)\n",
    "        true_phonemes.extend([y[i][:y_len[i]].cpu().numpy() for i in range(len(y))])\n",
    "        # true_phonemes.extend(y.cpu().numpy())\n",
    "        true_sentences.extend(transcriptions)\n",
    "        day_indices.extend(days.cpu().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "f28d9ea0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "np.float64(0.11543502005955386)"
      ]
     },
     "execution_count": 51,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(cer_list)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8b6ce329",
   "metadata": {},
   "source": [
    "## Willet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "cd2e446d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def ctc_greedy_with_tail_cut(logits, input_lengths, blank_id=0, blank_thresh=0.98, patience=10):\n",
    "    \"\"\"\n",
    "    logits: (B, T', C) BEFORE softmax/log_softmax\n",
    "    input_lengths: (B,) int T' per sample\n",
    "    Returns list[list[int]] predictions.\n",
    "    \"\"\"\n",
    "    import torch\n",
    "    B, T, C = logits.shape\n",
    "    logp = logits.log_softmax(dim=-1)\n",
    "    p_blank = logp[..., blank_id].exp()  # (B, T')\n",
    "    ids = logp.argmax(dim=-1)           # (B, T')\n",
    "\n",
    "    outs = []\n",
    "    for b in range(B):\n",
    "        T_b = int(input_lengths[b])\n",
    "        run_blanks = 0\n",
    "        hyp = []\n",
    "        prev = blank_id\n",
    "        for t in range(T_b):\n",
    "            # early stop if we’ve been confidently blank for a while\n",
    "            if p_blank[b, t].item() >= blank_thresh:\n",
    "                run_blanks += 1\n",
    "                if run_blanks >= patience:\n",
    "                    break\n",
    "            else:\n",
    "                run_blanks = 0\n",
    "\n",
    "            p = int(ids[b, t])\n",
    "            if p != blank_id and p != prev:\n",
    "                hyp.append(p)\n",
    "            prev = p\n",
    "        outs.append(hyp)\n",
    "    return outs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "id": "9da58709",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ctc_greedy_with_tail_cut(pred, X_len)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "id": "7de6586b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 14/14 [00:01<00:00, 11.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Willet test CER: 0.24081580347830553\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "pred_phonemes = []\n",
    "pred_logits = []\n",
    "true_phonemes = []\n",
    "true_sentences = []\n",
    "day_indices = []\n",
    "cer_list = []\n",
    "seq_lens = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm.tqdm(test_willet_loader):\n",
    "        X = batch[\"neural_features\"]\n",
    "        y = batch[\"seq_class_ids\"]\n",
    "        X_len = batch[\"n_time_steps\"]\n",
    "        y_len = batch[\"seq_lengths\"]\n",
    "        days = batch[\"day\"]\n",
    "        transcriptions = batch[\"sentence_label\"]\n",
    "\n",
    "        # Move data to device\n",
    "        X = X.to(device)\n",
    "        y = y.to(device)\n",
    "\n",
    "        days = days.to(device)\n",
    "        X_len = X_len.to(device)\n",
    "        y_len = y_len.to(device)\n",
    "\n",
    "        seq_lens.extend(y_len.cpu().numpy().tolist())\n",
    "\n",
    "        logits = model(X,days, card_indices = None, willet_indices=np.arange(X.shape[0]))\n",
    "        pred = torch.nn.functional.log_softmax(logits, dim=-1).cpu()\n",
    "        pred_logits.append(pred)\n",
    "\n",
    "        total_edit_distance, total_seq_length = 0, 0\n",
    "\n",
    "        if model.patch_size and model.patch_size > 0:\n",
    "            Ti = torch.div(X_len - model.patch_size, model.patch_stride, rounding_mode='floor').clamp_min(0) + 1\n",
    "        else:\n",
    "            Ti = X_len\n",
    "        Ti = Ti.to(torch.int64)  # (B,)\n",
    "\n",
    "        for i in range(pred.shape[0]):\n",
    "            decodedSeq = torch.argmax(pred[i, : int(Ti[i]), :], dim=-1)\n",
    "            decodedSeq = torch.unique_consecutive(decodedSeq, dim=-1)\n",
    "            decodedSeq = decodedSeq[decodedSeq != 0].cpu().numpy()\n",
    "            \n",
    "            trueSeq = y[i][:y_len[i]].cpu().numpy()\n",
    "            matcher = SequenceMatcher(a=trueSeq.tolist(), b=decodedSeq.tolist())\n",
    "            total_edit_distance += matcher.distance()\n",
    "            total_seq_length += len(trueSeq)\n",
    "\n",
    "            cer = total_edit_distance / total_seq_length if total_seq_length > 0 else 1.0\n",
    "            cer_list.append(cer)\n",
    "            \n",
    "        pp = decode_ctc_output(pred)\n",
    "\n",
    "        # pred_phonemes.extend([p[:y_len[i]] for p in pp])\n",
    "        pred_phonemes.extend(pp)\n",
    "        true_phonemes.extend([y[i][:y_len[i]].cpu().numpy() for i in range(len(y))])\n",
    "        # true_phonemes.extend(y.cpu().numpy())\n",
    "        true_sentences.extend(transcriptions)\n",
    "        day_indices.extend(days.cpu().numpy())\n",
    "\n",
    "print(\"Willet test CER:\", np.mean(cer_list))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "id": "4e899d90",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "880"
      ]
     },
     "execution_count": 86,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(seq_lens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "id": "6b6c15cd",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True: ['DH', 'IH', 'S', 'SIL', 'IH', 'Z', 'SIL', 'Y', 'UW', 'N', 'AH', 'F', 'AY', 'IH', 'NG', 'SIL', 'V', 'OY', 'S', 'SIL', 'W', 'IY', 'SIL', 'N', 'IY', 'D', 'SIL']\n",
      "Pred: ['DH', 'IH', 'S', 'SIL', 'IH', 'Z', 'SIL', 'EH', 'N', 'IY', 'EY', 'AH', 'NG', 'SIL', 'W', 'AA', 'S', 'SIL', 'DH', 'IY', 'R', 'SIL', 'N', 'IY', 'D', 'SIL', 'S']\n",
      "Sentence: This is unifying voice we need.\n"
     ]
    }
   ],
   "source": [
    "idx = 120\n",
    "\n",
    "print(\"True:\", idsToPhonemes(true_phonemes[idx]))\n",
    "print(\"Pred:\", idsToPhonemes(pred_phonemes[idx][:seq_lens[idx]]))\n",
    "print(\"Sentence:\", true_sentences[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "40be90ee",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Edit distance: 14\n"
     ]
    }
   ],
   "source": [
    "truseq = idsToPhonemes(true_phonemes[idx])\n",
    "predseq = idsToPhonemes(pred_phonemes[idx])\n",
    "matcher = SequenceMatcher(a=truseq, b=predseq)\n",
    "print(\"Edit distance:\", matcher.distance())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2bcb1683",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "b2txt25",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
