{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5edb7c20",
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install transformers sentencepiece"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bfec417e",
   "metadata": {},
   "outputs": [],
   "source": [
    "pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0d29e0d9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "from glob import glob\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random\n",
    "import math\n",
    "from tqdm import tqdm\n",
    "from ast import literal_eval\n",
    "import pickle\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import math\n",
    "import time\n",
    "from torch.utils.data import DataLoader\n",
    "import torch.optim as optim\n",
    "from transformers import T5ForConditionalGeneration, T5Tokenizer\n",
    "import os\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import LinearLR, CosineAnnealingLR, SequentialLR\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from transformers.modeling_outputs import BaseModelOutput\n",
    "import os\n",
    "from typing import List, Tuple\n",
    "from glob import glob\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "d4bcc80b",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n"
     ]
    }
   ],
   "source": [
    "# ---------------------\n",
    "# Load T5-Base Model\n",
    "# ---------------------\n",
    "summarizer_model = T5ForConditionalGeneration.from_pretrained(\"t5-base\")\n",
    "tokenizer = T5Tokenizer.from_pretrained(\"t5-base\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c9be291f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Loaded nid2body with 113762 items\n",
      "🧾 Sample NID: N10000\n",
      "📝 Headline: Only FIVE internationals allowed, count em, FIVE! So first off we should say, per our usual Atlanta United lineup predictions, this will be wrong. Why will it be wrong? Well, aside from the obvious, we still don't have a ton of data points from Frank de Boer in how he prefers to rotate his team for \n"
     ]
    }
   ],
   "source": [
    "# Load nid2body from pickle\n",
    "with open(\"nid2body.pkl\", \"rb\") as f:\n",
    "    nid2body = pickle.load(f)\n",
    "\n",
    "# Debug print\n",
    "print(f\"✅ Loaded nid2body with {len(nid2body)} items\")\n",
    "sample_nid = list(nid2body.keys())[0]\n",
    "print(f\"🧾 Sample NID: {sample_nid}\\n📝 Headline: {nid2body[sample_nid][:300]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "08fb8c2d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ Loaded sid2sum with 135001 items\n",
      "🧾 Sample SID: S-1\n",
      "📝 Summary: The officer reportedly also pointed his gun at Harper and her children.\n"
     ]
    }
   ],
   "source": [
    "# Load sid2sum from pickle\n",
    "with open(\"sid2sum.pkl\", \"rb\") as f:\n",
    "    sid2sum = pickle.load(f)\n",
    "\n",
    "# Debug print\n",
    "print(f\"✅ Loaded sid2sum with {len(sid2sum)} items\")\n",
    "sample_sid = list(sid2sum.keys())[0]\n",
    "print(f\"🧾 Sample SID: {sample_sid}\\n📝 Summary: {sid2sum[sample_sid][:300]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9e60dcba",
   "metadata": {},
   "outputs": [],
   "source": [
    "# === Device and Precision Setup ===\n",
    "torch.set_default_dtype(torch.float32)\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "hidden_dim = 768\n",
    "\n",
    "# === Utility Functions ===\n",
    "def get_embedding(key, table, dim):\n",
    "    if key not in table:\n",
    "        table[key] = torch.nn.Parameter(torch.randn(dim, dtype=torch.float32, device=device) * 0.01, requires_grad=True)\n",
    "    return table[key]\n",
    "\n",
    "\n",
    "# === Load Embeddings ===\n",
    "with open(\"summary_T5.pkl\", \"rb\") as f:\n",
    "    summary_embed = {k: torch.tensor(v, dtype=torch.float32, device=device) for k, v in pickle.load(f).items()}\n",
    "with open(\"newsbody_T5.pkl\", \"rb\") as f:\n",
    "    newsbody_embed = {k: torch.tensor(v, dtype=torch.float32, device=device) for k, v in pickle.load(f).items()}\n",
    "with open(\"headline_T5.pkl\", \"rb\") as f:\n",
    "    headline_embed = {k: torch.tensor(v, dtype=torch.float32, device=device) for k, v in pickle.load(f).items()}\n",
    "\n",
    "embed_tables = {\n",
    "    'summary': summary_embed,\n",
    "    'newsbody': newsbody_embed,\n",
    "    'headline': headline_embed\n",
    "}\n",
    "\n",
    "# === Load Dataset ===\n",
    "lookup_df = pd.read_csv(\"w2p_engage_list.csv\").set_index('EdgeID')\n",
    "train_df = pd.read_csv(\"train_df_gold_only.csv\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "5427da34",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 38417/38417 [00:03<00:00, 11439.69it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0    [E1, E2, E3, E4, E5, E6, E7, E8, E9, E10, E11,...\n",
      "1    [E1, E2, E3, E4, E5, E6, E7, E8, E9, E10, E11,...\n",
      "2                                               [E151]\n",
      "3    [E151, E152, E153, E154, E155, E156, E157, E15...\n",
      "4    [E151, E152, E153, E154, E155, E156, E157, E15...\n",
      "Name: EHist, dtype: object\n",
      "Max length + 1: 384\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "tqdm.pandas()  # enable progress_apply\n",
    "\n",
    "# Convert all Bhist strings to actual lists with progress bar\n",
    "train_df['EHist'] = train_df['EHist'].progress_apply(literal_eval)\n",
    "\n",
    "# Check a few entries\n",
    "print(train_df['EHist'].head())\n",
    "\n",
    "# Compute max length\n",
    "max_len = max(len(h) for h in train_df[\"EHist\"])\n",
    "max_len_plus_one = max_len + 1\n",
    "print(\"Max length + 1:\", max_len_plus_one)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b204f75e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Max length: 383\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 38417/38417 [00:00<00:00, 158023.31it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['E1', 'E2', 'E3', 'E4', 'E5', 'E6', 'E7', 'E8', 'E9', 'E10', 'E11', 'E12', 'E13', 'E14', 'E15', 'E16', 'E17', 'E18', 'E19', 'E20', 'E21', 'E22', 'E23', 'E24', 'E25', 'E26', 'E27', 'E28', 'E29', 'E30', 'E31', 'E32', 'E33', 'E34', 'E35', 'E36', 'E37', 'E38', 'E39', 'E40', 'E41', 'E42', 'E43', 'E44', 'E45', 'E46', 'E47', 'E48', 'E49', 'E50', 'E51', 'E52', 'E53', 'E54', 'E55', 'E56', 'E57', 'E58', 'E59', 'E60', 'E61', 'E62', 'E63', 'E64', 'E65', 'E66', 'E67', 'E68', 'E69', 'E70', 'E71', 'E72', 'E73', 'E74', 'E75', 'E76', 'E77', 'E78', 'E79', 'E80', 'E81', 'E82', 'E83', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']\n",
      "383\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "PAD_TOKEN = \"<PAD>\"   # or 0 if you want integer padding later\n",
    "max_len = max(len(h) for h in train_df[\"EHist\"])\n",
    "print(\"Max length:\", max_len)\n",
    "\n",
    "# Compute original lengths before padding\n",
    "train_df[\"EHist_len\"] = train_df[\"EHist\"].apply(len)\n",
    "\n",
    "# Then pad as before\n",
    "PAD_TOKEN = \"<PAD>\"\n",
    "train_df[\"EHist_padded\"] = train_df[\"EHist\"].progress_apply(\n",
    "    lambda h: h + [PAD_TOKEN] * (max_len - len(h))\n",
    ")\n",
    "\n",
    "\n",
    "# Check shapes\n",
    "print(train_df[\"EHist_padded\"].iloc[0])\n",
    "print(len(train_df[\"EHist_padded\"].iloc[0]))  # should equal max_len"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "70f912c0",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Unnamed: 0</th>\n",
       "      <th>UserID</th>\n",
       "      <th>EHist</th>\n",
       "      <th>EPos</th>\n",
       "      <th>EHist_len</th>\n",
       "      <th>EHist_padded</th>\n",
       "      <th>Tail</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>2</td>\n",
       "      <td>U10000_1</td>\n",
       "      <td>[E1, E2, E3, E4, E5, E6, E7, E8, E9, E10, E11,...</td>\n",
       "      <td>E84</td>\n",
       "      <td>83</td>\n",
       "      <td>[E1, E2, E3, E4, E5, E6, E7, E8, E9, E10, E11,...</td>\n",
       "      <td>S-1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>3</td>\n",
       "      <td>U10000_2</td>\n",
       "      <td>[E1, E2, E3, E4, E5, E6, E7, E8, E9, E10, E11,...</td>\n",
       "      <td>E133</td>\n",
       "      <td>132</td>\n",
       "      <td>[E1, E2, E3, E4, E5, E6, E7, E8, E9, E10, E11,...</td>\n",
       "      <td>S-2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>11</td>\n",
       "      <td>U100006_1</td>\n",
       "      <td>[E151]</td>\n",
       "      <td>E152</td>\n",
       "      <td>1</td>\n",
       "      <td>[E151, &lt;PAD&gt;, &lt;PAD&gt;, &lt;PAD&gt;, &lt;PAD&gt;, &lt;PAD&gt;, &lt;PAD...</td>\n",
       "      <td>S-3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>12</td>\n",
       "      <td>U100006_2</td>\n",
       "      <td>[E151, E152, E153, E154, E155, E156, E157, E15...</td>\n",
       "      <td>E168</td>\n",
       "      <td>17</td>\n",
       "      <td>[E151, E152, E153, E154, E155, E156, E157, E15...</td>\n",
       "      <td>S-4</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>13</td>\n",
       "      <td>U100006_3</td>\n",
       "      <td>[E151, E152, E153, E154, E155, E156, E157, E15...</td>\n",
       "      <td>E230</td>\n",
       "      <td>79</td>\n",
       "      <td>[E151, E152, E153, E154, E155, E156, E157, E15...</td>\n",
       "      <td>S-5</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38412</th>\n",
       "      <td>157388</td>\n",
       "      <td>U20775_1</td>\n",
       "      <td>[E1512010, E1512011, E1512012, E1512013, E1512...</td>\n",
       "      <td>E1512050</td>\n",
       "      <td>40</td>\n",
       "      <td>[E1512010, E1512011, E1512012, E1512013, E1512...</td>\n",
       "      <td>S-38704</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38413</th>\n",
       "      <td>157389</td>\n",
       "      <td>U20775_2</td>\n",
       "      <td>[E1512010, E1512011, E1512012, E1512013, E1512...</td>\n",
       "      <td>E1512064</td>\n",
       "      <td>54</td>\n",
       "      <td>[E1512010, E1512011, E1512012, E1512013, E1512...</td>\n",
       "      <td>S-38705</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38414</th>\n",
       "      <td>157390</td>\n",
       "      <td>U20775_3</td>\n",
       "      <td>[E1512010, E1512011, E1512012, E1512013, E1512...</td>\n",
       "      <td>E1512112</td>\n",
       "      <td>102</td>\n",
       "      <td>[E1512010, E1512011, E1512012, E1512013, E1512...</td>\n",
       "      <td>S-38706</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38415</th>\n",
       "      <td>157391</td>\n",
       "      <td>U20775_4</td>\n",
       "      <td>[E1512010, E1512011, E1512012, E1512013, E1512...</td>\n",
       "      <td>E1512115</td>\n",
       "      <td>105</td>\n",
       "      <td>[E1512010, E1512011, E1512012, E1512013, E1512...</td>\n",
       "      <td>S-38707</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>38416</th>\n",
       "      <td>157392</td>\n",
       "      <td>U20775_5</td>\n",
       "      <td>[E1512010, E1512011, E1512012, E1512013, E1512...</td>\n",
       "      <td>E1512133</td>\n",
       "      <td>123</td>\n",
       "      <td>[E1512010, E1512011, E1512012, E1512013, E1512...</td>\n",
       "      <td>S-38708</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>38417 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "       Unnamed: 0     UserID  \\\n",
       "0               2   U10000_1   \n",
       "1               3   U10000_2   \n",
       "2              11  U100006_1   \n",
       "3              12  U100006_2   \n",
       "4              13  U100006_3   \n",
       "...           ...        ...   \n",
       "38412      157388   U20775_1   \n",
       "38413      157389   U20775_2   \n",
       "38414      157390   U20775_3   \n",
       "38415      157391   U20775_4   \n",
       "38416      157392   U20775_5   \n",
       "\n",
       "                                                   EHist      EPos  EHist_len  \\\n",
       "0      [E1, E2, E3, E4, E5, E6, E7, E8, E9, E10, E11,...       E84         83   \n",
       "1      [E1, E2, E3, E4, E5, E6, E7, E8, E9, E10, E11,...      E133        132   \n",
       "2                                                 [E151]      E152          1   \n",
       "3      [E151, E152, E153, E154, E155, E156, E157, E15...      E168         17   \n",
       "4      [E151, E152, E153, E154, E155, E156, E157, E15...      E230         79   \n",
       "...                                                  ...       ...        ...   \n",
       "38412  [E1512010, E1512011, E1512012, E1512013, E1512...  E1512050         40   \n",
       "38413  [E1512010, E1512011, E1512012, E1512013, E1512...  E1512064         54   \n",
       "38414  [E1512010, E1512011, E1512012, E1512013, E1512...  E1512112        102   \n",
       "38415  [E1512010, E1512011, E1512012, E1512013, E1512...  E1512115        105   \n",
       "38416  [E1512010, E1512011, E1512012, E1512013, E1512...  E1512133        123   \n",
       "\n",
       "                                            EHist_padded     Tail  \n",
       "0      [E1, E2, E3, E4, E5, E6, E7, E8, E9, E10, E11,...      S-1  \n",
       "1      [E1, E2, E3, E4, E5, E6, E7, E8, E9, E10, E11,...      S-2  \n",
       "2      [E151, <PAD>, <PAD>, <PAD>, <PAD>, <PAD>, <PAD...      S-3  \n",
       "3      [E151, E152, E153, E154, E155, E156, E157, E15...      S-4  \n",
       "4      [E151, E152, E153, E154, E155, E156, E157, E15...      S-5  \n",
       "...                                                  ...      ...  \n",
       "38412  [E1512010, E1512011, E1512012, E1512013, E1512...  S-38704  \n",
       "38413  [E1512010, E1512011, E1512012, E1512013, E1512...  S-38705  \n",
       "38414  [E1512010, E1512011, E1512012, E1512013, E1512...  S-38706  \n",
       "38415  [E1512010, E1512011, E1512012, E1512013, E1512...  S-38707  \n",
       "38416  [E1512010, E1512011, E1512012, E1512013, E1512...  S-38708  \n",
       "\n",
       "[38417 rows x 7 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d5378608",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ======================================================\n",
    "# BehaviorEncoder (uses padded Bhist + Bhist_len)\n",
    "# ======================================================\n",
    "class BehaviorEncoder(nn.Module):\n",
    "    def __init__(self, hidden_dim, device, max_len, debug=False):\n",
    "        super().__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.device = device\n",
    "        self.debug = debug\n",
    "        self.max_len = max_len\n",
    "\n",
    "        # === Base action vectors ===\n",
    "        self.e_clk = nn.Parameter(torch.tensor([1., 0., 0., 0.], device=device))\n",
    "        self.e_skp = nn.Parameter(torch.tensor([0., 1., 0., 0.], device=device))\n",
    "        self.e_gensumm = nn.Parameter(torch.tensor([0., 0., 1., 0.], device=device))\n",
    "        self.e_sumgen = nn.Parameter(torch.tensor([0., 0., 0., 1.], device=device))\n",
    "\n",
    "        # Action-specific transforms\n",
    "        self.W_clk = nn.Linear(4, hidden_dim, bias=False)\n",
    "        self.W_skp = nn.Linear(4, hidden_dim, bias=False)\n",
    "        self.W_gensumm = nn.Linear(4, hidden_dim, bias=False)\n",
    "        self.W_sumgen = nn.Linear(4, hidden_dim, bias=False)\n",
    "\n",
    "        # State transforms\n",
    "        self.W_pull = nn.Linear(1, hidden_dim, bias=False)\n",
    "        self.W_s = nn.Linear(hidden_dim, hidden_dim, bias=False)\n",
    "        self.W_d = nn.Linear(hidden_dim, hidden_dim, bias=False)\n",
    "\n",
    "        # Fusion\n",
    "        self.Wh = nn.Linear(hidden_dim, hidden_dim, bias=False)\n",
    "        self.Wc = nn.Linear(hidden_dim, hidden_dim, bias=False)\n",
    "        self.Wz = nn.Linear(hidden_dim, 3, bias=False)\n",
    "        self.b_z = nn.Parameter(torch.zeros(3, device=device))\n",
    "        self.W_emb = nn.Linear(hidden_dim, hidden_dim, bias=True)\n",
    "        self.b_emb = nn.Parameter(torch.zeros(hidden_dim, device=device))  # kept as in your design\n",
    "\n",
    "        # Rotation/translation\n",
    "        self.W_angle = nn.Linear(hidden_dim, hidden_dim, bias=False)\n",
    "        self.W_theta = nn.Linear(hidden_dim, 1, bias=False)\n",
    "        self.W_h = nn.Linear(hidden_dim, hidden_dim, bias=False)\n",
    "        self.W_m = nn.Linear(hidden_dim, 1, bias=False)\n",
    "\n",
    "        # Scalars\n",
    "        self.alpha = nn.Parameter(torch.tensor(0.5, device=device))\n",
    "        self.beta = nn.Parameter(torch.tensor(0.5, device=device))\n",
    "\n",
    "        # Classifier head over timesteps (0...max_len+1)\n",
    "        self.classifier = nn.Linear(hidden_dim, max_len + 2)\n",
    "\n",
    "        # Next-step prediction head\n",
    "        self.W_next = nn.Linear(hidden_dim, hidden_dim, bias=False)\n",
    "\n",
    "    def softmin_pool(self, a, b):\n",
    "        return -self.alpha * torch.log(torch.exp(a / self.alpha) +\n",
    "                                       torch.exp(b / self.alpha) + 1e-9)\n",
    "\n",
    "    def forward(self, Bhist, Bhist_len, Bpos, lookup_df, embed_tables):\n",
    "        \"\"\"\n",
    "        Inputs:\n",
    "          - Bhist: padded list/sequence of EIDs (length == self.max_len)\n",
    "          - Bhist_len: original unpadded length (int)\n",
    "          - Bpos: target EID (scalar id)\n",
    "          - lookup_df: DataFrame indexed by EID with columns 'Tail','Head','Relation'\n",
    "          - embed_tables: dict with keys 'newsbody','summary','headline' mapping item-id -> tensor(hidden_dim,)\n",
    "        Outputs:\n",
    "          - eprime_last: final contextualized embedding (for last real token)\n",
    "          - eprime_next: predicted next-step embedding (W_next(eprime_last))\n",
    "          - logits_pos: classifier logits for the next-step (vector of size max_len+2)\n",
    "          - ctx_enc_loss: accumulated stepwise classification loss (sum over real steps)\n",
    "          - total_loss: weighted total loss (0.2 * step-loss + 0.8 * next-pos CE) as you had\n",
    "        Notes:\n",
    "          - This function explicitly iterates only over the first Bhist_len entries (ignoring PADs).\n",
    "        \"\"\"\n",
    "        # ensure Bhist is indexable (list/1-D tensor). If tensor, bring to cpu/int list for lookup indexing\n",
    "        if isinstance(Bhist, torch.Tensor):\n",
    "            Bhist_list = Bhist.detach().cpu().tolist()\n",
    "        else:\n",
    "            Bhist_list = list(Bhist)\n",
    "\n",
    "        # initialize losses on device\n",
    "        ctx_enc_loss = torch.tensor(0., dtype=torch.float32, device=self.device)\n",
    "\n",
    "        # PASS 1: build per-step raw embeddings E_seq for real tokens only\n",
    "        E_seq = []\n",
    "        h_clk = torch.zeros(self.hidden_dim, device=self.device)\n",
    "        h_skp = torch.zeros(self.hidden_dim, device=self.device)\n",
    "        h = torch.zeros(self.hidden_dim, device=self.device)\n",
    "\n",
    "        for t in range(int(Bhist_len)):  # iterate only real tokens\n",
    "            b_id = Bhist_list[t]\n",
    "            # skip if missing in lookup\n",
    "            if b_id not in lookup_df.index:\n",
    "                continue\n",
    "\n",
    "            row = lookup_df.loc[b_id]\n",
    "            tail_id, rel = row['Tail'], row['Relation']\n",
    "\n",
    "            # fetch embeddings (fallback zeros)\n",
    "            d_i = embed_tables['newsbody'].get(tail_id, torch.zeros(self.hidden_dim, device=self.device))\n",
    "            s_i = embed_tables['summary'].get(tail_id, torch.zeros(self.hidden_dim, device=self.device))\n",
    "            d_i_title = embed_tables['headline'].get(tail_id, torch.zeros(self.hidden_dim, device=self.device))\n",
    "\n",
    "            # initialize at t==0\n",
    "            if t == 0:\n",
    "                head_emb = embed_tables['headline'].get(tail_id, torch.zeros(self.hidden_dim, device=self.device))\n",
    "                h_clk, h_skp = head_emb.clone(), head_emb.clone()\n",
    "                # keep original math: W_s(head_emb) * h_clk + (1-W_s(head_emb)) * h_skp\n",
    "                h = torch.tanh(self.W_s(head_emb) * h_clk + (1 - (self.W_s(head_emb))) * h_skp)\n",
    "\n",
    "            # relation-specific context\n",
    "            if rel == \"click\":\n",
    "                c_i = torch.tanh(self.W_clk.weight @ self.e_clk * h) * d_i\n",
    "            elif rel == \"skip\":\n",
    "                d_ip1 = torch.zeros_like(d_i)\n",
    "                if (t + 1) < Bhist_len and Bhist_list[t+1] in lookup_df.index:\n",
    "                    next_row = lookup_df.loc[Bhist_list[t+1]]\n",
    "                    d_ip1 = embed_tables['newsbody'].get(next_row['Head'], torch.zeros_like(d_i))\n",
    "                pull_term = self.W_pull(torch.tensor([[torch.dot(h_clk, d_ip1) + torch.dot(h_skp, d_i)]],\n",
    "                                                     device=self.device)).squeeze(0)\n",
    "                c_i = -torch.tanh(self.W_skp.weight @ self.e_skp + d_i + pull_term * h) * d_i\n",
    "            elif rel == \"gen_summ\":\n",
    "                c_i = torch.tanh(self.W_gensumm.weight @ self.e_gensumm * h) * d_i_title\n",
    "            elif rel == \"summ_gen\":\n",
    "                gate_summgen = torch.tanh(self.W_s(self.W_sumgen.weight @ self.e_sumgen))\n",
    "                c_i = self.softmin_pool(gate_summgen * s_i, (1 - gate_summgen) * d_i)\n",
    "                h_clk = h_clk + self.W_d((torch.ones_like(d_i_title) - d_i_title) * s_i)\n",
    "            else:\n",
    "                c_i = d_i\n",
    "\n",
    "            # hidden update\n",
    "            z_i = torch.tanh(self.Wh(h) + self.Wc(c_i))\n",
    "            p_i = torch.softmax(self.Wz(z_i) + self.b_z, dim=-1)\n",
    "            m_i = p_i[0] * 0.1 + p_i[1] * 0.5 + p_i[2] * 0.9\n",
    "            if rel == \"click\":\n",
    "                h_clk = h_clk + m_i * c_i\n",
    "            elif rel == \"skip\":\n",
    "                h_skp = h_skp * (1 - m_i) + c_i\n",
    "            h = (self.beta * h_clk + (1 - self.beta) * h_skp)\n",
    "\n",
    "            e_i = torch.tanh(self.W_emb(c_i))\n",
    "            E_seq.append(e_i)\n",
    "\n",
    "        # if E_seq is empty (no valid history), return zeros to avoid crash\n",
    "        if len(E_seq) == 0:\n",
    "            zero_e = torch.zeros(self.hidden_dim, device=self.device)\n",
    "            eprime_last = zero_e\n",
    "            eprime_next = self.W_next(eprime_last)\n",
    "            logits_pos = self.classifier(eprime_next.unsqueeze(0))\n",
    "            ctx_enc_loss = torch.tensor(0., device=self.device)\n",
    "            total_loss = torch.tensor(0., device=self.device)\n",
    "            return eprime_last, eprime_next, logits_pos, ctx_enc_loss, total_loss\n",
    "\n",
    "        # PASS 2: contextualize E_seq -> Eprime_seq (only for real tokens)\n",
    "        Eprime_seq = []\n",
    "        eps = 1e-6\n",
    "        for i, e_i in enumerate(E_seq):\n",
    "            if i == 0:\n",
    "                e_prime = e_i\n",
    "            else:\n",
    "                e_prev, e_prime_prev = E_seq[i-1], Eprime_seq[-1]\n",
    "                theta_i = math.pi * torch.tanh(self.W_theta(torch.sigmoid(self.W_angle(e_prime_prev))))\n",
    "                m_i = F.softplus(self.W_m(self.W_h(e_prime_prev)))\n",
    "\n",
    "                # compute orthonormal direction and rotation (as in your original)\n",
    "                v_i = (e_i - e_prime_prev) / (e_i - e_prime_prev).norm(p=2).clamp(min=eps)\n",
    "                u_prev = e_prev / e_prime_prev.norm(p=2).clamp(min=eps)\n",
    "                o_i = (v_i - torch.dot(v_i, u_prev) * u_prev)\n",
    "                o_i = o_i / o_i.norm(p=2).clamp(min=eps)\n",
    "\n",
    "                e_prime = e_prime_prev + m_i * (torch.cos(theta_i) * u_prev + torch.sin(theta_i) * o_i).squeeze()\n",
    "                e_prime = torch.tanh(e_prime) + e_i\n",
    "\n",
    "            Eprime_seq.append(e_prime)\n",
    "\n",
    "            # per-step classification loss only for real steps\n",
    "            target_step = torch.tensor([i], device=self.device)\n",
    "            logits_step = self.classifier(e_prime.unsqueeze(0))\n",
    "            ctx_enc_loss = ctx_enc_loss + F.cross_entropy(logits_step, target_step)\n",
    "\n",
    "        # Final prediction on Bpos (next-step)\n",
    "        eprime_last = Eprime_seq[-1]\n",
    "        eprime_next = self.W_next(eprime_last)\n",
    "        logits_pos = self.classifier(eprime_next.unsqueeze(0))\n",
    "\n",
    "        # target for next-position is Bhist_len (0-based indexing consistent with your design)\n",
    "        target_pos = torch.tensor([int(Bhist_len)], device=self.device)\n",
    "        total_loss = 0.2 * (ctx_enc_loss / max(1, float(Bhist_len))) + 0.8 * F.cross_entropy(logits_pos, target_pos)\n",
    "\n",
    "        return eprime_last, eprime_next, logits_pos, ctx_enc_loss, total_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "5aa1ceed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ======================================================\n",
    "# BehaviorInverseDecoderPredict\n",
    "# ======================================================\n",
    "class BehaviorInverseDecoderPredict(nn.Module):\n",
    "    \"\"\"\n",
    "    Inverse mapping that takes a single predicted e' (embedding for Bpos)\n",
    "    and the document embedding for Bpos, and returns:\n",
    "      - c'_pos (approx pseudo-content)\n",
    "      - s_hat_pos (approx summary)\n",
    "    \"\"\"\n",
    "    def __init__(self, hidden_dim, device, debug=False):\n",
    "        super().__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.device = device\n",
    "        self.debug = debug\n",
    "\n",
    "        # Learnable pseudo-inverse / disentanglers\n",
    "        self.W_emb_pinv = nn.Linear(hidden_dim, hidden_dim, bias=False)   # approx W_emb^+\n",
    "        self.W1_pinv = nn.Linear(hidden_dim, hidden_dim, bias=False)      # remove head contribution\n",
    "        self.W2_pinv = nn.Linear(hidden_dim, hidden_dim, bias=False)      # map residual -> summary\n",
    "\n",
    "    def _show(self, name, tensor, maxlen=6):\n",
    "        if not self.debug:\n",
    "            return\n",
    "        if isinstance(tensor, torch.Tensor):\n",
    "            if tensor.ndim == 0:\n",
    "                print(f\"{name}: {tensor.item():.6f}\")\n",
    "            else:\n",
    "                flat = tensor.detach().cpu().numpy().flatten()\n",
    "                vals = \", \".join(f\"{x:.6f}\" for x in flat[:maxlen])\n",
    "                if len(flat) > maxlen: vals += \", ...\"\n",
    "                print(f\"{name} (shape={tuple(tensor.shape)}): [{vals}]\")\n",
    "        else:\n",
    "            print(f\"{name}: {tensor}\")\n",
    "\n",
    "    @staticmethod\n",
    "    def atanh_safe(x, eps=1e-6):\n",
    "        x = x.clamp(-1+eps, 1-eps)\n",
    "        return 0.5 * torch.log((1+x) / (1-x))\n",
    "\n",
    "    def forward(self, eprime_pos, b_emb, h_pos):\n",
    "        \"\"\"\n",
    "        eprime_pos: tensor (hidden_dim,) -- predicted e' embedding for Bpos (assumed already on device)\n",
    "        b_emb: encoder bias b_emb (tensor (hidden_dim,))\n",
    "        h_pos: head/headline embedding for Bpos (tensor (hidden_dim,))\n",
    "        Returns: c_prime_pos, s_hat_pos  (both normalized)\n",
    "        \"\"\"\n",
    "        # 1) invert embedding nonlinearity: atanh(e') - b_emb\n",
    "        x = self.atanh_safe(eprime_pos) - b_emb  # (hidden_dim,)\n",
    "        c_prime_pos = self.W_emb_pinv(x)         # approx c'_pos\n",
    "        c_prime_pos = F.normalize(c_prime_pos, p=2, dim=0)\n",
    "\n",
    "        # 2) subtract head contribution and map to summary\n",
    "        residual = c_prime_pos - self.W1_pinv(h_pos)\n",
    "        s_hat_pos = self.W2_pinv(residual)\n",
    "        s_hat_pos = F.normalize(s_hat_pos, p=2, dim=0)\n",
    "\n",
    "        return c_prime_pos, s_hat_pos\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "e250e288",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from transformers.modeling_outputs import BaseModelOutput\n",
    "\n",
    "class PersonalizedT5Summarizer(nn.Module):\n",
    "    \"\"\"\n",
    "    TRAIN (TF) returns:\n",
    "      { 'skip', 'total_loss', 'enc_loss', 'gen_loss', 'lm_logits',\n",
    "        'doc_text', 'gold_summary', 'pred_summary_tf' }\n",
    "    EVAL (AR) returns:\n",
    "      { 'pred_summary_ar', 'enc_loss', 'doc_text', 'gold_summary' }\n",
    "\n",
    "    Notes:\n",
    "    - Supports two input modes:\n",
    "      (A) Pre-tokenized (recommended for speed):\n",
    "          pass enc_input_ids, dec_input_ids, dec_labels (all [B,T] or [1,T])\n",
    "      (B) Legacy (fallback): pass Bpos + tokenizer + nid2body/sid2sum, and it will tokenize inside.\n",
    "    - Greedy preview during training is OFF by default; enable every N steps with preview_stride>0.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, hidden_dim, t5_model, behavior_encoder, inverse_decoder, device,\n",
    "                 learnable_ctx=True, unfreeze_last_decoder_blocks=10):\n",
    "        super().__init__()\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.t5 = t5_model                    # don't force eval() here\n",
    "        self.behavior_encoder = behavior_encoder\n",
    "        self.inverse_decoder = inverse_decoder\n",
    "        self.device = device\n",
    "\n",
    "        self.t5_hidden_dim = self.t5.config.d_model\n",
    "\n",
    "        # light projection layers for context injection\n",
    "        self.e_proj = nn.Linear(hidden_dim, self.t5_hidden_dim)\n",
    "        self.q_proj = nn.Linear(self.t5_hidden_dim, self.t5_hidden_dim)\n",
    "        self.k_proj = nn.Linear(hidden_dim, self.t5_hidden_dim)\n",
    "        self.v_proj = nn.Linear(hidden_dim, self.t5_hidden_dim)\n",
    "\n",
    "        if learnable_ctx:\n",
    "            self.ctx_scale_raw = nn.Parameter(torch.tensor(0.1))\n",
    "        else:\n",
    "            self.register_buffer(\"ctx_scale_raw\", torch.tensor(0.1), persistent=False)\n",
    "\n",
    "        # === Freeze all, then selectively unfreeze ===\n",
    "        for p in self.t5.parameters():\n",
    "            p.requires_grad = False\n",
    "\n",
    "        # Unfreeze last N decoder blocks\n",
    "        if hasattr(self.t5, \"decoder\") and hasattr(self.t5.decoder, \"block\"):\n",
    "            if unfreeze_last_decoder_blocks is not None and unfreeze_last_decoder_blocks > 0:\n",
    "                for layer in self.t5.decoder.block[-unfreeze_last_decoder_blocks:]:\n",
    "                    for p in layer.parameters():\n",
    "                        p.requires_grad = True\n",
    "\n",
    "        # Unfreeze final decoder layer norm if present\n",
    "        if hasattr(self.t5, \"decoder\") and hasattr(self.t5.decoder, \"final_layer_norm\"):\n",
    "            self.t5.decoder.final_layer_norm.weight.requires_grad = True\n",
    "\n",
    "        # T5 weight tying: shared <-> lm_head must both be trainable\n",
    "        if hasattr(self.t5, \"shared\"):\n",
    "            self.t5.shared.weight.requires_grad = True\n",
    "        if hasattr(self.t5, \"lm_head\"):\n",
    "            self.t5.lm_head.weight.requires_grad = True\n",
    "\n",
    "    # ---------- helpers ----------\n",
    "    def _behavior_encode(self, Bhist, Bhist_len, Bpos, lookup_df, embed_tables):\n",
    "        e_last, e_next, _, _, enc_loss = self.behavior_encoder(\n",
    "            Bhist, Bhist_len, Bpos, lookup_df, embed_tables\n",
    "        )\n",
    "        head_id = lookup_df.loc[Bpos]['Head']\n",
    "        head_emb = embed_tables['headline'].get(\n",
    "            head_id, torch.zeros(self.hidden_dim, device=self.device)\n",
    "        )\n",
    "        b_emb = getattr(\n",
    "            self.behavior_encoder, \"b_emb\",\n",
    "            torch.zeros(self.hidden_dim, device=self.device)\n",
    "        )\n",
    "        _, s_hat = self.inverse_decoder(e_next, b_emb, head_emb)\n",
    "        return e_last, s_hat, enc_loss\n",
    "\n",
    "    def _encode_query(self, Bpos, lookup_df, nid2body, tokenizer, max_len):\n",
    "        head_id = lookup_df.loc[Bpos]['Head']\n",
    "        tail_id = lookup_df.loc[Bpos]['Tail']\n",
    "        doc_text = nid2body.get(head_id, \"\")\n",
    "        query_text = \"Generate headline: \" + doc_text\n",
    "        # legacy tokenization (slow) — use only if pre-tokenized inputs are not given\n",
    "        query_tokens = tokenizer(\n",
    "            query_text,\n",
    "            return_tensors=\"pt\",\n",
    "            truncation=True,\n",
    "            max_length=max_len,\n",
    "            padding=\"max_length\"\n",
    "        ).input_ids.to(self.device)\n",
    "        enc_outs = self.t5.encoder(input_ids=query_tokens)\n",
    "        base_enc = enc_outs.last_hidden_state  # [1, L, d]\n",
    "        return base_enc, doc_text, tail_id\n",
    "\n",
    "    def _encode_query_with_ids(self, enc_input_ids):\n",
    "        # Fast path: pre-tokenized inputs [B,L] (or [1,L])\n",
    "        enc_outs = self.t5.encoder(input_ids=enc_input_ids)\n",
    "        return enc_outs.last_hidden_state  # [B, L, d]\n",
    "\n",
    "    def _inject_context(self, base_enc, e_last, s_hat):\n",
    "        # gate by user embedding\n",
    "        if e_last.dim() == 1:\n",
    "            e_last_batch = e_last.unsqueeze(0)\n",
    "        else:\n",
    "            e_last_batch = e_last\n",
    "        gate = torch.sigmoid(self.e_proj(e_last_batch) / 10.0)  # [B, d]\n",
    "        gate = gate.unsqueeze(1).expand(-1, base_enc.size(1), -1)  # [B, L, d]\n",
    "        gated_enc = base_enc * gate\n",
    "\n",
    "        # cross-attn from single s_hat vector\n",
    "        Q = self.q_proj(gated_enc)  # [B, L, d]\n",
    "        if s_hat.dim() == 1:\n",
    "            s_hat_batch = s_hat.unsqueeze(0)  # [B=1, h]\n",
    "        else:\n",
    "            s_hat_batch = s_hat                                    # [B, h]\n",
    "        K = self.k_proj(s_hat_batch).unsqueeze(1)  # [B, 1, d]\n",
    "        V = self.v_proj(s_hat_batch).unsqueeze(1)  # [B, 1, d]\n",
    "\n",
    "        attn_scores = torch.matmul(Q / (self.t5_hidden_dim ** 0.5), K.transpose(-1, -2))  # [B, L, 1]\n",
    "        attn_scores = torch.clamp(attn_scores, -10, 10)\n",
    "        attn_wts = F.softmax(attn_scores, dim=1)\n",
    "        ctx = torch.matmul(attn_wts, V)  # [B, L, d]\n",
    "\n",
    "        # cap the scale to avoid exploding context injection\n",
    "        ctx_scale = torch.clamp(F.softplus(self.ctx_scale_raw), max=10.0)\n",
    "        personalized_enc = gated_enc + ctx_scale * ctx\n",
    "\n",
    "        # safety: if NaNs arise, fall back to gated_enc\n",
    "        if not torch.isfinite(personalized_enc).all():\n",
    "            personalized_enc = gated_enc\n",
    "        return personalized_enc\n",
    "\n",
    "    # ---------- public API ----------\n",
    "    def forward(\n",
    "        self,\n",
    "        Bhist,\n",
    "        Bhist_len,\n",
    "        Bpos,\n",
    "        lookup_df,\n",
    "        embed_tables,\n",
    "        sid2sum,\n",
    "        nid2body,\n",
    "        tokenizer,\n",
    "        max_len=50,\n",
    "        mode=None,\n",
    "        gen_kwargs=None,\n",
    "        # NEW: fast path (pre-tokenized inputs)\n",
    "        enc_input_ids=None,   # [B,L] LongTensor on device\n",
    "        dec_input_ids=None,   # [B,T] LongTensor on device (gold[:, :-1])\n",
    "        dec_labels=None,      # [B,T] LongTensor on device (gold[:, 1:])\n",
    "        # NEW: gated preview to avoid step-time explosion\n",
    "        preview_stride: int = 0,\n",
    "        global_step: int | None = None,\n",
    "    ):\n",
    "        \"\"\"\n",
    "        If enc_input_ids/dec_input_ids/dec_labels are provided, the model will use them directly.\n",
    "        Otherwise it will tokenize internally (legacy behavior, slower).\n",
    "        \"\"\"\n",
    "\n",
    "        # Make sure decoder is in train() if we're training\n",
    "        if self.training and hasattr(self.t5, \"decoder\"):\n",
    "            self.t5.decoder.train()\n",
    "\n",
    "        # Behavior/user context\n",
    "        e_last, s_hat, enc_loss = self._behavior_encode(\n",
    "            Bhist, Bhist_len, Bpos, lookup_df, embed_tables\n",
    "        )\n",
    "\n",
    "        # Encoder path\n",
    "        doc_text = \"\"\n",
    "        gold_summary_text = \"\"\n",
    "        if enc_input_ids is not None:\n",
    "            # batched fast path\n",
    "            base_enc = self._encode_query_with_ids(enc_input_ids)         # [B,L,d]\n",
    "            # doc_text and gold_summary_text are not used in this path for loss; keep blank\n",
    "            # gold text still used in logging if you want — pass it separately via caller if needed\n",
    "            tail_id = lookup_df.loc[Bpos]['Tail']\n",
    "            gold_summary_text = sid2sum.get(tail_id, \"\")\n",
    "        else:\n",
    "            # legacy path (tokenize per-sample)\n",
    "            base_enc, doc_text, tail_id = self._encode_query(\n",
    "                Bpos, lookup_df, nid2body, tokenizer, max_len\n",
    "            )\n",
    "            gold_summary_text = sid2sum.get(tail_id, \"\")\n",
    "\n",
    "        # Personalize\n",
    "        # If base_enc is [L,d] (legacy), add batch dim. If [B,L,d], keep as-is.\n",
    "        if base_enc.dim() == 2:\n",
    "            base_enc = base_enc.unsqueeze(0)\n",
    "        personalized_enc = self._inject_context(base_enc, e_last, s_hat)\n",
    "\n",
    "        # Mode\n",
    "        if mode is None:\n",
    "            mode = \"train\" if self.training else \"ar\"\n",
    "        mode = mode.lower()\n",
    "\n",
    "        # -------------------------------\n",
    "        # Teacher Forcing (training)\n",
    "        # -------------------------------\n",
    "        if mode in (\"train\", \"tf\", \"teacher\", \"teacher_forcing\"):\n",
    "            # Guards only in legacy tokenization path\n",
    "            if enc_input_ids is None:\n",
    "                if gold_summary_text is None or gold_summary_text.strip() == \"\":\n",
    "                    return {\n",
    "                        \"skip\": True,\n",
    "                        \"reason\": \"empty_gold\",\n",
    "                        \"total_loss\": torch.tensor(0.0, device=self.device),\n",
    "                        \"enc_loss\": enc_loss,\n",
    "                        \"gen_loss\": torch.tensor(0.0, device=self.device),\n",
    "                        \"lm_logits\": None,\n",
    "                        \"doc_text\": doc_text,\n",
    "                        \"gold_summary\": \"\",\n",
    "                        \"pred_summary_tf\": \"\",\n",
    "                    }\n",
    "\n",
    "                gold_tokens = tokenizer(\n",
    "                    gold_summary_text,\n",
    "                    return_tensors=\"pt\",\n",
    "                    truncation=True,\n",
    "                    max_length=max_len,\n",
    "                    padding=\"max_length\"\n",
    "                ).input_ids.to(self.device)  # [1, L]\n",
    "\n",
    "                dec_input_ids = gold_tokens[:, :-1]       # [1, L-1]\n",
    "                dec_labels    = gold_tokens[:,  1:]       # [1, L-1]\n",
    "            else:\n",
    "                # fast path: assume tensors already provided\n",
    "                assert dec_input_ids is not None and dec_labels is not None, \\\n",
    "                    \"When using enc_input_ids, you must also pass dec_input_ids and dec_labels\"\n",
    "\n",
    "            pad_id = tokenizer.pad_token_id\n",
    "            dec_attn_mask = (dec_input_ids != pad_id).long()                      # [B, T]\n",
    "            enc_attn_mask = torch.ones(personalized_enc.size()[:2],               # [B, L]\n",
    "                                       device=self.device, dtype=torch.long)\n",
    "\n",
    "            # Decoder forward\n",
    "            outputs = self.t5.decoder(\n",
    "                input_ids=dec_input_ids,\n",
    "                attention_mask=dec_attn_mask,\n",
    "                encoder_hidden_states=personalized_enc,\n",
    "                encoder_attention_mask=enc_attn_mask\n",
    "            )\n",
    "            lm_logits = self.t5.lm_head(outputs.last_hidden_state)  # [B, T, V]\n",
    "\n",
    "            # Safety (no hard clamp so grads flow)\n",
    "            if not torch.isfinite(lm_logits).all():\n",
    "                lm_logits = torch.nan_to_num(lm_logits, nan=0.0, posinf=1e4, neginf=-1e4)\n",
    "\n",
    "            gen_loss_fct = nn.CrossEntropyLoss(ignore_index=pad_id)\n",
    "            gen_loss = gen_loss_fct(\n",
    "                lm_logits.view(-1, lm_logits.size(-1)),\n",
    "                dec_labels.view(-1)\n",
    "            )\n",
    "\n",
    "            # Let CE drive gradients; enc_loss detached for stability at start\n",
    "            total_loss = 0.4 * gen_loss + 0.6 * enc_loss\n",
    "\n",
    "            # Optional: cheap preview ONLY every preview_stride steps\n",
    "            do_preview = (preview_stride and global_step is not None and global_step % preview_stride == 0)\n",
    "            pred_summary_tf = \"\"\n",
    "            if do_preview:\n",
    "                with torch.no_grad():\n",
    "                    # Build minimal greedy preview using the same personalized_enc\n",
    "                    start_id = getattr(self.t5.config, \"decoder_start_token_id\", None)\n",
    "                    if start_id is None:\n",
    "                        start_id = getattr(tokenizer, \"pad_token_id\", None)\n",
    "                    if start_id is None:\n",
    "                        start_id = getattr(tokenizer, \"eos_token_id\", 1)\n",
    "                    eos_id = getattr(tokenizer, \"eos_token_id\", None)\n",
    "                    decoder_input = torch.tensor([[start_id]], device=self.device)\n",
    "                    enc_mask = enc_attn_mask[:1]  # [1,L]\n",
    "                    for _ in range(max(1, dec_labels.size(1))):\n",
    "                        dec_out = self.t5.decoder(\n",
    "                            input_ids=decoder_input,\n",
    "                            encoder_hidden_states=personalized_enc[:1],\n",
    "                            encoder_attention_mask=enc_mask\n",
    "                        )\n",
    "                        step_logits = self.t5.lm_head(dec_out.last_hidden_state)[:, -1, :]\n",
    "                        step_logits = torch.clamp(step_logits, -50, 50)\n",
    "                        next_id = torch.argmax(step_logits, dim=-1, keepdim=True)\n",
    "                        decoder_input = torch.cat([decoder_input, next_id], dim=1)\n",
    "                        if eos_id is not None and next_id.item() == eos_id:\n",
    "                            break\n",
    "                    pred_tokens = decoder_input[0, 1:]\n",
    "                    pred_summary_tf = tokenizer.decode(pred_tokens, skip_special_tokens=True)\n",
    "\n",
    "            return {\n",
    "                \"skip\": False,\n",
    "                \"total_loss\": total_loss,\n",
    "                \"enc_loss\": enc_loss,\n",
    "                \"gen_loss\": gen_loss,\n",
    "                \"lm_logits\": lm_logits,\n",
    "                \"doc_text\": doc_text,\n",
    "                \"gold_summary\": gold_summary_text,\n",
    "                \"pred_summary_tf\": pred_summary_tf,\n",
    "            }\n",
    "\n",
    "        # -------------------------------\n",
    "        # Autoregressive (evaluation)\n",
    "        # -------------------------------\n",
    "        elif mode in (\"eval\", \"test\", \"ar\", \"greedy\", \"beam\"):\n",
    "            enc_len = personalized_enc.size(1)\n",
    "            enc_mask = torch.ones((personalized_enc.size(0), enc_len), dtype=torch.long, device=self.device)\n",
    "\n",
    "            # pick start/eos ids safely\n",
    "            start_id = getattr(self.t5.config, \"decoder_start_token_id\", None)\n",
    "            if start_id is None:\n",
    "                start_id = getattr(tokenizer, \"pad_token_id\", None)\n",
    "            if start_id is None:\n",
    "                start_id = getattr(tokenizer, \"eos_token_id\", 1)\n",
    "            eos_id = getattr(tokenizer, \"eos_token_id\", None)\n",
    "\n",
    "            # Start token\n",
    "            decoder_input = torch.tensor([[start_id]], device=self.device)  # [1, 1]\n",
    "\n",
    "            # We will generate up to max_len-1 new tokens\n",
    "            max_new_tokens = max_len - 1\n",
    "\n",
    "            with torch.no_grad():\n",
    "                for _ in range(max_new_tokens):\n",
    "                    dec_out = self.t5.decoder(\n",
    "                        input_ids=decoder_input,\n",
    "                        encoder_hidden_states=personalized_enc[:1],\n",
    "                        encoder_attention_mask=enc_mask[:1]\n",
    "                    )\n",
    "                    lm_logits = self.t5.lm_head(dec_out.last_hidden_state)  # [1, T, V]\n",
    "                    next_token_logits = lm_logits[:, -1, :]                  # [1, V]\n",
    "                    next_token_logits = torch.clamp(next_token_logits, -50, 50)\n",
    "                    next_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)  # [1, 1]\n",
    "\n",
    "                    decoder_input = torch.cat([decoder_input, next_id], dim=1)\n",
    "\n",
    "                    if eos_id is not None and next_id.item() == eos_id:\n",
    "                        break\n",
    "\n",
    "            pred_tokens = decoder_input[0, 1:]\n",
    "            pred_summary_ar = tokenizer.decode(pred_tokens, skip_special_tokens=True)\n",
    "\n",
    "            return {\n",
    "                \"pred_summary_ar\": pred_summary_ar,\n",
    "                \"enc_loss\": enc_loss,\n",
    "                \"doc_text\": doc_text,\n",
    "                \"gold_summary\": gold_summary_text,\n",
    "            }\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "1d9eb67c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🔎 Using checkpoint: checkpoints/epoch_5.pt\n",
      "   epoch=5 | avg_total=2.014425294696192 | enc=2.9589636535152124 | gen=0.5976175675267787\n",
      "ℹ️ load_state: missing=0 unexpected=0\n",
      "🔹 Sampled batch size: 10\n",
      "\n",
      "✅ Teacher-Forcing (TRAIN) Results\n",
      "\n",
      "--- Sample 0 ---\n",
      "Doc:  A tabloid claims Jennifer Lopez won't marry Alex Rodriguez unless he signs a strict prenup with a cheating clause, and it's causing a major fight between the couple. The story is completely untrue. Go\n",
      "Gold: Early retirement can be more than just a daydream for those long Tuesday afternoons at work.\n",
      "TF Pred: The prince and his wife post  topically: How they came together\n",
      "Loss=3.0175 | Enc=3.4529 | Gen=2.3643\n",
      "Gold IDs: [8840, 6576, 54, 36, 72, 145, 131, 3, 9, 239, 26066, 21, 273, 307, 2818, 3742, 7, 44, 161, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [3957, 16, 36, 1137, 145, 3, 3, 9, 710, 26066, 21, 128, 113, 307, 1379, 7, 5, 161, 5, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 3, 3, 1, 1, 3, 3, 1, 1, 3, 1, 3, 1, 1, 3, 3]\n",
      "Gold decoded: Early retirement can be more than just a daydream for those long Tuesday afternoons at work.\n",
      "Pred decoded: signs in be cause than  a shortdream for some who long mornings. work.\n",
      "\n",
      "--- Sample 1 ---\n",
      "Doc:  Photos of milestones in the relationship between the United States and Iran. Supertanker Grace 1 off the coast of Gibraltar on July 6. Iran demanded on July 5 that Britain immediately release the tank\n",
      "Gold: The family claims they weren't aware their daughter had taken the doll until they were in their car and on their way to a nearby apartment complex where their babysitter lived.\n",
      "TF Pred: In the United States, you don'\n",
      "Loss=1.6134 | Enc=2.5167 | Gen=0.2586\n",
      "Gold IDs: [37, 384, 3213, 79, 9355, 31, 17, 2718, 70, 3062, 141, 1026, 8, 14295, 552, 79, 130, 16, 70, 443, 11, 30, 70, 194, 12, 3, 9, 4676, 4579, 1561, 213, 70, 1871, 27734, 4114, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [4279, 3213, 79, 9355, 31, 17, 2718, 70, 3062, 141, 1026, 8, 14295, 552, 79, 130, 16, 70, 443, 11, 30, 70, 194, 12, 3, 9, 4676, 4579, 1561, 213, 70, 1871, 27734, 4114, 5, 1, 1, 1, 3, 3, 312, 970, 1, 1, 1, 1, 17265, 1, 1]\n",
      "Gold decoded: The family claims they weren't aware their daughter had taken the doll until they were in their car and on their way to a nearby apartment complex where their babysitter lived.\n",
      "Pred decoded: shoot claims they weren't aware their daughter had taken the doll until they were in their car and on their way to a nearby apartment complex where their babysitter lived.   Le Duheir\n",
      "\n",
      "--- Sample 2 ---\n",
      "Doc:  Video by CNN WASHINGTON, June 27 (Reuters) - In a major blow to election reformers, the U.S. Supreme Court on Thursday rejected efforts to rein in the contentious practice of manipulating electoral di\n",
      "Gold: \"These three guys were threatening each other for a long time,\" Keith Castro said.\n",
      "TF Pred: \n",
      "Loss=1.7942 | Enc=2.9267 | Gen=0.0955\n",
      "Gold IDs: [96, 20347, 386, 3413, 130, 3, 14390, 284, 119, 21, 3, 9, 307, 97, 976, 17017, 28728, 243, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [1326, 386, 3413, 130, 3, 14390, 284, 119, 21, 3, 9, 307, 97, 976, 17017, 28728, 243, 5, 1, 1, 3, 3, 3, 32099, 3, 3, 3, 32099, 32099, 32099, 0, 32099, 32099, 3, 32099, 32099, 32099, 3, 32099, 32099, 3, 32099, 3, 32099, 32099, 32099, 3, 32099, 32099]\n",
      "Gold decoded: \"These three guys were threatening each other for a long time,\" Keith Castro said.\n",
      "Pred decoded: We three guys were threatening each other for a long time,\" Keith Castro said.\n",
      "\n",
      "--- Sample 3 ---\n",
      "Doc:  Have a special occasion on the horizon? Or do you just feel like having an over-the-top meal that you'll remember forever? From a humble town in Indiana to tony venues in New York and Chicago, these 5\n",
      "Gold: \"The heart is struggling and straining to deliver the oxygen to your body.\"\n",
      "TF Pred: \n",
      "Loss=2.8334 | Enc=4.5369 | Gen=0.2781\n",
      "Gold IDs: [96, 634, 842, 19, 8335, 11, 6035, 53, 12, 2156, 8, 11035, 12, 39, 643, 535, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [196, 842, 19, 8335, 11, 6035, 53, 12, 2156, 8, 11035, 12, 39, 643, 535, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 270, 3, 3, 3, 3, 3]\n",
      "Gold decoded: \"The heart is struggling and straining to deliver the oxygen to your body.\"\n",
      "Pred decoded: I heart is struggling and straining to deliver the oxygen to your body.\"                           here\n",
      "\n",
      "--- Sample 4 ---\n",
      "Doc:  The second night of the first 2020 Democratic presidential debate continued with 10 candidates , including three of the top polling leaders in former Vice President Joe Biden, Sen. Bernie Sanders of V\n",
      "Gold: June 26, 2019 The 60-year-old pop star has not publicly responded to Carter's sentiment.\n",
      "TF Pred: The second night of the first 2020 Democratic presidential debate continued with 10 candidates , including three of the top polling leaders in former Vice President Joe Biden, Sen. Bernie Sanders of V\n",
      "Loss=1.7272 | Enc=2.8479 | Gen=0.0460\n",
      "Gold IDs: [1515, 13597, 1360, 37, 1640, 18, 1201, 18, 1490, 2783, 2213, 65, 59, 11652, 11318, 12, 17080, 31, 7, 6493, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [16047, 1360, 37, 1640, 18, 1201, 18, 1490, 2783, 2213, 65, 59, 11652, 11318, 12, 17080, 31, 7, 6493, 5, 1, 1, 325, 1, 1, 1, 325, 32099, 96, 32099, 96, 1, 96, 32099, 96, 96, 96, 32099, 32099, 1798, 96, 32099, 96, 32099, 96, 96, 96, 32099, 1]\n",
      "Gold decoded: June 26, 2019 The 60-year-old pop star has not publicly responded to Carter's sentiment.\n",
      "Pred decoded: 20, 2019 The 60-year-old pop star has not publicly responded to Carter's sentiment. La La \" \" \" \" \" \" former \" \" \" \" \"\n",
      "\n",
      "--- Sample 5 ---\n",
      "Doc:  Have a special occasion on the horizon? Or do you just feel like having an over-the-top meal that you'll remember forever? From a humble town in Indiana to tony venues in New York and Chicago, these 5\n",
      "Gold: Dave W on TripAdvisor MAINE: THE WHITE BARN INN RESTAURANT Kennebunk Don't let the barn fool you: This restaurant has an AAA Five Diamond Award and five stars from Forbes.\n",
      "TF Pred: \n",
      "Loss=1.5858 | Enc=2.5909 | Gen=0.0781\n",
      "Gold IDs: [8545, 549, 30, 16993, 188, 26, 24680, 4800, 9730, 10, 1853, 3, 15313, 14871, 272, 24947, 27, 17235, 3, 12200, 3221, 5905, 9156, 16267, 115, 6513, 1008, 31, 17, 752, 8, 13754, 18554, 25, 10, 100, 2062, 65, 46, 22656, 9528, 10834, 3677, 11, 874, 4811, 45, 24852, 5, 1]\n",
      "Pred IDs: [549, 30, 16993, 188, 26, 24680, 4800, 9730, 10, 1853, 3, 15313, 14871, 272, 24947, 27, 17235, 3, 12200, 3221, 5905, 9156, 16267, 115, 6513, 1008, 31, 17, 752, 8, 13754, 18554, 25, 10, 100, 2062, 65, 46, 1237, 9528, 10834, 3677, 11, 874, 4811, 45, 24852, 5, 1]\n",
      "Gold decoded: Dave W on TripAdvisor MAINE: THE WHITE BARN INN RESTAURANT Kennebunk Don't let the barn fool you: This restaurant has an AAA Five Diamond Award and five stars from Forbes.\n",
      "Pred decoded: W on TripAdvisor MAINE: THE WHITE BARN INN RESTAURANT Kennebunk Don't let the barn fool you: This restaurant has an amazing Five Diamond Award and five stars from Forbes.\n",
      "\n",
      "--- Sample 6 ---\n",
      "Doc:  SANTA MARIA, Calif. (AP)   A third body has been found in the rubble of a burned-out mobile home in California, bringing to five the number of dead in a shooting and fire that began during an argument\n",
      "Gold: Click through the gallery above to find out when you should just say no.\n",
      "TF Pred: \n",
      "Loss=1.9857 | Enc=3.2929 | Gen=0.0248\n",
      "Gold IDs: [2001, 190, 8, 4865, 756, 12, 253, 91, 116, 25, 225, 131, 497, 150, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [190, 8, 4865, 756, 12, 253, 91, 116, 25, 225, 131, 497, 150, 5, 1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 3, 3, 3, 3, 5, 3, 3, 3, 3, 3, 3]\n",
      "Gold decoded: Click through the gallery above to find out when you should just say no.\n",
      "Pred decoded: through the gallery above to find out when you should just say no.                         .\n",
      "\n",
      "--- Sample 7 ---\n",
      "Doc:  What's an award show without a host? Well, as we learned at the most recent Oscars, sometimes it can be all right. However, a great host makes watching a bunch of famous people get awards that much mo\n",
      "Gold: She made a big deal about changing her dress a whole bunch of times during the night.\n",
      "TF Pred: \n",
      "Loss=2.3542 | Enc=3.3109 | Gen=0.9193\n",
      "Gold IDs: [451, 263, 3, 9, 600, 1154, 81, 2839, 160, 3270, 3, 9, 829, 7292, 13, 648, 383, 8, 706, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [47, 3, 9, 600, 1154, 81, 2839, 160, 3270, 3, 9, 7292, 7292, 13, 151, 383, 8, 706, 5, 1, 1, 3, 3, 3, 3, 1, 3, 5, 3, 3, 1, 3, 3, 3, 3, 3, 3, 3, 3, 22841, 3, 3, 3, 0, 0, 0, 3, 3, 3]\n",
      "Gold decoded: She made a big deal about changing her dress a whole bunch of times during the night.\n",
      "Pred decoded: was a big deal about changing her dress a bunch bunch of people during the night.     .          kleid\n",
      "\n",
      "--- Sample 8 ---\n",
      "Doc:  E. Jean Carroll tore through the doors of the Fifth Avenue entrance of Bergdorf Goodman, her heart racing. Ms. Carroll, a journalist and the host of the \"Ask E. Jean\" television show at the time, had \n",
      "Gold: \"I actually thought this was going to be a 'Travels With Charley,'\" Ms. Carroll said in an interview.\n",
      "TF Pred: \n",
      "Loss=2.6578 | Enc=3.0533 | Gen=2.0646\n",
      "Gold IDs: [96, 196, 700, 816, 48, 47, 352, 12, 36, 3, 9, 3, 31, 9402, 4911, 7, 438, 7435, 1306, 6, 31, 121, 283, 7, 5, 26508, 243, 16, 46, 2772, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [20347, 17, 141, 6, 47, 8, 12, 36, 17, 9, 600, 31, 2394, 4911, 7, 31, 20655, 21538, 31, 31, 121, 17839, 7, 5, 26508, 1219, 16, 3, 2772, 5, 1, 1, 283, 3, 283, 1, 283, 283, 283, 283, 1, 1, 283, 1, 3, 1, 1, 1, 1]\n",
      "Gold decoded: \"I actually thought this was going to be a 'Travels With Charley,'\" Ms. Carroll said in an interview.\n",
      "Pred decoded: Theset had, was the to beta big'90vels' Sandylotte''\" Kirks. Carroll told in  interview. M  M M M M M M\n",
      "\n",
      "--- Sample 9 ---\n",
      "Doc:  Have a special occasion on the horizon? Or do you just feel like having an over-the-top meal that you'll remember forever? From a humble town in Indiana to tony venues in New York and Chicago, these 5\n",
      "Gold: Don't want to spend a fortune?\n",
      "TF Pred: NBC is launching a new look at why the NBA is bringing back its \"Bitch Better Goats\n",
      "Loss=1.7216 | Enc=2.8572 | Gen=0.0182\n",
      "Gold IDs: [1008, 31, 17, 241, 12, 1492, 3, 9, 13462, 58, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [31, 17, 241, 12, 1492, 3, 9, 13462, 58, 1, 1, 3, 3, 3, 3, 2160, 2935, 2935, 3, 3, 3, 2935, 1, 2935, 1152, 3, 1, 11287, 1152, 1, 3, 3, 12, 3, 3, 3, 1152, 3, 1152, 3, 2935, 2935, 2935, 2935, 3, 1, 3, 3, 25687]\n",
      "Gold decoded: Don't want to spend a fortune?\n",
      "Pred decoded: 't want to spend a fortune?    bricracra   cracraber  Worber   to   ber ber cracracracra   travel\n",
      "\n",
      "🧭 Encoder Next-Step Prediction (Top-10) & Metrics\n",
      "\n",
      "— Sample 0 — target_pos=102 | rank=17 | AUC=0.958 | MRR=0.059 | nDCG@5=0.000\n",
      "     # 1: pos=107  p=0.0969\n",
      "     # 2: pos=104  p=0.0828\n",
      "     # 3: pos=105  p=0.0794\n",
      "     # 4: pos=106  p=0.0630\n",
      "     # 5: pos=108  p=0.0476\n",
      "     # 6: pos=103  p=0.0445\n",
      "     # 7: pos= 97  p=0.0440\n",
      "     # 8: pos=114  p=0.0435\n",
      "     # 9: pos=101  p=0.0423\n",
      "     #10: pos= 99  p=0.0420\n",
      "\n",
      "— Sample 1 — target_pos=144 | rank=2 | AUC=0.997 | MRR=0.500 | nDCG@5=0.631\n",
      "     # 1: pos=149  p=0.1879\n",
      "  ✅ # 2: pos=144  p=0.1023\n",
      "     # 3: pos=141  p=0.0943\n",
      "     # 4: pos=146  p=0.0788\n",
      "     # 5: pos=148  p=0.0595\n",
      "     # 6: pos=142  p=0.0570\n",
      "     # 7: pos=147  p=0.0546\n",
      "     # 8: pos=143  p=0.0440\n",
      "     # 9: pos=138  p=0.0438\n",
      "     #10: pos=136  p=0.0427\n",
      "\n",
      "— Sample 2 — target_pos=70 | rank=9 | AUC=0.979 | MRR=0.111 | nDCG@5=0.000\n",
      "     # 1: pos= 76  p=0.1182\n",
      "     # 2: pos= 75  p=0.0906\n",
      "     # 3: pos= 78  p=0.0886\n",
      "     # 4: pos= 74  p=0.0797\n",
      "     # 5: pos= 73  p=0.0791\n",
      "     # 6: pos= 71  p=0.0747\n",
      "     # 7: pos= 72  p=0.0654\n",
      "     # 8: pos= 69  p=0.0542\n",
      "  ✅ # 9: pos= 70  p=0.0541\n",
      "     #10: pos= 77  p=0.0477\n",
      "\n",
      "— Sample 3 — target_pos=102 | rank=32 | AUC=0.919 | MRR=0.031 | nDCG@5=0.000\n",
      "     # 1: pos=113  p=0.0791\n",
      "     # 2: pos=112  p=0.0708\n",
      "     # 3: pos=107  p=0.0623\n",
      "     # 4: pos=114  p=0.0609\n",
      "     # 5: pos=111  p=0.0549\n",
      "     # 6: pos=123  p=0.0506\n",
      "     # 7: pos=118  p=0.0338\n",
      "     # 8: pos=116  p=0.0263\n",
      "     # 9: pos=128  p=0.0257\n",
      "     #10: pos=103  p=0.0254\n",
      "\n",
      "— Sample 4 — target_pos=24 | rank=7 | AUC=0.984 | MRR=0.143 | nDCG@5=0.000\n",
      "     # 1: pos= 27  p=0.2245\n",
      "     # 2: pos= 22  p=0.1877\n",
      "     # 3: pos= 26  p=0.1371\n",
      "     # 4: pos= 23  p=0.1347\n",
      "     # 5: pos= 25  p=0.0891\n",
      "     # 6: pos= 21  p=0.0758\n",
      "  ✅ # 7: pos= 24  p=0.0476\n",
      "     # 8: pos= 28  p=0.0228\n",
      "     # 9: pos= 20  p=0.0164\n",
      "     #10: pos= 29  p=0.0133\n",
      "\n",
      "— Sample 5 — target_pos=107 | rank=3 | AUC=0.995 | MRR=0.333 | nDCG@5=0.500\n",
      "     # 1: pos=114  p=0.1101\n",
      "     # 2: pos=113  p=0.0906\n",
      "  ✅ # 3: pos=107  p=0.0869\n",
      "     # 4: pos=111  p=0.0646\n",
      "     # 5: pos=106  p=0.0527\n",
      "     # 6: pos=112  p=0.0519\n",
      "     # 7: pos=109  p=0.0473\n",
      "     # 8: pos=104  p=0.0471\n",
      "     # 9: pos=115  p=0.0458\n",
      "     #10: pos=118  p=0.0396\n",
      "\n",
      "— Sample 6 — target_pos=68 | rank=12 | AUC=0.971 | MRR=0.083 | nDCG@5=0.000\n",
      "     # 1: pos= 59  p=0.1094\n",
      "     # 2: pos= 61  p=0.0973\n",
      "     # 3: pos= 62  p=0.0890\n",
      "     # 4: pos= 66  p=0.0858\n",
      "     # 5: pos= 63  p=0.0856\n",
      "     # 6: pos= 60  p=0.0732\n",
      "     # 7: pos= 57  p=0.0555\n",
      "     # 8: pos= 65  p=0.0470\n",
      "     # 9: pos= 67  p=0.0402\n",
      "     #10: pos= 69  p=0.0372\n",
      "\n",
      "— Sample 7 — target_pos=138 | rank=13 | AUC=0.969 | MRR=0.077 | nDCG@5=0.000\n",
      "     # 1: pos=144  p=0.1098\n",
      "     # 2: pos=149  p=0.0964\n",
      "     # 3: pos=141  p=0.0878\n",
      "     # 4: pos=142  p=0.0565\n",
      "     # 5: pos=136  p=0.0557\n",
      "     # 6: pos=137  p=0.0550\n",
      "     # 7: pos=143  p=0.0439\n",
      "     # 8: pos=148  p=0.0430\n",
      "     # 9: pos=146  p=0.0398\n",
      "     #10: pos=130  p=0.0397\n",
      "\n",
      "— Sample 8 — target_pos=140 | rank=6 | AUC=0.987 | MRR=0.167 | nDCG@5=0.000\n",
      "     # 1: pos=149  p=0.0886\n",
      "     # 2: pos=144  p=0.0886\n",
      "     # 3: pos=141  p=0.0778\n",
      "     # 4: pos=137  p=0.0736\n",
      "     # 5: pos=146  p=0.0520\n",
      "  ✅ # 6: pos=140  p=0.0508\n",
      "     # 7: pos=138  p=0.0502\n",
      "     # 8: pos=139  p=0.0460\n",
      "     # 9: pos=142  p=0.0446\n",
      "     #10: pos=143  p=0.0418\n",
      "\n",
      "— Sample 9 — target_pos=62 | rank=6 | AUC=0.987 | MRR=0.167 | nDCG@5=0.000\n",
      "     # 1: pos= 66  p=0.0863\n",
      "     # 2: pos= 67  p=0.0858\n",
      "     # 3: pos= 69  p=0.0750\n",
      "     # 4: pos= 63  p=0.0666\n",
      "     # 5: pos= 71  p=0.0634\n",
      "  ✅ # 6: pos= 62  p=0.0586\n",
      "     # 7: pos= 70  p=0.0564\n",
      "     # 8: pos= 61  p=0.0475\n",
      "     # 9: pos= 59  p=0.0431\n",
      "     #10: pos= 73  p=0.0406\n",
      "\n",
      "📊 Batch Metrics (encoder next-step)\n",
      "  mean-rank = 10.70\n",
      "  AUC       = 0.975\n",
      "  MRR       = 0.167\n",
      "  nDCG@5    = 0.113\n",
      "\n",
      "✅ Autoregressive (EVAL) Results\n",
      "\n",
      "--- Sample 0 ---\n",
      "Doc:  A tabloid claims Jennifer Lopez won't marry Alex Rodriguez unless he signs a strict prenup with a cheating clause, and it's causing a major fight between the couple. The story is completely untrue. Go\n",
      "Gold: Early retirement can be more than just a daydream for those long Tuesday afternoons at work.\n",
      "AR Pred: Jennifer Lopez won't marry Alex Rodriguez unless he signs a strict prenup with a cheating clause, and it's causing a major fight between Jennifer Lopez and Alex Rodriguez\n",
      "EncLoss=3.4529\n",
      "Gold IDs: [8840, 6576, 54, 36, 72, 145, 131, 3, 9, 239, 26066, 21, 273, 307, 2818, 3742, 7, 44, 161, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [32099, 13560, 25692, 751, 31, 17, 20111, 5104, 27326, 3, 3227, 3, 88, 3957, 3, 9, 6926, 554, 29, 413, 28, 3, 9, 15009, 53, 14442, 6, 11, 34, 31, 7, 3, 5885, 3, 9, 779, 2870, 344, 13560, 25692, 11, 5104, 27326, 1]\n",
      "Gold decoded: Early retirement can be more than just a daydream for those long Tuesday afternoons at work.\n",
      "Pred decoded: Jennifer Lopez won't marry Alex Rodriguez unless he signs a strict prenup with a cheating clause, and it's causing a major fight between Jennifer Lopez and Alex Rodriguez\n",
      "\n",
      "--- Sample 1 ---\n",
      "Doc:  Photos of milestones in the relationship between the United States and Iran. Supertanker Grace 1 off the coast of Gibraltar on July 6. Iran demanded on July 5 that Britain immediately release the tank\n",
      "Gold: The family claims they weren't aware their daughter had taken the doll until they were in their car and on their way to a nearby apartment complex where their babysitter lived.\n",
      "AR Pred: The alleged gunman and others were taken to hospitals, Edwards said.\n",
      "EncLoss=2.5167\n",
      "Gold IDs: [37, 384, 3213, 79, 9355, 31, 17, 2718, 70, 3062, 141, 1026, 8, 14295, 552, 79, 130, 16, 70, 443, 11, 30, 70, 194, 12, 3, 9, 4676, 4579, 1561, 213, 70, 1871, 27734, 4114, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [32099, 37, 3, 12554, 4740, 348, 11, 717, 130, 1026, 12, 9612, 6, 8200, 7, 243, 5, 1]\n",
      "Gold decoded: The family claims they weren't aware their daughter had taken the doll until they were in their car and on their way to a nearby apartment complex where their babysitter lived.\n",
      "Pred decoded: The alleged gunman and others were taken to hospitals, Edwards said.\n",
      "\n",
      "--- Sample 2 ---\n",
      "Doc:  Video by CNN WASHINGTON, June 27 (Reuters) - In a major blow to election reformers, the U.S. Supreme Court on Thursday rejected efforts to rein in the contentious practice of manipulating electoral di\n",
      "Gold: \"These three guys were threatening each other for a long time,\" Keith Castro said.\n",
      "AR Pred: This was an unprovoked attack on a U.S. surveillance asset in international airspace.\"\n",
      "EncLoss=2.9267\n",
      "Gold IDs: [96, 20347, 386, 3413, 130, 3, 14390, 284, 119, 21, 3, 9, 307, 97, 976, 17017, 28728, 243, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [32099, 100, 47, 46, 73, 1409, 1621, 5100, 3211, 30, 3, 9, 412, 5, 134, 5, 12305, 7000, 16, 1038, 799, 6633, 535, 1]\n",
      "Gold decoded: \"These three guys were threatening each other for a long time,\" Keith Castro said.\n",
      "Pred decoded: This was an unprovoked attack on a U.S. surveillance asset in international airspace.\"\n",
      "\n",
      "--- Sample 3 ---\n",
      "Doc:  Have a special occasion on the horizon? Or do you just feel like having an over-the-top meal that you'll remember forever? From a humble town in Indiana to tony venues in New York and Chicago, these 5\n",
      "Gold: \"The heart is struggling and straining to deliver the oxygen to your body.\"\n",
      "AR Pred: The Best Restraunt in Each of the 50 States\n",
      "EncLoss=4.5369\n",
      "Gold IDs: [96, 634, 842, 19, 8335, 11, 6035, 53, 12, 2156, 8, 11035, 12, 39, 643, 535, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [32099, 37, 1648, 419, 3109, 202, 17, 16, 1698, 13, 8, 943, 1323, 1]\n",
      "Gold decoded: \"The heart is struggling and straining to deliver the oxygen to your body.\"\n",
      "Pred decoded: The Best Restraunt in Each of the 50 States\n",
      "\n",
      "--- Sample 4 ---\n",
      "Doc:  The second night of the first 2020 Democratic presidential debate continued with 10 candidates , including three of the top polling leaders in former Vice President Joe Biden, Sen. Bernie Sanders of V\n",
      "Gold: June 26, 2019 The 60-year-old pop star has not publicly responded to Carter's sentiment.\n",
      "AR Pred: The debate was a chance for him to get a chance to talk about it.\n",
      "EncLoss=2.8479\n",
      "Gold IDs: [1515, 13597, 1360, 37, 1640, 18, 1201, 18, 1490, 2783, 2213, 65, 59, 11652, 11318, 12, 17080, 31, 7, 6493, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [32099, 37, 5054, 47, 3, 9, 1253, 21, 376, 12, 129, 3, 9, 1253, 12, 1350, 81, 34, 5, 1]\n",
      "Gold decoded: June 26, 2019 The 60-year-old pop star has not publicly responded to Carter's sentiment.\n",
      "Pred decoded: The debate was a chance for him to get a chance to talk about it.\n",
      "\n",
      "--- Sample 5 ---\n",
      "Doc:  Have a special occasion on the horizon? Or do you just feel like having an over-the-top meal that you'll remember forever? From a humble town in Indiana to tony venues in New York and Chicago, these 5\n",
      "Gold: Dave W on TripAdvisor MAINE: THE WHITE BARN INN RESTAURANT Kennebunk Don't let the barn fool you: This restaurant has an AAA Five Diamond Award and five stars from Forbes.\n",
      "AR Pred: The Best Restraunt in Each of the 50 States\n",
      "EncLoss=2.5909\n",
      "Gold IDs: [8545, 549, 30, 16993, 188, 26, 24680, 4800, 9730, 10, 1853, 3, 15313, 14871, 272, 24947, 27, 17235, 3, 12200, 3221, 5905, 9156, 16267, 115, 6513, 1008, 31, 17, 752, 8, 13754, 18554, 25, 10, 100, 2062, 65, 46, 22656, 9528, 10834, 3677, 11, 874, 4811, 45, 24852, 5, 1]\n",
      "Pred IDs: [32099, 37, 1648, 419, 3109, 202, 17, 16, 1698, 13, 8, 943, 1323, 1]\n",
      "Gold decoded: Dave W on TripAdvisor MAINE: THE WHITE BARN INN RESTAURANT Kennebunk Don't let the barn fool you: This restaurant has an AAA Five Diamond Award and five stars from Forbes.\n",
      "Pred decoded: The Best Restraunt in Each of the 50 States\n",
      "\n",
      "--- Sample 6 ---\n",
      "Doc:  SANTA MARIA, Calif. (AP)   A third body has been found in the rubble of a burned-out mobile home in California, bringing to five the number of dead in a shooting and fire that began during an argument\n",
      "Gold: Click through the gallery above to find out when you should just say no.\n",
      "AR Pred: The shooter has not been identified.\n",
      "EncLoss=3.2929\n",
      "Gold IDs: [2001, 190, 8, 4865, 756, 12, 253, 91, 116, 25, 225, 131, 497, 150, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [32099, 37, 4279, 49, 65, 59, 118, 4313, 5, 1]\n",
      "Gold decoded: Click through the gallery above to find out when you should just say no.\n",
      "Pred decoded: The shooter has not been identified.\n",
      "\n",
      "--- Sample 7 ---\n",
      "Doc:  What's an award show without a host? Well, as we learned at the most recent Oscars, sometimes it can be all right. However, a great host makes watching a bunch of famous people get awards that much mo\n",
      "Gold: She made a big deal about changing her dress a whole bunch of times during the night.\n",
      "AR Pred: But if you have a lot of money, you can get a lot of good ideas.\n",
      "EncLoss=3.3109\n",
      "Gold IDs: [451, 263, 3, 9, 600, 1154, 81, 2839, 160, 3270, 3, 9, 829, 7292, 13, 648, 383, 8, 706, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [32099, 299, 3, 99, 25, 43, 3, 9, 418, 13, 540, 6, 25, 54, 129, 3, 9, 418, 13, 207, 912, 5, 1]\n",
      "Gold decoded: She made a big deal about changing her dress a whole bunch of times during the night.\n",
      "Pred decoded: But if you have a lot of money, you can get a lot of good ideas.\n",
      "\n",
      "--- Sample 8 ---\n",
      "Doc:  E. Jean Carroll tore through the doors of the Fifth Avenue entrance of Bergdorf Goodman, her heart racing. Ms. Carroll, a journalist and the host of the \"Ask E. Jean\" television show at the time, had \n",
      "Gold: \"I actually thought this was going to be a 'Travels With Charley,'\" Ms. Carroll said in an interview.\n",
      "AR Pred: Ms. Carroll said Ms. Carroll was laughing at first as she described an encounter she said she had just had in a Bergdorf's dressing room with Donald J. Trump that began as cheeky banter.\n",
      "EncLoss=3.0533\n",
      "Gold IDs: [96, 196, 700, 816, 48, 47, 352, 12, 36, 3, 9, 3, 31, 9402, 4911, 7, 438, 7435, 1306, 6, 31, 121, 283, 7, 5, 26508, 243, 16, 46, 2772, 5, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [32099, 283, 7, 5, 26508, 243, 283, 7, 5, 26508, 47, 20692, 44, 166, 38, 255, 3028, 46, 6326, 255, 243, 255, 141, 131, 141, 16, 3, 9, 5581, 8716, 31, 7, 9847, 562, 28, 7459, 446, 5, 2523, 24, 1553, 38, 18312, 63, 4514, 449, 5, 1]\n",
      "Gold decoded: \"I actually thought this was going to be a 'Travels With Charley,'\" Ms. Carroll said in an interview.\n",
      "Pred decoded: Ms. Carroll said Ms. Carroll was laughing at first as she described an encounter she said she had just had in a Bergdorf's dressing room with Donald J. Trump that began as cheeky banter.\n",
      "\n",
      "--- Sample 9 ---\n",
      "Doc:  Have a special occasion on the horizon? Or do you just feel like having an over-the-top meal that you'll remember forever? From a humble town in Indiana to tony venues in New York and Chicago, these 5\n",
      "Gold: Don't want to spend a fortune?\n",
      "AR Pred: The Best Restraunt in Each of the 50 States\n",
      "EncLoss=2.8572\n",
      "Gold IDs: [1008, 31, 17, 241, 12, 1492, 3, 9, 13462, 58, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n",
      "Pred IDs: [32099, 37, 1648, 419, 3109, 202, 17, 16, 1698, 13, 8, 943, 1323, 1]\n",
      "Gold decoded: Don't want to spend a fortune?\n",
      "Pred decoded: The Best Restraunt in Each of the 50 States\n"
     ]
    }
   ],
   "source": [
    "# =========================\n",
    "# Forward pass + Encoder Top-10 & Metrics (AUC/MRR/nDCG@5)\n",
    "# =========================\n",
    "import os, re, math, torch, numpy as np\n",
    "import torch.nn.functional as F\n",
    "from transformers.modeling_outputs import BaseModelOutput\n",
    "\n",
    "# ---------- helper: latest ckpt ----------\n",
    "def find_latest_ckpt(ckpt_dir=\"checkpoints\"):\n",
    "    paths = [os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir)\n",
    "             if f.startswith(\"epoch_\") and f.endswith(\".pt\")]\n",
    "    if not paths:\n",
    "        raise FileNotFoundError(f\"No checkpoints found in {ckpt_dir}\")\n",
    "    def _ep(p):\n",
    "        m = re.search(r\"epoch_(\\d+)\\.pt$\", os.path.basename(p))\n",
    "        return int(m.group(1)) if m else -1\n",
    "    paths.sort(key=lambda p: (_ep(p), os.path.getmtime(p)))\n",
    "    return paths[-1]\n",
    "\n",
    "# ---------- 0) Load checkpoint & build model ----------\n",
    "latest_ckpt = find_latest_ckpt(\"checkpoints\")\n",
    "ckpt = torch.load(latest_ckpt, map_location=device)\n",
    "print(f\"🔎 Using checkpoint: {latest_ckpt}\")\n",
    "print(f\"   epoch={ckpt.get('epoch')} | avg_total={ckpt.get('avg_total_loss')} | enc={ckpt.get('avg_enc_loss')} | gen={ckpt.get('avg_gen_loss')}\")\n",
    "\n",
    "behavior_encoder = BehaviorEncoder(hidden_dim, device=device, max_len=max_len).to(device)\n",
    "inverse_decoder  = BehaviorInverseDecoderPredict(hidden_dim, device=device).to(device)\n",
    "\n",
    "personalized_model = PersonalizedT5Summarizer(\n",
    "    hidden_dim, summarizer_model, behavior_encoder, inverse_decoder, device,\n",
    "    learnable_ctx=True, unfreeze_last_decoder_blocks=15\n",
    ").to(device)\n",
    "\n",
    "missing, unexpected = personalized_model.load_state_dict(ckpt[\"model_state\"], strict=False)\n",
    "print(f\"ℹ️ load_state: missing={len(missing)} unexpected={len(unexpected)}\")\n",
    "\n",
    "# ---------- 1) Sample 10 rows ----------\n",
    "batch_rows = train_df.sample(10, random_state=42)\n",
    "Bhist_batch = batch_rows[\"EHist_padded\"].tolist()\n",
    "Bhist_lens  = batch_rows[\"EHist_len\"].tolist()\n",
    "Bpos_batch  = batch_rows[\"EPos\"].tolist()\n",
    "print(f\"🔹 Sampled batch size: {len(Bhist_batch)}\")\n",
    "\n",
    "# ---------- 2) TRAIN pass (TF) + preview & token IDs (unchanged) ----------\n",
    "personalized_model.train()\n",
    "tf_total_losses, tf_enc_losses, tf_gen_losses = [], [], []\n",
    "tf_pred_texts, tf_doc_texts, tf_gold_texts = [], [], []\n",
    "tf_gold_ids, tf_pred_ids = [], []\n",
    "\n",
    "for i in range(len(Bhist_batch)):\n",
    "    Bhist     = Bhist_batch[i]\n",
    "    Bhist_len = int(Bhist_lens[i])\n",
    "    Bpos      = Bpos_batch[i]\n",
    "\n",
    "    out = personalized_model(\n",
    "        Bhist, Bhist_len, Bpos,\n",
    "        lookup_df, embed_tables, sid2sum, nid2body, tokenizer,\n",
    "        max_len=50, mode=\"train\",\n",
    "        preview_stride=1, global_step=i\n",
    "    )\n",
    "\n",
    "    tf_total_losses.append(out[\"total_loss\"].item())\n",
    "    tf_enc_losses.append(out[\"enc_loss\"].item())\n",
    "    tf_gen_losses.append(out[\"gen_loss\"].item())\n",
    "    tf_pred_texts.append(out[\"pred_summary_tf\"])\n",
    "    tf_doc_texts.append(out[\"doc_text\"])\n",
    "    tf_gold_texts.append(out[\"gold_summary\"])\n",
    "\n",
    "    if out[\"lm_logits\"] is not None:\n",
    "        pred_ids = torch.argmax(out[\"lm_logits\"], dim=-1)  # [B, T]\n",
    "        tf_pred_ids.append(pred_ids[0].tolist())\n",
    "        gold_tokens = tokenizer(\n",
    "            out[\"gold_summary\"],\n",
    "            return_tensors=\"pt\",\n",
    "            truncation=True,\n",
    "            max_length=50,\n",
    "            padding=\"max_length\"\n",
    "        ).input_ids.to(device)\n",
    "        tf_gold_ids.append(gold_tokens[0].tolist())\n",
    "    else:\n",
    "        tf_pred_ids.append([])\n",
    "        tf_gold_ids.append([])\n",
    "\n",
    "print(\"\\n✅ Teacher-Forcing (TRAIN) Results\")\n",
    "for i in range(len(Bhist_batch)):\n",
    "    print(f\"\\n--- Sample {i} ---\")\n",
    "    print(f\"Doc:  {tf_doc_texts[i][:200]}\")\n",
    "    print(f\"Gold: {tf_gold_texts[i][:200]}\")\n",
    "    print(f\"TF Pred: {tf_pred_texts[i][:200]}\")\n",
    "    print(f\"Loss={tf_total_losses[i]:.4f} | Enc={tf_enc_losses[i]:.4f} | Gen={tf_gen_losses[i]:.4f}\")\n",
    "    print(f\"Gold IDs: {tf_gold_ids[i]}\")\n",
    "    print(f\"Pred IDs: {tf_pred_ids[i]}\")\n",
    "    if tf_gold_ids[i]:\n",
    "        print(\"Gold decoded:\", tokenizer.decode(tf_gold_ids[i], skip_special_tokens=True))\n",
    "    if tf_pred_ids[i]:\n",
    "        print(\"Pred decoded:\", tokenizer.decode(tf_pred_ids[i], skip_special_tokens=True))\n",
    "\n",
    "# ---------- 2.5) BehaviorEncoder Top-10 & Metrics ----------\n",
    "print(\"\\n🧭 Encoder Next-Step Prediction (Top-10) & Metrics\")\n",
    "all_ranks, all_auc, all_mrr, all_ndcg5 = [], [], [], []\n",
    "\n",
    "def safe_softmax(x):\n",
    "    if not torch.isfinite(x).all():\n",
    "        x = torch.nan_to_num(x, nan=-1e4, posinf=1e4, neginf=-1e4)\n",
    "    return F.softmax(x, dim=-1)\n",
    "\n",
    "for i in range(len(Bhist_batch)):\n",
    "    Bhist     = Bhist_batch[i]\n",
    "    Bhist_len = int(Bhist_lens[i])   # ground-truth next position (index of next step)\n",
    "    Bpos      = Bpos_batch[i]\n",
    "\n",
    "    with torch.no_grad():\n",
    "        _, e_next, logits_pos, _, _ = behavior_encoder(\n",
    "            Bhist, Bhist_len, Bpos, lookup_df, embed_tables\n",
    "        )  # logits_pos: [1, C]\n",
    "        probs = safe_softmax(logits_pos.squeeze(0))  # [C]\n",
    "        C = probs.numel()\n",
    "\n",
    "        # top-10 predictions\n",
    "        K = min(10, C)\n",
    "        vals, idxs = torch.topk(probs, k=K)\n",
    "        vals = vals.detach().cpu().numpy()\n",
    "        idxs = idxs.detach().cpu().numpy()\n",
    "\n",
    "        # rank of the true target\n",
    "        sorted_all = torch.argsort(probs, descending=True).detach().cpu().numpy().tolist()\n",
    "        if Bhist_len < 0 or Bhist_len >= C:\n",
    "            rank = C  # invalid target -> worst\n",
    "        else:\n",
    "            rank = sorted_all.index(Bhist_len) + 1  # 1-based rank\n",
    "\n",
    "        # Metrics\n",
    "        auc_i = ((C - rank) / (C - 1)) if C > 1 else 1.0\n",
    "        mrr_i = 1.0 / rank\n",
    "        ndcg5_i = (1.0 / math.log2(1 + rank)) if rank <= 5 else 0.0\n",
    "\n",
    "        all_ranks.append(rank)\n",
    "        all_auc.append(auc_i)\n",
    "        all_mrr.append(mrr_i)\n",
    "        all_ndcg5.append(ndcg5_i)\n",
    "\n",
    "    # Pretty print\n",
    "    print(f\"\\n— Sample {i} — target_pos={Bhist_len} | rank={rank} | AUC={auc_i:.3f} | MRR={mrr_i:.3f} | nDCG@5={ndcg5_i:.3f}\")\n",
    "    for k in range(K):\n",
    "        marker = \"✅\" if idxs[k] == Bhist_len else \"  \"\n",
    "        print(f\"  {marker} #{k+1:>2}: pos={int(idxs[k]):>3}  p={vals[k]:.4f}\")\n",
    "\n",
    "def _mean(x):\n",
    "    return float(np.mean(x)) if len(x)>0 else float('nan')\n",
    "\n",
    "print(\"\\n📊 Batch Metrics (encoder next-step)\")\n",
    "print(f\"  mean-rank = {_mean(all_ranks):.2f}\")\n",
    "print(f\"  AUC       = {_mean(all_auc):.3f}\")\n",
    "print(f\"  MRR       = {_mean(all_mrr):.3f}\")\n",
    "print(f\"  nDCG@5    = {_mean(all_ndcg5):.3f}\")\n",
    "\n",
    "# ---------- 3) EVAL pass (AR) — simple greedy w/out KV cache ----------\n",
    "personalized_model.eval()\n",
    "ar_pred_texts, ar_doc_texts, ar_gold_texts, ar_enc_losses = [], [], [], []\n",
    "ar_gold_ids, ar_pred_ids = [], []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(len(Bhist_batch)):\n",
    "        Bhist     = Bhist_batch[i]\n",
    "        Bhist_len = int(Bhist_lens[i])\n",
    "        Bpos      = Bpos_batch[i]\n",
    "\n",
    "        # behavior + encoder enc\n",
    "        e_last, s_hat, enc_loss = personalized_model._behavior_encode(\n",
    "            Bhist, Bhist_len, Bpos, lookup_df, embed_tables\n",
    "        )\n",
    "        base_enc, doc_text, tail_id = personalized_model._encode_query(\n",
    "            Bpos, lookup_df, nid2body, tokenizer, max_len=50\n",
    "        )\n",
    "        if base_enc.dim() == 2:\n",
    "            base_enc = base_enc.unsqueeze(0)\n",
    "\n",
    "        personalized_enc = personalized_model._inject_context(base_enc, e_last, s_hat)\n",
    "        if not torch.isfinite(personalized_enc).all():\n",
    "            personalized_enc = torch.nan_to_num(personalized_enc, nan=0.0, posinf=0.0, neginf=0.0)\n",
    "\n",
    "        # encoder attention mask derived from same tokens\n",
    "        query_tokens = tokenizer(\n",
    "            \"Generate headline: \" + doc_text,\n",
    "            return_tensors=\"pt\", truncation=True,\n",
    "            max_length=50, padding=\"max_length\"\n",
    "        ).input_ids.to(device)\n",
    "        enc_attn_mask = (query_tokens != tokenizer.pad_token_id).long()\n",
    "\n",
    "        # greedy decode (no cache; simplest + robust across HF versions)\n",
    "        pad_id = tokenizer.pad_token_id\n",
    "        eos_id = tokenizer.eos_token_id\n",
    "        start_id = personalized_model.t5.config.decoder_start_token_id or pad_id\n",
    "\n",
    "        dec = torch.tensor([[start_id]], device=device, dtype=torch.long)  # [1,1]\n",
    "        max_new, min_new = 49, 5\n",
    "        for _ in range(max_new):\n",
    "            dec_out = personalized_model.t5.decoder(\n",
    "                input_ids=dec,\n",
    "                encoder_hidden_states=personalized_enc,\n",
    "                encoder_attention_mask=enc_attn_mask,\n",
    "                use_cache=False,\n",
    "                return_dict=True,\n",
    "            )\n",
    "            logits = personalized_model.t5.lm_head(dec_out.last_hidden_state)  # [1,T,V]\n",
    "            next_logits = logits[:, -1, :]\n",
    "            # block PAD; delay EOS\n",
    "            if pad_id is not None:\n",
    "                next_logits[:, pad_id] = -1e9\n",
    "            if eos_id is not None and dec.size(1) - 1 < min_new:\n",
    "                next_logits[:, eos_id] = -1e9\n",
    "            next_id = torch.argmax(next_logits, dim=-1, keepdim=True)\n",
    "            dec = torch.cat([dec, next_id], dim=1)\n",
    "            if eos_id is not None and next_id.item() == eos_id:\n",
    "                break\n",
    "\n",
    "        pred_tokens = dec[0, 1:]\n",
    "        pred_summary_ar = tokenizer.decode(pred_tokens, skip_special_tokens=True)\n",
    "\n",
    "        # gold ids\n",
    "        gold_text = sid2sum.get(tail_id, \"\")\n",
    "        gold_ids = tokenizer(\n",
    "            gold_text,\n",
    "            return_tensors=\"pt\",\n",
    "            truncation=True,\n",
    "            max_length=50,\n",
    "            padding=\"max_length\"\n",
    "        ).input_ids[0]\n",
    "\n",
    "        ar_pred_texts.append(pred_summary_ar)\n",
    "        ar_doc_texts.append(doc_text)\n",
    "        ar_gold_texts.append(gold_text)\n",
    "        ar_enc_losses.append(enc_loss.item() if torch.is_tensor(enc_loss) else float(enc_loss))\n",
    "        ar_gold_ids.append(gold_ids.tolist())\n",
    "        ar_pred_ids.append(pred_tokens.tolist())\n",
    "\n",
    "print(\"\\n✅ Autoregressive (EVAL) Results\")\n",
    "for i in range(len(Bhist_batch)):\n",
    "    print(f\"\\n--- Sample {i} ---\")\n",
    "    print(f\"Doc:  {ar_doc_texts[i][:200]}\")\n",
    "    print(f\"Gold: {ar_gold_texts[i][:200]}\")\n",
    "    print(f\"AR Pred: {ar_pred_texts[i][:200]}\")\n",
    "    print(f\"EncLoss={ar_enc_losses[i]:.4f}\")\n",
    "    print(f\"Gold IDs: {ar_gold_ids[i]}\")\n",
    "    print(f\"Pred IDs: {ar_pred_ids[i]}\")\n",
    "    if ar_gold_ids[i]:\n",
    "        print(\"Gold decoded:\", tokenizer.decode(ar_gold_ids[i], skip_special_tokens=True))\n",
    "    if ar_pred_ids[i]:\n",
    "        print(\"Pred decoded:\", tokenizer.decode(ar_pred_ids[i], skip_special_tokens=True))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "aced447d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "📦 Total params: 232,068,503\n",
      "🟢 Trainable:   147,114,263\n",
      "🧊 Frozen:      84,954,240\n",
      "\n",
      "Top trainable parameter groups:\n",
      "  • t5.shared.weight                                             24,674,304\n",
      "  • t5.decoder.block.0                                           9,439,872\n",
      "  • t5.decoder.block.1                                           9,439,488\n",
      "  • t5.decoder.block.2                                           9,439,488\n",
      "  • t5.decoder.block.3                                           9,439,488\n",
      "  • t5.decoder.block.4                                           9,439,488\n",
      "  • t5.decoder.block.5                                           9,439,488\n",
      "  • t5.decoder.block.6                                           9,439,488\n",
      "  • t5.decoder.block.7                                           9,439,488\n",
      "  • t5.decoder.block.8                                           9,439,488\n",
      "  • t5.decoder.block.9                                           9,439,488\n",
      "  • t5.decoder.block.10                                          9,439,488\n",
      "  • t5.decoder.block.11                                          9,439,488\n",
      "  • behavior_encoder.W_s.weight                                  589,824\n",
      "  • behavior_encoder.W_d.weight                                  589,824\n",
      "  • behavior_encoder.Wh.weight                                   589,824\n",
      "  • behavior_encoder.Wc.weight                                   589,824\n",
      "  • behavior_encoder.W_emb.weight                                589,824\n",
      "  • behavior_encoder.W_angle.weight                              589,824\n",
      "  • behavior_encoder.W_h.weight                                  589,824\n",
      "  • behavior_encoder.W_next.weight                               589,824\n",
      "  • inverse_decoder.W_emb_pinv.weight                            589,824\n",
      "  • inverse_decoder.W1_pinv.weight                               589,824\n",
      "  • inverse_decoder.W2_pinv.weight                               589,824\n",
      "  • e_proj.weight                                                589,824\n",
      "  • q_proj.weight                                                589,824\n",
      "  • k_proj.weight                                                589,824\n",
      "  • v_proj.weight                                                589,824\n",
      "  • behavior_encoder.classifier.weight                           295,680\n",
      "  • behavior_encoder.W_clk.weight                                3,072\n",
      "  • behavior_encoder.W_skp.weight                                3,072\n",
      "  • behavior_encoder.W_gensumm.weight                            3,072\n",
      "  • behavior_encoder.W_sumgen.weight                             3,072\n",
      "  • behavior_encoder.Wz.weight                                   2,304\n",
      "  • t5.decoder.final_layer_norm.weight                           768\n",
      "  • behavior_encoder.b_emb                                       768\n",
      "  • behavior_encoder.W_pull.weight                               768\n",
      "  • behavior_encoder.W_emb.bias                                  768\n",
      "  • behavior_encoder.W_theta.weight                              768\n",
      "  • behavior_encoder.W_m.weight                                  768\n",
      "\n",
      "Quick spot-check on a few important params:\n",
      "  • t5.shared.weight                                             requires_grad=True\n",
      "  • t5.decoder.final_layer_norm.weight                           requires_grad=True\n"
     ]
    }
   ],
   "source": [
    "# ---- Param report helpers ----\n",
    "from collections import defaultdict\n",
    "import math\n",
    "\n",
    "def count_params(module, only_trainable=False):\n",
    "    return sum(p.numel() for p in module.parameters() if (p.requires_grad or not only_trainable))\n",
    "\n",
    "def param_table_by_prefix(model, topn=25):\n",
    "    buckets = defaultdict(int)\n",
    "    for n, p in model.named_parameters():\n",
    "        if not p.requires_grad:\n",
    "            continue\n",
    "        prefix = \".\".join(n.split(\".\")[:4])\n",
    "        buckets[prefix] += p.numel()\n",
    "    return sorted(buckets.items(), key=lambda x: x[1], reverse=True)[:topn]\n",
    "\n",
    "def print_param_report(model):\n",
    "    total = count_params(model, only_trainable=False)\n",
    "    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    frozen = total - trainable\n",
    "    print(f\"📦 Total params: {total:,}\")\n",
    "    print(f\"🟢 Trainable:   {trainable:,}\")\n",
    "    print(f\"🧊 Frozen:      {frozen:,}\")\n",
    "\n",
    "    # show breakdown of where the trainable params live\n",
    "    rows = param_table_by_prefix(model, topn=40)\n",
    "    print(\"\\nTop trainable parameter groups:\")\n",
    "    for k, v in rows:\n",
    "        print(f\"  • {k:60s} {v:,}\")\n",
    "\n",
    "print_param_report(personalized_model)\n",
    "\n",
    "# sanity: list a few key modules that must be trainable\n",
    "must_be_trainable = [\n",
    "    \"t5.shared.weight\", \"t5.lm_head.weight\",\n",
    "    \"t5.decoder.final_layer_norm.weight\",\n",
    "    # last N decoder blocks (you set N in the class' __init__)\n",
    "    \"t5.decoder.block.-1.layer.0.SelfAttention.q.weight\"  # won't match literally; just a reminder\n",
    "]\n",
    "print(\"\\nQuick spot-check on a few important params:\")\n",
    "for name, p in personalized_model.named_parameters():\n",
    "    if any(k in name for k in [\"t5.shared.weight\", \"t5.lm_head.weight\", \"t5.decoder.final_layer_norm.weight\"]):\n",
    "        print(f\"  • {name:60s} requires_grad={p.requires_grad}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "77e4e685",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ DataLoader ready: 38417 samples, batch_size=8\n"
     ]
    }
   ],
   "source": [
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "class NewsDataset(Dataset):\n",
    "    def __init__(self, df):\n",
    "        # df must contain: EHist_padded, EHist_len, EPos\n",
    "        self.df = df.reset_index(drop=True)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.df)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        row = self.df.iloc[idx]\n",
    "        return {\n",
    "            \"Bhist\": row[\"EHist_padded\"],\n",
    "            \"Bhist_len\": int(row[\"EHist_len\"]),\n",
    "            \"Bpos\": row[\"EPos\"],\n",
    "        }\n",
    "\n",
    "def collate_simple(batch):\n",
    "    # keep as lists; model forward works per-sample\n",
    "    Bhist = [b[\"Bhist\"] for b in batch]\n",
    "    Blen  = [b[\"Bhist_len\"] for b in batch]\n",
    "    Bpos  = [b[\"Bpos\"] for b in batch]\n",
    "    return Bhist, Blen, Bpos\n",
    "\n",
    "# build dataset / dataloader\n",
    "train_ds = NewsDataset(train_df)\n",
    "train_loader = DataLoader(\n",
    "    train_ds,\n",
    "    batch_size=8,          # tune as your GPU allows\n",
    "    shuffle=True,\n",
    "    num_workers=0,\n",
    "    pin_memory=False,\n",
    "    collate_fn=collate_simple\n",
    ")\n",
    "\n",
    "print(f\"✅ DataLoader ready: {len(train_ds)} samples, batch_size={train_loader.batch_size}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "95903ccc",
   "metadata": {},
   "outputs": [],
   "source": [
    "def grad_health_report(model, limit_print=30):\n",
    "    issues = {\"none\": [], \"all_zero\": [], \"nan_inf\": []}\n",
    "    for name, p in model.named_parameters():\n",
    "        if not p.requires_grad:\n",
    "            continue\n",
    "        g = p.grad\n",
    "        if g is None:\n",
    "            issues[\"none\"].append(name)\n",
    "            continue\n",
    "        if not torch.isfinite(g).all():\n",
    "            issues[\"nan_inf\"].append(name)\n",
    "            continue\n",
    "        if torch.count_nonzero(g).item() == 0:\n",
    "            issues[\"all_zero\"].append(name)\n",
    "\n",
    "    total_bad = sum(len(v) for v in issues.values())\n",
    "    if total_bad == 0:\n",
    "        print(\"👍 gradients look healthy on all trainable params.\")\n",
    "        return True\n",
    "\n",
    "    print(f\"⚠️  {total_bad} gradient issues:\")\n",
    "    for k in [\"none\", \"all_zero\", \"nan_inf\"]:\n",
    "        if issues[k]:\n",
    "            shown = issues[k][:limit_print]\n",
    "            more = len(issues[k]) - len(shown)\n",
    "            print(f\"  • {k}: {len(issues[k])} {'(showing first ' + str(len(shown)) + ')' if more >= 0 else ''}\")\n",
    "            for n in shown:\n",
    "                print(f\"     - {n}\")\n",
    "            if more > 0:\n",
    "                print(f\"     ... and {more} more\")\n",
    "    return False\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2097668b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import os, re, torch\n",
    "# from tqdm import tqdm\n",
    "# import torch.optim as optim\n",
    "\n",
    "# def find_latest_ckpt(ckpt_dir=\"checkpoints\"):\n",
    "#     \"\"\"Return path of latest checkpoint in a folder, or None if none exist.\"\"\"\n",
    "#     if not os.path.isdir(ckpt_dir):\n",
    "#         return None\n",
    "#     paths = [os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir)\n",
    "#              if f.startswith(\"epoch_\") and f.endswith(\".pt\")]\n",
    "#     if not paths:\n",
    "#         return None\n",
    "#     def _ep(p):\n",
    "#         m = re.search(r\"epoch_(\\d+)\\.pt$\", os.path.basename(p))\n",
    "#         return int(m.group(1)) if m else -1\n",
    "#     return max(paths, key=lambda p: _ep(p))\n",
    "\n",
    "# # --- optimizer over trainable params only ---\n",
    "# trainable_params = [p for p in personalized_model.parameters() if p.requires_grad]\n",
    "# optimizer = optim.AdamW(trainable_params, lr=2e-4, weight_decay=0.01)\n",
    "\n",
    "# epochs = 5          # total training epochs (set as you like)\n",
    "# max_gen_len = 50\n",
    "# clip_norm = 1.0\n",
    "# os.makedirs(\"checkpoints\", exist_ok=True)\n",
    "\n",
    "# # ==== RESUME OR START FRESH ====\n",
    "# start_epoch = 1\n",
    "# latest_ckpt = find_latest_ckpt(\"checkpoints\")\n",
    "# if latest_ckpt:\n",
    "#     ckpt = torch.load(latest_ckpt, map_location=device)\n",
    "#     print(f\"🔎 Found checkpoint: {latest_ckpt}\")\n",
    "    \n",
    "#     # check param counts\n",
    "#     ckpt_params = len(ckpt[\"model_state\"])\n",
    "#     model_params = len(personalized_model.state_dict())\n",
    "#     print(f\"📊 Checkpoint params={ckpt_params}, Current model params={model_params}\")\n",
    "    \n",
    "#     # load weights\n",
    "#     personalized_model.load_state_dict(ckpt[\"model_state\"], strict=False)\n",
    "#     optimizer.load_state_dict(ckpt[\"optimizer_state\"])\n",
    "#     personalized_model.to(device)\n",
    "    \n",
    "#     last_epoch = ckpt[\"epoch\"]\n",
    "#     print(f\"✅ Resuming from epoch {last_epoch}, \"\n",
    "#           f\"last losses: tot={ckpt.get('avg_total_loss'):.4f}, \"\n",
    "#           f\"enc={ckpt.get('avg_enc_loss'):.4f}, \"\n",
    "#           f\"gen={ckpt.get('avg_gen_loss'):.4f}\")\n",
    "#     start_epoch = last_epoch + 1\n",
    "# else:\n",
    "#     print(\"⚠️ No checkpoint found, starting fresh training.\")\n",
    "#     personalized_model.to(device)\n",
    "\n",
    "# # ==== TRAINING LOOP ====\n",
    "# for ep in range(start_epoch, epochs + 1):\n",
    "#     personalized_model.train()\n",
    "#     pbar = tqdm(train_loader, desc=f\"Epoch {ep}/{epochs}\", dynamic_ncols=True)\n",
    "#     running_tot, running_enc, running_gen = 0.0, 0.0, 0.0\n",
    "#     seen = 0\n",
    "\n",
    "#     for Bhist_batch, Blen_batch, Bpos_batch in pbar:\n",
    "#         optimizer.zero_grad(set_to_none=True)\n",
    "\n",
    "#         batch_total, batch_enc, batch_gen = 0.0, 0.0, 0.0\n",
    "#         valid_in_batch = 0\n",
    "\n",
    "#         for i in range(len(Bhist_batch)):\n",
    "#             out = personalized_model(\n",
    "#                 Bhist_batch[i], int(Blen_batch[i]), Bpos_batch[i],\n",
    "#                 lookup_df, embed_tables, sid2sum, nid2body, tokenizer,\n",
    "#                 max_len=max_gen_len, mode=\"train\"\n",
    "#             )\n",
    "#             if out[\"skip\"]:\n",
    "#                 continue\n",
    "\n",
    "#             total_loss = out[\"total_loss\"]\n",
    "#             gen_loss   = out[\"gen_loss\"]\n",
    "#             enc_loss   = out[\"enc_loss\"]\n",
    "\n",
    "#             batch_total += float(total_loss.detach().item())\n",
    "#             batch_enc   += float(enc_loss.detach().item()) if torch.is_tensor(enc_loss) else float(enc_loss)\n",
    "#             batch_gen   += float(gen_loss.detach().item())\n",
    "#             valid_in_batch += 1\n",
    "\n",
    "#             total_loss.backward()\n",
    "\n",
    "#         if valid_in_batch == 0:\n",
    "#             pbar.set_postfix_str(\"skip batch\")\n",
    "#             continue\n",
    "\n",
    "#         torch.nn.utils.clip_grad_norm_(trainable_params, clip_norm)\n",
    "#         optimizer.step()\n",
    "\n",
    "#         ok = all(\n",
    "#             torch.isfinite(p.grad).all()\n",
    "#             for _, p in personalized_model.named_parameters()\n",
    "#             if p.requires_grad and p.grad is not None\n",
    "#         )\n",
    "\n",
    "#         running_tot += batch_total\n",
    "#         running_enc += batch_enc\n",
    "#         running_gen += batch_gen\n",
    "#         seen        += valid_in_batch\n",
    "\n",
    "#         avg_tot = running_tot / max(1, seen)\n",
    "#         avg_enc = running_enc / max(1, seen)\n",
    "#         avg_gen = running_gen / max(1, seen)\n",
    "\n",
    "#         pbar.set_postfix({\n",
    "#             \"enc\": f\"{avg_enc:.3f}\",\n",
    "#             \"dec\": f\"{avg_gen:.3f}\",\n",
    "#             \"tot\": f\"{avg_tot:.3f}\",\n",
    "#             \"grad\": \"👍\" if ok else \"⚠️\"\n",
    "#         })\n",
    "\n",
    "#     # ---- SAVE CHECKPOINT AT END OF EPOCH ----\n",
    "#     ckpt_path = f\"checkpoints/epoch_{ep}.pt\"\n",
    "#     torch.save({\n",
    "#         \"epoch\": ep,\n",
    "#         \"model_state\": personalized_model.state_dict(),\n",
    "#         \"optimizer_state\": optimizer.state_dict(),\n",
    "#         \"avg_total_loss\": avg_tot,\n",
    "#         \"avg_enc_loss\": avg_enc,\n",
    "#         \"avg_gen_loss\": avg_gen,\n",
    "#     }, ckpt_path)\n",
    "#     print(f\"💾 Saved checkpoint: {ckpt_path}\")\n",
    "\n",
    "#     # optional grad health report\n",
    "#     print(\"\\nEpoch grad health summary:\")\n",
    "#     grad_health_report(personalized_model, limit_print=40)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "8b1c3593",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ DataLoader: 38417 samples, batch_size=8\n",
      "📦 Total params: 232,068,503\n",
      "🟢 Trainable:   5,033,110\n",
      "🧊 Frozen:      227,035,393\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🔎 Found checkpoint: checkpoints/epoch_5.pt\n",
      "⚠️ Skipping optimizer_state load: loaded state dict contains a parameter group that doesn't match the size of optimizer's group\n",
      "✅ Resuming from epoch 5; prev avg_enc_loss=2.9589636535152124\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[ENC-ONLY] Epoch 6/8:   0%|          | 4/4803 [00:08<2:41:14,  2.02s/it, enc=2.852]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[20], line 136\u001b[0m\n\u001b[1;32m    132\u001b[0m e_last, e_next, logits_pos, ctx_enc_loss, enc_total_loss \u001b[38;5;241m=\u001b[39m personalized_model\u001b[38;5;241m.\u001b[39mbehavior_encoder(\n\u001b[1;32m    133\u001b[0m     Bhist_batch[i], \u001b[38;5;28mint\u001b[39m(Blen_batch[i]), Bpos_batch[i], lookup_df, embed_tables\n\u001b[1;32m    134\u001b[0m )\n\u001b[1;32m    135\u001b[0m loss \u001b[38;5;241m=\u001b[39m enc_total_loss\n\u001b[0;32m--> 136\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    137\u001b[0m batch_enc \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mfloat\u001b[39m(loss\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mitem())\n\u001b[1;32m    138\u001b[0m valid \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/_tensor.py:648\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    638\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    639\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    640\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m    641\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    646\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m    647\u001b[0m     )\n\u001b[0;32m--> 648\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    649\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m    650\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/autograd/__init__.py:353\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    348\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m    350\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m    351\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    352\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 353\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    354\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    355\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    356\u001b[0m \u001b[43m    \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    357\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    358\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    359\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    360\u001b[0m \u001b[43m    \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    361\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/autograd/graph.py:824\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m    822\u001b[0m     unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m    823\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 824\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    825\u001b[0m \u001b[43m        \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m    826\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m    827\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    828\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# =========================\n",
    "# Encoder-only training cell (freezes everything else)\n",
    "# =========================\n",
    "import os, re, torch\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "# ---------- helpers ----------\n",
    "def print_param_report(model):\n",
    "    total = sum(p.numel() for p in model.parameters())\n",
    "    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    print(f\"📦 Total params: {total:,}\")\n",
    "    print(f\"🟢 Trainable:   {trainable:,}\")\n",
    "    print(f\"🧊 Frozen:      {total-trainable:,}\")\n",
    "\n",
    "def grad_health_report(model, limit_print=30):\n",
    "    issues = {\"none\": [], \"all_zero\": [], \"nan_inf\": []}\n",
    "    for name, p in model.named_parameters():\n",
    "        if not p.requires_grad:\n",
    "            continue\n",
    "        g = p.grad\n",
    "        if g is None:\n",
    "            issues[\"none\"].append(name); continue\n",
    "        if not torch.isfinite(g).all():\n",
    "            issues[\"nan_inf\"].append(name); continue\n",
    "        if torch.count_nonzero(g).item() == 0:\n",
    "            issues[\"all_zero\"].append(name)\n",
    "    total_bad = sum(len(v) for v in issues.values())\n",
    "    if total_bad == 0:\n",
    "        print(\"👍 gradients look healthy on all trainable params.\")\n",
    "        return True\n",
    "    print(f\"⚠️  {total_bad} gradient issues:\")\n",
    "    for k in [\"none\", \"all_zero\", \"nan_inf\"]:\n",
    "        if issues[k]:\n",
    "            shown = issues[k][:limit_print]\n",
    "            more = len(issues[k]) - len(shown)\n",
    "            print(f\"  • {k}: {len(issues[k])}\")\n",
    "            for n in shown: print(f\"     - {n}\")\n",
    "            if more > 0: print(f\"     ... and {more} more\")\n",
    "    return False\n",
    "\n",
    "def find_latest_ckpt(ckpt_dir=\"checkpoints\"):\n",
    "    if not os.path.isdir(ckpt_dir): return None\n",
    "    paths = [os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir)\n",
    "             if f.startswith(\"epoch_\") and f.endswith(\".pt\")]\n",
    "    if not paths: return None\n",
    "    def _ep(p):\n",
    "        m = re.search(r\"epoch_(\\d+)\\.pt$\", os.path.basename(p))\n",
    "        return int(m.group(1)) if m else -1\n",
    "    return max(paths, key=lambda p: _ep(p))\n",
    "\n",
    "# ---------- dataset/dataloader ----------\n",
    "class NewsDataset(Dataset):\n",
    "    def __init__(self, df):\n",
    "        # df must contain: EHist_padded, EHist_len, EPos\n",
    "        self.df = df.reset_index(drop=True)\n",
    "    def __len__(self): return len(self.df)\n",
    "    def __getitem__(self, idx):\n",
    "        r = self.df.iloc[idx]\n",
    "        return {\"Bhist\": r[\"EHist_padded\"], \"Bhist_len\": int(r[\"EHist_len\"]), \"Bpos\": r[\"EPos\"]}\n",
    "\n",
    "def collate_simple(batch):\n",
    "    return (\n",
    "        [b[\"Bhist\"] for b in batch],\n",
    "        [b[\"Bhist_len\"] for b in batch],\n",
    "        [b[\"Bpos\"] for b in batch],\n",
    "    )\n",
    "\n",
    "train_ds = NewsDataset(train_df)\n",
    "train_loader = DataLoader(\n",
    "    train_ds, batch_size=8, shuffle=True, num_workers=0, pin_memory=False, collate_fn=collate_simple\n",
    ")\n",
    "print(f\"✅ DataLoader: {len(train_ds)} samples, batch_size={train_loader.batch_size}\")\n",
    "\n",
    "# ---------- build model (fresh) ----------\n",
    "behavior_encoder = BehaviorEncoder(hidden_dim, device=device, max_len=max_len).to(device)\n",
    "inverse_decoder  = BehaviorInverseDecoderPredict(hidden_dim, device=device).to(device)\n",
    "\n",
    "personalized_model = PersonalizedT5Summarizer(\n",
    "    hidden_dim, summarizer_model, behavior_encoder, inverse_decoder, device,\n",
    "    learnable_ctx=True, unfreeze_last_decoder_blocks=0  # <- do NOT unfreeze any T5 decoder blocks\n",
    ").to(device)\n",
    "\n",
    "# ---------- freeze everything except behavior_encoder ----------\n",
    "for name, p in personalized_model.named_parameters():\n",
    "    p.requires_grad = name.startswith(\"behavior_encoder\")\n",
    "\n",
    "print_param_report(personalized_model)\n",
    "\n",
    "# ---------- optimizer over encoder params only ----------\n",
    "enc_params = [p for p in personalized_model.behavior_encoder.parameters() if p.requires_grad]\n",
    "assert len(enc_params) > 0, \"No trainable params found in behavior_encoder!\"\n",
    "optimizer = optim.AdamW(enc_params, lr=2e-4, weight_decay=0.01)\n",
    "\n",
    "# ---------- (optional) resume ----------\n",
    "os.makedirs(\"checkpoints\", exist_ok=True)\n",
    "start_epoch = 1\n",
    "latest = find_latest_ckpt(\"checkpoints\")\n",
    "if latest:\n",
    "    ckpt = torch.load(latest, map_location=device)\n",
    "    print(f\"🔎 Found checkpoint: {latest}\")\n",
    "    # Safely load model (shape changes tolerated)\n",
    "    personalized_model.load_state_dict(ckpt[\"model_state\"], strict=False)\n",
    "    # Optimizer may fail to load if param set changed — guard it:\n",
    "    if \"optimizer_state\" in ckpt:\n",
    "        try:\n",
    "            optimizer.load_state_dict(ckpt[\"optimizer_state\"])\n",
    "        except Exception as e:\n",
    "            print(f\"⚠️ Skipping optimizer_state load: {e}\")\n",
    "    last_epoch = int(ckpt.get(\"epoch\", 0))\n",
    "    print(f\"✅ Resuming from epoch {last_epoch}; prev avg_enc_loss={ckpt.get('avg_enc_loss')}\")\n",
    "    start_epoch = last_epoch + 1\n",
    "else:\n",
    "    print(\"▶️ Starting fresh encoder-only training.\")\n",
    "\n",
    "# ---------- train (encoder-only) ----------\n",
    "epochs =8\n",
    "clip_norm = 1.0\n",
    "\n",
    "for ep in range(start_epoch, epochs + 1):\n",
    "    personalized_model.train()\n",
    "    pbar = tqdm(train_loader, desc=f\"[ENC-ONLY] Epoch {ep}/{epochs}\", dynamic_ncols=True)\n",
    "    running_enc, seen = 0.0, 0\n",
    "\n",
    "    for Bhist_batch, Blen_batch, Bpos_batch in pbar:\n",
    "        optimizer.zero_grad(set_to_none=True)\n",
    "\n",
    "        batch_enc, valid = 0.0, 0\n",
    "        for i in range(len(Bhist_batch)):\n",
    "            # Forward the behavior encoder ONLY; use its total loss\n",
    "            e_last, e_next, logits_pos, ctx_enc_loss, enc_total_loss = personalized_model.behavior_encoder(\n",
    "                Bhist_batch[i], int(Blen_batch[i]), Bpos_batch[i], lookup_df, embed_tables\n",
    "            )\n",
    "            loss = enc_total_loss\n",
    "            loss.backward()\n",
    "            batch_enc += float(loss.detach().item())\n",
    "            valid += 1\n",
    "\n",
    "        if valid == 0:\n",
    "            pbar.set_postfix_str(\"skip batch\"); continue\n",
    "\n",
    "        torch.nn.utils.clip_grad_norm_(enc_params, clip_norm)\n",
    "        optimizer.step()\n",
    "\n",
    "        running_enc += batch_enc\n",
    "        seen += valid\n",
    "        avg_enc = running_enc / max(1, seen)\n",
    "        pbar.set_postfix({\"enc\": f\"{avg_enc:.3f}\"})\n",
    "\n",
    "    # ---- save checkpoint ----\n",
    "    ckpt_path = f\"checkpoints/epoch_{ep}.pt\"\n",
    "    torch.save({\n",
    "        \"epoch\": ep,\n",
    "        \"model_state\": personalized_model.state_dict(),   # includes encoder weights\n",
    "        \"optimizer_state\": optimizer.state_dict(),\n",
    "        \"avg_total_loss\": None,\n",
    "        \"avg_enc_loss\": avg_enc,\n",
    "        \"avg_gen_loss\": None,\n",
    "    }, ckpt_path)\n",
    "    print(f\"💾 Saved encoder-only checkpoint: {ckpt_path}\")\n",
    "\n",
    "    print(\"\\nEpoch grad health summary (encoder only):\")\n",
    "    grad_health_report(personalized_model, limit_print=40)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "cf3817a6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "✅ DataLoader: 38417 samples, batch_size=8\n",
      "📦 Total params: 232,068,503\n",
      "🟢 Trainable:   142,081,153\n",
      "🧊 Frozen:      89,987,350\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🔎 Found checkpoint: checkpoints/epoch_5.pt\n",
      "⚠️ Skipping optimizer_state load: loaded state dict contains a parameter group that doesn't match the size of optimizer's group\n",
      "✅ Resuming from epoch 5; prev avg_gen_loss=0.5976175675267787 | prev avg_enc_loss=2.9589636535152124\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[INV+SUM] Epoch 6/8:   0%|          | 10/4803 [00:13<1:47:08,  1.34s/it, enc=3.088, gen=0.535, tot=2.067]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[21], line 169\u001b[0m\n\u001b[1;32m    166\u001b[0m enc_loss \u001b[38;5;241m=\u001b[39m out[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124menc_loss\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;28;01mif\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mis_tensor(out[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124menc_loss\u001b[39m\u001b[38;5;124m\"\u001b[39m]) \u001b[38;5;28;01melse\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mtensor(\u001b[38;5;28mfloat\u001b[39m(out[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124menc_loss\u001b[39m\u001b[38;5;124m\"\u001b[39m]), device\u001b[38;5;241m=\u001b[39mdevice)\n\u001b[1;32m    167\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0.4\u001b[39m \u001b[38;5;241m*\u001b[39m gen_loss \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m0.6\u001b[39m \u001b[38;5;241m*\u001b[39m enc_loss\u001b[38;5;241m.\u001b[39mdetach()                 \u001b[38;5;66;03m# detach => no grads into encoder\u001b[39;00m\n\u001b[0;32m--> 169\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    170\u001b[0m batch_tot \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mfloat\u001b[39m(loss\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mitem())\n\u001b[1;32m    171\u001b[0m batch_enc \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mfloat\u001b[39m(enc_loss\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mitem())\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/_tensor.py:648\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    638\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    639\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    640\u001b[0m         Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m    641\u001b[0m         (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    646\u001b[0m         inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m    647\u001b[0m     )\n\u001b[0;32m--> 648\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    649\u001b[0m \u001b[43m    \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m    650\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/autograd/__init__.py:353\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    348\u001b[0m     retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m    350\u001b[0m \u001b[38;5;66;03m# The reason we repeat the same comment below is that\u001b[39;00m\n\u001b[1;32m    351\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    352\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 353\u001b[0m \u001b[43m_engine_run_backward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    354\u001b[0m \u001b[43m    \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    355\u001b[0m \u001b[43m    \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    356\u001b[0m \u001b[43m    \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    357\u001b[0m \u001b[43m    \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    358\u001b[0m \u001b[43m    \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    359\u001b[0m \u001b[43m    \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    360\u001b[0m \u001b[43m    \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    361\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/autograd/graph.py:824\u001b[0m, in \u001b[0;36m_engine_run_backward\u001b[0;34m(t_outputs, *args, **kwargs)\u001b[0m\n\u001b[1;32m    822\u001b[0m     unregister_hooks \u001b[38;5;241m=\u001b[39m _register_logging_hooks_on_whole_graph(t_outputs)\n\u001b[1;32m    823\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 824\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m  \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    825\u001b[0m \u001b[43m        \u001b[49m\u001b[43mt_outputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\n\u001b[1;32m    826\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# Calls into the C++ engine to run the backward pass\u001b[39;00m\n\u001b[1;32m    827\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m    828\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m attach_logging_hooks:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# =========================\n",
    "# Train ONLY: inverse_decoder + personalized summarizer (T5 decoder + injectors)\n",
    "# =========================\n",
    "import os, re, torch, math\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from tqdm import tqdm\n",
    "\n",
    "# ---------- helpers ----------\n",
    "def print_param_report(model):\n",
    "    total = sum(p.numel() for p in model.parameters())\n",
    "    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "    print(f\"📦 Total params: {total:,}\")\n",
    "    print(f\"🟢 Trainable:   {trainable:,}\")\n",
    "    print(f\"🧊 Frozen:      {total-trainable:,}\")\n",
    "\n",
    "def grad_health_report(model, limit_print=30):\n",
    "    issues = {\"none\": [], \"all_zero\": [], \"nan_inf\": []}\n",
    "    for name, p in model.named_parameters():\n",
    "        if not p.requires_grad:\n",
    "            continue\n",
    "        g = p.grad\n",
    "        if g is None: issues[\"none\"].append(name); continue\n",
    "        if not torch.isfinite(g).all(): issues[\"nan_inf\"].append(name); continue\n",
    "        if torch.count_nonzero(g).item() == 0: issues[\"all_zero\"].append(name)\n",
    "    total_bad = sum(len(v) for v in issues.values())\n",
    "    if total_bad == 0:\n",
    "        print(\"👍 gradients look healthy on all trainable params.\")\n",
    "        return True\n",
    "    print(f\"⚠️  {total_bad} gradient issues:\")\n",
    "    for k in [\"none\", \"all_zero\", \"nan_inf\"]:\n",
    "        if issues[k]:\n",
    "            shown = issues[k][:limit_print]\n",
    "            more = len(issues[k]) - len(shown)\n",
    "            print(f\"  • {k}: {len(issues[k])}\")\n",
    "            for n in shown: print(f\"     - {n}\")\n",
    "            if more > 0: print(f\"     ... and {more} more\")\n",
    "    return False\n",
    "\n",
    "def find_latest_ckpt(ckpt_dir=\"checkpoints\"):\n",
    "    if not os.path.isdir(ckpt_dir): return None\n",
    "    paths = [os.path.join(ckpt_dir, f) for f in os.listdir(ckpt_dir)\n",
    "             if f.startswith(\"epoch_\") and f.endswith(\".pt\")]\n",
    "    if not paths: return None\n",
    "    def _ep(p):\n",
    "        m = re.search(r\"epoch_(\\d+)\\.pt$\", os.path.basename(p))\n",
    "        return int(m.group(1)) if m else -1\n",
    "    return max(paths, key=lambda p: _ep(p))\n",
    "\n",
    "# ---------- dataset/dataloader ----------\n",
    "class NewsDataset(Dataset):\n",
    "    def __init__(self, df):\n",
    "        # df must contain: EHist_padded, EHist_len, EPos\n",
    "        self.df = df.reset_index(drop=True)\n",
    "    def __len__(self): return len(self.df)\n",
    "    def __getitem__(self, idx):\n",
    "        r = self.df.iloc[idx]\n",
    "        return {\"Bhist\": r[\"EHist_padded\"], \"Bhist_len\": int(r[\"EHist_len\"]), \"Bpos\": r[\"EPos\"]}\n",
    "\n",
    "def collate_simple(batch):\n",
    "    return (\n",
    "        [b[\"Bhist\"] for b in batch],\n",
    "        [b[\"Bhist_len\"] for b in batch],\n",
    "        [b[\"Bpos\"] for b in batch],\n",
    "    )\n",
    "\n",
    "train_ds = NewsDataset(train_df)\n",
    "train_loader = DataLoader(\n",
    "    train_ds, batch_size=8, shuffle=True, num_workers=0, pin_memory=False, collate_fn=collate_simple\n",
    ")\n",
    "print(f\"✅ DataLoader: {len(train_ds)} samples, batch_size={train_loader.batch_size}\")\n",
    "\n",
    "# ---------- build model (fresh) ----------\n",
    "# Unfreeze some last T5 decoder blocks so they can learn (tune N as you like)\n",
    "UNFREEZE_LAST_DECODER_BLOCKS = 12\n",
    "\n",
    "behavior_encoder = BehaviorEncoder(hidden_dim, device=device, max_len=max_len).to(device)\n",
    "inverse_decoder  = BehaviorInverseDecoderPredict(hidden_dim, device=device).to(device)\n",
    "\n",
    "personalized_model = PersonalizedT5Summarizer(\n",
    "    hidden_dim, summarizer_model, behavior_encoder, inverse_decoder, device,\n",
    "    learnable_ctx=True, unfreeze_last_decoder_blocks=UNFREEZE_LAST_DECODER_BLOCKS\n",
    ").to(device)\n",
    "\n",
    "# ---------- freeze everything, then selectively enable trainable parts ----------\n",
    "for name, p in personalized_model.named_parameters():\n",
    "    p.requires_grad = False\n",
    "\n",
    "def allow(name: str) -> bool:\n",
    "    # train inverse_decoder\n",
    "    if name.startswith(\"inverse_decoder.\"): return True\n",
    "    # train context-injection projections\n",
    "    if name.startswith(\"e_proj.\") or name.startswith(\"q_proj.\") or name.startswith(\"k_proj.\") or name.startswith(\"v_proj.\"): return True\n",
    "    if name == \"ctx_scale_raw\": return True\n",
    "    # train T5 decoder last N blocks (already unfrozen in __init__, but we enforce here)\n",
    "    if name.startswith(\"t5.decoder.block.\"): return True\n",
    "    # train decoder final layer norm\n",
    "    if name.startswith(\"t5.decoder.final_layer_norm.\"): return True\n",
    "    # train tied embeddings + lm_head (T5 weight tying)\n",
    "    if name.startswith(\"t5.shared.\") or name.startswith(\"t5.lm_head.\"): return True\n",
    "    return False\n",
    "\n",
    "for name, p in personalized_model.named_parameters():\n",
    "    if allow(name):\n",
    "        p.requires_grad = True\n",
    "\n",
    "# explicitly ensure encoder is frozen\n",
    "for name, p in personalized_model.named_parameters():\n",
    "    if name.startswith(\"behavior_encoder.\"):\n",
    "        p.requires_grad = False\n",
    "\n",
    "print_param_report(personalized_model)\n",
    "\n",
    "# ---------- optimizer over trainable params (inverse+summarizer only) ----------\n",
    "trainable_params = [p for p in personalized_model.parameters() if p.requires_grad]\n",
    "assert len(trainable_params) > 0, \"No trainable params selected!\"\n",
    "optimizer = optim.AdamW(trainable_params, lr=1.5e-4, weight_decay=0.01)  # tune LR as needed\n",
    "\n",
    "# ---------- (optional) resume ----------\n",
    "os.makedirs(\"checkpoints\", exist_ok=True)\n",
    "start_epoch = 1\n",
    "latest = find_latest_ckpt(\"checkpoints\")\n",
    "if latest:\n",
    "    ckpt = torch.load(latest, map_location=device)\n",
    "    print(f\"🔎 Found checkpoint: {latest}\")\n",
    "    # Load weights permissively (shape changes tolerated)\n",
    "    personalized_model.load_state_dict(ckpt[\"model_state\"], strict=False)\n",
    "    # Optimizer might not match param groups if you previously trained different parts\n",
    "    if \"optimizer_state\" in ckpt:\n",
    "        try:\n",
    "            optimizer.load_state_dict(ckpt[\"optimizer_state\"])\n",
    "        except Exception as e:\n",
    "            print(f\"⚠️ Skipping optimizer_state load: {e}\")\n",
    "    last_epoch = int(ckpt.get(\"epoch\", 0))\n",
    "    print(f\"✅ Resuming from epoch {last_epoch}; prev avg_gen_loss={ckpt.get('avg_gen_loss')} | prev avg_enc_loss={ckpt.get('avg_enc_loss')}\")\n",
    "    start_epoch = last_epoch + 1\n",
    "else:\n",
    "    print(\"▶️ Starting fresh inverse+summarizer training.\")\n",
    "\n",
    "# ---------- train loop (inverse + summarizer only) ----------\n",
    "epochs_more = 3                  # how many NEW epochs to run from start_epoch\n",
    "end_epoch   = start_epoch + epochs_more - 1\n",
    "clip_norm   = 1.0\n",
    "max_gen_len = 50\n",
    "\n",
    "for ep in range(start_epoch, end_epoch + 1):\n",
    "    personalized_model.train()\n",
    "    pbar = tqdm(train_loader, desc=f\"[INV+SUM] Epoch {ep}/{end_epoch}\", dynamic_ncols=True)\n",
    "    running_tot, running_enc, running_gen, seen = 0.0, 0.0, 0.0, 0\n",
    "\n",
    "    for Bhist_batch, Blen_batch, Bpos_batch in pbar:\n",
    "        optimizer.zero_grad(set_to_none=True)\n",
    "\n",
    "        batch_tot, batch_enc, batch_gen, valid = 0.0, 0.0, 0.0, 0\n",
    "        for i in range(len(Bhist_batch)):\n",
    "            out = personalized_model(\n",
    "                Bhist_batch[i], int(Blen_batch[i]), Bpos_batch[i],\n",
    "                lookup_df, embed_tables, sid2sum, nid2body, tokenizer,\n",
    "                max_len=max_gen_len, mode=\"train\"\n",
    "            )\n",
    "            if out.get(\"skip\", False):\n",
    "                continue\n",
    "\n",
    "            # Use gen loss for training; include enc loss only as a detached term\n",
    "            gen_loss = out[\"gen_loss\"]                                      # drives summarizer/inverse\n",
    "            enc_loss = out[\"enc_loss\"] if torch.is_tensor(out[\"enc_loss\"]) else torch.tensor(float(out[\"enc_loss\"]), device=device)\n",
    "            loss = 0.4 * gen_loss + 0.6 * enc_loss.detach()                 # detach => no grads into encoder\n",
    "\n",
    "            loss.backward()\n",
    "            batch_tot += float(loss.detach().item())\n",
    "            batch_enc += float(enc_loss.detach().item())\n",
    "            batch_gen += float(gen_loss.detach().item())\n",
    "            valid += 1\n",
    "\n",
    "        if valid == 0:\n",
    "            pbar.set_postfix_str(\"skip batch\"); continue\n",
    "\n",
    "        torch.nn.utils.clip_grad_norm_(trainable_params, clip_norm)\n",
    "        optimizer.step()\n",
    "\n",
    "        running_tot += batch_tot\n",
    "        running_enc += batch_enc\n",
    "        running_gen += batch_gen\n",
    "        seen        += valid\n",
    "\n",
    "        pbar.set_postfix({\n",
    "            \"enc\": f\"{running_enc / seen:.3f}\",\n",
    "            \"gen\": f\"{running_gen / seen:.3f}\",\n",
    "            \"tot\": f\"{running_tot / seen:.3f}\",\n",
    "        })\n",
    "\n",
    "    # ---- save checkpoint ----\n",
    "    ckpt_path = f\"checkpoints/epoch_{ep}.pt\"\n",
    "    torch.save({\n",
    "        \"epoch\": ep,\n",
    "        \"model_state\": personalized_model.state_dict(),   # includes inverse + summarizer weights\n",
    "        \"optimizer_state\": optimizer.state_dict(),\n",
    "        \"avg_total_loss\": running_tot / max(1, seen),\n",
    "        \"avg_enc_loss\":   running_enc / max(1, seen),\n",
    "        \"avg_gen_loss\":   running_gen / max(1, seen),\n",
    "    }, ckpt_path)\n",
    "    print(f\"💾 Saved inverse+summarizer checkpoint: {ckpt_path}\")\n",
    "\n",
    "    print(\"\\nEpoch grad health summary (inverse+summarizer):\")\n",
    "    grad_health_report(personalized_model, limit_print=40)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ff2209d0",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
