{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b9e2d729",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MPS available: False\n",
      "CUDA version: 11.7\n",
      "MPS not available — running on CPU instead.\n",
      "CPU cores: 48\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import logging\n",
    "import argparse\n",
    "import torch\n",
    "from scipy.sparse import issparse\n",
    "import numpy as np\n",
    "import anndata as ad\n",
    "import pandas as pd\n",
    "\n",
    "print(\"MPS available:\", torch.backends.mps.is_available())\n",
    "print(\"CUDA version:\", torch.version.cuda)\n",
    "\n",
    "# Show MPS device info safely\n",
    "if torch.backends.mps.is_available():\n",
    "    print(\"MPS backend is active on macOS (Metal).\")\n",
    "    print(\"Device: Apple GPU via Metal Performance Shaders (MPS).\")\n",
    "else:\n",
    "    print(\"MPS not available — running on CPU instead.\")\n",
    "\n",
    "print(\"CPU cores:\", os.cpu_count())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fe4f5248",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "AnnData object with n_obs × n_vars = 762039 × 1000\n",
      "    obs: 'size_factor', 'cell_type', 'replicate', 'dose', 'drug_code', 'pathway_level_1', 'pathway_level_2', 'product_name', 'target', 'pathway', 'drug', 'drug-dose', 'drug_code-dose', 'n_genes'\n",
      "    var: 'gene_short_name', 'n_cells', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'\n",
      "    uns: 'hvg', 'pca', 'rank_genes_groups'\n",
      "    obsm: 'X_pca'\n",
      "    varm: 'PCs', 'marker_genes-drug-rank', 'marker_genes-drug-score'\n",
      "Shape: (762039, 1000)\n",
      "Observation columns: ['size_factor', 'cell_type', 'replicate', 'dose', 'drug_code', 'pathway_level_1', 'pathway_level_2', 'product_name', 'target', 'pathway']\n",
      "Feature columns: ['ENSG00000243620.1', 'ENSG00000271503.5', 'ENSG00000259124.1', 'ENSG00000121101.15', 'ENSG00000160963.13', 'ENSG00000135346.8', 'ENSG00000143839.14', 'ENSG00000100867.14', 'ENSG00000140986.7', 'ENSG00000230666.5']\n",
      "Unique conditions: ['tak_901', 'ag_490', 'abexinostat', 'alisertib', 'busulfan', ..., 'ac480', 'cediranib', 'thiotepa', 'cerdulatinib', 'zileuton']\n",
      "Length: 189\n",
      "Categories (189, object): ['2_methoxyestradiol', '_jq1', 'a_366', 'abt_737', ..., 'xav_939', 'ym155', 'zm_447439', 'zileuton']\n",
      "Number of cells per condition:\n",
      "control              17565\n",
      "ellagic_acid          6255\n",
      "divalproex_sodium     6198\n",
      "ruxolitinib           6139\n",
      "mc1568                6120\n",
      "                     ...  \n",
      "alvespimycin_hcl      2085\n",
      "patupilone            1798\n",
      "flavopiridol_hcl      1703\n",
      "epothilone_a          1401\n",
      "ym155                  989\n",
      "Name: drug, Length: 189, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "import scanpy as sc\n",
    "# Load your dataset\n",
    "adata_sc = sc.read_h5ad(\"./scrna-sciplex3/hvg.h5ad\")\n",
    "\n",
    "# Basic overview\n",
    "print(adata_sc)\n",
    "print(\"Shape:\", adata_sc.shape)\n",
    "\n",
    "# View column names (metadata about each cell)\n",
    "print(\"Observation columns:\", adata_sc.obs.columns.tolist()[:10])\n",
    "print(\"Feature columns:\", adata_sc.var_names[:10].tolist())\n",
    "\n",
    "# How many drugs (conditions)?\n",
    "print(\"Unique conditions:\", adata_sc.obs['drug'].unique())\n",
    "print(\"Number of cells per condition:\")\n",
    "print(adata_sc.obs['drug'].value_counts())\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4b97042f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import logging\n",
    "import argparse\n",
    "import geomloss\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "from typing import Dict, Tuple, List, Optional\n",
    "from umap import UMAP\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
    "from sklearn.metrics.pairwise import rbf_kernel\n",
    "from typing import Dict, Tuple, List\n",
    "from scipy.stats import ks_2samp\n",
    "from scipy.spatial.distance import cdist\n",
    "from sklearn.metrics import r2_score\n",
    "\n",
    "import gc\n",
    "gc.collect()\n",
    "\n",
    "def median_heuristic_gamma(X: np.ndarray, Y: np.ndarray) -> float:\n",
    "    \"\"\"\n",
    "    Median heuristic for RBF bandwidth: gamma = 1 / median(||x - y||^2).\n",
    "    Uses the median of pairwise distances in the pooled set.\n",
    "    \"\"\"\n",
    "    Z = np.vstack([X, Y])\n",
    "    # Sample if too large for efficiency\n",
    "    max_samples = 5000\n",
    "    if Z.shape[0] > max_samples:\n",
    "        idx = np.random.choice(Z.shape[0], size=max_samples, replace=False)\n",
    "        Z = Z[idx]\n",
    "    D2 = cdist(Z, Z, metric='sqeuclidean')\n",
    "    # Use upper triangular without diagonal\n",
    "    triu = D2[np.triu_indices_from(D2, k=1)]\n",
    "    med = np.median(triu[triu > 0]) if np.any(triu > 0) else 1.0\n",
    "    return 1.0 / max(med, 1e-12)\n",
    "\n",
    "def mmd_distance(X: np.ndarray, Y: np.ndarray, gamma: float) -> float:\n",
    "    \"\"\"\n",
    "    Unbiased MMD^2 estimator using Gaussian (RBF) kernel, sklearn backend.\n",
    "\n",
    "    Args:\n",
    "        X: (n_samples, n_features) first sample\n",
    "        Y: (m_samples, n_features) second sample\n",
    "        gamma: RBF kernel bandwidth; if None, uses median heuristic\n",
    "\n",
    "    Returns:\n",
    "        Unbiased MMD^2 value\n",
    "    \"\"\"\n",
    "    n = X.shape[0]\n",
    "    m = Y.shape[0]\n",
    "\n",
    "    # Kernel matrices\n",
    "    Kxx = rbf_kernel(X, X, gamma=gamma)\n",
    "    Kyy = rbf_kernel(Y, Y, gamma=gamma)\n",
    "    Kxy = rbf_kernel(X, Y, gamma=gamma)\n",
    "\n",
    "    # Unbiased: exclude diagonal entries\n",
    "    np.fill_diagonal(Kxx, 0.0)\n",
    "    np.fill_diagonal(Kyy, 0.0)\n",
    "\n",
    "    term_xx = Kxx.sum() / (n * (n - 1)) if n > 1 else 0.0\n",
    "    term_yy = Kyy.sum() / (m * (m - 1)) if m > 1 else 0.0\n",
    "    term_xy = 2.0 * Kxy.mean()\n",
    "\n",
    "    mmd2 = term_xx + term_yy - term_xy\n",
    "    mmd2 = max(mmd2, 0.0)  # Numerical stability\n",
    "    return float(mmd2)\n",
    "\n",
    "def r2_feature_means(y_true: np.ndarray, y_pred: np.ndarray) -> float:\n",
    "    \"\"\"\n",
    "    R^2 computed across features between mean vectors of y_true and y_pred.\n",
    "    \"\"\"\n",
    "    mu_true = y_true.mean(axis=0)\n",
    "    mu_pred = y_pred.mean(axis=0)\n",
    "    ss_res = float(np.sum((mu_true - mu_pred) ** 2))\n",
    "    ss_tot = float(np.sum((mu_true - mu_true.mean()) ** 2))\n",
    "    if ss_tot <= 1e-12:\n",
    "        return 1.0 if ss_res <= 1e-12 else 0.0\n",
    "    return 1.0 - ss_res / ss_tot\n",
    "\n",
    "def wasserstein_pointcloud(\n",
    "    X,\n",
    "    Y,\n",
    "    p: int = 2,\n",
    "    a=None,\n",
    "    b=None,\n",
    "    method: str = \"emd\",          # \"emd\" (exact) or \"sinkhorn\" (approx)\n",
    "    reg: float = 1e-1,            # Sinkhorn regularization (only used if method=\"sinkhorn\")\n",
    "    return_plan: bool = False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Compute Wasserstein distance W_p between two empirical distributions supported on point sets X and Y.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    X : (n, d) array-like\n",
    "        Source points.\n",
    "    Y : (m, d) array-like\n",
    "        Target points.\n",
    "    p : int\n",
    "        Order of the Wasserstein distance (commonly 1 or 2).\n",
    "    a : (n,) array-like or None\n",
    "        Weights for X; if None, uniform weights.\n",
    "    b : (m,) array-like or None\n",
    "        Weights for Y; if None, uniform weights.\n",
    "    method : str\n",
    "        \"emd\" for exact optimal transport (requires POT),\n",
    "        \"sinkhorn\" for entropic approximation (requires POT).\n",
    "    reg : float\n",
    "        Entropic regularization strength for Sinkhorn.\n",
    "    return_plan : bool\n",
    "        If True, also return the optimal transport plan.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    Wp : float\n",
    "        Wasserstein distance of order p.\n",
    "    plan : (n, m) ndarray, optional\n",
    "        Optimal transport plan (only if return_plan=True).\n",
    "    \"\"\"\n",
    "    X = np.asarray(X, dtype=np.float64)\n",
    "    Y = np.asarray(Y, dtype=np.float64)\n",
    "    if X.ndim != 2 or Y.ndim != 2:\n",
    "        raise ValueError(\"X and Y must be 2D arrays with shape (n, d) and (m, d).\")\n",
    "    if X.shape[1] != Y.shape[1]:\n",
    "        raise ValueError(f\"Dimension mismatch: X has d={X.shape[1]}, Y has d={Y.shape[1]}.\")\n",
    "\n",
    "    n, d = X.shape\n",
    "    m, _ = Y.shape\n",
    "\n",
    "    if a is None:\n",
    "        a = np.full(n, 1.0 / n, dtype=np.float64)\n",
    "    else:\n",
    "        a = np.asarray(a, dtype=np.float64)\n",
    "        a = a / a.sum()\n",
    "\n",
    "    if b is None:\n",
    "        b = np.full(m, 1.0 / m, dtype=np.float64)\n",
    "    else:\n",
    "        b = np.asarray(b, dtype=np.float64)\n",
    "        b = b / b.sum()\n",
    "\n",
    "    # Cost matrix: C_ij = ||x_i - y_j||^p\n",
    "    # Compute squared Euclidean via (x-y)^2 = x^2 + y^2 - 2xy for speed\n",
    "    X2 = np.sum(X * X, axis=1, keepdims=True)          # (n, 1)\n",
    "    Y2 = np.sum(Y * Y, axis=1, keepdims=True).T        # (1, m)\n",
    "    sq = np.maximum(X2 + Y2 - 2.0 * (X @ Y.T), 0.0)     # (n, m)\n",
    "    if p == 2:\n",
    "        C = sq\n",
    "    else:\n",
    "        C = sq ** (p / 2.0)\n",
    "\n",
    "    try:\n",
    "        import ot  # POT: Python Optimal Transport\n",
    "    except ImportError as e:\n",
    "        raise ImportError(\n",
    "            \"This function requires the POT library. Install with: pip install pot\"\n",
    "        ) from e\n",
    "\n",
    "    method = method.lower()\n",
    "    if method == \"emd\":\n",
    "        # exact OT: minimizes <P, C>\n",
    "        P = ot.emd(a, b, C)\n",
    "        cost = float(np.sum(P * C))\n",
    "    elif method == \"sinkhorn\":\n",
    "        # entropic OT approximation\n",
    "        P = ot.sinkhorn(a, b, C, reg=reg)\n",
    "        cost = float(np.sum(P * C))\n",
    "    else:\n",
    "        raise ValueError('method must be either \"emd\" or \"sinkhorn\".')\n",
    "\n",
    "    Wp = cost ** (1.0 / p)\n",
    "\n",
    "    if return_plan:\n",
    "        return Wp, P\n",
    "    return Wp\n",
    "\n",
    "def summarize_metrics(y_true: np.ndarray, y_pred: np.ndarray, median_gamma: float) -> dict:\n",
    "    \"\"\"\n",
    "    Compute a standard set of metrics: MMD^2 (RBF), R^2 of feature means, median KS across features, and Wasserstein distance.\n",
    "    \"\"\"\n",
    "    # Drop any samples that contain NaNs in either true or pred\n",
    "    mask = (~np.isnan(y_true).any(axis=1)) & (~np.isnan(y_pred).any(axis=1))\n",
    "    if mask.sum() < len(y_true):\n",
    "        print(f\"[summarize_metrics] Dropping {len(y_true) - mask.sum()} samples with NaNs.\")\n",
    "    \n",
    "    y_true = y_true[mask]\n",
    "    y_pred = y_pred[mask]\n",
    "\n",
    "    out = {}\n",
    "\n",
    "    out['mmd2_gamma_median'] = mmd_distance(y_true, y_pred, gamma=median_gamma)\n",
    "    out['mmd2_gamma_0.5'] = mmd_distance(y_true, y_pred, gamma=0.5)\n",
    "    out['mmd2_gamma_1.0'] = mmd_distance(y_true, y_pred, gamma=1.0)\n",
    "    out['wasserstein_distance'] = wasserstein_pointcloud(y_true, y_pred, p=2, method=\"emd\")\n",
    "    out['R2_feature_means'] = r2_feature_means(y_true, y_pred)\n",
    "    return out\n",
    "\n",
    "def split_train_test(X: np.ndarray, Y: np.ndarray, train_fraction: float, seed: int = 42) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:\n",
    "    if X.shape[0] != Y.shape[0]:\n",
    "        min_len = min(len(X), len(Y))\n",
    "        X = X[:min_len]\n",
    "        Y = Y[:min_len]\n",
    "\n",
    "    n = X.shape[0]\n",
    "    n_train = max(1, int(n * train_fraction))\n",
    "    rng = np.random.default_rng(seed)\n",
    "    idx = rng.permutation(n)\n",
    "    tr_idx, te_idx = idx[:n_train], idx[n_train:]\n",
    "    return X[tr_idx], X[te_idx], Y[tr_idx], Y[te_idx]\n",
    "\n",
    "def topk_markers(adata, drug: str, k: int = 50, rank_key: str = \"marker_genes-drug-rank\"):\n",
    "    R = adata.varm[rank_key]\n",
    "\n",
    "    # --- get the rank vector for this drug ---\n",
    "    if hasattr(R, \"columns\") and hasattr(R, \"iloc\"):  # pandas DataFrame\n",
    "        if drug in R.columns:\n",
    "            r = R[drug].to_numpy()\n",
    "        else:\n",
    "            # fallback: interpret columns as ordered groups; try to map via rank_genes_groups names\n",
    "            names = adata.uns[\"rank_genes_groups\"][\"names\"]\n",
    "            groups = list(names.dtype.names) if (hasattr(names, \"dtype\") and names.dtype.names is not None) else list(names.columns)\n",
    "            r = R.iloc[:, groups.index(drug)].to_numpy()\n",
    "    else:  # numpy array (or array-like)\n",
    "        names = adata.uns[\"rank_genes_groups\"][\"names\"]\n",
    "        groups = list(names.dtype.names) if (hasattr(names, \"dtype\") and names.dtype.names is not None) else list(names.columns)\n",
    "        r = np.asarray(R)[:, groups.index(drug)]\n",
    "\n",
    "    # smaller rank => stronger marker\n",
    "    idx = np.argsort(r)[:k]\n",
    "    gene_ids = adata.var_names[idx].to_list()\n",
    "    gene_short = (adata.var.iloc[idx][\"gene_short_name\"].to_list()\n",
    "                  if \"gene_short_name\" in adata.var.columns else None)\n",
    "    return gene_ids, gene_short, idx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2d497048",
   "metadata": {},
   "outputs": [],
   "source": [
    "def SCGEN(\n",
    "    X_tr_pre, Y_tr_post, X_te_pre, Y_te_post,\n",
    "    max_epochs=1000,\n",
    "    batch_size=64,\n",
    "    early_stopping=True,\n",
    "    early_stopping_patience=50,\n",
    "    condition_key=\"condition\",\n",
    "    ctrl_label=\"control\",\n",
    "    stim_label=\"treated\",\n",
    "    cell_type_key=\"cell_type\",\n",
    "    cell_type_label=\"cell\",\n",
    "    n_hidden=256,\n",
    "    n_latent=100,\n",
    "    n_layers=2,\n",
    "    dropout_rate=0.2,\n",
    "    accelerator=\"auto\",   # \"auto\" | \"cpu\" | \"gpu\" | \"mps\"\n",
    "    learning_rate=5e-5,\n",
    "    seed=12345,\n",
    "    verbose=True,\n",
    "    metrics_fn=None,      # e.g., summarize_metrics(y_true, y_pred)\n",
    "    top_feature_subset= None,\n",
    "):\n",
    "    import numpy as np\n",
    "    import pandas as pd\n",
    "    import random\n",
    "\n",
    "    try:\n",
    "        import anndata as ad\n",
    "        import torch\n",
    "        import scgen\n",
    "        try:\n",
    "            import scvi\n",
    "        except Exception:\n",
    "            scvi = None\n",
    "    except Exception as e:\n",
    "        msg = str(e)\n",
    "        hint = None\n",
    "        if \"scvi._compat\" in msg:\n",
    "            hint = (\n",
    "                \"Likely scGen (PyPI) vs scvi-tools mismatch. \"\n",
    "                \"Try: pip uninstall -y scgen && pip install git+https://github.com/theislab/scgen.git\"\n",
    "            )\n",
    "        return {\"error\": f\"Import failed: {e}\", \"hint\": hint}\n",
    "\n",
    "    # ----------- Validate shapes -----------\n",
    "    X_tr_pre = np.asarray(X_tr_pre); Y_tr_post = np.asarray(Y_tr_post)\n",
    "    X_te_pre = np.asarray(X_te_pre); Y_te_post = np.asarray(Y_te_post)\n",
    "\n",
    "    if any(a.ndim != 2 for a in [X_tr_pre, Y_tr_post, X_te_pre, Y_te_post]):\n",
    "        return {\"error\": \"All inputs must be 2D arrays (n_cells, n_features).\"}\n",
    "    d = X_tr_pre.shape[1]\n",
    "    if any(a.shape[1] != d for a in [Y_tr_post, X_te_pre, Y_te_post]):\n",
    "        return {\"error\": \"Feature dimension mismatch across inputs.\"}\n",
    "\n",
    "    # ----------- Seeds -----------\n",
    "    random.seed(seed); np.random.seed(seed)\n",
    "    try:\n",
    "        torch.manual_seed(seed)\n",
    "        if torch.cuda.is_available():\n",
    "            torch.cuda.manual_seed_all(seed)\n",
    "    except Exception:\n",
    "        pass\n",
    "    if scvi is not None:\n",
    "        try:\n",
    "            scvi.settings.seed = seed\n",
    "        except Exception:\n",
    "            pass\n",
    "\n",
    "    # ----------- Accelerator selection -----------\n",
    "    def _auto_accel():\n",
    "        if torch.cuda.is_available():\n",
    "            return \"cuda\"\n",
    "        if hasattr(torch.backends, \"mps\") and torch.backends.mps.is_available():\n",
    "            return \"mps\"\n",
    "        return \"cpu\"\n",
    "    accelerator_used = _auto_accel() if accelerator == \"auto\" else accelerator\n",
    "\n",
    "    # ----------- Build training AnnData -----------\n",
    "    try:\n",
    "        X_tr_all = np.vstack([X_tr_pre, Y_tr_post]).astype(np.float32, copy=False)\n",
    "        cond_tr = np.array([ctrl_label] * len(X_tr_pre) + [stim_label] * len(Y_tr_post), dtype=object)\n",
    "        ctype_tr = np.array([cell_type_label] * len(X_tr_all), dtype=object)\n",
    "\n",
    "        ad_tr = ad.AnnData(\n",
    "            X_tr_all,\n",
    "            obs=pd.DataFrame({condition_key: cond_tr, cell_type_key: ctype_tr})\n",
    "        )\n",
    "        # avoid \"Observation names are not unique\"\n",
    "        ad_tr.obs_names = [f\"tr_{i}\" for i in range(ad_tr.n_obs)]\n",
    "        ad_tr.var_names = [f\"g_{j}\" for j in range(ad_tr.n_vars)]\n",
    "\n",
    "        scgen.SCGEN.setup_anndata(ad_tr, batch_key=condition_key, labels_key=cell_type_key)\n",
    "\n",
    "        model = scgen.SCGEN(\n",
    "            ad_tr,\n",
    "            n_hidden=n_hidden,\n",
    "            n_latent=n_latent,\n",
    "            n_layers=n_layers,\n",
    "            dropout_rate=dropout_rate,\n",
    "        )\n",
    "\n",
    "        train_kwargs = dict(\n",
    "            max_epochs=int(max_epochs),\n",
    "            batch_size=int(batch_size),\n",
    "            early_stopping=bool(early_stopping),\n",
    "            early_stopping_patience=int(early_stopping_patience),\n",
    "            enable_progress_bar=bool(verbose),\n",
    "            plan_kwargs={\"lr\": float(learning_rate)},\n",
    "        )\n",
    "\n",
    "\n",
    "        model.train(**train_kwargs)\n",
    "\n",
    "    except Exception as e:\n",
    "        return {\"error\": f\"scGen training failed: {e}\"}\n",
    "\n",
    "    # ----------- Predict treated for test controls -----------\n",
    "\n",
    "    ad_te = ad.AnnData(\n",
    "        X_te_pre.astype(np.float32, copy=False),\n",
    "        obs=pd.DataFrame({\n",
    "            condition_key: np.array([ctrl_label] * len(X_te_pre), dtype=object),\n",
    "            cell_type_key: np.array([cell_type_label] * len(X_te_pre), dtype=object),\n",
    "        })\n",
    "    )\n",
    "    ad_te.obs_names = [f\"te_{i}\" for i in range(ad_te.n_obs)]\n",
    "    ad_te.var_names = ad_tr.var_names.copy()  # ensure exact match\n",
    "\n",
    "    # IMPORTANT: pass ONLY adata_to_predict OR celltype_to_predict (not both). :contentReference[oaicite:3]{index=3}\n",
    "    pred_adata, delta = model.predict(\n",
    "        ctrl_key=ctrl_label,\n",
    "        stim_key=stim_label,\n",
    "        adata_to_predict=ad_te,\n",
    "    )\n",
    "\n",
    "    y_pred = np.asarray(pred_adata.X)\n",
    "\n",
    "    return {\n",
    "        \"y_pred\": y_pred,\n",
    "        \"delta\": delta,\n",
    "        \"model\": model,\n",
    "        \"adata_train\": ad_tr,\n",
    "        \"accelerator_used\": accelerator_used,\n",
    "    }\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "7e11a61c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X_pre cells: (17565, 1000)\n",
      "X_post cells: (3277, 1000)\n",
      "(2621, 1000)\n",
      "(656, 1000)\n",
      "(2621, 1000)\n",
      "(656, 1000)\n",
      "Median heuristic gamma: 0.0022819381995053574\n"
     ]
    }
   ],
   "source": [
    "drug = \"trametinib\"\n",
    "X_pre = adata_sc[adata_sc.obs[\"drug\"] == \"control\"].copy().to_df()\n",
    "X_post  = adata_sc[adata_sc.obs[\"drug\"] == drug].copy().to_df()\n",
    "\n",
    "print(\"X_pre cells:\", X_pre.shape)\n",
    "print(\"X_post cells:\", X_post.shape)\n",
    "\n",
    "top_genes_ids, top_genes_short, top_genes_idx = topk_markers(adata_sc, drug, k=100)\n",
    "\n",
    "X_tr_pre, X_te_pre, Y_tr_post, Y_te_post = split_train_test(X_pre.values, X_post.values, 0.8)\n",
    "\n",
    "print(X_tr_pre.shape)\n",
    "print(X_te_pre.shape)\n",
    "print(Y_tr_post.shape)\n",
    "print(Y_te_post.shape)\n",
    "\n",
    "\n",
    "# Compute median heuristic gamma on training data\n",
    "median_gamma = median_heuristic_gamma(X_tr_pre, Y_tr_post)\n",
    "print(\"Median heuristic gamma:\", median_gamma)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b2a6a182",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "**************** Run: 0 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/pytorch_lightning/utilities/imports.py:22: UserWarning: pkg_resources is deprecated as an API. See https://setuptools.pypa.io/en/latest/pkg_resources.html. The pkg_resources package is slated for removal as early as 2025-11-30. Refrain from using this package or pin to Setuptools<81.\n",
      "  import pkg_resources\n",
      "Global seed set to 0\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/pytorch_lightning/utilities/warnings.py:53: LightningDeprecationWarning: pytorch_lightning.utilities.warnings.rank_zero_deprecation has been deprecated in v1.6 and will be removed in v1.8. Use the equivalent function from the pytorch_lightning.utilities.rank_zero module instead.\n",
      "  new_rank_zero_deprecation(\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/pytorch_lightning/utilities/warnings.py:58: LightningDeprecationWarning: The `pytorch_lightning.loggers.base.rank_zero_experiment` is deprecated in v1.7 and will be removed in v1.9. Please use `pytorch_lightning.loggers.logger.rank_zero_experiment` instead.\n",
      "  return new_rank_zero_deprecation(*args, **kwargs)\n",
      "Global seed set to 1234\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:05<01:42,  9.28it/s, loss=2.03, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 58.538. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1235\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.007178729269851303, 'mmd2_gamma_0.5': 0.0001868965798021635, 'mmd2_gamma_1.0': 0.00011129798702495092, 'wasserstein_distance': 6.985626268083184, 'R2_feature_means': 0.5457381494170639}\n",
      "**************** Run: 1 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:11<03:27,  4.56it/s, loss=1.97, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 56.453. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1236\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.0047188137390388185, 'mmd2_gamma_0.5': 0.0001846521954046341, 'mmd2_gamma_1.0': 0.00010360968797296853, 'wasserstein_distance': 6.742930905710108, 'R2_feature_means': 0.6446185686977057}\n",
      "**************** Run: 2 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 52/1000:   5%|██████████▉                                                                                                                                                                                                       | 52/1000 [00:05<01:37,  9.74it/s, loss=1.92, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 56.909. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1237\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.004173577970420617, 'mmd2_gamma_0.5': 0.00019756758748588378, 'mmd2_gamma_1.0': 0.00011270277494056333, 'wasserstein_distance': 6.607497451304653, 'R2_feature_means': 0.6998553332801385}\n",
      "**************** Run: 3 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 52/1000:   5%|██████████▉                                                                                                                                                                                                       | 52/1000 [00:10<03:18,  4.79it/s, loss=1.95, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 62.707. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1238\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.0038278843565950904, 'mmd2_gamma_0.5': 0.00020841862107608025, 'mmd2_gamma_1.0': 0.0001383787414920703, 'wasserstein_distance': 6.697291616290857, 'R2_feature_means': 0.759096740436977}\n",
      "**************** Run: 4 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▊                                                                                                                                                                                                          | 51/1000 [00:05<01:41,  9.36it/s, loss=2, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 58.167. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1239\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.004795815562473882, 'mmd2_gamma_0.5': 0.00020799117470367344, 'mmd2_gamma_1.0': 0.00012692916455571223, 'wasserstein_distance': 6.71239886593147, 'R2_feature_means': 0.657689323858536}\n",
      "**************** Run: 5 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 52/1000:   5%|██████████▉                                                                                                                                                                                                       | 52/1000 [00:06<01:55,  8.18it/s, loss=1.94, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 53.493. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1240\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.0033694648716040554, 'mmd2_gamma_0.5': 0.0002258175134557289, 'mmd2_gamma_1.0': 0.00015782186144102766, 'wasserstein_distance': 6.569540010814972, 'R2_feature_means': 0.7946270407407584}\n",
      "**************** Run: 6 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:05<01:41,  9.36it/s, loss=1.93, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 56.379. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1241\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.004707698408306982, 'mmd2_gamma_0.5': 0.00018365277210250506, 'mmd2_gamma_1.0': 0.00010626050571283251, 'wasserstein_distance': 6.697203585448582, 'R2_feature_means': 0.6407262811478182}\n",
      "**************** Run: 7 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 52/1000:   5%|██████████▉                                                                                                                                                                                                       | 52/1000 [00:05<01:36,  9.84it/s, loss=1.91, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 57.695. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1242\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.005326260292558782, 'mmd2_gamma_0.5': 0.0002151046770283568, 'mmd2_gamma_1.0': 0.00014145401287116448, 'wasserstein_distance': 6.77764288584333, 'R2_feature_means': 0.6665500119827215}\n",
      "**************** Run: 8 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:11<03:38,  4.34it/s, loss=1.95, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 57.150. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "Metrics: {'mmd2_gamma_median': 0.005552639150681493, 'mmd2_gamma_0.5': 0.00014691163920009314, 'mmd2_gamma_1.0': 8.151732842594239e-05, 'wasserstein_distance': 6.697836638807177, 'R2_feature_means': 0.6285045488811316}\n",
      "**************** Run: 9 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1243\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:05<01:42,  9.22it/s, loss=1.95, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 60.970. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "Metrics: {'mmd2_gamma_median': 0.005411204961486504, 'mmd2_gamma_0.5': 0.00018162958840270817, 'mmd2_gamma_1.0': 0.00012288149587622857, 'wasserstein_distance': 6.839163804927917, 'R2_feature_means': 0.5724881752103128}\n",
      "=== Metrics Summary over Runs for top 100 genes ===\n",
      "                        mean     std\n",
      "mmd2_gamma_median     0.0049  0.0011\n",
      "mmd2_gamma_0.5        0.0002  0.0000\n",
      "mmd2_gamma_1.0        0.0001  0.0000\n",
      "wasserstein_distance  6.7327  0.1173\n",
      "R2_feature_means      0.6610  0.0759\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n"
     ]
    }
   ],
   "source": [
    "# Multiple runs for robustness\n",
    "all_metrics = []\n",
    "for run in range(10):\n",
    "    print(f\"**************** Run: {run} ****************\")\n",
    "    seed = 1234 + run\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "    out = SCGEN(X_tr_pre[:, top_genes_idx], Y_tr_post[:, top_genes_idx], X_te_pre[:, top_genes_idx], Y_te_post[:, top_genes_idx], max_epochs=1000,early_stopping_patience=50,learning_rate=1e-3, batch_size= 256, early_stopping=20, dropout_rate=0, seed=seed)\n",
    "    metrics = summarize_metrics(out[\"y_pred\"][:, :], Y_te_post[:, top_genes_idx], median_gamma)\n",
    "    print(\"Metrics:\", metrics)\n",
    "    all_metrics.append(metrics)\n",
    "\n",
    "# Results summary\n",
    "df = pd.DataFrame(all_metrics)\n",
    "print(\"=== Metrics Summary over Runs for top 100 genes ===\")\n",
    "print(df.describe().T[['mean', 'std']].round(4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a06d14a5",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4d17c06",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "9d031247",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X_pre cells: (17565, 1000)\n",
      "X_post cells: (3541, 1000)\n",
      "(2832, 1000)\n",
      "(709, 1000)\n",
      "(2832, 1000)\n",
      "(709, 1000)\n",
      "Median heuristic gamma: 0.0022144700418584803\n"
     ]
    }
   ],
   "source": [
    "drug = \"givinostat\"\n",
    "X_pre = adata_sc[adata_sc.obs[\"drug\"] == \"control\"].copy().to_df()\n",
    "X_post  = adata_sc[adata_sc.obs[\"drug\"] == drug].copy().to_df()\n",
    "\n",
    "print(\"X_pre cells:\", X_pre.shape)\n",
    "print(\"X_post cells:\", X_post.shape)\n",
    "\n",
    "top_genes_ids, top_genes_short, top_genes_idx = topk_markers(adata_sc, drug, k=100)\n",
    "\n",
    "X_tr_pre, X_te_pre, Y_tr_post, Y_te_post = split_train_test(X_pre.values, X_post.values, 0.8)\n",
    "\n",
    "print(X_tr_pre.shape)\n",
    "print(X_te_pre.shape)\n",
    "print(Y_tr_post.shape)\n",
    "print(Y_te_post.shape)\n",
    "\n",
    "\n",
    "# Compute median heuristic gamma on training data\n",
    "median_gamma = median_heuristic_gamma(X_tr_pre, Y_tr_post)\n",
    "print(\"Median heuristic gamma:\", median_gamma)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "822b9837",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 1234\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "**************** Run: 0 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:12<03:48,  4.15it/s, loss=2.13, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 66.223. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1235\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.006313466447630756, 'mmd2_gamma_0.5': 0.0003691429917613674, 'mmd2_gamma_1.0': 0.0001747245901470099, 'wasserstein_distance': 7.534902539956839, 'R2_feature_means': 0.8026611892731065}\n",
      "**************** Run: 1 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 52/1000:   5%|██████████▉                                                                                                                                                                                                       | 52/1000 [00:04<01:23, 11.30it/s, loss=2.06, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 67.438. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1236\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.004663395514790558, 'mmd2_gamma_0.5': 0.000359719027614816, 'mmd2_gamma_1.0': 0.00017151946775517663, 'wasserstein_distance': 7.380270689787843, 'R2_feature_means': 0.860467307522756}\n",
      "**************** Run: 2 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 52/1000:   5%|██████████▉                                                                                                                                                                                                       | 52/1000 [00:05<01:45,  8.99it/s, loss=2.08, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 66.525. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1237\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.006980319464012785, 'mmd2_gamma_0.5': 0.00041610557033485373, 'mmd2_gamma_1.0': 0.00018827332024264802, 'wasserstein_distance': 7.648671190433756, 'R2_feature_means': 0.7692648210854474}\n",
      "**************** Run: 3 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 52/1000:   5%|██████████▉                                                                                                                                                                                                       | 52/1000 [00:06<01:55,  8.21it/s, loss=2.04, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 68.106. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1238\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.004851224895651818, 'mmd2_gamma_0.5': 0.00037470371796055277, 'mmd2_gamma_1.0': 0.00017604734193948208, 'wasserstein_distance': 7.474743400960504, 'R2_feature_means': 0.846642165547775}\n",
      "**************** Run: 4 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:05<01:48,  8.71it/s, loss=2.11, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 64.038. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1239\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.004407899319865738, 'mmd2_gamma_0.5': 0.00036076560370369617, 'mmd2_gamma_1.0': 0.00017331953281213438, 'wasserstein_distance': 7.432771916337902, 'R2_feature_means': 0.8534023356363849}\n",
      "**************** Run: 5 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▊                                                                                                                                                                                                        | 51/1000 [00:04<01:27, 10.88it/s, loss=2.1, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 64.045. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1240\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.0061898199275296495, 'mmd2_gamma_0.5': 0.00035566485083692815, 'mmd2_gamma_1.0': 0.00017108019600011414, 'wasserstein_distance': 7.519769181974495, 'R2_feature_means': 0.818496027809167}\n",
      "**************** Run: 6 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 52/1000:   5%|██████████▉                                                                                                                                                                                                       | 52/1000 [00:05<01:48,  8.77it/s, loss=2.13, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 63.163. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1241\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.0055403603352182085, 'mmd2_gamma_0.5': 0.00034416812584166126, 'mmd2_gamma_1.0': 0.00016361235961306143, 'wasserstein_distance': 7.54134069342653, 'R2_feature_means': 0.804210689611502}\n",
      "**************** Run: 7 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 52/1000:   5%|██████████▉                                                                                                                                                                                                       | 52/1000 [00:05<01:49,  8.68it/s, loss=2.12, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 65.321. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1242\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.0038562006694771167, 'mmd2_gamma_0.5': 0.00034013512940862464, 'mmd2_gamma_1.0': 0.00016386140971805376, 'wasserstein_distance': 7.3244523983155165, 'R2_feature_means': 0.8850607502390836}\n",
      "**************** Run: 8 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 52/1000:   5%|██████████▉                                                                                                                                                                                                       | 52/1000 [00:05<01:42,  9.22it/s, loss=2.16, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 65.661. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1243\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.006372802771999808, 'mmd2_gamma_0.5': 0.00034521305617024774, 'mmd2_gamma_1.0': 0.00016882490760494507, 'wasserstein_distance': 7.49713472030701, 'R2_feature_means': 0.8277694693738168}\n",
      "**************** Run: 9 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▊                                                                                                                                                                                                        | 51/1000 [00:05<01:46,  8.87it/s, loss=2.1, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 63.642. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "Metrics: {'mmd2_gamma_median': 0.005975560137400082, 'mmd2_gamma_0.5': 0.00040123210785285844, 'mmd2_gamma_1.0': 0.00018048788729635093, 'wasserstein_distance': 7.5188287049999865, 'R2_feature_means': 0.7931495874851591}\n",
      "=== Metrics Summary over Runs for top 100 genes ===\n",
      "                        mean     std\n",
      "mmd2_gamma_median     0.0055  0.0010\n",
      "mmd2_gamma_0.5        0.0004  0.0000\n",
      "mmd2_gamma_1.0        0.0002  0.0000\n",
      "wasserstein_distance  7.4873  0.0910\n",
      "R2_feature_means      0.8261  0.0353\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n"
     ]
    }
   ],
   "source": [
    "# Multiple runs for robustness\n",
    "all_metrics = []\n",
    "for run in range(10):\n",
    "    print(f\"**************** Run: {run} ****************\")\n",
    "    seed = 1234 + run\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "    out = SCGEN(X_tr_pre[:, top_genes_idx], Y_tr_post[:, top_genes_idx], X_te_pre[:, top_genes_idx], Y_te_post[:, top_genes_idx], max_epochs=1000,early_stopping_patience=50,learning_rate=1e-3, batch_size= 256, early_stopping=20, dropout_rate=0, seed=seed)\n",
    "    metrics = summarize_metrics(out[\"y_pred\"][:, :], Y_te_post[:, top_genes_idx], median_gamma)\n",
    "    print(\"Metrics:\", metrics)\n",
    "    all_metrics.append(metrics)\n",
    "\n",
    "# Results summary\n",
    "df = pd.DataFrame(all_metrics)\n",
    "print(\"=== Metrics Summary over Runs for top 100 genes ===\")\n",
    "print(df.describe().T[['mean', 'std']].round(4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fe1d90d",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c367fda4",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8815e20",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "39f9f54b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X_pre cells: (17565, 1000)\n",
      "X_post cells: (4505, 1000)\n",
      "(3604, 1000)\n",
      "(901, 1000)\n",
      "(3604, 1000)\n",
      "(901, 1000)\n",
      "Median heuristic gamma: 0.0021463110496127745\n"
     ]
    }
   ],
   "source": [
    "drug = \"abexinostat\"\n",
    "X_pre = adata_sc[adata_sc.obs[\"drug\"] == \"control\"].copy().to_df()\n",
    "X_post  = adata_sc[adata_sc.obs[\"drug\"] == drug].copy().to_df()\n",
    "\n",
    "print(\"X_pre cells:\", X_pre.shape)\n",
    "print(\"X_post cells:\", X_post.shape)\n",
    "\n",
    "top_genes_ids, top_genes_short, top_genes_idx = topk_markers(adata_sc, drug, k=100)\n",
    "\n",
    "X_tr_pre, X_te_pre, Y_tr_post, Y_te_post = split_train_test(X_pre.values, X_post.values, 0.8)\n",
    "\n",
    "print(X_tr_pre.shape)\n",
    "print(X_te_pre.shape)\n",
    "print(Y_tr_post.shape)\n",
    "print(Y_te_post.shape)\n",
    "\n",
    "\n",
    "# Compute median heuristic gamma on training data\n",
    "median_gamma = median_heuristic_gamma(X_tr_pre, Y_tr_post)\n",
    "print(\"Median heuristic gamma:\", median_gamma)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7540bbd9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 1234\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "**************** Run: 0 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<02:04,  7.60it/s, loss=1.83, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 61.081. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1235\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.004099922455375937, 'mmd2_gamma_0.5': 3.936145949170255e-05, 'mmd2_gamma_1.0': 2.4719942106737896e-05, 'wasserstein_distance': 7.556115953106611, 'R2_feature_means': 0.8267798770414407}\n",
      "**************** Run: 1 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<02:04,  7.60it/s, loss=1.81, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 62.285. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1236\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.003990773078559995, 'mmd2_gamma_0.5': 3.8787764125923384e-05, 'mmd2_gamma_1.0': 2.2708037345213588e-05, 'wasserstein_distance': 7.560411632713481, 'R2_feature_means': 0.8312673397984465}\n",
      "**************** Run: 2 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<02:04,  7.65it/s, loss=1.78, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 60.248. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1237\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.0043537781267664055, 'mmd2_gamma_0.5': 3.800818719274806e-05, 'mmd2_gamma_1.0': 2.20867801755453e-05, 'wasserstein_distance': 7.5691237454982385, 'R2_feature_means': 0.8259455067975542}\n",
      "**************** Run: 3 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<01:57,  8.05it/s, loss=1.78, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 59.682. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1238\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.003650091787877674, 'mmd2_gamma_0.5': 3.7261375627051945e-05, 'mmd2_gamma_1.0': 2.0973287825647343e-05, 'wasserstein_distance': 7.458084182809892, 'R2_feature_means': 0.8786476079474157}\n",
      "**************** Run: 4 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<02:00,  7.90it/s, loss=1.84, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 58.333. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1239\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.00359194513754435, 'mmd2_gamma_0.5': 4.141258928031315e-05, 'mmd2_gamma_1.0': 2.325942501146805e-05, 'wasserstein_distance': 7.454792863043423, 'R2_feature_means': 0.8658794410182207}\n",
      "**************** Run: 5 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<01:54,  8.26it/s, loss=1.87, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 60.800. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1240\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.003603802998625394, 'mmd2_gamma_0.5': 3.7587837640740466e-05, 'mmd2_gamma_1.0': 2.0495448114068224e-05, 'wasserstein_distance': 7.524643609558601, 'R2_feature_means': 0.859209134111456}\n",
      "**************** Run: 6 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:05<01:48,  8.72it/s, loss=1.86, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 60.352. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1241\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.004879579490556285, 'mmd2_gamma_0.5': 4.034248703319939e-05, 'mmd2_gamma_1.0': 2.2198579269021188e-05, 'wasserstein_distance': 7.667293406474566, 'R2_feature_means': 0.7750152994431767}\n",
      "**************** Run: 7 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<01:57,  8.10it/s, loss=1.84, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 61.351. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1242\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.0055146547327620254, 'mmd2_gamma_0.5': 4.043435410243556e-05, 'mmd2_gamma_1.0': 2.2626546432869483e-05, 'wasserstein_distance': 7.533998972834932, 'R2_feature_means': 0.7802233334326238}\n",
      "**************** Run: 8 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:07<02:11,  7.22it/s, loss=1.85, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 63.211. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1243\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.0043786293987810865, 'mmd2_gamma_0.5': 3.499790119894403e-05, 'mmd2_gamma_1.0': 2.0828849805826214e-05, 'wasserstein_distance': 7.5412773076501605, 'R2_feature_means': 0.8363971273065126}\n",
      "**************** Run: 9 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<02:07,  7.43it/s, loss=1.86, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 59.337. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.0035178527539190263, 'mmd2_gamma_0.5': 3.809052969061726e-05, 'mmd2_gamma_1.0': 2.2511289173081983e-05, 'wasserstein_distance': 7.514659877139623, 'R2_feature_means': 0.8712491723685304}\n",
      "=== Metrics Summary over Runs for top 100 genes ===\n",
      "                        mean     std\n",
      "mmd2_gamma_median     0.0042  0.0006\n",
      "mmd2_gamma_0.5        0.0000  0.0000\n",
      "mmd2_gamma_1.0        0.0000  0.0000\n",
      "wasserstein_distance  7.5380  0.0602\n",
      "R2_feature_means      0.8351  0.0358\n"
     ]
    }
   ],
   "source": [
    "# Multiple runs for robustness\n",
    "all_metrics = []\n",
    "for run in range(10):\n",
    "    print(f\"**************** Run: {run} ****************\")\n",
    "    seed = 1234 + run\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "    out = SCGEN(X_tr_pre[:, top_genes_idx], Y_tr_post[:, top_genes_idx], X_te_pre[:, top_genes_idx], Y_te_post[:, top_genes_idx], max_epochs=1000,early_stopping_patience=50,learning_rate=1e-3, batch_size= 256, early_stopping=20, dropout_rate=0, seed=seed)\n",
    "    metrics = summarize_metrics(out[\"y_pred\"][:, :], Y_te_post[:, top_genes_idx], median_gamma)\n",
    "    print(\"Metrics:\", metrics)\n",
    "    all_metrics.append(metrics)\n",
    "\n",
    "# Results summary\n",
    "df = pd.DataFrame(all_metrics)\n",
    "print(\"=== Metrics Summary over Runs for top 100 genes ===\")\n",
    "print(df.describe().T[['mean', 'std']].round(4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "34b28923",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Global seed set to 1234\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "**************** Run: 0 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:07<02:19,  6.81it/s, loss=56.5, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 288.650. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1235\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.009299218881319904, 'mmd2_gamma_0.5': 0.00019875310976055153, 'mmd2_gamma_1.0': 7.138180949756421e-06, 'wasserstein_distance': 6.758180188797397, 'R2_feature_means': 0.760500614586567}\n",
      "**************** Run: 1 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:07<02:15,  7.00it/s, loss=56.1, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 285.936. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1236\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.007165860392953638, 'mmd2_gamma_0.5': 0.00020338266182881725, 'mmd2_gamma_1.0': 8.133351980598193e-06, 'wasserstein_distance': 6.685411696962444, 'R2_feature_means': 0.8635144115542632}\n",
      "**************** Run: 2 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:07<02:11,  7.21it/s, loss=56.2, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 269.998. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1237\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.010312317320621212, 'mmd2_gamma_0.5': 0.00021870026349531213, 'mmd2_gamma_1.0': 9.149838921228902e-06, 'wasserstein_distance': 6.823148336054127, 'R2_feature_means': 0.7320781911427089}\n",
      "**************** Run: 3 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<02:09,  7.32it/s, loss=55.9, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 272.835. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1238\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.00998779606592759, 'mmd2_gamma_0.5': 0.00022943744107738648, 'mmd2_gamma_1.0': 1.02012825372399e-05, 'wasserstein_distance': 6.811279681799353, 'R2_feature_means': 0.6922548593934514}\n",
      "**************** Run: 4 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<02:03,  7.69it/s, loss=56.1, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 256.241. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1239\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.0083891093485422, 'mmd2_gamma_0.5': 0.00021230850774868216, 'mmd2_gamma_1.0': 8.980747249262955e-06, 'wasserstein_distance': 6.777777413215647, 'R2_feature_means': 0.7978011768214603}\n",
      "**************** Run: 5 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<01:56,  8.17it/s, loss=56.5, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 292.999. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1240\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.008275433363463591, 'mmd2_gamma_0.5': 0.00018862068196988312, 'mmd2_gamma_1.0': 7.370020892575665e-06, 'wasserstein_distance': 6.8032411586042265, 'R2_feature_means': 0.7669175491719943}\n",
      "**************** Run: 6 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<02:03,  7.66it/s, loss=55.3, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 260.314. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1241\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.009084417456865612, 'mmd2_gamma_0.5': 0.0002864084873594836, 'mmd2_gamma_1.0': 1.1849176107848295e-05, 'wasserstein_distance': 6.84570545778712, 'R2_feature_means': 0.758891083316906}\n",
      "**************** Run: 7 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:07<02:19,  6.82it/s, loss=55.5, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 295.559. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1242\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.007774773943599156, 'mmd2_gamma_0.5': 0.00020465328429133535, 'mmd2_gamma_1.0': 8.063604989936494e-06, 'wasserstein_distance': 6.716625536328823, 'R2_feature_means': 0.8028028593332961}\n",
      "**************** Run: 8 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▋                                                                                                                                                                                                       | 51/1000 [00:06<02:04,  7.64it/s, loss=56.5, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 279.298. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n",
      "Global seed set to 1243\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.011166420744414873, 'mmd2_gamma_0.5': 0.00027186881533112876, 'mmd2_gamma_1.0': 1.2589257222048309e-05, 'wasserstein_distance': 6.903050438915976, 'R2_feature_means': 0.5857998375991802}\n",
      "**************** Run: 9 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "GPU available: True (cuda), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n",
      "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51/1000:   5%|██████████▊                                                                                                                                                                                                         | 51/1000 [00:06<02:08,  7.39it/s, loss=56, v_num=1]\n",
      "Monitored metric elbo_validation did not improve in the last 50 records. Best score: 278.714. Signaling Trainer to stop.\n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Received view of anndata, making copy.                                                                    \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n",
      "\u001b[34mINFO    \u001b[0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:121: ImplicitModificationWarning: Transforming to str index.\n",
      "  warnings.warn(\"Transforming to str index.\", ImplicitModificationWarning)\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/anndata/_core/anndata.py:1828: UserWarning: Observation names are not unique. To make them unique, call `.obs_names_make_unique`.\n",
      "  utils.warn_names_duplicates(\"obs\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Metrics: {'mmd2_gamma_median': 0.009560534549209354, 'mmd2_gamma_0.5': 0.00023838204644722148, 'mmd2_gamma_1.0': 9.782957914069552e-06, 'wasserstein_distance': 6.779318405322399, 'R2_feature_means': 0.7534238301195895}\n",
      "=== Metrics Summary over Runs for top 100 genes ===\n",
      "                        mean     std\n",
      "mmd2_gamma_median     0.0091  0.0012\n",
      "mmd2_gamma_0.5        0.0002  0.0000\n",
      "mmd2_gamma_1.0        0.0000  0.0000\n",
      "wasserstein_distance  6.7904  0.0625\n",
      "R2_feature_means      0.7514  0.0738\n"
     ]
    }
   ],
   "source": [
    "# Multiple runs for robustness\n",
    "all_metrics = []\n",
    "for run in range(10):\n",
    "    print(f\"**************** Run: {run} ****************\")\n",
    "    seed = 1234 + run\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "    out = SCGEN(X_tr_pre[:, ], Y_tr_post[:, ], X_te_pre[:, ], Y_te_post[:, ], max_epochs=1000,early_stopping_patience=50,learning_rate=1e-3, batch_size= 256, early_stopping=20, dropout_rate=0, seed=seed)\n",
    "    metrics = summarize_metrics(out[\"y_pred\"][:, top_genes_idx], Y_te_post[:, top_genes_idx], median_gamma)\n",
    "    print(\"Metrics:\", metrics)\n",
    "    all_metrics.append(metrics)\n",
    "\n",
    "# Results summary\n",
    "df = pd.DataFrame(all_metrics)\n",
    "print(\"=== Metrics Summary over Runs for top 100 genes ===\")\n",
    "print(df.describe().T[['mean', 'std']].round(4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2f39ad9f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
