{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9913d975-1792-4a06-b18c-59cc78cf4b01",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using device: cuda\n"
     ]
    }
   ],
   "source": [
    "# In[1]: imports & device setup\n",
    "import h5py\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import random\n",
    "from pathlib import Path\n",
    "from typing import List, Tuple, Dict\n",
    "from tqdm import tqdm\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9304fb9b-1254-4818-8b1d-1b90bcb8d707",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Linear(in_features=1920, out_features=5, bias=True)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = nn.Linear(1920, 5).to(device)\n",
    "model.load_state_dict(torch.load('ckpt_lr0.0005_wd0.0_maxf10.0000.pt', map_location=device)['model_state_dict'])\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2f44522c-67fa-4f7b-b33b-8b10451902cf",
   "metadata": {},
   "outputs": [],
   "source": [
    "initial_dataset_file = h5py.File('/home/jovyan/evo2_experiments/sequences.hdf5', \"r\")\n",
    "dataset_file = h5py.File('/home/jovyan/evo2_experiments/sequences_exon_gene_level_embeddings.h5', \"r\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "6c91c42d-062b-4813-b05b-6b5ca260bfc0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 980/980 [12:01<00:00,  1.36it/s]  \n"
     ]
    }
   ],
   "source": [
    "with h5py.File('sequences_exon_gene_level_evo2_predictions.h5', 'w') as f:\n",
    "    for i in tqdm(range(len(list(dataset_file.keys())))):\n",
    "        sample_name = \"transcript_\" + str(i)\n",
    "        with torch.no_grad():\n",
    "            transcript_preds = model(torch.tensor(np.array(dataset_file[sample_name]['transcript_embeddings'])[:]).to('cuda').to(dtype=torch.float32)).detach().cpu().numpy()\n",
    "            gene_preds = model(torch.tensor(np.array(dataset_file[sample_name]['gene_embeddings'])[:]).to('cuda').to(dtype=torch.float32)).detach().cpu().numpy()\n",
    "\n",
    "        assert len(initial_dataset_file[sample_name].attrs['transcript_seq']) == transcript_preds.shape[1]\n",
    "        assert len(initial_dataset_file[sample_name].attrs['gene_seq']) == gene_preds.shape[1]\n",
    "        assert len(transcript_preds.shape) == 3\n",
    "        assert len(gene_preds.shape) == 3\n",
    "\n",
    "        group = f.create_group(sample_name)\n",
    "        group.create_dataset('transcript_preds', data=transcript_preds, compression='gzip', compression_opts=4) # 4 is default\n",
    "        group.create_dataset('gene_preds', data=gene_preds, compression='gzip', compression_opts=4) # 4 is default\n",
    "        group.attrs['ID'] = initial_dataset_file[sample_name].attrs['ID']\n",
    "        group.attrs['Parent'] = initial_dataset_file[sample_name].attrs['Parent']\n",
    "        group.attrs['chromosome'] = initial_dataset_file[sample_name].attrs['chromosome']\n",
    "        group.attrs['gene_seq'] = initial_dataset_file[sample_name].attrs['gene_seq']\n",
    "        group.attrs['strand'] = initial_dataset_file[sample_name].attrs['strand']\n",
    "        group.attrs['transcript_seq'] = initial_dataset_file[sample_name].attrs['transcript_seq']\n",
    "        group.attrs['type'] = initial_dataset_file[sample_name].attrs['type']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37c2aa8f-5537-4c00-905a-06d3bd9fb69c",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0e74c6b4-6498-40bb-b093-32c01290e61a",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "03fe72b7-e7cf-4527-92c8-cd732fbbf18b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# In[2]: metric‐computation functions (as provided, with minimal adaptation)\n",
    "def find_segments_ones(array: np.ndarray) -> List[Tuple[int,int]]:\n",
    "    ones_idx = np.where(array == 1)[0]\n",
    "    if len(ones_idx) == 0:\n",
    "        return []\n",
    "    split_idx = np.where(np.diff(ones_idx) > 1)[0] + 1\n",
    "    split_ones_idx = np.split(ones_idx, split_idx)\n",
    "    return [(seg[0], seg[-1] + 1) for seg in split_ones_idx]\n",
    "\n",
    "def exon_level(threshold: float, y_labels: np.ndarray, p_labels: np.ndarray, metrics: Dict[str,int]):\n",
    "    y_segs = find_segments_ones((y_labels >= threshold).astype(int))\n",
    "    p_segs = find_segments_ones((p_labels >= threshold).astype(int))\n",
    "    y_set = set(y_segs)\n",
    "    p_set = set(p_segs)\n",
    "    metrics[f'TP_{threshold}'] += len(y_set & p_set)\n",
    "    metrics[f'FP_{threshold}'] += len(p_set - y_set)\n",
    "    metrics[f'FN_{threshold}'] += len(y_set - p_set)\n",
    "\n",
    "def compute_exon_metrics(\n",
    "    all_y: List[np.ndarray],\n",
    "    all_p: List[np.ndarray],\n",
    ") -> Dict[str, float]:\n",
    "    # initialize counters\n",
    "    thresholds = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.95]\n",
    "    counts = {f'{m}_{t}': 0 for t in thresholds for m in ('TP','FP','FN')}\n",
    "    # accumulate over transcripts\n",
    "    y_all = np.concatenate(all_y, axis=0)  # shape (sum_L, 5)\n",
    "    p_all = np.concatenate(all_p, axis=0)  # shape (sum_L, 5)\n",
    "    for t in thresholds:\n",
    "        exon_level(t, y_all[:,1], p_all[:,1], counts)\n",
    "    # compute precision/recall/f1\n",
    "    metrics = {}\n",
    "    for t in thresholds:\n",
    "        tp = counts[f'TP_{t}']; fp = counts[f'FP_{t}']; fn = counts[f'FN_{t}']\n",
    "        rec = tp/(tp+fn) if (tp+fn)>0 else 0.0\n",
    "        prec = tp/(tp+fp) if (tp+fp)>0 else 0.0\n",
    "        f1 = 2*prec*rec/(prec+rec) if (prec+rec)>0 else 0.0\n",
    "        metrics[f'precision_exon_level_{t}'] = prec\n",
    "        metrics[f'recall_exon_level_{t}']    = rec\n",
    "        metrics[f'f1_exon_level_{t}']        = f1\n",
    "    metrics['max_f1_exon_level'] = max(metrics[f'f1_exon_level_{t}'] for t in thresholds)\n",
    "    return metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fdfc83d5-0441-4caa-8466-08457640594e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "→ number of training transcripts:   32385\n",
      "→ number of validation transcripts: 2955\n"
     ]
    }
   ],
   "source": [
    "# In[3]: open HDF5 files and collect transcript‐keys\n",
    "# adjust paths as needed\n",
    "train_labels_path = '/home/jovyan/shares/SR003.nfs2/mane_no_intergenic_combined/mane_transcript_train_dataset_max_exon_cds.hdf5'\n",
    "val_labels_path   = '/home/jovyan/shares/SR003.nfs2/mane_no_intergenic_combined/mane_transcript_val_dataset_max_exon_cds.hdf5'\n",
    "train_emb_path    = './mane_transcript_train_dataset_max_exon_cds_evo2_embeddings_compressed_length_no_greater_32k.h5'\n",
    "val_emb_path      = './mane_transcript_val_dataset_max_exon_cds_evo2_embeddings_compressed_length_no_greater_32k.h5'\n",
    "\n",
    "train_lbl_f = h5py.File(train_labels_path, 'r')\n",
    "val_lbl_f   = h5py.File(val_labels_path,   'r')\n",
    "train_emb_f = h5py.File(train_emb_path,    'r')\n",
    "val_emb_f   = h5py.File(val_emb_path,      'r')\n",
    "\n",
    "# count number of transcripts in each dataset\n",
    "train_keys = list(train_lbl_f.keys())\n",
    "val_keys   = list(val_lbl_f.keys())\n",
    "n_train = len(train_keys)\n",
    "n_val   = len(val_keys)\n",
    "\n",
    "# sanity‐check that embedding files have the same transcripts\n",
    "assert set(train_keys) == set(train_emb_f.keys()), \"Train embed file keys mismatch!\"\n",
    "assert set(val_keys)   == set(val_emb_f.keys()),   \"Val embed file keys mismatch!\"\n",
    "\n",
    "print(f\"→ number of training transcripts:   {n_train}\")\n",
    "print(f\"→ number of validation transcripts: {n_val}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7ed11d16-5060-423c-a3c8-780da45a24c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['transcript_0',\n",
       " 'transcript_1',\n",
       " 'transcript_10',\n",
       " 'transcript_100',\n",
       " 'transcript_1000',\n",
       " 'transcript_10000',\n",
       " 'transcript_10001',\n",
       " 'transcript_10002',\n",
       " 'transcript_10003',\n",
       " 'transcript_10004']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_keys[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "458ca2cf-62b4-4857-a161-fcd814f23423",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(5183, 5)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(train_lbl_f['transcript_1311']['labels_atcg'])[:].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ae709e44-4ce4-4059-a22a-57dac18fee96",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 5183, 1920)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array(train_emb_f['transcript_1311']['embeddings'])[:].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "20cb07ad-4be2-4731-88db-4114f4a1bab3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/jovyan/dnalm/my_saved_conda_envs/gena/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# In[4]: hyperparameters & model/optimizer setup\n",
    "# --- hyperparameters you can tweak ---\n",
    "lr               = 5e-5      # learning rate\n",
    "weight_decay     = 1e-4      # AdamW weight decay\n",
    "num_iterations   = 500000     # total gradient steps\n",
    "val_period       = 1000       # run validation every N iterations\n",
    "val_sample_size  = 1000      # number of transcripts to sample at validation\n",
    "max_seq_length   = 32000      # e.g. 20000 to restrict letters per transcript, or None\n",
    "# ------------------------------------\n",
    "\n",
    "model = nn.Linear(1920, 5).to(device)\n",
    "optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)\n",
    "criterion = nn.BCEWithLogitsLoss()\n",
    "best_val_metric = -1.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e1fc9cf7-2fc8-4517-9828-2918f9e28c4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  1%|          | 2559/500000 [07:17<23:36:05,  5.85it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[9], line 5\u001b[0m\n\u001b[1;32m      2\u001b[0m key \u001b[38;5;241m=\u001b[39m random\u001b[38;5;241m.\u001b[39mchoice(train_keys)\n\u001b[1;32m      3\u001b[0m \u001b[38;5;66;03m# print(key)\u001b[39;00m\n\u001b[1;32m      4\u001b[0m \u001b[38;5;66;03m# y_np = np.array(train_lbl_f[key]['labels_atcg'])[:, :] \u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m x_np \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_emb_f\u001b[49m\u001b[43m[\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m]\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43membeddings\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m[\u001b[38;5;241m0\u001b[39m, :, :]\n",
      "File \u001b[0;32mh5py/_objects.pyx:54\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32mh5py/_objects.pyx:55\u001b[0m, in \u001b[0;36mh5py._objects.with_phil.wrapper\u001b[0;34m()\u001b[0m\n",
      "File \u001b[0;32m~/dnalm/my_saved_conda_envs/gena/lib/python3.9/site-packages/h5py/_hl/dataset.py:1063\u001b[0m, in \u001b[0;36mDataset.__array__\u001b[0;34m(self, dtype)\u001b[0m\n\u001b[1;32m   1060\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msize \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m   1061\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m arr\n\u001b[0;32m-> 1063\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread_direct\u001b[49m\u001b[43m(\u001b[49m\u001b[43marr\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1064\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m arr\n",
      "File \u001b[0;32m~/dnalm/my_saved_conda_envs/gena/lib/python3.9/site-packages/h5py/_hl/dataset.py:1024\u001b[0m, in \u001b[0;36mDataset.read_direct\u001b[0;34m(self, dest, source_sel, dest_sel)\u001b[0m\n\u001b[1;32m   1021\u001b[0m     dest_sel \u001b[38;5;241m=\u001b[39m sel\u001b[38;5;241m.\u001b[39mselect(dest\u001b[38;5;241m.\u001b[39mshape, dest_sel)\n\u001b[1;32m   1023\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m mspace \u001b[38;5;129;01min\u001b[39;00m dest_sel\u001b[38;5;241m.\u001b[39mbroadcast(source_sel\u001b[38;5;241m.\u001b[39marray_shape):\n\u001b[0;32m-> 1024\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mid\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mread\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmspace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfspace\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdest\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdxpl\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dxpl\u001b[49m\u001b[43m)\u001b[49m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "for it in tqdm(range(1, num_iterations+1)):\n",
    "    key = random.choice(train_keys)\n",
    "    # print(key)\n",
    "    # y_np = np.array(train_lbl_f[key]['labels_atcg'])[:, :] \n",
    "    x_np = np.array(train_emb_f[key]['embeddings'])[0, :, :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "46679071-0df8-4e85-bfa3-099911147814",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 1/500000 [00:00<36:20:33,  3.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1998, 1920]) torch.Size([1998, 5])\n",
      "torch.Size([1259, 1920]) torch.Size([1259, 5])\n",
      "torch.Size([3281, 1920]) torch.Size([3281, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 4/500000 [00:00<25:02:08,  5.55it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 5/500000 [00:01<33:10:31,  4.19it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([2640, 1920]) torch.Size([2640, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 9/500000 [00:01<22:31:49,  6.16it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([1711, 1920]) torch.Size([1711, 5])\n",
      "torch.Size([4638, 1920]) torch.Size([4638, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 10/500000 [00:02<30:51:32,  4.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 11/500000 [00:02<37:58:42,  3.66it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 14/500000 [00:03<28:18:18,  4.91it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([5691, 1920]) torch.Size([5691, 5])\n",
      "torch.Size([4794, 1920]) torch.Size([4794, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 17/500000 [00:03<18:57:52,  7.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3904, 1920]) torch.Size([3904, 5])\n",
      "torch.Size([942, 1920]) torch.Size([942, 5])\n",
      "torch.Size([4975, 1920]) torch.Size([4975, 5])\n",
      "torch.Size([11099, 1920]) torch.Size([11099, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 19/500000 [00:03<22:50:39,  6.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([23770, 1920]) torch.Size([23770, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 22/500000 [00:04<20:52:57,  6.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([25413, 1920]) torch.Size([25413, 5])\n",
      "torch.Size([1019, 1920]) torch.Size([1019, 5])\n",
      "torch.Size([7150, 1920]) torch.Size([7150, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 23/500000 [00:04<28:13:08,  4.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([4301, 1920]) torch.Size([4301, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 27/500000 [00:05<22:58:37,  6.04it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([25810, 1920]) torch.Size([25810, 5])\n",
      "torch.Size([3913, 1920]) torch.Size([3913, 5])\n",
      "torch.Size([8049, 1920]) torch.Size([8049, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 29/500000 [00:05<26:43:13,  5.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([31554, 1920]) torch.Size([31554, 5])\n",
      "torch.Size([7123, 1920]) torch.Size([7123, 5])\n",
      "torch.Size([4558, 1920]) torch.Size([4558, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 31/500000 [00:06<24:02:40,  5.78it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([18069, 1920]) torch.Size([18069, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 32/500000 [00:06<27:35:50,  5.03it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([22819, 1920]) torch.Size([22819, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 33/500000 [00:06<34:38:54,  4.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([6037, 1920]) torch.Size([6037, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 35/500000 [00:07<34:10:53,  4.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 36/500000 [00:07<39:39:34,  3.50it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 37/500000 [00:07<42:33:58,  3.26it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([29711, 1920]) torch.Size([29711, 5])\n",
      "torch.Size([1637, 1920]) torch.Size([1637, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 39/500000 [00:08<37:45:06,  3.68it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 40/500000 [00:08<40:30:14,  3.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([26863, 1920]) torch.Size([26863, 5])\n",
      "torch.Size([776, 1920]) torch.Size([776, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 43/500000 [00:09<34:16:44,  4.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([14033, 1920]) torch.Size([14033, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 44/500000 [00:09<31:55:51,  4.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([12382, 1920]) torch.Size([12382, 5])\n",
      "torch.Size([4101, 1920]) torch.Size([4101, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 46/500000 [00:10<32:05:11,  4.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 47/500000 [00:10<37:13:58,  3.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 48/500000 [00:10<41:54:54,  3.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 50/500000 [00:11<37:37:40,  3.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([9299, 1920]) torch.Size([9299, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 51/500000 [00:11<35:24:07,  3.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16693, 1920]) torch.Size([16693, 5])\n",
      "torch.Size([1469, 1920]) torch.Size([1469, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 54/500000 [00:12<29:32:19,  4.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([10773, 1920]) torch.Size([10773, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 55/500000 [00:12<36:13:31,  3.83it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([4757, 1920]) torch.Size([4757, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 57/500000 [00:13<34:40:19,  4.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([4406, 1920]) torch.Size([4406, 5])\n",
      "torch.Size([412, 1920]) torch.Size([412, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 60/500000 [00:13<28:42:30,  4.84it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 62/500000 [00:14<31:32:06,  4.40it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([11519, 1920]) torch.Size([11519, 5])\n",
      "torch.Size([1907, 1920]) torch.Size([1907, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 64/500000 [00:14<25:47:19,  5.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([14494, 1920]) torch.Size([14494, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 65/500000 [00:14<32:09:53,  4.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 66/500000 [00:15<37:15:30,  3.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 67/500000 [00:15<36:39:44,  3.79it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([18149, 1920]) torch.Size([18149, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 68/500000 [00:15<41:34:05,  3.34it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 69/500000 [00:16<44:54:16,  3.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 70/500000 [00:16<47:51:34,  2.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 72/500000 [00:17<42:26:03,  3.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([12120, 1920]) torch.Size([12120, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 74/500000 [00:17<32:57:59,  4.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([11529, 1920]) torch.Size([11529, 5])\n",
      "torch.Size([12482, 1920]) torch.Size([12482, 5])\n",
      "torch.Size([5634, 1920]) torch.Size([5634, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 76/500000 [00:17<27:53:14,  4.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([17474, 1920]) torch.Size([17474, 5])\n",
      "torch.Size([2274, 1920]) torch.Size([2274, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 78/500000 [00:18<27:24:21,  5.07it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([24629, 1920]) torch.Size([24629, 5])\n",
      "torch.Size([4126, 1920]) torch.Size([4126, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 80/500000 [00:18<29:52:25,  4.65it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([4859, 1920]) torch.Size([4859, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 82/500000 [00:18<27:04:40,  5.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([18054, 1920]) torch.Size([18054, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 83/500000 [00:19<33:05:37,  4.20it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([1040, 1920]) torch.Size([1040, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 85/500000 [00:19<32:08:25,  4.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 86/500000 [00:20<37:11:20,  3.73it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 87/500000 [00:20<41:30:21,  3.35it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 90/500000 [00:21<29:34:12,  4.70it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([4544, 1920]) torch.Size([4544, 5])\n",
      "torch.Size([1347, 1920]) torch.Size([1347, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 92/500000 [00:21<27:54:21,  4.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([23970, 1920]) torch.Size([23970, 5])\n",
      "torch.Size([7994, 1920]) torch.Size([7994, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 93/500000 [00:21<35:36:15,  3.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 96/500000 [00:22<27:30:12,  5.05it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([4703, 1920]) torch.Size([4703, 5])\n",
      "torch.Size([5298, 1920]) torch.Size([5298, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 97/500000 [00:22<27:48:35,  4.99it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([15138, 1920]) torch.Size([15138, 5])\n",
      "torch.Size([616, 1920]) torch.Size([616, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 99/500000 [00:23<28:30:56,  4.87it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 100/500000 [00:23<30:26:13,  4.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([20269, 1920]) torch.Size([20269, 5])\n",
      "[train] iter    100 — loss 0.6945\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 102/500000 [00:23<31:14:04,  4.45it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([7118, 1920]) torch.Size([7118, 5])\n",
      "torch.Size([6901, 1920]) torch.Size([6901, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 106/500000 [00:24<21:38:54,  6.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([16379, 1920]) torch.Size([16379, 5])\n",
      "torch.Size([2194, 1920]) torch.Size([2194, 5])\n",
      "torch.Size([10634, 1920]) torch.Size([10634, 5])\n",
      "torch.Size([6629, 1920]) torch.Size([6629, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 108/500000 [00:24<25:50:24,  5.37it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 109/500000 [00:25<32:21:52,  4.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([31792, 1920]) torch.Size([31792, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 111/500000 [00:25<32:23:54,  4.29it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([8500, 1920]) torch.Size([8500, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 112/500000 [00:26<38:36:07,  3.60it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 113/500000 [00:26<43:13:17,  3.21it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 115/500000 [00:27<38:20:20,  3.62it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([30735, 1920]) torch.Size([30735, 5])\n",
      "torch.Size([9996, 1920]) torch.Size([9996, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 116/500000 [00:27<31:26:33,  4.42it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([7222, 1920]) torch.Size([7222, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 118/500000 [00:27<32:51:19,  4.23it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([10392, 1920]) torch.Size([10392, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 119/500000 [00:28<34:54:32,  3.98it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([22191, 1920]) torch.Size([22191, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 120/500000 [00:28<41:41:24,  3.33it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 122/500000 [00:29<40:45:11,  3.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([15359, 1920]) torch.Size([15359, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 124/500000 [00:29<35:41:22,  3.89it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n",
      "torch.Size([7047, 1920]) torch.Size([7047, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 125/500000 [00:29<32:13:50,  4.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([12574, 1920]) torch.Size([12574, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 126/500000 [00:30<33:57:02,  4.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([21810, 1920]) torch.Size([21810, 5])\n",
      "torch.Size([4919, 1920]) torch.Size([4919, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 128/500000 [00:30<33:35:05,  4.13it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([32000, 1920]) torch.Size([32000, 5])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  0%|          | 129/500000 [00:30<38:35:05,  3.60it/s]"
     ]
    }
   ],
   "source": [
    "# In[5]: training + validation loop\n",
    "for it in tqdm(range(1, num_iterations+1)):\n",
    "    # sample one training transcript\n",
    "    key = random.choice(train_keys)\n",
    "    # print(key)\n",
    "    y_np = np.array(train_lbl_f[key]['labels_atcg'])[:, :] \n",
    "    x_np = np.array(train_emb_f[key]['embeddings'])[0, :, :] \n",
    "    # optional length cap\n",
    "    if max_seq_length and x_np.shape[0] >= max_seq_length:\n",
    "        x_np = x_np[:max_seq_length]\n",
    "        y_np = y_np[:max_seq_length]\n",
    "    x = torch.from_numpy(x_np).float().to(device)\n",
    "    y = torch.from_numpy(y_np).float().to(device)\n",
    "\n",
    "    print(x.shape, y.shape)\n",
    "\n",
    "    optimizer.zero_grad()\n",
    "    logits = model(x)            \n",
    "    loss = criterion(logits, y)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if it % 100 == 0:\n",
    "        print(f\"[train] iter {it:6d} — loss {loss.item():.4f}\")\n",
    "\n",
    "    # validation\n",
    "    if it % val_period == 0:\n",
    "        sample = random.sample(val_keys, min(val_sample_size, n_val))\n",
    "        all_y, all_p = [], []\n",
    "        model.eval()\n",
    "        with torch.no_grad():\n",
    "            for k in sample:\n",
    "                yv = val_lbl_f[k]['labels_atcg'][:]\n",
    "                xv = val_emb_f[k]['embeddings'][:]\n",
    "                if max_seq_length and xv.shape[0] > max_seq_length:\n",
    "                    xv = xv[:max_seq_length]\n",
    "                    yv = yv[:max_seq_length]\n",
    "                xv_t = torch.from_numpy(xv).float().to(device)\n",
    "                lv_logits = model(xv_t)\n",
    "                lv_probs  = torch.sigmoid(lv_logits).cpu().numpy()[0, :, :]\n",
    "                all_p.append(lv_probs)\n",
    "                all_y.append(yv)\n",
    "        model.train()\n",
    "\n",
    "        metrics = compute_exon_metrics(all_y, all_p)\n",
    "        mv = metrics['max_f1_exon_level']\n",
    "        print(f\"[val]   iter {it:6d} — max_f1_exon_level {mv:.4f}\")\n",
    "\n",
    "        # --- save metrics to text file ---\n",
    "        with open('validation_metrics.txt', 'a') as f:\n",
    "            # header for this iteration\n",
    "            f.write(f\"Iteration: {it}\\n\")\n",
    "            # write out each metric on its own line\n",
    "            for name, val in metrics.items():\n",
    "                f.write(f\"{name}: {val:.6f}\\n\")\n",
    "            f.write(\"\\n\")  # blank line between records\n",
    "\n",
    "        if mv > best_val_metric:\n",
    "            best_val_metric = mv\n",
    "            ckpt = (\n",
    "                f\"ckpt_lr{lr}_wd{weight_decay}\"\n",
    "                f\"_maxf1{mv:.4f}.pt\"\n",
    "            )\n",
    "            torch.save({\n",
    "                'iteration': it,\n",
    "                'model_state_dict': model.state_dict(),\n",
    "                'optimizer_state_dict': optimizer.state_dict(),\n",
    "                'best_val_metric': best_val_metric,\n",
    "                'hyperparams': {\n",
    "                    'lr': lr,\n",
    "                    'weight_decay': weight_decay,\n",
    "                    'max_seq_length': max_seq_length\n",
    "                }\n",
    "            }, ckpt)\n",
    "            print(f\" *** saved checkpoint: {ckpt}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1be75a5-26b8-41d2-935e-f076d23fceee",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "gena_ipynb",
   "language": "python",
   "name": "gena_ipynb"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
