{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25724f40",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import numpy as np\n",
    "import h5py\n",
    "import os\n",
    "import glob\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.auto import tqdm\n",
    "import gc\n",
    "import time\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3b7dd19c",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "# Set random seeds\n",
    "#torch.manual_seed(31)\n",
    "#np.random.seed(31)\n",
    "#torch.manual_seed(50)\n",
    "#np.random.seed(50)\n",
    "#torch.manual_seed(25)\n",
    "#np.random.seed(25)\n",
    "#torch.manual_seed(76)\n",
    "#np.random.seed(76)\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deb3b204",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class BehaviorEncoder(nn.Module):\n",
    "    def __init__(\n",
    "        self,\n",
    "        hidden_dim=384,\n",
    "        relation_dim=4,\n",
    "        behavior_dim=772,\n",
    "        attn_heads=4,\n",
    "        num_positions=10000,\n",
    "        num_negatives=20,\n",
    "        dropout=0.1\n",
    "    ):\n",
    "        super().__init__()\n",
    "        # dims\n",
    "        self.hidden_dim = hidden_dim\n",
    "        self.relation_dim = relation_dim\n",
    "        self.behavior_dim = behavior_dim\n",
    "        self.num_positions = num_positions\n",
    "        self.num_negatives = num_negatives\n",
    "\n",
    "        # projections for triplet parts\n",
    "        self.W_hd = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.W_rel = nn.Linear(relation_dim, hidden_dim)\n",
    "        self.W_tl = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.W_b  = nn.Linear(hidden_dim, hidden_dim)\n",
    "        self.W_proj = nn.Linear(hidden_dim, behavior_dim)\n",
    "\n",
    "        # self-attention on EMA history\n",
    "        self.self_attn = nn.MultiheadAttention(behavior_dim, attn_heads, batch_first=True)\n",
    "        self.W_fuse = nn.Linear(2 * behavior_dim, behavior_dim)\n",
    "\n",
    "        # negative sampling for position prediction\n",
    "        self.pos_embed = nn.Embedding(num_positions, behavior_dim)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "\n",
    "    def negative_sampling_loss(self, vec, true_idx):\n",
    "        # positive sample\n",
    "        pos_vec = self.pos_embed(true_idx)\n",
    "        pos_score = torch.sum(vec * pos_vec, dim=-1)\n",
    "        pos_loss = -torch.log(torch.sigmoid(pos_score) + 1e-9)\n",
    "        # sample negatives\n",
    "        neg_idx = torch.randint(\n",
    "            0, self.num_positions, (self.num_negatives,), device=vec.device\n",
    "        )\n",
    "        # ensure no true index in negatives\n",
    "        mask = neg_idx == true_idx\n",
    "        while mask.any():\n",
    "            neg_idx[mask] = torch.randint(\n",
    "                0, self.num_positions, (mask.sum().item(),), device=vec.device\n",
    "            )\n",
    "            mask = neg_idx == true_idx\n",
    "        neg_vecs = self.pos_embed(neg_idx)\n",
    "        neg_scores = torch.matmul(neg_vecs, vec.unsqueeze(-1)).squeeze(-1)\n",
    "        neg_loss = -torch.sum(torch.log(1 - torch.sigmoid(neg_scores) + 1e-9))\n",
    "        return pos_loss + neg_loss\n",
    "\n",
    "    def forward(self, B_hist, true_positions=None):\n",
    "        batch, seq_len, _ = B_hist.shape\n",
    "        device = B_hist.device\n",
    "\n",
    "        h_state = torch.zeros(batch, self.hidden_dim, device=device)\n",
    "        total_local = torch.tensor(0.0, device=device)\n",
    "        total_nll   = torch.tensor(0.0, device=device)\n",
    "        B_mega_list = []\n",
    "\n",
    "        for b in range(batch):\n",
    "            ema_prev = None\n",
    "            history = []\n",
    "            for t in range(seq_len):\n",
    "                x = B_hist[b, t]\n",
    "                if torch.all(x == 0):\n",
    "                    history.append(ema_prev if ema_prev is not None else x)\n",
    "                    continue\n",
    "\n",
    "                # split H, R, T\n",
    "                H = x[:self.hidden_dim]\n",
    "                R = x[self.hidden_dim:self.hidden_dim + self.relation_dim]\n",
    "                T = x[self.hidden_dim + self.relation_dim:]\n",
    "                # encode parts\n",
    "                h_h = torch.tanh(self.W_hd(H))\n",
    "                h_r = torch.tanh(self.W_rel(R))\n",
    "                h_t = torch.tanh(self.W_tl(T))\n",
    "                # behavior rep\n",
    "                b_mid = torch.tanh(self.W_b(h_t))\n",
    "                b_t   = self.W_proj(b_mid)\n",
    "\n",
    "                # EMA update\n",
    "                if ema_prev is None:\n",
    "                    ema = b_t\n",
    "                else:\n",
    "                    seq_pair = torch.stack([ema_prev, b_t], dim=0).unsqueeze(0)\n",
    "                    attn_out, _ = self.self_attn(seq_pair, seq_pair, seq_pair)\n",
    "                    cur_att = attn_out[0, 1]\n",
    "                    gate = torch.sigmoid(\n",
    "                        self.W_fuse(torch.cat([ema_prev, cur_att], dim=-1))\n",
    "                    )\n",
    "                    ema = gate * cur_att + (1 - gate) * ema_prev\n",
    "                ema = self.dropout(ema)\n",
    "                history.append(ema)\n",
    "                ema_prev = ema\n",
    "\n",
    "                # local RMSD (assume self.rmsd exists)\n",
    "                total_local += self.rmsd(ema, x)\n",
    "\n",
    "                # negative sampling loss\n",
    "                if true_positions is not None:\n",
    "                    true_idx = true_positions[b, t]\n",
    "                    total_nll += self.negative_sampling_loss(ema, true_idx)\n",
    "\n",
    "            B_mega_list.append(torch.stack(history))\n",
    "            # hidden update placeholder\n",
    "            h_state[b] = torch.zeros(self.hidden_dim, device=device)\n",
    "\n",
    "        B_mega = torch.stack(B_mega_list)\n",
    "        L_enc = total_local + total_nll\n",
    "\n",
    "        # inference softmax scores if no true_positions\n",
    "        if true_positions is None:\n",
    "            # B_mega: [batch, seq_len, behavior_dim]\n",
    "            logits = torch.matmul(B_mega, self.pos_embed.weight.t())\n",
    "            log_probs = F.log_softmax(logits, dim=-1)\n",
    "            return h_state, L_enc, B_mega, log_probs\n",
    "\n",
    "        return h_state, L_enc, B_mega\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3da5107",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "class BehaviorPredictor(nn.Module):\n",
    "    def __init__(self, behavior_dim=772, dec_dim=384, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.pred_head = nn.Linear(behavior_dim, dec_dim)\n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.activation = nn.GELU()\n",
    "\n",
    "    def forward(self, B_final):\n",
    "        x = self.dropout(B_final)\n",
    "        return self.activation(self.pred_head(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d6d37d9a",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "n_kp = extract_vocab()\n",
    "class KeyphraseDecoder(nn.Module):\n",
    "    def __init__(self, dec_dim=384, vocab_size=n_kp, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.mlp = nn.Sequential(nn.Dropout(dropout), nn.Linear(dec_dim, vocab_size))\n",
    "        self.log_softmax = nn.LogSoftmax(dim=-1)\n",
    "\n",
    "    def forward(self, dec_emb):\n",
    "        logits = self.mlp(dec_emb)\n",
    "        return self.log_softmax(logits)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fb2d504",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "class BehaviorModel(nn.Module):\n",
    "    def __init__(self, hidden_dim=384, relation_dim=4, behavior_dim=772,\n",
    "                 num_positions=100, vocab_size=10000):\n",
    "        super().__init__()\n",
    "        self.encoder = BehaviorEncoder(hidden_dim, relation_dim, behavior_dim,\n",
    "                                       attn_heads=4, num_positions=num_positions)\n",
    "        self.decoder = BehaviorPredictor(behavior_dim)\n",
    "        self.kp_decoder = KeyphraseDecoder(self.decoder.pred_head.out_features, vocab_size)\n",
    "\n",
    "    def forward(self, B_hist, true_positions=None, true_kps=None):\n",
    "        h_enc, L_enc, B_mega = self.encoder(B_hist, true_positions)\n",
    "        B_final = B_mega[:,-1]\n",
    "        dec_emb = self.decoder(B_final)\n",
    "        logp_kp = self.kp_decoder(dec_emb)\n",
    "        L_kp = 0.0\n",
    "        if true_kps is not None:\n",
    "            L_kp = nn.functional.nll_loss(logp_kp, true_kps, reduction='mean')\n",
    "        return L_enc, L_kp, logp_kp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "973595ed",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "# ===== Training Loop =====\n",
    "\n",
    "def train(model, dataloader, epochs=5, lr=1e-4, device='cuda'):\n",
    "    model.to(device)\n",
    "    optimizer = optim.AdamW(model.parameters(), lr=lr)\n",
    "    history = {'loss': [], 'enc_loss': [], 'kp_loss': []}\n",
    "\n",
    "    for epoch in range(1, epochs+1):\n",
    "        model.train()\n",
    "        epoch_enc, epoch_kp, epoch_total = 0.0, 0.0, 0.0\n",
    "        for B_hist, pos, kps in tqdm(dataloader, desc=f\"Epoch {epoch}\"):\n",
    "            B_hist = B_hist.to(device)\n",
    "            pos = pos.to(device)\n",
    "            kps = kps.to(device)\n",
    "\n",
    "            optimizer.zero_grad()\n",
    "            L_enc, L_kp, logp_kp = model(B_hist, pos, kps)\n",
    "            loss = 0.4 * L_enc + 0.6 * L_kp\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            epoch_enc += L_enc.item()\n",
    "            epoch_kp += L_kp.item()\n",
    "            epoch_total += loss.item()\n",
    "\n",
    "        avg_enc = epoch_enc / len(dataloader)\n",
    "        avg_kp = epoch_kp / len(dataloader)\n",
    "        avg_loss = epoch_total / len(dataloader)\n",
    "        history['enc_loss'].append(avg_enc)\n",
    "        history['kp_loss'].append(avg_kp)\n",
    "        history['loss'].append(avg_loss)\n",
    "        print(f\"Epoch {epoch}: total={avg_loss:.4f}, enc={avg_enc:.4f}, kp={avg_kp:.4f}\")\n",
    "\n",
    "    return history\n",
    "\n",
    "# Example usage:\n",
    "# from torch.utils.data import DataLoader\n",
    "# dataset = MyBehaviorDataset()\n",
    "# dataloader = DataLoader(dataset, batch_size=128, shuffle=True)\n",
    "# model = BehaviorModel()\n",
    "# history = train(model, dataloader, epochs=10)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
