{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7463b5e",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "827e0690",
   "metadata": {
    "vscode": {
     "languageId": "plaintext"
    }
   },
   "outputs": [],
   "source": [
    "class BehaviorSequenceDataset(Dataset):\n",
    "    \"\"\"\n",
    "    Dataset for loading user behavior sequences from a CSV file.\n",
    "    Expects a CSV with columns: `userid`, `behavior_id`.\n",
    "    \"\"\"\n",
    "    def __init__(self, csv_path, behavior_dim, max_seq_len=None):\n",
    "        self.df = pd.read_csv(csv_path)\n",
    "        # group by user and collect sequences\n",
    "        self.sequences = (\n",
    "            self.df.groupby('userid')['behavior_id']\n",
    "                   .apply(list)\n",
    "                   .tolist()\n",
    "        )\n",
    "        self.behavior_dim = behavior_dim\n",
    "        # Optionally enforce a maximum sequence length\n",
    "        self.max_seq_len = max_seq_len\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.sequences)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        seq = self.sequences[idx]\n",
    "        # truncate if needed\n",
    "        if self.max_seq_len and len(seq) > self.max_seq_len:\n",
    "            seq = seq[-self.max_seq_len:]\n",
    "\n",
    "        # Build one-hot history matrix [seq_len, behavior_dim]\n",
    "        hist = torch.zeros(len(seq), self.behavior_dim)\n",
    "        for i, b in enumerate(seq):\n",
    "            hist[i, b] = 1.0\n",
    "\n",
    "        # True positions (for each timestep predict its index)\n",
    "        true_positions = torch.arange(len(seq), dtype=torch.long)\n",
    "\n",
    "        # Next-behavior label is last element (for keyphrase decoder)\n",
    "        true_kp = torch.tensor(seq[-1], dtype=torch.long)\n",
    "\n",
    "        return hist, true_positions, true_kp\n",
    "\n",
    "\n",
    "def collate_fn(batch):\n",
    "    \"\"\"\n",
    "    Pads a batch of (hist, positions, kp) tuples to form tensors:\n",
    "      - B_hist: [batch, max_len, behavior_dim]\n",
    "      - pos:    [batch, max_len]\n",
    "      - kps:    [batch]\n",
    "    \"\"\"\n",
    "    hists, positions, kps = zip(*batch)\n",
    "    batch_size = len(hists)\n",
    "    seq_lens = [h.shape[0] for h in hists]\n",
    "    max_len = max(seq_lens)\n",
    "    beh_dim = hists[0].size(1)\n",
    "\n",
    "    # Pad history matrices with zeros\n",
    "    B_hist = torch.zeros(batch_size, max_len, beh_dim)\n",
    "    pos_tensor = torch.zeros(batch_size, max_len, dtype=torch.long)\n",
    "    kp_tensor = torch.stack(kps)\n",
    "\n",
    "    for i, (h, pos) in enumerate(zip(hists, positions)):\n",
    "        length = h.size(0)\n",
    "        B_hist[i, :length] = h\n",
    "        pos_tensor[i, :length] = pos\n",
    "\n",
    "    return B_hist, pos_tensor, kp_tensor\n",
    "\n",
    "# Usage example:\n",
    "# dataset = BehaviorSequenceDataset('behavior_sequence.csv', behavior_dim=772, max_seq_len=50)\n",
    "# loader = DataLoader(dataset, batch_size=128, shuffle=True, collate_fn=collate_fn)\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
