{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a044f8aa",
   "metadata": {},
   "source": [
    "# Quickstart: Using scDataset with the Tahoe-100M Dataset\n",
    "This tutorial demonstrates how to use scDataset to load the  Tahoe-100M single-cell dataset in h5ad format and train a simple linear classifier in PyTorch."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85ba099f",
   "metadata": {},
   "source": [
    "## Downloading the Tahoe-100M h5ad Files\n",
    "You can download the Tahoe-100M dataset in h5ad format from the [Arc Institute Google Cloud Storage](https://github.com/ArcInstitute/arc-virtual-cell-atlas/blob/main/tahoe-100M/README.md) or convert it from the [HuggingFace version](https://huggingface.co/datasets/tahoebio/Tahoe-100M). For this tutorial, download one or more `plate*_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad` files and note their paths."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b6850d2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install required packages (uncomment if running in a fresh environment)\n",
    "# %pip install scipy scikit-learn tqdm torch anndata scDataset \n",
    "\n",
    "# Import libraries\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from tqdm import tqdm\n",
    "from scipy import sparse\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "import anndata as ad\n",
    "from anndata.experimental import AnnCollection\n",
    "from scdataset import scDataset, Streaming, BlockShuffling"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22dccdf3",
   "metadata": {},
   "source": [
    "## 1. Load the Tahoe-100M Dataset (h5ad)\n",
    "We will use the AnnData library to load one or more Tahoe-100M h5ad files and combine them into an AnnCollection for efficient access."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ff6e71e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total cells: 5481420\n",
      "Total genes: 62710\n"
     ]
    }
   ],
   "source": [
    "# Load multiple Tahoe-100M h5ad files and create an AnnCollection\n",
    "h5ad_FILES_PATH = '/path/to/h5ad_folder'  # Update this to your local folder with the h5ad files\n",
    "h5ad_paths = [f'{h5ad_FILES_PATH}/plate{i}_filt_Vevo_Tahoe100M_WServicesFrom_ParseGigalab.h5ad' for i in range(1, 2)] # Loading only 1 plate for demonstration\n",
    "\n",
    "adatas = [ad.read_h5ad(path, backed='r') for path in h5ad_paths]\n",
    "for adata in adatas:\n",
    "    # Keep 'cell_line' column\n",
    "    adata.obs = adata.obs[['cell_line']]\n",
    "collection = AnnCollection(adatas)\n",
    "print('Total cells:', collection.n_obs)\n",
    "print('Total genes:', collection.n_vars)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1560d645",
   "metadata": {},
   "source": [
    "## 2. Cell Line Recognition Task and Stratified Split\n",
    "For this tutorial, we will train a linear classifier to recognize the cell line of each cell. We will split the dataset into train and test sets using a stratified split on the `cell_line` column."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4032df7c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train size: 4385136, Test size: 1096284\n"
     ]
    }
   ],
   "source": [
    "# Get cell_line labels from the AnnCollection\n",
    "cell_lines = collection.obs['cell_line'].values\n",
    "indices = np.arange(collection.n_obs)\n",
    "\n",
    "# Stratified train/test split on cell_line\n",
    "train_idx, test_idx = train_test_split(\n",
    "    indices, stratify=cell_lines, test_size=0.2, random_state=42)\n",
    "\n",
    "print(f'Train size: {len(train_idx)}, Test size: {len(test_idx)}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "237c559e",
   "metadata": {},
   "source": [
    "### Custom fetch_transform and batch_transform for AnnCollection\n",
    "\n",
    "Before we wrap our data with scDataset, it's important to know that scDataset v0.2.0 allows you to provide custom functions for handling how data is fetched and transformed. This makes it easy to adapt to different data formats and preprocessing needs, especially for large or complex datasets like Tahoe-100M.\n",
    "\n",
    "- **fetch_transform:** Ensures that each chunk from the AnnCollection is materialized as an in-memory AnnData object. This is necessary because AnnCollection chunks are lazy by default and do not load the full X matrix until `to_adata()` is called.\n",
    "- **batch_transform:** Converts each AnnData batch into a tuple `(X, y)` suitable for PyTorch training. It densifies the X matrix if it is sparse and encodes the cell line labels as integer tensors. This makes each batch ready for direct use in a PyTorch model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f97917a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare label encoder for cell lines\n",
    "cell_line_encoder = LabelEncoder()\n",
    "cell_line_encoder.fit(collection.obs['cell_line'].values)\n",
    "\n",
    "# Define fetch_transform and batch_transform for scDataset\n",
    "def fetch_transform(batch):\n",
    "    # Materialize the AnnData batch (X matrix) in memory\n",
    "    return batch.to_adata()\n",
    "\n",
    "def batch_transform(batch, cell_line_encoder=cell_line_encoder):\n",
    "    # Convert AnnData batch to (X, y) tensors for training\n",
    "    X = batch.X.astype('float32')\n",
    "    # Densify if X is a sparse matrix\n",
    "    if sparse.issparse(X):\n",
    "        X = X.toarray()\n",
    "    X = torch.from_numpy(X)\n",
    "    y = cell_line_encoder.transform(batch.obs['cell_line'].values)\n",
    "    y = torch.from_numpy(y).long()\n",
    "    return X, y"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eb79a39e",
   "metadata": {},
   "source": [
    "## 3. Wrap the AnnCollection with scDataset\n",
    "To efficiently train and evaluate models, we use `scDataset` v0.2.0 with sampling strategies to wrap the AnnCollection and create PyTorch DataLoaders for both train and test splits.\n",
    "\n",
    "**New in v0.2.0:** scDataset now uses a strategy-based approach for sampling. Each dataset requires a sampling strategy that defines how data is accessed and in what order.\n",
    "\n",
    "For **training**, we'll use `BlockShuffling` to randomize data while maintaining some locality for better I/O performance.\n",
    "For **evaluation**, we'll use `Streaming` for deterministic, sequential access.\n",
    "\n",
    "You can tune `batch_size` and `fetch_factor` for your hardware; for best performance, set `prefetch_factor = fetch_factor + 1` in the DataLoader. The sampling strategies control how indices are generated and whether shuffling occurs. See the scDataset documentation for more details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "577d13b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up scDataset for train and test splits\n",
    "batch_size = 64\n",
    "fetch_factor = 16\n",
    "num_workers = 12\n",
    "\n",
    "# Training split with block shuffling for randomization\n",
    "train_strategy = BlockShuffling(block_size=8, indices=train_idx)\n",
    "scdata_train = scDataset(\n",
    "    data_collection=collection,\n",
    "    strategy=train_strategy,\n",
    "    batch_size=batch_size,\n",
    "    fetch_factor=fetch_factor,\n",
    "    fetch_transform=fetch_transform,\n",
    "    batch_transform=batch_transform,\n",
    ")\n",
    "\n",
    "train_loader = DataLoader(\n",
    "    scdata_train,\n",
    "    batch_size=None,\n",
    "    num_workers=num_workers,\n",
    "    prefetch_factor=fetch_factor+1,\n",
    "    persistent_workers=True,\n",
    "    pin_memory=True\n",
    ")\n",
    "\n",
    "# Test split with streaming for deterministic evaluation\n",
    "test_strategy = Streaming(indices=test_idx)\n",
    "scdata_test = scDataset(\n",
    "    data_collection=collection,\n",
    "    strategy=test_strategy,\n",
    "    batch_size=batch_size,\n",
    "    fetch_factor=fetch_factor,\n",
    "    fetch_transform=fetch_transform,\n",
    "    batch_transform=batch_transform,\n",
    ")\n",
    "\n",
    "test_loader = DataLoader(\n",
    "    scdata_test,\n",
    "    batch_size=None,\n",
    "    num_workers=num_workers,\n",
    "    prefetch_factor=fetch_factor+1,\n",
    "    persistent_workers=True,\n",
    "    pin_memory=True\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d83f0f4",
   "metadata": {},
   "source": [
    "## 4. Train and Evaluate a Linear Classifier for Cell Line Recognition\n",
    "We will train a simple linear classifier to predict the cell line from gene expression. The model will be trained for one epoch on the training set and evaluated on the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "44ef2629",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:   0%|          | 7/68512 [00:04<9:25:04,  2.02it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 0: loss = 3.9899\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:   2%|▏         | 1050/68512 [00:22<19:27, 57.77it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 1000: loss = 0.3440\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:   3%|▎         | 2042/68512 [00:40<13:59, 79.20it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 2000: loss = 0.0654\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:   4%|▍         | 3070/68512 [00:57<07:00, 155.47it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 3000: loss = 0.0337\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:   6%|▌         | 4019/68512 [01:15<08:11, 131.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 4000: loss = 0.1021\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:   7%|▋         | 5074/68512 [01:36<17:49, 59.33it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 5000: loss = 0.4971\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:   9%|▉         | 6078/68512 [01:53<13:07, 79.33it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 6000: loss = 0.1427\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  10%|█         | 7046/68512 [02:11<12:39, 80.96it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 7000: loss = 0.0129\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  12%|█▏        | 8049/68512 [02:28<09:09, 110.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 8000: loss = 0.1937\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  13%|█▎        | 9013/68512 [02:46<07:25, 133.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 9000: loss = 0.3482\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  15%|█▍        | 10053/68512 [03:06<17:54, 54.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 10000: loss = 0.0561\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  16%|█▌        | 11042/68512 [03:24<13:05, 73.19it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 11000: loss = 0.0025\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  18%|█▊        | 12058/68512 [03:42<09:05, 103.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 12000: loss = 0.0513\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  19%|█▉        | 13040/68512 [03:59<07:20, 125.81it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 13000: loss = 0.5847\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  20%|██        | 13998/68512 [04:17<08:27, 107.43it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 14000: loss = 0.1122\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  22%|██▏       | 15064/68512 [04:37<14:59, 59.45it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 15000: loss = 0.4273\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  23%|██▎       | 16057/68512 [04:55<11:03, 79.10it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 16000: loss = 0.2817\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  25%|██▍       | 17074/68512 [05:13<06:52, 124.59it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 17000: loss = 0.1180\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  26%|██▋       | 18038/68512 [05:30<06:43, 124.95it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 18000: loss = 1.0325\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  28%|██▊       | 18996/68512 [05:48<06:43, 122.75it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 19000: loss = 0.0683\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  29%|██▉       | 20063/68512 [06:08<13:23, 60.33it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 20000: loss = 0.0278\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  31%|███       | 21055/68512 [06:26<09:51, 80.28it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 21000: loss = 0.2712\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  32%|███▏      | 22042/68512 [06:44<07:30, 103.09it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 22000: loss = 0.0053\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  34%|███▎      | 23025/68512 [07:01<06:25, 118.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 23000: loss = 0.1565\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  35%|███▍      | 23977/68512 [07:19<06:59, 106.27it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 24000: loss = 0.0108\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  37%|███▋      | 25058/68512 [07:40<12:04, 59.95it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 25000: loss = 0.2106\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  38%|███▊      | 26054/68512 [07:57<08:41, 81.48it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 26000: loss = 0.0147\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  39%|███▉      | 27042/68512 [08:15<06:30, 106.25it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 27000: loss = 0.3865\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  41%|████      | 28020/68512 [08:32<05:19, 126.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 28000: loss = 0.9068\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  42%|████▏     | 29047/68512 [08:53<15:08, 43.45it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 29000: loss = 0.1300\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  44%|████▍     | 30047/68512 [09:11<10:44, 59.64it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 30000: loss = 0.4086\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  45%|████▌     | 31051/68512 [09:28<07:38, 81.75it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 31000: loss = 0.0032\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  47%|████▋     | 32061/68512 [09:46<05:28, 110.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 32000: loss = 0.2427\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  48%|████▊     | 33019/68512 [10:03<05:29, 107.56it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 33000: loss = 0.5098\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  50%|████▉     | 34043/68512 [10:24<13:14, 43.40it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 34000: loss = 1.5931\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  51%|█████     | 35070/68512 [10:42<06:53, 80.96it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 35000: loss = 0.2432\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  53%|█████▎    | 36051/68512 [11:00<06:31, 82.90it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 36000: loss = 0.6859\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  54%|█████▍    | 37051/68512 [11:17<04:08, 126.49it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 37000: loss = 0.5389\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  55%|█████▌    | 38017/68512 [11:35<04:00, 126.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 38000: loss = 0.5876\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  57%|█████▋    | 39053/68512 [11:55<08:37, 56.91it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 39000: loss = 0.1008\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  58%|█████▊    | 40072/68512 [12:13<05:55, 79.97it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 40000: loss = 0.0231\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  60%|█████▉    | 41054/68512 [12:31<04:22, 104.52it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 41000: loss = 0.0160\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  61%|██████▏   | 42033/68512 [12:48<04:20, 101.71it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 42000: loss = 0.3354\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  63%|██████▎   | 42993/68512 [13:06<03:31, 120.92it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 43000: loss = 0.0277\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  64%|██████▍   | 44048/68512 [13:27<07:16, 56.06it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 44000: loss = 0.2737\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  66%|██████▌   | 45055/68512 [13:44<04:53, 79.80it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 45000: loss = 0.0011\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  67%|██████▋   | 46062/68512 [14:02<03:09, 118.53it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 46000: loss = 0.2583\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  69%|██████▊   | 47017/68512 [14:20<03:30, 102.01it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 47000: loss = 0.2095\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  70%|███████   | 47985/68512 [14:37<03:17, 104.11it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 48000: loss = 0.5802\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  72%|███████▏  | 49064/68512 [14:58<05:15, 61.57it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 49000: loss = 0.4957\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  73%|███████▎  | 50044/68512 [15:15<03:50, 79.98it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 50000: loss = 0.5093\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  74%|███████▍  | 51041/68512 [15:33<02:48, 103.41it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 51000: loss = 0.9492\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  76%|███████▌  | 52030/68512 [15:51<02:13, 123.06it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 52000: loss = 0.1337\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  77%|███████▋  | 53035/68512 [16:12<06:23, 40.35it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 53000: loss = 0.0673\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  79%|███████▉  | 54055/68512 [16:29<03:24, 70.62it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 54000: loss = 0.4136\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  80%|████████  | 55055/68512 [16:47<02:42, 82.91it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 55000: loss = 0.4310\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  82%|████████▏ | 56052/68512 [17:05<02:02, 101.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 56000: loss = 0.0508\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  83%|████████▎ | 56999/68512 [17:22<01:47, 106.94it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 57000: loss = 0.8195\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  85%|████████▍ | 58066/68512 [17:43<03:06, 56.01it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 58000: loss = 0.6672\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  86%|████████▌ | 59073/68512 [18:01<01:58, 79.81it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 59000: loss = 0.4724\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  88%|████████▊ | 60044/68512 [18:18<01:43, 81.81it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 60000: loss = 0.7914\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  89%|████████▉ | 61031/68512 [18:36<01:13, 101.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 61000: loss = 0.0464\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  90%|█████████ | 61996/68512 [18:53<01:01, 106.63it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 62000: loss = 0.3595\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  92%|█████████▏| 63065/68512 [19:14<01:32, 59.13it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 63000: loss = 0.2254\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  94%|█████████▎| 64059/68512 [19:31<00:55, 80.02it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 64000: loss = 0.2061\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  95%|█████████▍| 65055/68512 [19:49<00:33, 102.15it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 65000: loss = 0.2344\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  96%|█████████▋| 66030/68512 [20:06<00:23, 107.32it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 66000: loss = 0.0641\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  98%|█████████▊| 66985/68512 [20:24<00:14, 107.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 67000: loss = 0.0083\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training:  99%|█████████▉| 68060/68512 [20:45<00:07, 59.74it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train batch 68000: loss = 0.9251\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training: 100%|██████████| 68512/68512 [20:52<00:00, 54.72it/s] \n",
      "Evaluating: 100%|██████████| 17130/17130 [05:17<00:00, 53.97it/s] "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Test accuracy: 97.11% (1064575/1096284)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# Model definition\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "num_genes = collection.n_vars\n",
    "num_classes = len(cell_line_encoder.classes_)\n",
    "model = nn.Linear(num_genes, num_classes).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)\n",
    "loss_fn = nn.CrossEntropyLoss()\n",
    "\n",
    "# Training loop (1 epoch)\n",
    "model.train()\n",
    "for i, (x, y) in enumerate(tqdm(train_loader, desc='Training')):\n",
    "    x = x.to(device)\n",
    "    y = y.to(device)\n",
    "    logits = model(x)\n",
    "    loss = loss_fn(logits, y)\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    if i % 1000 == 0:\n",
    "        print(f'Train batch {i}: loss = {loss.item():.4f}')\n",
    "\n",
    "# Evaluation loop (one pass over test set)\n",
    "model.eval()\n",
    "correct = 0\n",
    "total = 0\n",
    "with torch.no_grad():\n",
    "    for x, y in tqdm(test_loader, desc='Evaluating'):\n",
    "        x = x.to(device)\n",
    "        y = y.to(device)\n",
    "        logits = model(x)\n",
    "        preds = torch.argmax(logits, dim=1)\n",
    "        correct += (preds == y).sum().item()\n",
    "        total += y.size(0)\n",
    "print(f'Test accuracy: {100.0 * correct / total:.2f}% ({correct}/{total})')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f503073",
   "metadata": {},
   "source": [
    "## 5. Summary and Next Steps\n",
    "\n",
    "You have now seen how to load the Tahoe-100M dataset, perform a stratified split, and train/test a linear classifier for cell line recognition using scDataset v0.2.0. \n",
    "\n",
    "**Key features demonstrated:**\n",
    "- **Strategy-based sampling**: Used `BlockShuffling` for training randomization and `Streaming` for deterministic evaluation\n",
    "- **Custom transforms**: Applied `fetch_transform` and `batch_transform` to handle AnnCollection data format\n",
    "- **Efficient data loading**: Configured `fetch_factor` and multiprocessing for optimal performance\n",
    "\n",
    "**Try exploring other sampling strategies:**\n",
    "- `ClassBalancedSampling` for handling imbalanced cell type distributions\n",
    "- `BlockWeightedSampling` for custom sample weighting\n",
    "- `MultiIndexable` for handling multi-modal data (gene expression + protein data)\n",
    "\n",
    "For more advanced usage and complete API documentation, see the scDataset source code."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "sctest",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.17"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
