{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b9e2d729",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True\n",
      "11.7\n",
      "NVIDIA RTX 6000 Ada Generation\n",
      "2\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "48"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "print(torch.cuda.is_available())   # should be True\n",
    "print(torch.version.cuda)          # CUDA version PyTorch is built with\n",
    "print(torch.cuda.get_device_name(0)) if torch.cuda.is_available() else None\n",
    "print(torch.cuda.device_count())\n",
    "import os\n",
    "os.cpu_count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "fe4f5248",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "from typing import Dict, Tuple, List, Optional\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import logging\n",
    "logger = logging.getLogger(__name__)\n",
    "\n",
    "def load_ap1_data_from_csv(csv_filepath: str, replicate: Optional[int] = None) -> Dict[str, np.ndarray]:\n",
    "    \"\"\"\n",
    "    Loads AP1 single-cell data from CSV or Excel file.\n",
    "\n",
    "    Args:\n",
    "        csv_filepath: Path to the CSV or Excel file\n",
    "\n",
    "    Returns:\n",
    "        Dictionary with condition identifiers as keys and feature matrices as values\n",
    "    \"\"\"\n",
    "    logger.info(f\"Loading data from: {csv_filepath}\")\n",
    "\n",
    "    # Load the data based on file extension\n",
    "    if csv_filepath.endswith('.csv'):\n",
    "        df = pd.read_csv(csv_filepath)\n",
    "    elif csv_filepath.endswith('.xlsx'):\n",
    "        df = pd.read_excel(csv_filepath)\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported file format. Please provide a .csv or .xlsx file.\")\n",
    "    \n",
    "    replacement_map = {\n",
    "                        '0.316 uM Vemurafenib': 'Vem',\n",
    "                        '0.316 uM Vem + 0.0316 uM Tram': 'Vem+Tram'\n",
    "                        }\n",
    "    df['condition'] = df['condition'].replace(replacement_map)\n",
    "\n",
    "    print(df['condition'].unique())\n",
    "\n",
    "    # Define AP1 protein features (these are in log space already)\n",
    "    ap1_features = [\n",
    "        'cFOS (log a.u.)', 'p-cFOS (log a.u.)', 'FRA1 (log a.u.)', 'p-FRA1 (log a.u.)', 'FRA2 (log a.u.)',\n",
    "        'cJUN (log a.u.)', 'p-cJUN (log a.u.)', 'JUNB (log a.u.)', 'JUND (log a.u.)', 'p-ATF1 (log a.u.)',\n",
    "        'ATF2 (log a.u.)',\t 'p-ATF2 (log a.u.)', 'ATF3 (log a.u.)', 'ATF4 (log a.u.)', 'p-ATF4 (log a.u.)',\n",
    "        'ATF5 (log a.u.)', 'ATF6 (log a.u.)', 'MITF (log a.u.)', 'NGFR (log a.u.)', 'p-ERK (log a.u.)',\n",
    "    ]\n",
    "\n",
    "    # Check if all features exist\n",
    "    missing_features = [f for f in ap1_features if f not in df.columns]\n",
    "    if missing_features:\n",
    "        logger.warning(f\"Missing features: {missing_features}\")\n",
    "        ap1_features = [f for f in ap1_features if f in df.columns]\n",
    "\n",
    "    logger.info(f\"Using {len(ap1_features)} AP1 features\")\n",
    "\n",
    "    # Create condition-based data dictionary\n",
    "    data_dict = {}\n",
    "\n",
    "    if replicate is not None:\n",
    "        # Group by condition, time, and cell line\n",
    "        for (condition, time, cell_line, replicate_id), group in df.groupby(['condition', 'time', 'cell_line', 'replicate_id']):\n",
    "            # Create condition identifier\n",
    "            condition_id = f\"{cell_line}_{condition}_{time.replace(' ', '')}_rep{replicate_id}\"\n",
    "\n",
    "            # Extract feature matrix\n",
    "            feature_matrix = group[ap1_features].values\n",
    "\n",
    "            # Remove rows with any NaN values\n",
    "            valid_rows = ~np.isnan(feature_matrix).any(axis=1)\n",
    "            feature_matrix = feature_matrix[valid_rows]\n",
    "\n",
    "            if len(feature_matrix) > 0:\n",
    "                data_dict[condition_id] = feature_matrix\n",
    "                logger.info(f\"Loaded {condition_id}: {feature_matrix.shape}\")\n",
    "            else:\n",
    "                logger.warning(f\"No valid data for {condition_id}\")\n",
    "    else:\n",
    "        # Group by condition, time, and cell line\n",
    "        for (condition, time, cell_line), group in df.groupby(['condition', 'time', 'cell_line']):\n",
    "            # Create condition identifier\n",
    "            condition_id = f\"{cell_line}_{condition}_{time.replace(' ', '')}\"\n",
    "\n",
    "            # Extract feature matrix\n",
    "            feature_matrix = group[ap1_features].values\n",
    "\n",
    "            # Remove rows with any NaN values\n",
    "            valid_rows = ~np.isnan(feature_matrix).any(axis=1)\n",
    "            feature_matrix = feature_matrix[valid_rows]\n",
    "\n",
    "            if len(feature_matrix) > 0:\n",
    "                data_dict[condition_id] = feature_matrix\n",
    "                logger.info(f\"Loaded {condition_id}: {feature_matrix.shape}\")\n",
    "            else:\n",
    "                logger.warning(f\"No valid data for {condition_id}\")\n",
    "\n",
    "    return data_dict\n",
    "\n",
    "def prepare_pair_from_mat(cell_line: str,\n",
    "                          baseline_condition: str, baseline_time: str,\n",
    "                          target_condition: str, target_time: str,\n",
    "                          replicate: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]:\n",
    "    print(\"Cell line: \", cell_line)\n",
    "    raw_data_dict = load_ap1_data_from_csv('mmc5.xlsx', replicate)\n",
    "\n",
    "    if replicate is not None:\n",
    "        pre_key = f\"{cell_line}_{baseline_condition}_{baseline_time}_rep{replicate}\"\n",
    "        post_key = f\"{cell_line}_{target_condition}_{target_time}_rep{replicate}\"\n",
    "    else:\n",
    "        pre_key = f\"{cell_line}_{baseline_condition}_{baseline_time}\"\n",
    "        post_key = f\"{cell_line}_{target_condition}_{target_time}\"\n",
    "\n",
    "    if pre_key not in raw_data_dict or post_key not in raw_data_dict:\n",
    "        raise ValueError(f\"Pair not found: {pre_key}, {post_key}\")\n",
    "\n",
    "    # Equalize N\n",
    "    n = min(len(raw_data_dict[pre_key]), len(raw_data_dict[post_key]))\n",
    "    X_pre_raw = raw_data_dict[pre_key][:n]\n",
    "    X_post_raw = raw_data_dict[post_key][:n]\n",
    "    return X_pre_raw, X_post_raw\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "75d3cecc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import json\n",
    "import logging\n",
    "import argparse\n",
    "import geomloss\n",
    "import random\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "from typing import Dict, Tuple, List, Optional\n",
    "from umap import UMAP\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.optimize import linear_sum_assignment\n",
    "from sklearn.preprocessing import StandardScaler, MinMaxScaler\n",
    "from sklearn.metrics.pairwise import rbf_kernel\n",
    "from typing import Dict, Tuple, List\n",
    "from scipy.stats import ks_2samp\n",
    "from scipy.spatial.distance import cdist\n",
    "from sklearn.metrics import r2_score\n",
    "\n",
    "import gc\n",
    "gc.collect()\n",
    "\n",
    "def median_heuristic_gamma(X: np.ndarray, Y: np.ndarray) -> float:\n",
    "    \"\"\"\n",
    "    Median heuristic for RBF bandwidth: gamma = 1 / median(||x - y||^2).\n",
    "    Uses the median of pairwise distances in the pooled set.\n",
    "    \"\"\"\n",
    "    Z = np.vstack([X, Y])\n",
    "    # Sample if too large for efficiency\n",
    "    max_samples = 5000\n",
    "    if Z.shape[0] > max_samples:\n",
    "        idx = np.random.choice(Z.shape[0], size=max_samples, replace=False)\n",
    "        Z = Z[idx]\n",
    "    D2 = cdist(Z, Z, metric='sqeuclidean')\n",
    "    # Use upper triangular without diagonal\n",
    "    triu = D2[np.triu_indices_from(D2, k=1)]\n",
    "    med = np.median(triu[triu > 0]) if np.any(triu > 0) else 1.0\n",
    "    return 1.0 / max(med, 1e-12)\n",
    "\n",
    "def mmd_distance(X: np.ndarray, Y: np.ndarray, gamma: float) -> float:\n",
    "    \"\"\"\n",
    "    Unbiased MMD^2 estimator using Gaussian (RBF) kernel, sklearn backend.\n",
    "\n",
    "    Args:\n",
    "        X: (n_samples, n_features) first sample\n",
    "        Y: (m_samples, n_features) second sample\n",
    "        gamma: RBF kernel bandwidth; if None, uses median heuristic\n",
    "\n",
    "    Returns:\n",
    "        Unbiased MMD^2 value\n",
    "    \"\"\"\n",
    "    n = X.shape[0]\n",
    "    m = Y.shape[0]\n",
    "\n",
    "    # Kernel matrices\n",
    "    Kxx = rbf_kernel(X, X, gamma=gamma)\n",
    "    Kyy = rbf_kernel(Y, Y, gamma=gamma)\n",
    "    Kxy = rbf_kernel(X, Y, gamma=gamma)\n",
    "\n",
    "    # Unbiased: exclude diagonal entries\n",
    "    np.fill_diagonal(Kxx, 0.0)\n",
    "    np.fill_diagonal(Kyy, 0.0)\n",
    "\n",
    "    term_xx = Kxx.sum() / (n * (n - 1)) if n > 1 else 0.0\n",
    "    term_yy = Kyy.sum() / (m * (m - 1)) if m > 1 else 0.0\n",
    "    term_xy = 2.0 * Kxy.mean()\n",
    "\n",
    "    mmd2 = term_xx + term_yy - term_xy\n",
    "    mmd2 = max(mmd2, 0.0)  # Numerical stability\n",
    "    return float(mmd2)\n",
    "\n",
    "def r2_feature_means(y_true: np.ndarray, y_pred: np.ndarray) -> float:\n",
    "    \"\"\"\n",
    "    R^2 computed across features between mean vectors of y_true and y_pred.\n",
    "    \"\"\"\n",
    "    mu_true = y_true.mean(axis=0)\n",
    "    mu_pred = y_pred.mean(axis=0)\n",
    "    ss_res = float(np.sum((mu_true - mu_pred) ** 2))\n",
    "    ss_tot = float(np.sum((mu_true - mu_true.mean()) ** 2))\n",
    "    if ss_tot <= 1e-12:\n",
    "        return 1.0 if ss_res <= 1e-12 else 0.0\n",
    "    return 1.0 - ss_res / ss_tot\n",
    "\n",
    "def wasserstein_pointcloud(\n",
    "    X,\n",
    "    Y,\n",
    "    p: int = 2,\n",
    "    a=None,\n",
    "    b=None,\n",
    "    method: str = \"emd\",          # \"emd\" (exact) or \"sinkhorn\" (approx)\n",
    "    reg: float = 1e-1,            # Sinkhorn regularization (only used if method=\"sinkhorn\")\n",
    "    return_plan: bool = False,\n",
    "):\n",
    "    \"\"\"\n",
    "    Compute Wasserstein distance W_p between two empirical distributions supported on point sets X and Y.\n",
    "\n",
    "    Parameters\n",
    "    ----------\n",
    "    X : (n, d) array-like\n",
    "        Source points.\n",
    "    Y : (m, d) array-like\n",
    "        Target points.\n",
    "    p : int\n",
    "        Order of the Wasserstein distance (commonly 1 or 2).\n",
    "    a : (n,) array-like or None\n",
    "        Weights for X; if None, uniform weights.\n",
    "    b : (m,) array-like or None\n",
    "        Weights for Y; if None, uniform weights.\n",
    "    method : str\n",
    "        \"emd\" for exact optimal transport (requires POT),\n",
    "        \"sinkhorn\" for entropic approximation (requires POT).\n",
    "    reg : float\n",
    "        Entropic regularization strength for Sinkhorn.\n",
    "    return_plan : bool\n",
    "        If True, also return the optimal transport plan.\n",
    "\n",
    "    Returns\n",
    "    -------\n",
    "    Wp : float\n",
    "        Wasserstein distance of order p.\n",
    "    plan : (n, m) ndarray, optional\n",
    "        Optimal transport plan (only if return_plan=True).\n",
    "    \"\"\"\n",
    "    X = np.asarray(X, dtype=np.float64)\n",
    "    Y = np.asarray(Y, dtype=np.float64)\n",
    "    if X.ndim != 2 or Y.ndim != 2:\n",
    "        raise ValueError(\"X and Y must be 2D arrays with shape (n, d) and (m, d).\")\n",
    "    if X.shape[1] != Y.shape[1]:\n",
    "        raise ValueError(f\"Dimension mismatch: X has d={X.shape[1]}, Y has d={Y.shape[1]}.\")\n",
    "\n",
    "    n, d = X.shape\n",
    "    m, _ = Y.shape\n",
    "\n",
    "    if a is None:\n",
    "        a = np.full(n, 1.0 / n, dtype=np.float64)\n",
    "    else:\n",
    "        a = np.asarray(a, dtype=np.float64)\n",
    "        a = a / a.sum()\n",
    "\n",
    "    if b is None:\n",
    "        b = np.full(m, 1.0 / m, dtype=np.float64)\n",
    "    else:\n",
    "        b = np.asarray(b, dtype=np.float64)\n",
    "        b = b / b.sum()\n",
    "\n",
    "    # Cost matrix: C_ij = ||x_i - y_j||^p\n",
    "    # Compute squared Euclidean via (x-y)^2 = x^2 + y^2 - 2xy for speed\n",
    "    X2 = np.sum(X * X, axis=1, keepdims=True)          # (n, 1)\n",
    "    Y2 = np.sum(Y * Y, axis=1, keepdims=True).T        # (1, m)\n",
    "    sq = np.maximum(X2 + Y2 - 2.0 * (X @ Y.T), 0.0)     # (n, m)\n",
    "    if p == 2:\n",
    "        C = sq\n",
    "    else:\n",
    "        C = sq ** (p / 2.0)\n",
    "\n",
    "    try:\n",
    "        import ot  # POT: Python Optimal Transport\n",
    "    except ImportError as e:\n",
    "        raise ImportError(\n",
    "            \"This function requires the POT library. Install with: pip install pot\"\n",
    "        ) from e\n",
    "\n",
    "    method = method.lower()\n",
    "    if method == \"emd\":\n",
    "        # exact OT: minimizes <P, C>\n",
    "        P = ot.emd(a, b, C)\n",
    "        cost = float(np.sum(P * C))\n",
    "    elif method == \"sinkhorn\":\n",
    "        # entropic OT approximation\n",
    "        P = ot.sinkhorn(a, b, C, reg=reg)\n",
    "        cost = float(np.sum(P * C))\n",
    "    else:\n",
    "        raise ValueError('method must be either \"emd\" or \"sinkhorn\".')\n",
    "\n",
    "    Wp = cost ** (1.0 / p)\n",
    "\n",
    "    if return_plan:\n",
    "        return Wp, P\n",
    "    return Wp\n",
    "\n",
    "def summarize_metrics(y_true: np.ndarray, y_pred: np.ndarray, median_gamma: float) -> dict:\n",
    "    \"\"\"\n",
    "    Compute a standard set of metrics: MMD^2 (RBF), R^2 of feature means, median KS across features, and Wasserstein distance.\n",
    "    \"\"\"\n",
    "    # Drop any samples that contain NaNs in either true or pred\n",
    "    mask = (~np.isnan(y_true).any(axis=1)) & (~np.isnan(y_pred).any(axis=1))\n",
    "    if mask.sum() < len(y_true):\n",
    "        print(f\"[summarize_metrics] Dropping {len(y_true) - mask.sum()} samples with NaNs.\")\n",
    "    \n",
    "    y_true = y_true[mask]\n",
    "    y_pred = y_pred[mask]\n",
    "\n",
    "    out = {}\n",
    "\n",
    "    out['mmd2_gamma_median'] = mmd_distance(y_true, y_pred, gamma=median_gamma)\n",
    "    out['mmd2_gamma_0.5'] = mmd_distance(y_true, y_pred, gamma=0.5)\n",
    "    out['mmd2_gamma_1.0'] = mmd_distance(y_true, y_pred, gamma=1.0)\n",
    "    out['wasserstein_distance'] = wasserstein_pointcloud(y_true, y_pred, p=2, method=\"emd\")\n",
    "    out['R2_feature_means'] = r2_feature_means(y_true, y_pred)\n",
    "    return out\n",
    "\n",
    "def split_train_test(X: np.ndarray, Y: np.ndarray, train_fraction: float, seed: int = 42) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:\n",
    "    if X.shape[0] != Y.shape[0]:\n",
    "        min_len = min(len(X), len(Y))\n",
    "        X = X[:min_len]\n",
    "        Y = Y[:min_len]\n",
    "\n",
    "    n = X.shape[0]\n",
    "    n_train = max(1, int(n * train_fraction))\n",
    "    rng = np.random.default_rng(seed)\n",
    "    idx = rng.permutation(n)\n",
    "    tr_idx, te_idx = idx[:n_train], idx[n_train:]\n",
    "    return X[tr_idx], X[te_idx], Y[tr_idx], Y[te_idx]\n",
    "\n",
    "def topk_markers(adata, drug: str, k: int = 50, rank_key: str = \"marker_genes-drug-rank\"):\n",
    "    R = adata.varm[rank_key]\n",
    "\n",
    "    # --- get the rank vector for this drug ---\n",
    "    if hasattr(R, \"columns\") and hasattr(R, \"iloc\"):  # pandas DataFrame\n",
    "        if drug in R.columns:\n",
    "            r = R[drug].to_numpy()\n",
    "        else:\n",
    "            # fallback: interpret columns as ordered groups; try to map via rank_genes_groups names\n",
    "            names = adata.uns[\"rank_genes_groups\"][\"names\"]\n",
    "            groups = list(names.dtype.names) if (hasattr(names, \"dtype\") and names.dtype.names is not None) else list(names.columns)\n",
    "            r = R.iloc[:, groups.index(drug)].to_numpy()\n",
    "    else:  # numpy array (or array-like)\n",
    "        names = adata.uns[\"rank_genes_groups\"][\"names\"]\n",
    "        groups = list(names.dtype.names) if (hasattr(names, \"dtype\") and names.dtype.names is not None) else list(names.columns)\n",
    "        r = np.asarray(R)[:, groups.index(drug)]\n",
    "\n",
    "    # smaller rank => stronger marker\n",
    "    idx = np.argsort(r)[:k]\n",
    "    gene_ids = adata.var_names[idx].to_list()\n",
    "    gene_short = (adata.var.iloc[idx][\"gene_short_name\"].to_list()\n",
    "                  if \"gene_short_name\" in adata.var.columns else None)\n",
    "    return gene_ids, gene_short, idx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c533daf9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "eb96ecaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import torch\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.decomposition import PCA\n",
    "from cellot.models.cellot import load_networks, compute_loss_f, compute_loss_g\n",
    "\n",
    "from sklearn.metrics.pairwise import rbf_kernel\n",
    "\n",
    "\n",
    "def mmd_distance(x, y, gamma):\n",
    "    xx = rbf_kernel(x, x, gamma)\n",
    "    xy = rbf_kernel(x, y, gamma)\n",
    "    yy = rbf_kernel(y, y, gamma)\n",
    "\n",
    "    return xx.mean() + yy.mean() - 2 * xy.mean()\n",
    "\n",
    "def compute_mmd_loss(lhs, rhs, gammas):\n",
    "    return np.mean([mmd_distance(lhs, rhs, g) for g in gammas])\n",
    "\n",
    "from cellot.losses.mmd import mmd_distance\n",
    "\n",
    "def run_cellot_pair(train_pre: np.ndarray, train_post: np.ndarray,\n",
    "                    test_pre: np.ndarray, test_post: np.ndarray,\n",
    "                    layers: Optional[List[int]] = [32, 32 ,32],\n",
    "                    n_epochs: int = 5000,\n",
    "                    feature_subset: Optional[List[int]] = None,) -> Dict:\n",
    "    \n",
    "    device = 'cuda'\n",
    "    print(f\"VERS torch={torch.__version__} (CellOT), device={device}\", file=sys.stderr, flush=True)\n",
    "\n",
    "\n",
    "    # Apply feature subset if specified\n",
    "    if feature_subset is not None:\n",
    "        print(f\"Using feature subset of size {len(feature_subset)}\", file=sys.stderr, flush=True)\n",
    "        train_pre = train_pre[:, feature_subset]\n",
    "        train_post = train_post[:, feature_subset]\n",
    "        test_pre = test_pre[:, feature_subset]\n",
    "        test_post = test_post[:, feature_subset]\n",
    "\n",
    "    # Preprocess: standardize jointly and optionally apply PCA for stability\n",
    "    X_all = np.vstack([train_pre, train_post])\n",
    "    scaler = StandardScaler()\n",
    "    X_all_s = scaler.fit_transform(X_all)\n",
    "    d = X_all_s.shape[1]\n",
    "    pca_dims = min(50, d)\n",
    "    if pca_dims < d:\n",
    "        pca = PCA(n_components=pca_dims, svd_solver='full', random_state=42)\n",
    "        X_all_p = pca.fit_transform(X_all_s)\n",
    "        tr_pre_p = X_all_p[:len(train_pre)]\n",
    "        tr_post_p = X_all_p[len(train_pre):]\n",
    "        te_pre_p = pca.transform(scaler.transform(test_pre))\n",
    "        use_pca = True\n",
    "    else:\n",
    "        tr_pre_p = X_all_s[:len(train_pre)]\n",
    "        tr_post_p = X_all_s[len(train_pre):]\n",
    "        te_pre_p = scaler.transform(test_pre)\n",
    "        use_pca = False\n",
    "\n",
    "    # Networks - Using official CellOT configuration\n",
    "    input_dim = tr_pre_p.shape[1]\n",
    "    config = {\n",
    "        'model': {\n",
    "            'name': 'cellot',\n",
    "            'hidden_units': layers,\n",
    "            'kernel_init_fxn': {'name': 'uniform', 'a': -0.01, 'b': 0.01},\n",
    "            'activation': 'relu',\n",
    "            'softplus_W_kernels': True,\n",
    "            'f': {},\n",
    "            'g': {}\n",
    "        }\n",
    "    }\n",
    "    f, g = load_networks(config, input_dim=input_dim)\n",
    "    f = f.to(device).float()\n",
    "    g = g.to(device).float()\n",
    "\n",
    "    # Data tensors\n",
    "    src = torch.tensor(tr_pre_p, dtype=torch.float32, device=device)\n",
    "    tgt = torch.tensor(tr_post_p, dtype=torch.float32, device=device)\n",
    "    te_src = torch.tensor(te_pre_p, dtype=torch.float32, device=device)\n",
    "\n",
    "    # Optimizers matching official config\n",
    "    lr = 1e-4\n",
    "    optim_f = torch.optim.Adam(f.parameters(), lr=lr, betas=(0.5, 0.9), weight_decay=0)\n",
    "    optim_g = torch.optim.Adam(g.parameters(), lr=lr, betas=(0.5, 0.9), weight_decay=0)\n",
    "\n",
    "    # No schedulers in official config\n",
    "    # n_epochs = 1200  # More epochs for better convergence\n",
    "    n_epochs = n_epochs + 1  \n",
    "    # Training loop following official CellOT implementation\n",
    "    f.train(); g.train()\n",
    "    batch_size = 256  # Official config\n",
    "    n_inner_iters = 10  # Official config\n",
    "\n",
    "\n",
    "    for epoch in range(n_epochs):\n",
    "        f.train(); g.train()\n",
    "        perm_t = torch.randperm(len(tgt), device=device)[:batch_size]\n",
    "        yt = tgt[perm_t]\n",
    "        \n",
    "        # Multiple g updates per iteration (official implementation)\n",
    "        for _ in range(n_inner_iters):\n",
    "            perm_s = torch.randperm(len(src), device=device)[:batch_size]\n",
    "            xs = src[perm_s].detach().clone().requires_grad_(True)\n",
    "            \n",
    "            optim_g.zero_grad()\n",
    "            g_loss = compute_loss_g(f, g, xs).mean()\n",
    "            g_loss.backward()\n",
    "            torch.nn.utils.clip_grad_norm_(g.parameters(), max_norm=0.5)\n",
    "            optim_g.step()\n",
    "        \n",
    "        # Single f update (official implementation)\n",
    "        perm_s = torch.randperm(len(src), device=device)[:batch_size]\n",
    "        xs = src[perm_s].detach().clone().requires_grad_(True)\n",
    "        \n",
    "        optim_f.zero_grad()\n",
    "        f_loss = compute_loss_f(f, g, xs, yt).mean()\n",
    "        f_loss.backward()\n",
    "        optim_f.step()\n",
    "        \n",
    "        # Clamp weights for f (official implementation)\n",
    "        if hasattr(f, 'clamp_w'):\n",
    "            f.clamp_w()\n",
    "        \n",
    "        \n",
    "        # ---- Evaluate train MMD and early-stop ----\n",
    "        if epoch % 50 == 0: \n",
    "            f.eval()\n",
    "            g.eval()\n",
    "\n",
    "\n",
    "            # Transport a fixed subset of training PRE (in preprocessed space)\n",
    "            tr_src_eval = src.requires_grad_(True)\n",
    "            tr_pred_p = g.transport(tr_src_eval).detach().cpu().numpy()\n",
    "            # Invert preprocessing to original space (so MMD is comparable to your final eval)\n",
    "            if use_pca:\n",
    "                tr_pred = scaler.inverse_transform(pca.inverse_transform(tr_pred_p))\n",
    "            else:\n",
    "                tr_pred = scaler.inverse_transform(tr_pred_p)\n",
    "            train_mmd_min = mmd_distance(train_post, tr_pred, gamma=1.0)\n",
    "\n",
    "\n",
    "            te_src_full = te_src.detach().clone().requires_grad_(True)\n",
    "            te_pred_full = g.transport(te_src_full).detach().cpu().numpy()\n",
    "            if use_pca:\n",
    "                te_pred_inv_full = scaler.inverse_transform(pca.inverse_transform(te_pred_full))\n",
    "            else:\n",
    "                te_pred_inv_full = scaler.inverse_transform(te_pred_full)\n",
    "            test_metrics = mmd_distance(test_post, te_pred_inv_full, gamma=median_gamma)\n",
    "\n",
    "            print(\n",
    "                f\"[CellOT] epoch={epoch} f_loss={f_loss.item():.4f} g_loss={g_loss.item():.4f} | \"\n",
    "                f\"train mmd={train_mmd_min:.4f} | \"\n",
    "                f\"test_mmd={test_metrics:.4f}\",\n",
    "                file=sys.stderr,\n",
    "                flush=True,\n",
    "            )\n",
    "\n",
    "                \n",
    "            \n",
    "\n",
    "    # Inference (CellOT transport requires gradients for autodiff)\n",
    "    f.eval(); g.eval()\n",
    "    # CellOT needs gradients even in eval mode for transport computation\n",
    "    te_src_for_transport = te_src.detach().clone().requires_grad_(True)\n",
    "    te_tx = g.transport(te_src_for_transport).detach().cpu().numpy()\n",
    "\n",
    "    # Inverse preprocess\n",
    "    if use_pca:\n",
    "        te_tx_inv = scaler.inverse_transform(pca.inverse_transform(te_tx))\n",
    "    else:\n",
    "        te_tx_inv = scaler.inverse_transform(te_tx)\n",
    "    # Final evaluation\n",
    "    metrics = summarize_metrics(test_post[:len(te_tx_inv)], te_tx_inv, median_gamma)\n",
    "\n",
    "    gammas = np.logspace(1, -3, num=50)\n",
    "    mmd = compute_mmd_loss(test_post[:len(te_tx_inv)], te_tx_inv, gammas=gammas)\n",
    "    print(f\"[CellOT] Final CellOT MMD: {mmd:.4f}\", file=sys.stderr, flush=True)\n",
    "    \n",
    "    return {'y_pred': te_tx_inv, 'metrics': metrics}\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b2a6a182",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cell line:  COLO858\n",
      "['DMSO' 'Vem' 'Vem+Tram']\n",
      "X_pre cells: (3026, 20)\n",
      "X_post cells: (3026, 20)\n",
      "(2420, 20)\n",
      "(606, 20)\n",
      "(2420, 20)\n",
      "(606, 20)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Median heuristic gamma: 0.05162262759745905\n",
      "**************** Run: 0 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=-1.3571 g_loss=-5.0167 | train mmd=0.2892 | test_mmd=1.3294\n",
      "[CellOT] epoch=50 f_loss=-2.5143 g_loss=0.3889 | train mmd=0.5300 | test_mmd=0.3511\n",
      "[CellOT] epoch=100 f_loss=-2.7093 g_loss=3.3984 | train mmd=0.7644 | test_mmd=0.3322\n",
      "[CellOT] epoch=150 f_loss=-4.6768 g_loss=5.7451 | train mmd=0.7639 | test_mmd=0.2845\n",
      "[CellOT] epoch=200 f_loss=-6.2082 g_loss=8.3595 | train mmd=0.7725 | test_mmd=0.2526\n",
      "[CellOT] epoch=250 f_loss=-8.3589 g_loss=10.3283 | train mmd=0.7793 | test_mmd=0.2200\n",
      "[CellOT] epoch=300 f_loss=-9.2954 g_loss=12.8297 | train mmd=0.7534 | test_mmd=0.1925\n",
      "[CellOT] epoch=350 f_loss=-10.7652 g_loss=15.3954 | train mmd=0.7190 | test_mmd=0.1685\n",
      "[CellOT] epoch=400 f_loss=-11.8829 g_loss=16.8155 | train mmd=0.7151 | test_mmd=0.1418\n",
      "[CellOT] epoch=450 f_loss=-12.4099 g_loss=18.5495 | train mmd=0.7033 | test_mmd=0.1227\n",
      "[CellOT] epoch=500 f_loss=-12.7849 g_loss=21.1363 | train mmd=0.6769 | test_mmd=0.1006\n",
      "[CellOT] epoch=550 f_loss=-14.6765 g_loss=22.5376 | train mmd=0.6306 | test_mmd=0.0814\n",
      "[CellOT] epoch=600 f_loss=-12.3123 g_loss=25.5090 | train mmd=0.5872 | test_mmd=0.0625\n",
      "[CellOT] epoch=650 f_loss=-12.1184 g_loss=25.5943 | train mmd=0.5332 | test_mmd=0.0482\n",
      "[CellOT] epoch=700 f_loss=-10.6701 g_loss=28.9012 | train mmd=0.5362 | test_mmd=0.0356\n",
      "[CellOT] epoch=750 f_loss=-10.5158 g_loss=30.8831 | train mmd=0.5038 | test_mmd=0.0265\n",
      "[CellOT] epoch=800 f_loss=-6.1584 g_loss=32.6846 | train mmd=0.4605 | test_mmd=0.0183\n",
      "[CellOT] epoch=850 f_loss=-5.2701 g_loss=35.3705 | train mmd=0.4147 | test_mmd=0.0122\n",
      "[CellOT] epoch=900 f_loss=1.5462 g_loss=34.0732 | train mmd=0.3574 | test_mmd=0.0075\n",
      "[CellOT] epoch=950 f_loss=4.8897 g_loss=26.6522 | train mmd=0.1492 | test_mmd=0.0113\n",
      "[CellOT] epoch=1000 f_loss=2.0572 g_loss=22.3635 | train mmd=0.0402 | test_mmd=0.0044\n",
      "[CellOT] epoch=1050 f_loss=1.3366 g_loss=20.7850 | train mmd=0.0338 | test_mmd=0.0065\n",
      "[CellOT] epoch=1100 f_loss=0.0106 g_loss=20.3072 | train mmd=0.0301 | test_mmd=0.0031\n",
      "[CellOT] epoch=1150 f_loss=0.8958 g_loss=20.9379 | train mmd=0.0406 | test_mmd=0.0021\n",
      "[CellOT] epoch=1200 f_loss=0.5587 g_loss=21.3058 | train mmd=0.0459 | test_mmd=0.0029\n",
      "[CellOT] epoch=1250 f_loss=0.0036 g_loss=21.3006 | train mmd=0.0410 | test_mmd=0.0065\n",
      "[CellOT] epoch=1300 f_loss=0.0105 g_loss=22.0539 | train mmd=0.0386 | test_mmd=0.0040\n",
      "[CellOT] epoch=1350 f_loss=-0.2105 g_loss=22.1190 | train mmd=0.0292 | test_mmd=0.0030\n",
      "[CellOT] epoch=1400 f_loss=-0.3228 g_loss=22.0275 | train mmd=0.0301 | test_mmd=0.0021\n",
      "[CellOT] epoch=1450 f_loss=-0.7377 g_loss=22.1069 | train mmd=0.0296 | test_mmd=0.0046\n",
      "[CellOT] epoch=1500 f_loss=0.1765 g_loss=22.3272 | train mmd=0.0219 | test_mmd=0.0023\n",
      "[CellOT] epoch=1550 f_loss=0.0965 g_loss=22.2864 | train mmd=0.0174 | test_mmd=0.0020\n",
      "[CellOT] epoch=1600 f_loss=-0.1174 g_loss=21.9033 | train mmd=0.0180 | test_mmd=0.0022\n",
      "[CellOT] epoch=1650 f_loss=-0.1918 g_loss=21.4166 | train mmd=0.0172 | test_mmd=0.0022\n",
      "[CellOT] epoch=1700 f_loss=0.3476 g_loss=21.3394 | train mmd=0.0097 | test_mmd=0.0017\n",
      "[CellOT] epoch=1750 f_loss=0.1206 g_loss=21.1588 | train mmd=0.0121 | test_mmd=0.0008\n",
      "[CellOT] epoch=1800 f_loss=0.0055 g_loss=21.1490 | train mmd=0.0110 | test_mmd=0.0008\n",
      "[CellOT] epoch=1850 f_loss=-0.4549 g_loss=21.3170 | train mmd=0.0123 | test_mmd=0.0026\n",
      "[CellOT] epoch=1900 f_loss=0.5764 g_loss=21.2081 | train mmd=0.0112 | test_mmd=0.0018\n",
      "[CellOT] epoch=1950 f_loss=0.0917 g_loss=20.8859 | train mmd=0.0125 | test_mmd=0.0021\n",
      "[CellOT] epoch=2000 f_loss=0.3675 g_loss=20.7492 | train mmd=0.0074 | test_mmd=0.0009\n",
      "[CellOT] Final CellOT MMD: 0.0035\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=0.5174 g_loss=-5.0610 | train mmd=0.3603 | test_mmd=1.2173\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 0 metrics: {'mmd2_gamma_median': 0.0008898991539725287, 'mmd2_gamma_0.5': 0.007069895455005026, 'mmd2_gamma_1.0': 0.010359813910979315, 'wasserstein_distance': 0.8182263571257119, 'R2_feature_means': 0.9980937226362605}\n",
      "**************** Run: 1 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-2.4906 g_loss=0.4710 | train mmd=0.5403 | test_mmd=0.3807\n",
      "[CellOT] epoch=100 f_loss=-2.8514 g_loss=3.4207 | train mmd=0.7766 | test_mmd=0.3310\n",
      "[CellOT] epoch=150 f_loss=-4.4655 g_loss=6.1398 | train mmd=0.7934 | test_mmd=0.2957\n",
      "[CellOT] epoch=200 f_loss=-7.6001 g_loss=8.5983 | train mmd=0.7593 | test_mmd=0.2514\n",
      "[CellOT] epoch=250 f_loss=-8.6677 g_loss=10.8540 | train mmd=0.7757 | test_mmd=0.2247\n",
      "[CellOT] epoch=300 f_loss=-9.7066 g_loss=12.3457 | train mmd=0.7256 | test_mmd=0.1922\n",
      "[CellOT] epoch=350 f_loss=-11.3115 g_loss=15.5891 | train mmd=0.7172 | test_mmd=0.1654\n",
      "[CellOT] epoch=400 f_loss=-13.1452 g_loss=18.8188 | train mmd=0.6971 | test_mmd=0.1416\n",
      "[CellOT] epoch=450 f_loss=-12.8291 g_loss=19.9814 | train mmd=0.6766 | test_mmd=0.1202\n",
      "[CellOT] epoch=500 f_loss=-14.8492 g_loss=21.6721 | train mmd=0.6507 | test_mmd=0.1006\n",
      "[CellOT] epoch=550 f_loss=-13.9256 g_loss=23.3323 | train mmd=0.6011 | test_mmd=0.0796\n",
      "[CellOT] epoch=600 f_loss=-14.8558 g_loss=24.9957 | train mmd=0.5853 | test_mmd=0.0633\n",
      "[CellOT] epoch=650 f_loss=-12.9454 g_loss=27.3428 | train mmd=0.5681 | test_mmd=0.0528\n",
      "[CellOT] epoch=700 f_loss=-12.3754 g_loss=28.5942 | train mmd=0.5304 | test_mmd=0.0394\n",
      "[CellOT] epoch=750 f_loss=-10.3630 g_loss=30.7509 | train mmd=0.4917 | test_mmd=0.0292\n",
      "[CellOT] epoch=800 f_loss=-8.1420 g_loss=32.4673 | train mmd=0.4446 | test_mmd=0.0191\n",
      "[CellOT] epoch=850 f_loss=-5.7793 g_loss=35.0198 | train mmd=0.4265 | test_mmd=0.0133\n",
      "[CellOT] epoch=900 f_loss=-0.9444 g_loss=35.8990 | train mmd=0.3949 | test_mmd=0.0098\n",
      "[CellOT] epoch=950 f_loss=3.6796 g_loss=31.3658 | train mmd=0.2780 | test_mmd=0.0065\n",
      "[CellOT] epoch=1000 f_loss=6.6701 g_loss=25.6656 | train mmd=0.1045 | test_mmd=0.0073\n",
      "[CellOT] epoch=1050 f_loss=1.5677 g_loss=25.6723 | train mmd=0.0729 | test_mmd=0.0048\n",
      "[CellOT] epoch=1100 f_loss=2.0540 g_loss=25.4642 | train mmd=0.0616 | test_mmd=0.0030\n",
      "[CellOT] epoch=1150 f_loss=1.2870 g_loss=27.1306 | train mmd=0.0426 | test_mmd=0.0024\n",
      "[CellOT] epoch=1200 f_loss=0.9565 g_loss=27.7010 | train mmd=0.0481 | test_mmd=0.0018\n",
      "[CellOT] epoch=1250 f_loss=1.2509 g_loss=28.4265 | train mmd=0.0580 | test_mmd=0.0019\n",
      "[CellOT] epoch=1300 f_loss=0.8106 g_loss=28.5297 | train mmd=0.0481 | test_mmd=0.0031\n",
      "[CellOT] epoch=1350 f_loss=0.5816 g_loss=29.5112 | train mmd=0.0560 | test_mmd=0.0020\n",
      "[CellOT] epoch=1400 f_loss=0.0617 g_loss=31.7753 | train mmd=0.0422 | test_mmd=0.0031\n",
      "[CellOT] epoch=1450 f_loss=-0.0568 g_loss=31.1360 | train mmd=0.0424 | test_mmd=0.0025\n",
      "[CellOT] epoch=1500 f_loss=-0.2116 g_loss=32.0117 | train mmd=0.0354 | test_mmd=0.0039\n",
      "[CellOT] epoch=1550 f_loss=-0.0618 g_loss=32.1091 | train mmd=0.0305 | test_mmd=0.0025\n",
      "[CellOT] epoch=1600 f_loss=-0.0160 g_loss=31.8281 | train mmd=0.0235 | test_mmd=0.0016\n",
      "[CellOT] epoch=1650 f_loss=1.0145 g_loss=32.4101 | train mmd=0.0265 | test_mmd=0.0031\n",
      "[CellOT] epoch=1700 f_loss=-0.0650 g_loss=31.9639 | train mmd=0.0260 | test_mmd=0.0017\n",
      "[CellOT] epoch=1750 f_loss=-0.1872 g_loss=32.0672 | train mmd=0.0288 | test_mmd=0.0041\n",
      "[CellOT] epoch=1800 f_loss=0.1876 g_loss=31.0125 | train mmd=0.0340 | test_mmd=0.0028\n",
      "[CellOT] epoch=1850 f_loss=0.2666 g_loss=30.1587 | train mmd=0.0374 | test_mmd=0.0008\n",
      "[CellOT] epoch=1900 f_loss=-0.3251 g_loss=29.8163 | train mmd=0.0364 | test_mmd=0.0025\n",
      "[CellOT] epoch=1950 f_loss=-0.0471 g_loss=29.9464 | train mmd=0.0256 | test_mmd=0.0032\n",
      "[CellOT] epoch=2000 f_loss=-0.0388 g_loss=29.9188 | train mmd=0.0190 | test_mmd=0.0010\n",
      "[CellOT] Final CellOT MMD: 0.0078\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=-1.8381 g_loss=-3.4163 | train mmd=0.3046 | test_mmd=1.0963\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 1 metrics: {'mmd2_gamma_median': 0.001020110719267553, 'mmd2_gamma_0.5': 0.01794226207019589, 'mmd2_gamma_1.0': 0.025040865103214627, 'wasserstein_distance': 0.8619834472029366, 'R2_feature_means': 0.9985559172820524}\n",
      "**************** Run: 2 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-2.6959 g_loss=0.8991 | train mmd=0.5041 | test_mmd=0.3510\n",
      "[CellOT] epoch=100 f_loss=-2.8619 g_loss=3.8002 | train mmd=0.7269 | test_mmd=0.3080\n",
      "[CellOT] epoch=150 f_loss=-5.2862 g_loss=6.2116 | train mmd=0.7570 | test_mmd=0.2771\n",
      "[CellOT] epoch=200 f_loss=-6.9514 g_loss=9.0297 | train mmd=0.7652 | test_mmd=0.2454\n",
      "[CellOT] epoch=250 f_loss=-9.2076 g_loss=11.9401 | train mmd=0.7713 | test_mmd=0.2179\n",
      "[CellOT] epoch=300 f_loss=-9.7170 g_loss=14.6516 | train mmd=0.7247 | test_mmd=0.1879\n",
      "[CellOT] epoch=350 f_loss=-11.8805 g_loss=16.8919 | train mmd=0.7517 | test_mmd=0.1606\n",
      "[CellOT] epoch=400 f_loss=-12.6275 g_loss=19.0220 | train mmd=0.6699 | test_mmd=0.1316\n",
      "[CellOT] epoch=450 f_loss=-12.9890 g_loss=19.7493 | train mmd=0.6566 | test_mmd=0.1084\n",
      "[CellOT] epoch=500 f_loss=-13.9485 g_loss=21.9765 | train mmd=0.6444 | test_mmd=0.0940\n",
      "[CellOT] epoch=550 f_loss=-12.0966 g_loss=23.7916 | train mmd=0.6255 | test_mmd=0.0821\n",
      "[CellOT] epoch=600 f_loss=-11.8459 g_loss=26.6054 | train mmd=0.5622 | test_mmd=0.0616\n",
      "[CellOT] epoch=650 f_loss=-10.4453 g_loss=29.8165 | train mmd=0.5421 | test_mmd=0.0500\n",
      "[CellOT] epoch=700 f_loss=-12.5768 g_loss=29.4027 | train mmd=0.5475 | test_mmd=0.0401\n",
      "[CellOT] epoch=750 f_loss=-10.4145 g_loss=32.5187 | train mmd=0.5397 | test_mmd=0.0295\n",
      "[CellOT] epoch=800 f_loss=-5.7293 g_loss=36.2221 | train mmd=0.5128 | test_mmd=0.0206\n",
      "[CellOT] epoch=850 f_loss=-4.6067 g_loss=38.4557 | train mmd=0.4526 | test_mmd=0.0132\n",
      "[CellOT] epoch=900 f_loss=0.9955 g_loss=37.9481 | train mmd=0.3977 | test_mmd=0.0085\n",
      "[CellOT] epoch=950 f_loss=3.3479 g_loss=36.9818 | train mmd=0.3216 | test_mmd=0.0069\n",
      "[CellOT] epoch=1000 f_loss=5.3294 g_loss=31.8534 | train mmd=0.1270 | test_mmd=0.0071\n",
      "[CellOT] epoch=1050 f_loss=2.9006 g_loss=31.1499 | train mmd=0.0235 | test_mmd=0.0024\n",
      "[CellOT] epoch=1100 f_loss=2.0586 g_loss=32.4408 | train mmd=0.0363 | test_mmd=0.0046\n",
      "[CellOT] epoch=1150 f_loss=1.4836 g_loss=33.4684 | train mmd=0.0373 | test_mmd=0.0020\n",
      "[CellOT] epoch=1200 f_loss=0.2153 g_loss=34.2184 | train mmd=0.0387 | test_mmd=0.0043\n",
      "[CellOT] epoch=1250 f_loss=1.3092 g_loss=34.4626 | train mmd=0.0445 | test_mmd=0.0030\n",
      "[CellOT] epoch=1300 f_loss=0.5080 g_loss=34.3857 | train mmd=0.0327 | test_mmd=0.0014\n",
      "[CellOT] epoch=1350 f_loss=0.7098 g_loss=35.1164 | train mmd=0.0345 | test_mmd=0.0015\n",
      "[CellOT] epoch=1400 f_loss=0.1349 g_loss=35.7418 | train mmd=0.0323 | test_mmd=0.0028\n",
      "[CellOT] epoch=1450 f_loss=0.5635 g_loss=35.7523 | train mmd=0.0355 | test_mmd=0.0035\n",
      "[CellOT] epoch=1500 f_loss=0.5390 g_loss=35.9836 | train mmd=0.0225 | test_mmd=0.0011\n",
      "[CellOT] epoch=1550 f_loss=0.2263 g_loss=36.6913 | train mmd=0.0283 | test_mmd=0.0038\n",
      "[CellOT] epoch=1600 f_loss=-0.0575 g_loss=37.0929 | train mmd=0.0236 | test_mmd=0.0019\n",
      "[CellOT] epoch=1650 f_loss=-0.0546 g_loss=36.5991 | train mmd=0.0182 | test_mmd=0.0013\n",
      "[CellOT] epoch=1700 f_loss=0.1583 g_loss=36.7283 | train mmd=0.0179 | test_mmd=0.0016\n",
      "[CellOT] epoch=1750 f_loss=-0.6076 g_loss=37.5808 | train mmd=0.0145 | test_mmd=0.0021\n",
      "[CellOT] epoch=1800 f_loss=0.3655 g_loss=37.6435 | train mmd=0.0205 | test_mmd=0.0020\n",
      "[CellOT] epoch=1850 f_loss=0.0839 g_loss=37.2206 | train mmd=0.0212 | test_mmd=0.0034\n",
      "[CellOT] epoch=1900 f_loss=0.1380 g_loss=37.6321 | train mmd=0.0184 | test_mmd=0.0026\n",
      "[CellOT] epoch=1950 f_loss=1.0177 g_loss=37.5047 | train mmd=0.0264 | test_mmd=0.0051\n",
      "[CellOT] epoch=2000 f_loss=0.7026 g_loss=37.7241 | train mmd=0.0164 | test_mmd=0.0027\n",
      "[CellOT] Final CellOT MMD: 0.0081\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=0.7495 g_loss=-4.0812 | train mmd=0.3220 | test_mmd=1.1629\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 2 metrics: {'mmd2_gamma_median': 0.0027355647192131016, 'mmd2_gamma_0.5': 0.018382068437655996, 'mmd2_gamma_1.0': 0.023473337056691967, 'wasserstein_distance': 0.8563874212148906, 'R2_feature_means': 0.9937967556607126}\n",
      "**************** Run: 3 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-2.7409 g_loss=0.7216 | train mmd=0.4185 | test_mmd=0.3321\n",
      "[CellOT] epoch=100 f_loss=-2.5430 g_loss=3.3267 | train mmd=0.6933 | test_mmd=0.3117\n",
      "[CellOT] epoch=150 f_loss=-4.4303 g_loss=5.8628 | train mmd=0.7726 | test_mmd=0.2834\n",
      "[CellOT] epoch=200 f_loss=-6.2860 g_loss=7.9416 | train mmd=0.7607 | test_mmd=0.2465\n",
      "[CellOT] epoch=250 f_loss=-7.4473 g_loss=10.2619 | train mmd=0.7578 | test_mmd=0.2198\n",
      "[CellOT] epoch=300 f_loss=-8.9950 g_loss=12.0807 | train mmd=0.7124 | test_mmd=0.1827\n",
      "[CellOT] epoch=350 f_loss=-10.5923 g_loss=14.4168 | train mmd=0.6852 | test_mmd=0.1552\n",
      "[CellOT] epoch=400 f_loss=-10.6812 g_loss=15.4669 | train mmd=0.6562 | test_mmd=0.1322\n",
      "[CellOT] epoch=450 f_loss=-11.9417 g_loss=16.9186 | train mmd=0.6708 | test_mmd=0.1130\n",
      "[CellOT] epoch=500 f_loss=-12.0665 g_loss=20.8225 | train mmd=0.6379 | test_mmd=0.0901\n",
      "[CellOT] epoch=550 f_loss=-10.7181 g_loss=22.7183 | train mmd=0.6065 | test_mmd=0.0777\n",
      "[CellOT] epoch=600 f_loss=-11.4408 g_loss=23.8269 | train mmd=0.5781 | test_mmd=0.0620\n",
      "[CellOT] epoch=650 f_loss=-11.1607 g_loss=26.4530 | train mmd=0.5242 | test_mmd=0.0453\n",
      "[CellOT] epoch=700 f_loss=-11.1541 g_loss=27.5915 | train mmd=0.5409 | test_mmd=0.0371\n",
      "[CellOT] epoch=750 f_loss=-7.2811 g_loss=27.2416 | train mmd=0.4152 | test_mmd=0.0203\n",
      "[CellOT] epoch=800 f_loss=-5.6859 g_loss=30.6204 | train mmd=0.3768 | test_mmd=0.0151\n",
      "[CellOT] epoch=850 f_loss=-1.1662 g_loss=29.0823 | train mmd=0.2875 | test_mmd=0.0086\n",
      "[CellOT] epoch=900 f_loss=1.6119 g_loss=25.9892 | train mmd=0.1779 | test_mmd=0.0038\n",
      "[CellOT] epoch=950 f_loss=3.7248 g_loss=27.8298 | train mmd=0.1145 | test_mmd=0.0029\n",
      "[CellOT] epoch=1000 f_loss=1.1559 g_loss=31.0124 | train mmd=0.1243 | test_mmd=0.0028\n",
      "[CellOT] epoch=1050 f_loss=0.5660 g_loss=30.2879 | train mmd=0.0603 | test_mmd=0.0025\n",
      "[CellOT] epoch=1100 f_loss=0.3740 g_loss=30.9307 | train mmd=0.0379 | test_mmd=0.0016\n",
      "[CellOT] epoch=1150 f_loss=0.5002 g_loss=31.7958 | train mmd=0.0324 | test_mmd=0.0028\n",
      "[CellOT] epoch=1200 f_loss=0.3855 g_loss=32.3971 | train mmd=0.0415 | test_mmd=0.0023\n",
      "[CellOT] epoch=1250 f_loss=0.5251 g_loss=33.5741 | train mmd=0.0502 | test_mmd=0.0018\n",
      "[CellOT] epoch=1300 f_loss=-0.1059 g_loss=34.0101 | train mmd=0.0436 | test_mmd=0.0026\n",
      "[CellOT] epoch=1350 f_loss=0.8746 g_loss=34.0320 | train mmd=0.0565 | test_mmd=0.0022\n",
      "[CellOT] epoch=1400 f_loss=-0.2061 g_loss=35.1767 | train mmd=0.0322 | test_mmd=0.0034\n",
      "[CellOT] epoch=1450 f_loss=0.5436 g_loss=34.7119 | train mmd=0.0307 | test_mmd=0.0035\n",
      "[CellOT] epoch=1500 f_loss=0.2046 g_loss=34.5117 | train mmd=0.0471 | test_mmd=0.0047\n",
      "[CellOT] epoch=1550 f_loss=0.6021 g_loss=34.2241 | train mmd=0.0541 | test_mmd=0.0032\n",
      "[CellOT] epoch=1600 f_loss=-0.6244 g_loss=34.0849 | train mmd=0.0549 | test_mmd=0.0025\n",
      "[CellOT] epoch=1650 f_loss=0.6015 g_loss=33.1782 | train mmd=0.0845 | test_mmd=0.0030\n",
      "[CellOT] epoch=1700 f_loss=0.3212 g_loss=32.9644 | train mmd=0.0429 | test_mmd=0.0031\n",
      "[CellOT] epoch=1750 f_loss=0.2567 g_loss=33.9153 | train mmd=0.0339 | test_mmd=0.0015\n",
      "[CellOT] epoch=1800 f_loss=0.2690 g_loss=34.7482 | train mmd=0.0264 | test_mmd=0.0012\n",
      "[CellOT] epoch=1850 f_loss=-0.1894 g_loss=35.1404 | train mmd=0.0324 | test_mmd=0.0020\n",
      "[CellOT] epoch=1900 f_loss=0.4111 g_loss=35.0285 | train mmd=0.0368 | test_mmd=0.0019\n",
      "[CellOT] epoch=1950 f_loss=-0.6016 g_loss=35.8955 | train mmd=0.0372 | test_mmd=0.0050\n",
      "[CellOT] epoch=2000 f_loss=-0.2862 g_loss=35.6417 | train mmd=0.0326 | test_mmd=0.0025\n",
      "[CellOT] Final CellOT MMD: 0.0119\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=1.3918 g_loss=-7.2148 | train mmd=0.2864 | test_mmd=1.3941\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 3 metrics: {'mmd2_gamma_median': 0.0025033148282909146, 'mmd2_gamma_0.5': 0.027621538598427087, 'mmd2_gamma_1.0': 0.0373842368149645, 'wasserstein_distance': 0.9049259161354114, 'R2_feature_means': 0.9951682313896952}\n",
      "**************** Run: 4 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-2.6134 g_loss=0.1850 | train mmd=0.4353 | test_mmd=0.3792\n",
      "[CellOT] epoch=100 f_loss=-2.7357 g_loss=3.0324 | train mmd=0.7675 | test_mmd=0.3419\n",
      "[CellOT] epoch=150 f_loss=-4.3167 g_loss=5.5132 | train mmd=0.7895 | test_mmd=0.3033\n",
      "[CellOT] epoch=200 f_loss=-6.3501 g_loss=8.1772 | train mmd=0.7607 | test_mmd=0.2609\n",
      "[CellOT] epoch=250 f_loss=-7.5294 g_loss=10.5203 | train mmd=0.8219 | test_mmd=0.2443\n",
      "[CellOT] epoch=300 f_loss=-10.2900 g_loss=11.8582 | train mmd=0.7191 | test_mmd=0.1985\n",
      "[CellOT] epoch=350 f_loss=-10.6060 g_loss=14.6116 | train mmd=0.7218 | test_mmd=0.1744\n",
      "[CellOT] epoch=400 f_loss=-12.8360 g_loss=16.4170 | train mmd=0.7252 | test_mmd=0.1487\n",
      "[CellOT] epoch=450 f_loss=-11.3018 g_loss=17.6407 | train mmd=0.6680 | test_mmd=0.1232\n",
      "[CellOT] epoch=500 f_loss=-12.8000 g_loss=20.2028 | train mmd=0.6627 | test_mmd=0.1018\n",
      "[CellOT] epoch=550 f_loss=-14.0756 g_loss=20.4469 | train mmd=0.6058 | test_mmd=0.0805\n",
      "[CellOT] epoch=600 f_loss=-13.2344 g_loss=23.8541 | train mmd=0.5746 | test_mmd=0.0631\n",
      "[CellOT] epoch=650 f_loss=-12.1144 g_loss=23.5781 | train mmd=0.5345 | test_mmd=0.0463\n",
      "[CellOT] epoch=700 f_loss=-11.3004 g_loss=24.3541 | train mmd=0.4811 | test_mmd=0.0320\n",
      "[CellOT] epoch=750 f_loss=-8.1055 g_loss=25.0881 | train mmd=0.4153 | test_mmd=0.0231\n",
      "[CellOT] epoch=800 f_loss=-4.6978 g_loss=27.3009 | train mmd=0.3675 | test_mmd=0.0148\n",
      "[CellOT] epoch=850 f_loss=-4.3194 g_loss=25.0764 | train mmd=0.3184 | test_mmd=0.0084\n",
      "[CellOT] epoch=900 f_loss=2.9426 g_loss=24.1533 | train mmd=0.2314 | test_mmd=0.0064\n",
      "[CellOT] epoch=950 f_loss=2.4109 g_loss=20.5597 | train mmd=0.1666 | test_mmd=0.0047\n",
      "[CellOT] epoch=1000 f_loss=1.8051 g_loss=19.3493 | train mmd=0.1246 | test_mmd=0.0032\n",
      "[CellOT] epoch=1050 f_loss=1.3122 g_loss=20.3736 | train mmd=0.1191 | test_mmd=0.0024\n",
      "[CellOT] epoch=1100 f_loss=1.4387 g_loss=22.4475 | train mmd=0.1257 | test_mmd=0.0044\n",
      "[CellOT] epoch=1150 f_loss=-0.6905 g_loss=24.0037 | train mmd=0.1011 | test_mmd=0.0042\n",
      "[CellOT] epoch=1200 f_loss=-0.3887 g_loss=24.7892 | train mmd=0.1220 | test_mmd=0.0082\n",
      "[CellOT] epoch=1250 f_loss=-0.0385 g_loss=25.7098 | train mmd=0.1074 | test_mmd=0.0039\n",
      "[CellOT] epoch=1300 f_loss=0.1400 g_loss=26.5430 | train mmd=0.0779 | test_mmd=0.0035\n",
      "[CellOT] epoch=1350 f_loss=-0.3296 g_loss=26.0512 | train mmd=0.0695 | test_mmd=0.0075\n",
      "[CellOT] epoch=1400 f_loss=0.2329 g_loss=25.4579 | train mmd=0.0597 | test_mmd=0.0023\n",
      "[CellOT] epoch=1450 f_loss=0.3321 g_loss=26.1401 | train mmd=0.0557 | test_mmd=0.0033\n",
      "[CellOT] epoch=1500 f_loss=0.4219 g_loss=26.1093 | train mmd=0.0403 | test_mmd=0.0016\n",
      "[CellOT] epoch=1550 f_loss=0.0904 g_loss=26.5117 | train mmd=0.0400 | test_mmd=0.0022\n",
      "[CellOT] epoch=1600 f_loss=-0.3686 g_loss=27.5417 | train mmd=0.0315 | test_mmd=0.0025\n",
      "[CellOT] epoch=1650 f_loss=-0.8262 g_loss=26.6084 | train mmd=0.0271 | test_mmd=0.0020\n",
      "[CellOT] epoch=1700 f_loss=0.8380 g_loss=26.5197 | train mmd=0.0280 | test_mmd=0.0025\n",
      "[CellOT] epoch=1750 f_loss=0.0293 g_loss=26.7991 | train mmd=0.0293 | test_mmd=0.0025\n",
      "[CellOT] epoch=1800 f_loss=0.9064 g_loss=26.6780 | train mmd=0.0299 | test_mmd=0.0021\n",
      "[CellOT] epoch=1850 f_loss=-0.2503 g_loss=27.5506 | train mmd=0.0455 | test_mmd=0.0047\n",
      "[CellOT] epoch=1900 f_loss=-0.5052 g_loss=26.3254 | train mmd=0.0633 | test_mmd=0.0040\n",
      "[CellOT] epoch=1950 f_loss=0.6538 g_loss=25.6432 | train mmd=0.0631 | test_mmd=0.0029\n",
      "[CellOT] epoch=2000 f_loss=0.4240 g_loss=25.3219 | train mmd=0.0430 | test_mmd=0.0030\n",
      "[CellOT] Final CellOT MMD: 0.0168\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 4 metrics: {'mmd2_gamma_median': 0.0030480421335439267, 'mmd2_gamma_0.5': 0.04030990915367039, 'mmd2_gamma_1.0': 0.0539875240856193, 'wasserstein_distance': 0.9330180389416185, 'R2_feature_means': 0.9952024971040033}\n",
      "**************** Run: 5 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=1.4676 g_loss=-6.0173 | train mmd=0.3471 | test_mmd=1.3586\n",
      "[CellOT] epoch=50 f_loss=-2.0563 g_loss=0.7326 | train mmd=0.4177 | test_mmd=0.4109\n",
      "[CellOT] epoch=100 f_loss=-3.1300 g_loss=3.5013 | train mmd=0.7811 | test_mmd=0.3607\n",
      "[CellOT] epoch=150 f_loss=-4.8689 g_loss=6.3893 | train mmd=0.8051 | test_mmd=0.3181\n",
      "[CellOT] epoch=200 f_loss=-7.4827 g_loss=8.9539 | train mmd=0.7519 | test_mmd=0.2700\n",
      "[CellOT] epoch=250 f_loss=-9.0560 g_loss=10.9928 | train mmd=0.7363 | test_mmd=0.2329\n",
      "[CellOT] epoch=300 f_loss=-10.6156 g_loss=14.7211 | train mmd=0.7323 | test_mmd=0.2067\n",
      "[CellOT] epoch=350 f_loss=-11.9578 g_loss=17.0497 | train mmd=0.7464 | test_mmd=0.1878\n",
      "[CellOT] epoch=400 f_loss=-14.6397 g_loss=19.2456 | train mmd=0.7218 | test_mmd=0.1644\n",
      "[CellOT] epoch=450 f_loss=-13.9028 g_loss=21.0629 | train mmd=0.6802 | test_mmd=0.1415\n",
      "[CellOT] epoch=500 f_loss=-15.5455 g_loss=23.1224 | train mmd=0.6669 | test_mmd=0.1220\n",
      "[CellOT] epoch=550 f_loss=-16.7254 g_loss=27.7914 | train mmd=0.6647 | test_mmd=0.1050\n",
      "[CellOT] epoch=600 f_loss=-19.2715 g_loss=29.5915 | train mmd=0.6518 | test_mmd=0.0883\n",
      "[CellOT] epoch=650 f_loss=-15.1238 g_loss=29.7379 | train mmd=0.5797 | test_mmd=0.0701\n",
      "[CellOT] epoch=700 f_loss=-14.8660 g_loss=30.8548 | train mmd=0.5734 | test_mmd=0.0550\n",
      "[CellOT] epoch=750 f_loss=-12.3264 g_loss=32.6558 | train mmd=0.5051 | test_mmd=0.0403\n",
      "[CellOT] epoch=800 f_loss=-7.6013 g_loss=31.7240 | train mmd=0.4309 | test_mmd=0.0268\n",
      "[CellOT] epoch=850 f_loss=-6.3587 g_loss=36.9175 | train mmd=0.4288 | test_mmd=0.0198\n",
      "[CellOT] epoch=900 f_loss=-3.9579 g_loss=37.0130 | train mmd=0.4093 | test_mmd=0.0135\n",
      "[CellOT] epoch=950 f_loss=1.6381 g_loss=37.7178 | train mmd=0.3497 | test_mmd=0.0095\n",
      "[CellOT] epoch=1000 f_loss=4.0493 g_loss=32.4149 | train mmd=0.2546 | test_mmd=0.0067\n",
      "[CellOT] epoch=1050 f_loss=7.1632 g_loss=29.9042 | train mmd=0.0945 | test_mmd=0.0060\n",
      "[CellOT] epoch=1100 f_loss=2.1366 g_loss=29.7194 | train mmd=0.0382 | test_mmd=0.0017\n",
      "[CellOT] epoch=1150 f_loss=1.8863 g_loss=30.1771 | train mmd=0.0332 | test_mmd=0.0036\n",
      "[CellOT] epoch=1200 f_loss=1.1036 g_loss=30.5932 | train mmd=0.0353 | test_mmd=0.0042\n",
      "[CellOT] epoch=1250 f_loss=1.4567 g_loss=30.6544 | train mmd=0.0941 | test_mmd=0.0065\n",
      "[CellOT] epoch=1300 f_loss=1.1183 g_loss=30.8972 | train mmd=0.0562 | test_mmd=0.0071\n",
      "[CellOT] epoch=1350 f_loss=-0.2355 g_loss=31.9437 | train mmd=0.0362 | test_mmd=0.0035\n",
      "[CellOT] epoch=1400 f_loss=0.9273 g_loss=31.7684 | train mmd=0.0502 | test_mmd=0.0026\n",
      "[CellOT] epoch=1450 f_loss=0.1356 g_loss=32.2974 | train mmd=0.0277 | test_mmd=0.0023\n",
      "[CellOT] epoch=1500 f_loss=-0.1147 g_loss=32.4414 | train mmd=0.0347 | test_mmd=0.0034\n",
      "[CellOT] epoch=1550 f_loss=0.3281 g_loss=32.7087 | train mmd=0.0478 | test_mmd=0.0035\n",
      "[CellOT] epoch=1600 f_loss=0.4948 g_loss=32.8299 | train mmd=0.0345 | test_mmd=0.0026\n",
      "[CellOT] epoch=1650 f_loss=-0.0932 g_loss=32.3008 | train mmd=0.0436 | test_mmd=0.0012\n",
      "[CellOT] epoch=1700 f_loss=0.9317 g_loss=31.6051 | train mmd=0.0561 | test_mmd=0.0034\n",
      "[CellOT] epoch=1750 f_loss=0.5563 g_loss=31.1170 | train mmd=0.0404 | test_mmd=0.0025\n",
      "[CellOT] epoch=1800 f_loss=0.1824 g_loss=31.1199 | train mmd=0.0218 | test_mmd=0.0015\n",
      "[CellOT] epoch=1850 f_loss=0.4938 g_loss=30.7947 | train mmd=0.0233 | test_mmd=0.0013\n",
      "[CellOT] epoch=1900 f_loss=0.4793 g_loss=31.2244 | train mmd=0.0245 | test_mmd=0.0023\n",
      "[CellOT] epoch=1950 f_loss=0.2550 g_loss=30.9780 | train mmd=0.0227 | test_mmd=0.0017\n",
      "[CellOT] epoch=2000 f_loss=0.2025 g_loss=31.1431 | train mmd=0.0199 | test_mmd=0.0011\n",
      "[CellOT] Final CellOT MMD: 0.0080\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=-0.0134 g_loss=-3.6942 | train mmd=0.2998 | test_mmd=1.2547\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 5 metrics: {'mmd2_gamma_median': 0.0010848678117909571, 'mmd2_gamma_0.5': 0.01704851508336025, 'mmd2_gamma_1.0': 0.02512443334442649, 'wasserstein_distance': 0.9660502285000773, 'R2_feature_means': 0.9987087716823166}\n",
      "**************** Run: 6 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-1.9723 g_loss=0.7952 | train mmd=0.4800 | test_mmd=0.3547\n",
      "[CellOT] epoch=100 f_loss=-2.9738 g_loss=3.6520 | train mmd=0.7565 | test_mmd=0.3291\n",
      "[CellOT] epoch=150 f_loss=-4.7889 g_loss=5.8078 | train mmd=0.7869 | test_mmd=0.2912\n",
      "[CellOT] epoch=200 f_loss=-6.5407 g_loss=8.6996 | train mmd=0.7782 | test_mmd=0.2532\n",
      "[CellOT] epoch=250 f_loss=-8.3612 g_loss=10.7503 | train mmd=0.7674 | test_mmd=0.2208\n",
      "[CellOT] epoch=300 f_loss=-10.2122 g_loss=13.1557 | train mmd=0.7448 | test_mmd=0.1906\n",
      "[CellOT] epoch=350 f_loss=-10.4088 g_loss=14.5139 | train mmd=0.7133 | test_mmd=0.1580\n",
      "[CellOT] epoch=400 f_loss=-12.0708 g_loss=17.3004 | train mmd=0.7020 | test_mmd=0.1361\n",
      "[CellOT] epoch=450 f_loss=-11.2149 g_loss=17.7517 | train mmd=0.6772 | test_mmd=0.1132\n",
      "[CellOT] epoch=500 f_loss=-12.3782 g_loss=21.3097 | train mmd=0.6580 | test_mmd=0.0962\n",
      "[CellOT] epoch=550 f_loss=-12.0489 g_loss=22.8088 | train mmd=0.6149 | test_mmd=0.0746\n",
      "[CellOT] epoch=600 f_loss=-12.3898 g_loss=25.7896 | train mmd=0.5801 | test_mmd=0.0597\n",
      "[CellOT] epoch=650 f_loss=-13.3617 g_loss=25.8315 | train mmd=0.5371 | test_mmd=0.0461\n",
      "[CellOT] epoch=700 f_loss=-10.0022 g_loss=27.6579 | train mmd=0.4919 | test_mmd=0.0342\n",
      "[CellOT] epoch=750 f_loss=-8.1347 g_loss=26.8771 | train mmd=0.4296 | test_mmd=0.0230\n",
      "[CellOT] epoch=800 f_loss=-4.9549 g_loss=30.5109 | train mmd=0.4006 | test_mmd=0.0163\n",
      "[CellOT] epoch=850 f_loss=-2.3798 g_loss=28.3850 | train mmd=0.3414 | test_mmd=0.0105\n",
      "[CellOT] epoch=900 f_loss=-0.0624 g_loss=29.4535 | train mmd=0.3100 | test_mmd=0.0071\n",
      "[CellOT] epoch=950 f_loss=4.6728 g_loss=27.2038 | train mmd=0.1979 | test_mmd=0.0058\n",
      "[CellOT] epoch=1000 f_loss=3.3162 g_loss=25.2398 | train mmd=0.1164 | test_mmd=0.0046\n",
      "[CellOT] epoch=1050 f_loss=1.7212 g_loss=25.6419 | train mmd=0.0486 | test_mmd=0.0028\n",
      "[CellOT] epoch=1100 f_loss=-0.6241 g_loss=25.6484 | train mmd=0.0500 | test_mmd=0.0020\n",
      "[CellOT] epoch=1150 f_loss=0.0478 g_loss=25.3833 | train mmd=0.0579 | test_mmd=0.0020\n",
      "[CellOT] epoch=1200 f_loss=0.4985 g_loss=25.1273 | train mmd=0.0471 | test_mmd=0.0011\n",
      "[CellOT] epoch=1250 f_loss=0.4258 g_loss=25.7153 | train mmd=0.0391 | test_mmd=0.0026\n",
      "[CellOT] epoch=1300 f_loss=-0.3972 g_loss=25.4418 | train mmd=0.0370 | test_mmd=0.0039\n",
      "[CellOT] epoch=1350 f_loss=0.2868 g_loss=25.7535 | train mmd=0.0329 | test_mmd=0.0016\n",
      "[CellOT] epoch=1400 f_loss=-0.4871 g_loss=26.1775 | train mmd=0.0334 | test_mmd=0.0020\n",
      "[CellOT] epoch=1450 f_loss=-1.0477 g_loss=26.7040 | train mmd=0.0320 | test_mmd=0.0019\n",
      "[CellOT] epoch=1500 f_loss=-0.8760 g_loss=26.7824 | train mmd=0.0245 | test_mmd=0.0010\n",
      "[CellOT] epoch=1550 f_loss=-2.0317 g_loss=28.5584 | train mmd=0.0263 | test_mmd=0.0037\n",
      "[CellOT] epoch=1600 f_loss=-0.1134 g_loss=26.5765 | train mmd=0.0394 | test_mmd=0.0032\n",
      "[CellOT] epoch=1650 f_loss=0.0972 g_loss=25.5662 | train mmd=0.0597 | test_mmd=0.0025\n",
      "[CellOT] epoch=1700 f_loss=0.0301 g_loss=25.2282 | train mmd=0.0698 | test_mmd=0.0037\n",
      "[CellOT] epoch=1750 f_loss=0.5154 g_loss=24.7519 | train mmd=0.0457 | test_mmd=0.0021\n",
      "[CellOT] epoch=1800 f_loss=-0.1304 g_loss=24.9056 | train mmd=0.0296 | test_mmd=0.0014\n",
      "[CellOT] epoch=1850 f_loss=-0.4681 g_loss=25.0513 | train mmd=0.0276 | test_mmd=0.0033\n",
      "[CellOT] epoch=1900 f_loss=0.0914 g_loss=24.9423 | train mmd=0.0157 | test_mmd=0.0017\n",
      "[CellOT] epoch=1950 f_loss=-0.3893 g_loss=24.6858 | train mmd=0.0198 | test_mmd=0.0015\n",
      "[CellOT] epoch=2000 f_loss=-0.7343 g_loss=24.8074 | train mmd=0.0166 | test_mmd=0.0018\n",
      "[CellOT] Final CellOT MMD: 0.0073\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=-0.4172 g_loss=-2.6770 | train mmd=0.3231 | test_mmd=1.0764\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 6 metrics: {'mmd2_gamma_median': 0.0018159263353565436, 'mmd2_gamma_0.5': 0.01790595361775371, 'mmd2_gamma_1.0': 0.022462272879465117, 'wasserstein_distance': 0.8465432504159739, 'R2_feature_means': 0.9964924044274477}\n",
      "**************** Run: 7 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-2.1157 g_loss=0.8540 | train mmd=0.5570 | test_mmd=0.3528\n",
      "[CellOT] epoch=100 f_loss=-3.2830 g_loss=3.8243 | train mmd=0.7675 | test_mmd=0.3341\n",
      "[CellOT] epoch=150 f_loss=-4.8700 g_loss=6.2541 | train mmd=0.7731 | test_mmd=0.2911\n",
      "[CellOT] epoch=200 f_loss=-6.6892 g_loss=8.7775 | train mmd=0.7717 | test_mmd=0.2537\n",
      "[CellOT] epoch=250 f_loss=-8.0343 g_loss=11.5978 | train mmd=0.7483 | test_mmd=0.2178\n",
      "[CellOT] epoch=300 f_loss=-9.6268 g_loss=14.0214 | train mmd=0.7359 | test_mmd=0.1900\n",
      "[CellOT] epoch=350 f_loss=-11.5047 g_loss=16.9469 | train mmd=0.7305 | test_mmd=0.1663\n",
      "[CellOT] epoch=400 f_loss=-12.0947 g_loss=17.8561 | train mmd=0.7082 | test_mmd=0.1471\n",
      "[CellOT] epoch=450 f_loss=-14.7785 g_loss=20.3197 | train mmd=0.6832 | test_mmd=0.1263\n",
      "[CellOT] epoch=500 f_loss=-15.1269 g_loss=24.8448 | train mmd=0.6771 | test_mmd=0.1093\n",
      "[CellOT] epoch=550 f_loss=-16.5890 g_loss=26.0597 | train mmd=0.6530 | test_mmd=0.0885\n",
      "[CellOT] epoch=600 f_loss=-15.7865 g_loss=27.7711 | train mmd=0.6393 | test_mmd=0.0735\n",
      "[CellOT] epoch=650 f_loss=-12.4981 g_loss=28.8829 | train mmd=0.5600 | test_mmd=0.0560\n",
      "[CellOT] epoch=700 f_loss=-11.2397 g_loss=32.0480 | train mmd=0.5393 | test_mmd=0.0416\n",
      "[CellOT] epoch=750 f_loss=-9.3341 g_loss=34.7617 | train mmd=0.5135 | test_mmd=0.0304\n",
      "[CellOT] epoch=800 f_loss=-7.3292 g_loss=35.0311 | train mmd=0.4292 | test_mmd=0.0189\n",
      "[CellOT] epoch=850 f_loss=-4.4017 g_loss=35.6788 | train mmd=0.3619 | test_mmd=0.0114\n",
      "[CellOT] epoch=900 f_loss=-0.4855 g_loss=33.1682 | train mmd=0.2836 | test_mmd=0.0065\n",
      "[CellOT] epoch=950 f_loss=4.3261 g_loss=29.4292 | train mmd=0.1258 | test_mmd=0.0045\n",
      "[CellOT] epoch=1000 f_loss=2.8335 g_loss=31.9545 | train mmd=0.0546 | test_mmd=0.0027\n",
      "[CellOT] epoch=1050 f_loss=2.2194 g_loss=31.6787 | train mmd=0.0452 | test_mmd=0.0034\n",
      "[CellOT] epoch=1100 f_loss=0.3930 g_loss=33.2001 | train mmd=0.0279 | test_mmd=0.0023\n",
      "[CellOT] epoch=1150 f_loss=0.8311 g_loss=34.0795 | train mmd=0.0376 | test_mmd=0.0016\n",
      "[CellOT] epoch=1200 f_loss=0.5660 g_loss=34.9332 | train mmd=0.0379 | test_mmd=0.0037\n",
      "[CellOT] epoch=1250 f_loss=0.6044 g_loss=35.6381 | train mmd=0.0331 | test_mmd=0.0016\n",
      "[CellOT] epoch=1300 f_loss=0.2317 g_loss=35.9921 | train mmd=0.0596 | test_mmd=0.0033\n",
      "[CellOT] epoch=1350 f_loss=-0.4296 g_loss=36.9218 | train mmd=0.0373 | test_mmd=0.0023\n",
      "[CellOT] epoch=1400 f_loss=-0.4218 g_loss=37.0207 | train mmd=0.0298 | test_mmd=0.0033\n",
      "[CellOT] epoch=1450 f_loss=0.4127 g_loss=37.5493 | train mmd=0.0447 | test_mmd=0.0026\n",
      "[CellOT] epoch=1500 f_loss=-0.1362 g_loss=38.1435 | train mmd=0.0382 | test_mmd=0.0026\n",
      "[CellOT] epoch=1550 f_loss=-0.2880 g_loss=38.5783 | train mmd=0.0366 | test_mmd=0.0032\n",
      "[CellOT] epoch=1600 f_loss=-0.3355 g_loss=39.0607 | train mmd=0.0250 | test_mmd=0.0022\n",
      "[CellOT] epoch=1650 f_loss=0.5668 g_loss=39.5997 | train mmd=0.0314 | test_mmd=0.0068\n",
      "[CellOT] epoch=1700 f_loss=0.0939 g_loss=39.3447 | train mmd=0.0253 | test_mmd=0.0026\n",
      "[CellOT] epoch=1750 f_loss=-0.1812 g_loss=40.0541 | train mmd=0.0247 | test_mmd=0.0035\n",
      "[CellOT] epoch=1800 f_loss=0.3283 g_loss=39.6331 | train mmd=0.0309 | test_mmd=0.0037\n",
      "[CellOT] epoch=1850 f_loss=0.2662 g_loss=39.7320 | train mmd=0.0233 | test_mmd=0.0014\n",
      "[CellOT] epoch=1900 f_loss=0.1981 g_loss=40.2694 | train mmd=0.0266 | test_mmd=0.0023\n",
      "[CellOT] epoch=1950 f_loss=0.5421 g_loss=40.5532 | train mmd=0.0242 | test_mmd=0.0012\n",
      "[CellOT] epoch=2000 f_loss=0.6642 g_loss=40.7361 | train mmd=0.0244 | test_mmd=0.0018\n",
      "[CellOT] Final CellOT MMD: 0.0101\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=1.0833 g_loss=-4.7926 | train mmd=0.3158 | test_mmd=1.2949\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 7 metrics: {'mmd2_gamma_median': 0.0018170254078970771, 'mmd2_gamma_0.5': 0.02071910502896135, 'mmd2_gamma_1.0': 0.030517818180140543, 'wasserstein_distance': 0.9077170437814286, 'R2_feature_means': 0.9959520793340495}\n",
      "**************** Run: 8 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-2.2892 g_loss=0.5087 | train mmd=0.5296 | test_mmd=0.3615\n",
      "[CellOT] epoch=100 f_loss=-2.9831 g_loss=3.4274 | train mmd=0.7483 | test_mmd=0.3242\n",
      "[CellOT] epoch=150 f_loss=-4.8987 g_loss=6.4117 | train mmd=0.7726 | test_mmd=0.2932\n",
      "[CellOT] epoch=200 f_loss=-6.8844 g_loss=9.1379 | train mmd=0.7726 | test_mmd=0.2531\n",
      "[CellOT] epoch=250 f_loss=-8.5420 g_loss=10.5204 | train mmd=0.7214 | test_mmd=0.2149\n",
      "[CellOT] epoch=300 f_loss=-10.3428 g_loss=14.1010 | train mmd=0.7738 | test_mmd=0.1981\n",
      "[CellOT] epoch=350 f_loss=-10.6274 g_loss=14.6018 | train mmd=0.7213 | test_mmd=0.1618\n",
      "[CellOT] epoch=400 f_loss=-11.7444 g_loss=18.5489 | train mmd=0.6985 | test_mmd=0.1371\n",
      "[CellOT] epoch=450 f_loss=-13.3784 g_loss=19.5322 | train mmd=0.6761 | test_mmd=0.1154\n",
      "[CellOT] epoch=500 f_loss=-11.1356 g_loss=20.1556 | train mmd=0.6215 | test_mmd=0.0934\n",
      "[CellOT] epoch=550 f_loss=-11.2872 g_loss=23.6440 | train mmd=0.6053 | test_mmd=0.0775\n",
      "[CellOT] epoch=600 f_loss=-12.9464 g_loss=24.8391 | train mmd=0.5739 | test_mmd=0.0625\n",
      "[CellOT] epoch=650 f_loss=-12.4916 g_loss=26.6990 | train mmd=0.5513 | test_mmd=0.0483\n",
      "[CellOT] epoch=700 f_loss=-10.1415 g_loss=29.1650 | train mmd=0.5194 | test_mmd=0.0372\n",
      "[CellOT] epoch=750 f_loss=-10.5092 g_loss=30.2678 | train mmd=0.5063 | test_mmd=0.0281\n",
      "[CellOT] epoch=800 f_loss=-7.9559 g_loss=30.5671 | train mmd=0.4631 | test_mmd=0.0187\n",
      "[CellOT] epoch=850 f_loss=-5.7237 g_loss=32.7453 | train mmd=0.3991 | test_mmd=0.0116\n",
      "[CellOT] epoch=900 f_loss=1.0205 g_loss=27.4569 | train mmd=0.2931 | test_mmd=0.0076\n",
      "[CellOT] epoch=950 f_loss=3.9563 g_loss=24.7286 | train mmd=0.1719 | test_mmd=0.0067\n",
      "[CellOT] epoch=1000 f_loss=2.4915 g_loss=25.0970 | train mmd=0.0822 | test_mmd=0.0042\n",
      "[CellOT] epoch=1050 f_loss=1.6930 g_loss=25.5911 | train mmd=0.0219 | test_mmd=0.0017\n",
      "[CellOT] epoch=1100 f_loss=1.3191 g_loss=26.9555 | train mmd=0.0165 | test_mmd=0.0028\n",
      "[CellOT] epoch=1150 f_loss=-0.4099 g_loss=28.4689 | train mmd=0.0198 | test_mmd=0.0023\n",
      "[CellOT] epoch=1200 f_loss=0.4099 g_loss=28.9007 | train mmd=0.0201 | test_mmd=0.0026\n",
      "[CellOT] epoch=1250 f_loss=0.3008 g_loss=29.4901 | train mmd=0.0181 | test_mmd=0.0024\n",
      "[CellOT] epoch=1300 f_loss=0.4866 g_loss=29.4677 | train mmd=0.0236 | test_mmd=0.0027\n",
      "[CellOT] epoch=1350 f_loss=-0.0210 g_loss=29.4777 | train mmd=0.0210 | test_mmd=0.0012\n",
      "[CellOT] epoch=1400 f_loss=0.1456 g_loss=30.3225 | train mmd=0.0300 | test_mmd=0.0012\n",
      "[CellOT] epoch=1450 f_loss=0.2569 g_loss=30.4617 | train mmd=0.0246 | test_mmd=0.0030\n",
      "[CellOT] epoch=1500 f_loss=0.4848 g_loss=29.7414 | train mmd=0.0282 | test_mmd=0.0019\n",
      "[CellOT] epoch=1550 f_loss=0.3474 g_loss=30.1412 | train mmd=0.0225 | test_mmd=0.0028\n",
      "[CellOT] epoch=1600 f_loss=-0.1349 g_loss=30.4014 | train mmd=0.0271 | test_mmd=0.0023\n",
      "[CellOT] epoch=1650 f_loss=0.5343 g_loss=30.4789 | train mmd=0.0248 | test_mmd=0.0028\n",
      "[CellOT] epoch=1700 f_loss=-0.5738 g_loss=31.0377 | train mmd=0.0218 | test_mmd=0.0009\n",
      "[CellOT] epoch=1750 f_loss=0.5682 g_loss=30.9427 | train mmd=0.0230 | test_mmd=0.0016\n",
      "[CellOT] epoch=1800 f_loss=-0.4772 g_loss=30.7337 | train mmd=0.0235 | test_mmd=0.0017\n",
      "[CellOT] epoch=1850 f_loss=0.2982 g_loss=31.0736 | train mmd=0.0191 | test_mmd=0.0007\n",
      "[CellOT] epoch=1900 f_loss=0.2499 g_loss=31.0137 | train mmd=0.0218 | test_mmd=0.0020\n",
      "[CellOT] epoch=1950 f_loss=-0.0061 g_loss=31.4985 | train mmd=0.0187 | test_mmd=0.0026\n",
      "[CellOT] epoch=2000 f_loss=0.1053 g_loss=31.5140 | train mmd=0.0145 | test_mmd=0.0023\n",
      "[CellOT] Final CellOT MMD: 0.0065\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=-1.6348 g_loss=-6.1536 | train mmd=0.3307 | test_mmd=1.4985\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 8 metrics: {'mmd2_gamma_median': 0.0022774238370297795, 'mmd2_gamma_0.5': 0.01424091953093476, 'mmd2_gamma_1.0': 0.0190023251082215, 'wasserstein_distance': 0.8918732790423525, 'R2_feature_means': 0.9942331961038458}\n",
      "**************** Run: 9 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-2.2321 g_loss=0.8794 | train mmd=0.5110 | test_mmd=0.3691\n",
      "[CellOT] epoch=100 f_loss=-2.9668 g_loss=3.8101 | train mmd=0.7691 | test_mmd=0.3355\n",
      "[CellOT] epoch=150 f_loss=-4.9126 g_loss=5.9245 | train mmd=0.7562 | test_mmd=0.2899\n",
      "[CellOT] epoch=200 f_loss=-6.8247 g_loss=8.3989 | train mmd=0.7290 | test_mmd=0.2534\n",
      "[CellOT] epoch=250 f_loss=-7.9488 g_loss=10.5668 | train mmd=0.7516 | test_mmd=0.2269\n",
      "[CellOT] epoch=300 f_loss=-9.9818 g_loss=12.3380 | train mmd=0.7548 | test_mmd=0.2001\n",
      "[CellOT] epoch=350 f_loss=-10.3673 g_loss=15.1801 | train mmd=0.6951 | test_mmd=0.1682\n",
      "[CellOT] epoch=400 f_loss=-11.2173 g_loss=17.0635 | train mmd=0.7196 | test_mmd=0.1504\n",
      "[CellOT] epoch=450 f_loss=-12.3667 g_loss=18.3096 | train mmd=0.6771 | test_mmd=0.1260\n",
      "[CellOT] epoch=500 f_loss=-12.0766 g_loss=20.3262 | train mmd=0.6309 | test_mmd=0.1017\n",
      "[CellOT] epoch=550 f_loss=-13.3576 g_loss=23.7932 | train mmd=0.6053 | test_mmd=0.0832\n",
      "[CellOT] epoch=600 f_loss=-13.9132 g_loss=25.1361 | train mmd=0.5793 | test_mmd=0.0653\n",
      "[CellOT] epoch=650 f_loss=-13.2985 g_loss=26.9889 | train mmd=0.5224 | test_mmd=0.0510\n",
      "[CellOT] epoch=700 f_loss=-10.8746 g_loss=26.3105 | train mmd=0.4616 | test_mmd=0.0356\n",
      "[CellOT] epoch=750 f_loss=-8.1605 g_loss=30.1192 | train mmd=0.4556 | test_mmd=0.0264\n",
      "[CellOT] epoch=800 f_loss=-6.8763 g_loss=32.4147 | train mmd=0.4237 | test_mmd=0.0194\n",
      "[CellOT] epoch=850 f_loss=-1.8202 g_loss=32.5396 | train mmd=0.3878 | test_mmd=0.0139\n",
      "[CellOT] epoch=900 f_loss=3.0885 g_loss=28.3352 | train mmd=0.2437 | test_mmd=0.0088\n",
      "[CellOT] epoch=950 f_loss=5.1558 g_loss=22.2568 | train mmd=0.1032 | test_mmd=0.0063\n",
      "[CellOT] epoch=1000 f_loss=2.0729 g_loss=21.1119 | train mmd=0.0637 | test_mmd=0.0040\n",
      "[CellOT] epoch=1050 f_loss=1.5993 g_loss=21.3768 | train mmd=0.0306 | test_mmd=0.0023\n",
      "[CellOT] epoch=1100 f_loss=1.2239 g_loss=22.2863 | train mmd=0.0243 | test_mmd=0.0028\n",
      "[CellOT] epoch=1150 f_loss=0.7049 g_loss=22.8642 | train mmd=0.0315 | test_mmd=0.0022\n",
      "[CellOT] epoch=1200 f_loss=0.3912 g_loss=23.0150 | train mmd=0.0357 | test_mmd=0.0011\n",
      "[CellOT] epoch=1250 f_loss=-0.2240 g_loss=23.5917 | train mmd=0.0282 | test_mmd=0.0016\n",
      "[CellOT] epoch=1300 f_loss=0.9272 g_loss=23.6247 | train mmd=0.0257 | test_mmd=0.0024\n",
      "[CellOT] epoch=1350 f_loss=0.0694 g_loss=23.9164 | train mmd=0.0277 | test_mmd=0.0023\n",
      "[CellOT] epoch=1400 f_loss=0.6294 g_loss=23.8857 | train mmd=0.0299 | test_mmd=0.0016\n",
      "[CellOT] epoch=1450 f_loss=-0.0523 g_loss=23.8214 | train mmd=0.0363 | test_mmd=0.0021\n",
      "[CellOT] epoch=1500 f_loss=-1.0045 g_loss=24.8567 | train mmd=0.0301 | test_mmd=0.0061\n",
      "[CellOT] epoch=1550 f_loss=0.0169 g_loss=25.0241 | train mmd=0.0258 | test_mmd=0.0030\n",
      "[CellOT] epoch=1600 f_loss=0.0702 g_loss=24.6930 | train mmd=0.0435 | test_mmd=0.0017\n",
      "[CellOT] epoch=1650 f_loss=-0.0443 g_loss=24.4024 | train mmd=0.0369 | test_mmd=0.0019\n",
      "[CellOT] epoch=1700 f_loss=-0.4701 g_loss=24.7262 | train mmd=0.0266 | test_mmd=0.0038\n",
      "[CellOT] epoch=1750 f_loss=0.1009 g_loss=24.8350 | train mmd=0.0246 | test_mmd=0.0032\n",
      "[CellOT] epoch=1800 f_loss=-0.1615 g_loss=24.8174 | train mmd=0.0225 | test_mmd=0.0040\n",
      "[CellOT] epoch=1850 f_loss=-0.1515 g_loss=24.6555 | train mmd=0.0343 | test_mmd=0.0015\n",
      "[CellOT] epoch=1900 f_loss=0.3031 g_loss=23.4058 | train mmd=0.0588 | test_mmd=0.0032\n",
      "[CellOT] epoch=1950 f_loss=0.7822 g_loss=22.4883 | train mmd=0.0486 | test_mmd=0.0025\n",
      "[CellOT] epoch=2000 f_loss=0.2156 g_loss=21.9841 | train mmd=0.0425 | test_mmd=0.0020\n",
      "[CellOT] Final CellOT MMD: 0.0167\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.\n",
      "  warnings.warn(\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.\n",
      "  warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 9 metrics: {'mmd2_gamma_median': 0.001978776996309106, 'mmd2_gamma_0.5': 0.03915400267308511, 'mmd2_gamma_1.0': 0.05330024544615042, 'wasserstein_distance': 0.9318318257802892, 'R2_feature_means': 0.9983560635644584}\n",
      "                        mean     std\n",
      "mmd2_gamma_median     0.0019  0.0007\n",
      "mmd2_gamma_0.5        0.0220  0.0106\n",
      "mmd2_gamma_1.0        0.0301  0.0142\n",
      "wasserstein_distance  0.8919  0.0458\n",
      "R2_feature_means      0.9965  0.0019\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.\n",
      "  warnings.warn(\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.\n",
      "  warnings.warn(\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "drug = \"Vem\"\n",
    "X_pre, X_post = prepare_pair_from_mat('COLO858', 'DMSO','24h', drug, '72h')\n",
    "jfe_indices = [1, 6, 0, 5, 4, 7, 8, 2, 3, 19]  \n",
    "\n",
    "print(\"X_pre cells:\", X_pre.shape)\n",
    "print(\"X_post cells:\", X_post.shape)\n",
    "\n",
    "X_tr_pre, X_te_pre, Y_tr_post, Y_te_post = split_train_test(X_pre, X_post, 0.8)\n",
    "\n",
    "print(X_tr_pre.shape)\n",
    "print(X_te_pre.shape)\n",
    "print(Y_tr_post.shape)\n",
    "print(Y_te_post.shape)\n",
    "\n",
    "# Compute median heuristic gamma on training data\n",
    "median_gamma = median_heuristic_gamma(X_tr_pre, Y_tr_post)\n",
    "print(\"Median heuristic gamma:\", median_gamma)\n",
    "\n",
    "\n",
    "all_metrics = []\n",
    "for run in range(10):\n",
    "    print(f\"**************** Run: {run} ****************\")\n",
    "    seed = 1234 + run\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "    out = run_cellot_pair(X_tr_pre[:, jfe_indices], Y_tr_post[:, jfe_indices], X_te_pre[:, jfe_indices], Y_te_post[:, jfe_indices], n_epochs=2000)\n",
    "    metrics = summarize_metrics(out[\"y_pred\"], Y_te_post[:, jfe_indices], median_gamma)\n",
    "    print(f\"Run {run} metrics: {metrics}\")\n",
    "    all_metrics.append(metrics)\n",
    "\n",
    "# Results summary\n",
    "df = pd.DataFrame(all_metrics)\n",
    "print(df.describe().T[['mean', 'std']].round(4))\n",
    "\n",
    "\n",
    "from umap import UMAP\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "source = Y_tr_post[:, jfe_indices]\n",
    "target = Y_te_post[:, jfe_indices]\n",
    "predicted = out.get('y_pred') \n",
    "\n",
    "# Instantiate UMAP\n",
    "umap_model = UMAP(n_components=2, random_state=42)\n",
    "\n",
    "all_sample_umap = umap_model.fit_transform(np.vstack([source, target]))\n",
    "source_umap = umap_model.transform(source)\n",
    "target_umap = umap_model.transform(target)\n",
    "y_pred_umap = umap_model.transform(predicted)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(4, 4))\n",
    "# ax.scatter(source_umap[:, 0], source_umap[:, 1], s=10, alpha=0.7, label='train_post', color='C2')\n",
    "ax.scatter(target_umap[:, 0], target_umap[:, 1], s=10, alpha=0.7, label='observed treated cells', color=\"#C88131\")\n",
    "ax.scatter(y_pred_umap[:, 0], y_pred_umap[:, 1], s=10, alpha=0.7, label='predicted cells', color=\"#1F4D8D\")\n",
    "\n",
    "ax.set_title(f'{drug}')\n",
    "# ax.set_xlabel('UMAP 1')\n",
    "# ax.set_ylabel('UMAP 2')\n",
    "ax.set_aspect('equal', 'box')\n",
    "# Add a legend to distinguish the points\n",
    "ax.legend()\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "# Display the plot\n",
    "plt.savefig(f\"./plots/cellot_on_4i_drug_{drug}.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c338b528",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "31734785",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cell line:  WM902B\n",
      "['DMSO' 'Vem' 'Vem+Tram']\n",
      "X_pre cells: (5690, 20)\n",
      "X_post cells: (5690, 20)\n",
      "(4552, 20)\n",
      "(1138, 20)\n",
      "(4552, 20)\n",
      "(1138, 20)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Median heuristic gamma: 0.06456030933401641\n",
      "**************** Run: 0 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=-1.3478 g_loss=-4.2979 | train mmd=0.4075 | test_mmd=1.3562\n",
      "[CellOT] epoch=50 f_loss=-2.0507 g_loss=0.4978 | train mmd=0.5396 | test_mmd=0.2764\n",
      "[CellOT] epoch=100 f_loss=-2.4600 g_loss=2.1598 | train mmd=0.6309 | test_mmd=0.2202\n",
      "[CellOT] epoch=150 f_loss=-4.0853 g_loss=4.2834 | train mmd=0.6577 | test_mmd=0.1989\n",
      "[CellOT] epoch=200 f_loss=-5.6215 g_loss=5.8955 | train mmd=0.6331 | test_mmd=0.1690\n",
      "[CellOT] epoch=250 f_loss=-7.3941 g_loss=7.7300 | train mmd=0.6004 | test_mmd=0.1391\n",
      "[CellOT] epoch=300 f_loss=-7.2905 g_loss=8.9752 | train mmd=0.5553 | test_mmd=0.1113\n",
      "[CellOT] epoch=350 f_loss=-7.7490 g_loss=11.1106 | train mmd=0.5165 | test_mmd=0.0891\n",
      "[CellOT] epoch=400 f_loss=-8.3747 g_loss=11.5345 | train mmd=0.4686 | test_mmd=0.0674\n",
      "[CellOT] epoch=450 f_loss=-9.3412 g_loss=13.4609 | train mmd=0.4196 | test_mmd=0.0495\n",
      "[CellOT] epoch=500 f_loss=-8.7416 g_loss=14.0152 | train mmd=0.3637 | test_mmd=0.0343\n",
      "[CellOT] epoch=550 f_loss=-7.7250 g_loss=14.2650 | train mmd=0.3161 | test_mmd=0.0233\n",
      "[CellOT] epoch=600 f_loss=-8.0112 g_loss=14.2168 | train mmd=0.2781 | test_mmd=0.0152\n",
      "[CellOT] epoch=650 f_loss=-5.7695 g_loss=12.8404 | train mmd=0.2426 | test_mmd=0.0089\n",
      "[CellOT] epoch=700 f_loss=-4.0565 g_loss=14.1272 | train mmd=0.2048 | test_mmd=0.0054\n",
      "[CellOT] epoch=750 f_loss=-1.8051 g_loss=12.4069 | train mmd=0.1448 | test_mmd=0.0026\n",
      "[CellOT] epoch=800 f_loss=0.9835 g_loss=9.0374 | train mmd=0.0200 | test_mmd=0.0015\n",
      "[CellOT] epoch=850 f_loss=0.6139 g_loss=10.8747 | train mmd=0.0281 | test_mmd=0.0024\n",
      "[CellOT] epoch=900 f_loss=-0.1446 g_loss=11.5256 | train mmd=0.0313 | test_mmd=0.0015\n",
      "[CellOT] epoch=950 f_loss=-0.0421 g_loss=13.4999 | train mmd=0.0192 | test_mmd=0.0018\n",
      "[CellOT] epoch=1000 f_loss=0.4635 g_loss=14.4308 | train mmd=0.0240 | test_mmd=0.0021\n",
      "[CellOT] epoch=1050 f_loss=0.3785 g_loss=14.5862 | train mmd=0.0186 | test_mmd=0.0016\n",
      "[CellOT] epoch=1100 f_loss=0.3296 g_loss=15.2712 | train mmd=0.0145 | test_mmd=0.0011\n",
      "[CellOT] epoch=1150 f_loss=0.3178 g_loss=16.1314 | train mmd=0.0170 | test_mmd=0.0026\n",
      "[CellOT] epoch=1200 f_loss=-0.4856 g_loss=16.1006 | train mmd=0.0096 | test_mmd=0.0035\n",
      "[CellOT] epoch=1250 f_loss=-0.0293 g_loss=16.4168 | train mmd=0.0087 | test_mmd=0.0007\n",
      "[CellOT] epoch=1300 f_loss=0.1228 g_loss=16.9480 | train mmd=0.0085 | test_mmd=0.0018\n",
      "[CellOT] epoch=1350 f_loss=0.0406 g_loss=17.0079 | train mmd=0.0143 | test_mmd=0.0033\n",
      "[CellOT] epoch=1400 f_loss=0.5879 g_loss=17.1920 | train mmd=0.0071 | test_mmd=0.0008\n",
      "[CellOT] epoch=1450 f_loss=0.1740 g_loss=16.8221 | train mmd=0.0068 | test_mmd=0.0019\n",
      "[CellOT] epoch=1500 f_loss=-0.0319 g_loss=17.1364 | train mmd=0.0065 | test_mmd=0.0029\n",
      "[CellOT] epoch=1550 f_loss=0.2418 g_loss=17.4187 | train mmd=0.0064 | test_mmd=0.0026\n",
      "[CellOT] epoch=1600 f_loss=0.1388 g_loss=18.2946 | train mmd=0.0055 | test_mmd=0.0017\n",
      "[CellOT] epoch=1650 f_loss=-0.4022 g_loss=18.1245 | train mmd=0.0070 | test_mmd=0.0025\n",
      "[CellOT] epoch=1700 f_loss=-0.2011 g_loss=18.2807 | train mmd=0.0058 | test_mmd=0.0017\n",
      "[CellOT] epoch=1750 f_loss=-0.0760 g_loss=17.9417 | train mmd=0.0051 | test_mmd=0.0018\n",
      "[CellOT] epoch=1800 f_loss=0.0364 g_loss=18.5127 | train mmd=0.0069 | test_mmd=0.0016\n",
      "[CellOT] epoch=1850 f_loss=0.3058 g_loss=18.3842 | train mmd=0.0065 | test_mmd=0.0015\n",
      "[CellOT] epoch=1900 f_loss=-0.6963 g_loss=18.5002 | train mmd=0.0064 | test_mmd=0.0026\n",
      "[CellOT] epoch=1950 f_loss=-0.1852 g_loss=18.4411 | train mmd=0.0054 | test_mmd=0.0013\n",
      "[CellOT] epoch=2000 f_loss=0.0633 g_loss=18.9419 | train mmd=0.0051 | test_mmd=0.0013\n",
      "[CellOT] Final CellOT MMD: 0.0022\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 0 metrics: {'mmd2_gamma_median': 0.0012542330548115377, 'mmd2_gamma_0.5': 0.004492773892949664, 'mmd2_gamma_1.0': 0.005741941641809523, 'wasserstein_distance': 0.6023647707603519, 'R2_feature_means': 0.9984231791671069}\n",
      "**************** Run: 1 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=1.2507 g_loss=-4.6816 | train mmd=0.4875 | test_mmd=1.1309\n",
      "[CellOT] epoch=50 f_loss=-1.8007 g_loss=0.1995 | train mmd=0.5078 | test_mmd=0.2856\n",
      "[CellOT] epoch=100 f_loss=-2.6468 g_loss=2.5587 | train mmd=0.6517 | test_mmd=0.2312\n",
      "[CellOT] epoch=150 f_loss=-4.1713 g_loss=4.4061 | train mmd=0.6655 | test_mmd=0.2029\n",
      "[CellOT] epoch=200 f_loss=-5.3258 g_loss=6.4887 | train mmd=0.6705 | test_mmd=0.1754\n",
      "[CellOT] epoch=250 f_loss=-6.8793 g_loss=8.5455 | train mmd=0.6269 | test_mmd=0.1470\n",
      "[CellOT] epoch=300 f_loss=-8.0911 g_loss=10.2999 | train mmd=0.5653 | test_mmd=0.1153\n",
      "[CellOT] epoch=350 f_loss=-9.1927 g_loss=11.2812 | train mmd=0.5146 | test_mmd=0.0898\n",
      "[CellOT] epoch=400 f_loss=-9.3389 g_loss=12.2078 | train mmd=0.4748 | test_mmd=0.0691\n",
      "[CellOT] epoch=450 f_loss=-9.8026 g_loss=13.8462 | train mmd=0.4543 | test_mmd=0.0559\n",
      "[CellOT] epoch=500 f_loss=-9.3554 g_loss=15.2290 | train mmd=0.3864 | test_mmd=0.0373\n",
      "[CellOT] epoch=550 f_loss=-8.7197 g_loss=15.4543 | train mmd=0.3480 | test_mmd=0.0265\n",
      "[CellOT] epoch=600 f_loss=-8.0585 g_loss=16.0201 | train mmd=0.2947 | test_mmd=0.0172\n",
      "[CellOT] epoch=650 f_loss=-7.4990 g_loss=15.0505 | train mmd=0.2503 | test_mmd=0.0101\n",
      "[CellOT] epoch=700 f_loss=-6.4157 g_loss=15.4589 | train mmd=0.1765 | test_mmd=0.0057\n",
      "[CellOT] epoch=750 f_loss=-4.9053 g_loss=13.9188 | train mmd=0.1848 | test_mmd=0.0037\n",
      "[CellOT] epoch=800 f_loss=-1.1918 g_loss=12.4662 | train mmd=0.1207 | test_mmd=0.0045\n",
      "[CellOT] epoch=850 f_loss=1.0033 g_loss=11.3476 | train mmd=0.0606 | test_mmd=0.0056\n",
      "[CellOT] epoch=900 f_loss=1.1551 g_loss=11.9759 | train mmd=0.0305 | test_mmd=0.0025\n",
      "[CellOT] epoch=950 f_loss=0.4511 g_loss=13.7249 | train mmd=0.0356 | test_mmd=0.0026\n",
      "[CellOT] epoch=1000 f_loss=0.6047 g_loss=14.9679 | train mmd=0.0431 | test_mmd=0.0043\n",
      "[CellOT] epoch=1050 f_loss=-0.3751 g_loss=16.2623 | train mmd=0.0300 | test_mmd=0.0025\n",
      "[CellOT] epoch=1100 f_loss=0.2415 g_loss=17.2593 | train mmd=0.0262 | test_mmd=0.0032\n",
      "[CellOT] epoch=1150 f_loss=0.3128 g_loss=17.6971 | train mmd=0.0325 | test_mmd=0.0027\n",
      "[CellOT] epoch=1200 f_loss=-0.0371 g_loss=18.5029 | train mmd=0.0390 | test_mmd=0.0035\n",
      "[CellOT] epoch=1250 f_loss=-0.3912 g_loss=18.5647 | train mmd=0.0270 | test_mmd=0.0027\n",
      "[CellOT] epoch=1300 f_loss=0.1150 g_loss=18.8644 | train mmd=0.0235 | test_mmd=0.0019\n",
      "[CellOT] epoch=1350 f_loss=0.2802 g_loss=18.6436 | train mmd=0.0265 | test_mmd=0.0023\n",
      "[CellOT] epoch=1400 f_loss=0.3920 g_loss=19.2127 | train mmd=0.0157 | test_mmd=0.0017\n",
      "[CellOT] epoch=1450 f_loss=0.3545 g_loss=19.6628 | train mmd=0.0133 | test_mmd=0.0015\n",
      "[CellOT] epoch=1500 f_loss=0.3451 g_loss=19.1781 | train mmd=0.0096 | test_mmd=0.0019\n",
      "[CellOT] epoch=1550 f_loss=-0.2418 g_loss=19.6162 | train mmd=0.0092 | test_mmd=0.0015\n",
      "[CellOT] epoch=1600 f_loss=-0.0897 g_loss=19.9106 | train mmd=0.0075 | test_mmd=0.0032\n",
      "[CellOT] epoch=1650 f_loss=-0.1331 g_loss=19.5798 | train mmd=0.0076 | test_mmd=0.0045\n",
      "[CellOT] epoch=1700 f_loss=0.1732 g_loss=19.4152 | train mmd=0.0060 | test_mmd=0.0018\n",
      "[CellOT] epoch=1750 f_loss=0.2214 g_loss=19.5749 | train mmd=0.0065 | test_mmd=0.0027\n",
      "[CellOT] epoch=1800 f_loss=-0.0601 g_loss=19.7936 | train mmd=0.0063 | test_mmd=0.0016\n",
      "[CellOT] epoch=1850 f_loss=-0.1204 g_loss=19.7496 | train mmd=0.0066 | test_mmd=0.0024\n",
      "[CellOT] epoch=1900 f_loss=0.0220 g_loss=19.9447 | train mmd=0.0051 | test_mmd=0.0010\n",
      "[CellOT] epoch=1950 f_loss=0.2482 g_loss=19.8904 | train mmd=0.0050 | test_mmd=0.0012\n",
      "[CellOT] epoch=2000 f_loss=0.0816 g_loss=19.7086 | train mmd=0.0043 | test_mmd=0.0018\n",
      "[CellOT] Final CellOT MMD: 0.0025\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 1 metrics: {'mmd2_gamma_median': 0.0017934694296724008, 'mmd2_gamma_0.5': 0.005422420174304277, 'mmd2_gamma_1.0': 0.005738780089787787, 'wasserstein_distance': 0.6179441768913163, 'R2_feature_means': 0.9978225517719511}\n",
      "**************** Run: 2 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=-1.2638 g_loss=-3.3768 | train mmd=0.3846 | test_mmd=1.0378\n",
      "[CellOT] epoch=50 f_loss=-1.6895 g_loss=0.5251 | train mmd=0.5795 | test_mmd=0.2602\n",
      "[CellOT] epoch=100 f_loss=-3.2869 g_loss=2.4345 | train mmd=0.6550 | test_mmd=0.2236\n",
      "[CellOT] epoch=150 f_loss=-4.3113 g_loss=4.0982 | train mmd=0.6639 | test_mmd=0.2000\n",
      "[CellOT] epoch=200 f_loss=-5.5382 g_loss=5.9398 | train mmd=0.6215 | test_mmd=0.1619\n",
      "[CellOT] epoch=250 f_loss=-6.9380 g_loss=7.3939 | train mmd=0.5931 | test_mmd=0.1351\n",
      "[CellOT] epoch=300 f_loss=-7.8892 g_loss=9.0675 | train mmd=0.5432 | test_mmd=0.1091\n",
      "[CellOT] epoch=350 f_loss=-8.4878 g_loss=10.3584 | train mmd=0.5033 | test_mmd=0.0853\n",
      "[CellOT] epoch=400 f_loss=-9.1273 g_loss=10.8783 | train mmd=0.4510 | test_mmd=0.0625\n",
      "[CellOT] epoch=450 f_loss=-8.5118 g_loss=11.9605 | train mmd=0.3851 | test_mmd=0.0436\n",
      "[CellOT] epoch=500 f_loss=-8.4021 g_loss=11.5395 | train mmd=0.3343 | test_mmd=0.0290\n",
      "[CellOT] epoch=550 f_loss=-6.9587 g_loss=14.1154 | train mmd=0.2901 | test_mmd=0.0196\n",
      "[CellOT] epoch=600 f_loss=-6.0034 g_loss=13.6910 | train mmd=0.2374 | test_mmd=0.0107\n",
      "[CellOT] epoch=650 f_loss=-5.8903 g_loss=12.5364 | train mmd=0.1997 | test_mmd=0.0064\n",
      "[CellOT] epoch=700 f_loss=-1.7300 g_loss=12.1507 | train mmd=0.1564 | test_mmd=0.0044\n",
      "[CellOT] epoch=750 f_loss=-0.2214 g_loss=11.1251 | train mmd=0.0891 | test_mmd=0.0028\n",
      "[CellOT] epoch=800 f_loss=0.8384 g_loss=8.4766 | train mmd=0.0187 | test_mmd=0.0019\n",
      "[CellOT] epoch=850 f_loss=0.3977 g_loss=9.3533 | train mmd=0.0208 | test_mmd=0.0032\n",
      "[CellOT] epoch=900 f_loss=0.3922 g_loss=10.1383 | train mmd=0.0329 | test_mmd=0.0025\n",
      "[CellOT] epoch=950 f_loss=0.0617 g_loss=11.1477 | train mmd=0.0207 | test_mmd=0.0027\n",
      "[CellOT] epoch=1000 f_loss=-0.1429 g_loss=12.2135 | train mmd=0.0154 | test_mmd=0.0017\n",
      "[CellOT] epoch=1050 f_loss=-0.5528 g_loss=12.9025 | train mmd=0.0145 | test_mmd=0.0025\n",
      "[CellOT] epoch=1100 f_loss=0.7645 g_loss=13.0275 | train mmd=0.0173 | test_mmd=0.0022\n",
      "[CellOT] epoch=1150 f_loss=0.2898 g_loss=13.4251 | train mmd=0.0125 | test_mmd=0.0018\n",
      "[CellOT] epoch=1200 f_loss=0.2130 g_loss=13.4587 | train mmd=0.0177 | test_mmd=0.0035\n",
      "[CellOT] epoch=1250 f_loss=0.5519 g_loss=13.5378 | train mmd=0.0121 | test_mmd=0.0039\n",
      "[CellOT] epoch=1300 f_loss=-0.3482 g_loss=13.7284 | train mmd=0.0086 | test_mmd=0.0010\n",
      "[CellOT] epoch=1350 f_loss=0.4457 g_loss=13.5665 | train mmd=0.0121 | test_mmd=0.0020\n",
      "[CellOT] epoch=1400 f_loss=0.4274 g_loss=13.1452 | train mmd=0.0097 | test_mmd=0.0014\n",
      "[CellOT] epoch=1450 f_loss=-0.0055 g_loss=14.1168 | train mmd=0.0107 | test_mmd=0.0021\n",
      "[CellOT] epoch=1500 f_loss=-0.2793 g_loss=14.4139 | train mmd=0.0070 | test_mmd=0.0042\n",
      "[CellOT] epoch=1550 f_loss=0.2896 g_loss=14.4325 | train mmd=0.0082 | test_mmd=0.0026\n",
      "[CellOT] epoch=1600 f_loss=0.5596 g_loss=14.2711 | train mmd=0.0053 | test_mmd=0.0009\n",
      "[CellOT] epoch=1650 f_loss=-0.7801 g_loss=14.3231 | train mmd=0.0066 | test_mmd=0.0035\n",
      "[CellOT] epoch=1700 f_loss=-0.4779 g_loss=14.8814 | train mmd=0.0086 | test_mmd=0.0022\n",
      "[CellOT] epoch=1750 f_loss=0.1272 g_loss=14.2304 | train mmd=0.0077 | test_mmd=0.0009\n",
      "[CellOT] epoch=1800 f_loss=0.2242 g_loss=14.9517 | train mmd=0.0058 | test_mmd=0.0022\n",
      "[CellOT] epoch=1850 f_loss=-0.0839 g_loss=14.7134 | train mmd=0.0062 | test_mmd=0.0051\n",
      "[CellOT] epoch=1900 f_loss=-0.5214 g_loss=15.0816 | train mmd=0.0076 | test_mmd=0.0038\n",
      "[CellOT] epoch=1950 f_loss=-0.1332 g_loss=14.9526 | train mmd=0.0073 | test_mmd=0.0017\n",
      "[CellOT] epoch=2000 f_loss=-0.0919 g_loss=15.2252 | train mmd=0.0082 | test_mmd=0.0022\n",
      "[CellOT] Final CellOT MMD: 0.0032\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 2 metrics: {'mmd2_gamma_median': 0.0021964796835240996, 'mmd2_gamma_0.5': 0.007443005562874805, 'mmd2_gamma_1.0': 0.008075257632478572, 'wasserstein_distance': 0.6205717814402512, 'R2_feature_means': 0.9970761963993026}\n",
      "**************** Run: 3 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=1.4161 g_loss=-3.7501 | train mmd=0.4031 | test_mmd=1.1646\n",
      "[CellOT] epoch=50 f_loss=-1.9747 g_loss=0.1136 | train mmd=0.5027 | test_mmd=0.2654\n",
      "[CellOT] epoch=100 f_loss=-2.9280 g_loss=2.2921 | train mmd=0.6811 | test_mmd=0.2332\n",
      "[CellOT] epoch=150 f_loss=-4.0763 g_loss=4.9193 | train mmd=0.6692 | test_mmd=0.2038\n",
      "[CellOT] epoch=200 f_loss=-5.4003 g_loss=6.8769 | train mmd=0.6787 | test_mmd=0.1776\n",
      "[CellOT] epoch=250 f_loss=-6.5271 g_loss=7.7030 | train mmd=0.6176 | test_mmd=0.1429\n",
      "[CellOT] epoch=300 f_loss=-7.4830 g_loss=9.2108 | train mmd=0.5795 | test_mmd=0.1170\n",
      "[CellOT] epoch=350 f_loss=-9.0135 g_loss=9.6901 | train mmd=0.5190 | test_mmd=0.0893\n",
      "[CellOT] epoch=400 f_loss=-8.8895 g_loss=12.2026 | train mmd=0.4768 | test_mmd=0.0695\n",
      "[CellOT] epoch=450 f_loss=-9.5210 g_loss=13.3489 | train mmd=0.4333 | test_mmd=0.0511\n",
      "[CellOT] epoch=500 f_loss=-9.1238 g_loss=12.9191 | train mmd=0.3887 | test_mmd=0.0367\n",
      "[CellOT] epoch=550 f_loss=-7.8871 g_loss=13.5073 | train mmd=0.3544 | test_mmd=0.0267\n",
      "[CellOT] epoch=600 f_loss=-6.7433 g_loss=13.2698 | train mmd=0.2860 | test_mmd=0.0155\n",
      "[CellOT] epoch=650 f_loss=-6.0465 g_loss=13.5745 | train mmd=0.2311 | test_mmd=0.0096\n",
      "[CellOT] epoch=700 f_loss=-3.9197 g_loss=13.6321 | train mmd=0.2004 | test_mmd=0.0052\n",
      "[CellOT] epoch=750 f_loss=-1.0051 g_loss=11.6360 | train mmd=0.1473 | test_mmd=0.0030\n",
      "[CellOT] epoch=800 f_loss=0.1980 g_loss=10.8961 | train mmd=0.0917 | test_mmd=0.0026\n",
      "[CellOT] epoch=850 f_loss=0.2042 g_loss=8.9849 | train mmd=0.0225 | test_mmd=0.0016\n",
      "[CellOT] epoch=900 f_loss=0.3842 g_loss=9.4354 | train mmd=0.0330 | test_mmd=0.0020\n",
      "[CellOT] epoch=950 f_loss=0.4329 g_loss=10.6600 | train mmd=0.0262 | test_mmd=0.0016\n",
      "[CellOT] epoch=1000 f_loss=0.2421 g_loss=11.2912 | train mmd=0.0254 | test_mmd=0.0054\n",
      "[CellOT] epoch=1050 f_loss=0.2843 g_loss=12.0839 | train mmd=0.0141 | test_mmd=0.0031\n",
      "[CellOT] epoch=1100 f_loss=0.2145 g_loss=12.7207 | train mmd=0.0143 | test_mmd=0.0016\n",
      "[CellOT] epoch=1150 f_loss=-0.1279 g_loss=12.9311 | train mmd=0.0129 | test_mmd=0.0021\n",
      "[CellOT] epoch=1200 f_loss=0.7232 g_loss=12.8101 | train mmd=0.0102 | test_mmd=0.0014\n",
      "[CellOT] epoch=1250 f_loss=0.0483 g_loss=13.2870 | train mmd=0.0107 | test_mmd=0.0011\n",
      "[CellOT] epoch=1300 f_loss=-0.3439 g_loss=13.4436 | train mmd=0.0082 | test_mmd=0.0015\n",
      "[CellOT] epoch=1350 f_loss=0.1264 g_loss=13.7546 | train mmd=0.0071 | test_mmd=0.0014\n",
      "[CellOT] epoch=1400 f_loss=0.1840 g_loss=13.9844 | train mmd=0.0081 | test_mmd=0.0033\n",
      "[CellOT] epoch=1450 f_loss=0.2859 g_loss=13.5078 | train mmd=0.0054 | test_mmd=0.0013\n",
      "[CellOT] epoch=1500 f_loss=0.0300 g_loss=13.8894 | train mmd=0.0057 | test_mmd=0.0009\n",
      "[CellOT] epoch=1550 f_loss=-0.1048 g_loss=14.2552 | train mmd=0.0059 | test_mmd=0.0037\n",
      "[CellOT] epoch=1600 f_loss=-0.1680 g_loss=14.1198 | train mmd=0.0049 | test_mmd=0.0012\n",
      "[CellOT] epoch=1650 f_loss=0.0446 g_loss=14.2603 | train mmd=0.0047 | test_mmd=0.0023\n",
      "[CellOT] epoch=1700 f_loss=0.1059 g_loss=14.1343 | train mmd=0.0044 | test_mmd=0.0008\n",
      "[CellOT] epoch=1750 f_loss=-0.2752 g_loss=14.3894 | train mmd=0.0056 | test_mmd=0.0012\n",
      "[CellOT] epoch=1800 f_loss=0.0423 g_loss=14.2262 | train mmd=0.0041 | test_mmd=0.0006\n",
      "[CellOT] epoch=1850 f_loss=-0.2299 g_loss=14.2408 | train mmd=0.0041 | test_mmd=0.0011\n",
      "[CellOT] epoch=1900 f_loss=0.0967 g_loss=14.4024 | train mmd=0.0062 | test_mmd=0.0019\n",
      "[CellOT] epoch=1950 f_loss=-0.4080 g_loss=14.2387 | train mmd=0.0043 | test_mmd=0.0020\n",
      "[CellOT] epoch=2000 f_loss=0.5581 g_loss=14.4243 | train mmd=0.0046 | test_mmd=0.0018\n",
      "[CellOT] Final CellOT MMD: 0.0025\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 3 metrics: {'mmd2_gamma_median': 0.001817827933499716, 'mmd2_gamma_0.5': 0.005047039611933424, 'mmd2_gamma_1.0': 0.005951318362085234, 'wasserstein_distance': 0.6079471027342879, 'R2_feature_means': 0.9974405647120288}\n",
      "**************** Run: 4 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=1.2356 g_loss=-6.8029 | train mmd=0.3934 | test_mmd=1.3949\n",
      "[CellOT] epoch=50 f_loss=-2.4821 g_loss=-0.3500 | train mmd=0.4714 | test_mmd=0.3359\n",
      "[CellOT] epoch=100 f_loss=-2.6842 g_loss=2.1810 | train mmd=0.6575 | test_mmd=0.2379\n",
      "[CellOT] epoch=150 f_loss=-3.8595 g_loss=3.7548 | train mmd=0.6701 | test_mmd=0.2096\n",
      "[CellOT] epoch=200 f_loss=-5.7595 g_loss=6.0988 | train mmd=0.6498 | test_mmd=0.1805\n",
      "[CellOT] epoch=250 f_loss=-7.0644 g_loss=8.0169 | train mmd=0.6107 | test_mmd=0.1481\n",
      "[CellOT] epoch=300 f_loss=-7.9892 g_loss=9.7728 | train mmd=0.5634 | test_mmd=0.1182\n",
      "[CellOT] epoch=350 f_loss=-8.3128 g_loss=10.6935 | train mmd=0.5145 | test_mmd=0.0925\n",
      "[CellOT] epoch=400 f_loss=-9.1530 g_loss=12.2671 | train mmd=0.4676 | test_mmd=0.0701\n",
      "[CellOT] epoch=450 f_loss=-9.5314 g_loss=13.5709 | train mmd=0.4066 | test_mmd=0.0527\n",
      "[CellOT] epoch=500 f_loss=-9.5814 g_loss=13.4305 | train mmd=0.3565 | test_mmd=0.0378\n",
      "[CellOT] epoch=550 f_loss=-8.2230 g_loss=14.7940 | train mmd=0.3160 | test_mmd=0.0258\n",
      "[CellOT] epoch=600 f_loss=-8.2825 g_loss=15.0977 | train mmd=0.2737 | test_mmd=0.0167\n",
      "[CellOT] epoch=650 f_loss=-6.7328 g_loss=13.8325 | train mmd=0.2157 | test_mmd=0.0092\n",
      "[CellOT] epoch=700 f_loss=-5.8056 g_loss=13.4629 | train mmd=0.1839 | test_mmd=0.0052\n",
      "[CellOT] epoch=750 f_loss=-3.4923 g_loss=12.0444 | train mmd=0.1067 | test_mmd=0.0033\n",
      "[CellOT] epoch=800 f_loss=-0.0004 g_loss=10.1689 | train mmd=0.0485 | test_mmd=0.0031\n",
      "[CellOT] epoch=850 f_loss=0.6585 g_loss=9.6297 | train mmd=0.0349 | test_mmd=0.0013\n",
      "[CellOT] epoch=900 f_loss=0.3553 g_loss=10.4668 | train mmd=0.0342 | test_mmd=0.0026\n",
      "[CellOT] epoch=950 f_loss=0.3507 g_loss=11.4868 | train mmd=0.0309 | test_mmd=0.0016\n",
      "[CellOT] epoch=1000 f_loss=0.0620 g_loss=12.6248 | train mmd=0.0183 | test_mmd=0.0030\n",
      "[CellOT] epoch=1050 f_loss=-0.1884 g_loss=13.1322 | train mmd=0.0162 | test_mmd=0.0012\n",
      "[CellOT] epoch=1100 f_loss=0.2190 g_loss=13.5846 | train mmd=0.0127 | test_mmd=0.0011\n",
      "[CellOT] epoch=1150 f_loss=0.3845 g_loss=13.9129 | train mmd=0.0131 | test_mmd=0.0009\n",
      "[CellOT] epoch=1200 f_loss=-0.1294 g_loss=14.2197 | train mmd=0.0087 | test_mmd=0.0018\n",
      "[CellOT] epoch=1250 f_loss=0.0742 g_loss=14.5898 | train mmd=0.0090 | test_mmd=0.0032\n",
      "[CellOT] epoch=1300 f_loss=0.3733 g_loss=14.4673 | train mmd=0.0083 | test_mmd=0.0009\n",
      "[CellOT] epoch=1350 f_loss=0.5837 g_loss=14.5942 | train mmd=0.0063 | test_mmd=0.0010\n",
      "[CellOT] epoch=1400 f_loss=-0.1714 g_loss=14.8963 | train mmd=0.0056 | test_mmd=0.0008\n",
      "[CellOT] epoch=1450 f_loss=0.0670 g_loss=14.7206 | train mmd=0.0062 | test_mmd=0.0007\n",
      "[CellOT] epoch=1500 f_loss=0.1156 g_loss=14.4994 | train mmd=0.0042 | test_mmd=0.0009\n",
      "[CellOT] epoch=1550 f_loss=-0.2964 g_loss=14.6501 | train mmd=0.0037 | test_mmd=0.0006\n",
      "[CellOT] epoch=1600 f_loss=0.2466 g_loss=14.7033 | train mmd=0.0054 | test_mmd=0.0026\n",
      "[CellOT] epoch=1650 f_loss=0.0804 g_loss=14.4207 | train mmd=0.0037 | test_mmd=0.0009\n",
      "[CellOT] epoch=1700 f_loss=-0.2255 g_loss=14.3580 | train mmd=0.0039 | test_mmd=0.0007\n",
      "[CellOT] epoch=1750 f_loss=0.0061 g_loss=14.5253 | train mmd=0.0054 | test_mmd=0.0015\n",
      "[CellOT] epoch=1800 f_loss=0.5596 g_loss=14.9843 | train mmd=0.0069 | test_mmd=0.0015\n",
      "[CellOT] epoch=1850 f_loss=0.3060 g_loss=14.8493 | train mmd=0.0035 | test_mmd=0.0014\n",
      "[CellOT] epoch=1900 f_loss=-0.0535 g_loss=14.7472 | train mmd=0.0044 | test_mmd=0.0012\n",
      "[CellOT] epoch=1950 f_loss=-0.2076 g_loss=14.9545 | train mmd=0.0042 | test_mmd=0.0011\n",
      "[CellOT] epoch=2000 f_loss=0.0077 g_loss=14.7950 | train mmd=0.0060 | test_mmd=0.0022\n",
      "[CellOT] Final CellOT MMD: 0.0028\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 4 metrics: {'mmd2_gamma_median': 0.0021885016910447863, 'mmd2_gamma_0.5': 0.005442121128773514, 'mmd2_gamma_1.0': 0.006106373459866199, 'wasserstein_distance': 0.6150786337918093, 'R2_feature_means': 0.9965964623464285}\n",
      "**************** Run: 5 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=1.5064 g_loss=-5.9721 | train mmd=0.4594 | test_mmd=1.3559\n",
      "[CellOT] epoch=50 f_loss=-2.1758 g_loss=-0.2058 | train mmd=0.4477 | test_mmd=0.3056\n",
      "[CellOT] epoch=100 f_loss=-2.7468 g_loss=2.8433 | train mmd=0.6890 | test_mmd=0.2486\n",
      "[CellOT] epoch=150 f_loss=-4.7134 g_loss=4.4206 | train mmd=0.7069 | test_mmd=0.2211\n",
      "[CellOT] epoch=200 f_loss=-6.3939 g_loss=6.5845 | train mmd=0.6455 | test_mmd=0.1812\n",
      "[CellOT] epoch=250 f_loss=-7.9323 g_loss=7.5431 | train mmd=0.6047 | test_mmd=0.1468\n",
      "[CellOT] epoch=300 f_loss=-8.8027 g_loss=10.0735 | train mmd=0.5942 | test_mmd=0.1266\n",
      "[CellOT] epoch=350 f_loss=-10.3706 g_loss=11.8534 | train mmd=0.5436 | test_mmd=0.0998\n",
      "[CellOT] epoch=400 f_loss=-9.4889 g_loss=13.4175 | train mmd=0.4969 | test_mmd=0.0771\n",
      "[CellOT] epoch=450 f_loss=-11.0181 g_loss=15.3177 | train mmd=0.4757 | test_mmd=0.0625\n",
      "[CellOT] epoch=500 f_loss=-11.0189 g_loss=14.4836 | train mmd=0.3876 | test_mmd=0.0435\n",
      "[CellOT] epoch=550 f_loss=-10.0819 g_loss=17.1984 | train mmd=0.3812 | test_mmd=0.0347\n",
      "[CellOT] epoch=600 f_loss=-11.6763 g_loss=18.0826 | train mmd=0.3485 | test_mmd=0.0243\n",
      "[CellOT] epoch=650 f_loss=-11.3142 g_loss=17.4058 | train mmd=0.2872 | test_mmd=0.0166\n",
      "[CellOT] epoch=700 f_loss=-7.0830 g_loss=20.9376 | train mmd=0.2515 | test_mmd=0.0112\n",
      "[CellOT] epoch=750 f_loss=-5.8628 g_loss=19.0439 | train mmd=0.2179 | test_mmd=0.0070\n",
      "[CellOT] epoch=800 f_loss=-1.4662 g_loss=18.9696 | train mmd=0.1555 | test_mmd=0.0043\n",
      "[CellOT] epoch=850 f_loss=1.4036 g_loss=18.5368 | train mmd=0.0581 | test_mmd=0.0052\n",
      "[CellOT] epoch=900 f_loss=0.6297 g_loss=17.4290 | train mmd=0.0353 | test_mmd=0.0023\n",
      "[CellOT] epoch=950 f_loss=0.4170 g_loss=18.2736 | train mmd=0.0287 | test_mmd=0.0013\n",
      "[CellOT] epoch=1000 f_loss=0.2409 g_loss=19.0788 | train mmd=0.0250 | test_mmd=0.0015\n",
      "[CellOT] epoch=1050 f_loss=0.0744 g_loss=20.5391 | train mmd=0.0283 | test_mmd=0.0020\n",
      "[CellOT] epoch=1100 f_loss=-0.3998 g_loss=20.7692 | train mmd=0.0180 | test_mmd=0.0012\n",
      "[CellOT] epoch=1150 f_loss=0.1031 g_loss=21.5285 | train mmd=0.0152 | test_mmd=0.0026\n",
      "[CellOT] epoch=1200 f_loss=0.0177 g_loss=21.9354 | train mmd=0.0111 | test_mmd=0.0028\n",
      "[CellOT] epoch=1250 f_loss=0.1090 g_loss=22.3233 | train mmd=0.0109 | test_mmd=0.0013\n",
      "[CellOT] epoch=1300 f_loss=0.0843 g_loss=22.4839 | train mmd=0.0136 | test_mmd=0.0015\n",
      "[CellOT] epoch=1350 f_loss=-0.1765 g_loss=22.4486 | train mmd=0.0089 | test_mmd=0.0012\n",
      "[CellOT] epoch=1400 f_loss=0.0710 g_loss=22.8317 | train mmd=0.0077 | test_mmd=0.0008\n",
      "[CellOT] epoch=1450 f_loss=-0.2063 g_loss=22.8722 | train mmd=0.0093 | test_mmd=0.0017\n",
      "[CellOT] epoch=1500 f_loss=0.1787 g_loss=22.8469 | train mmd=0.0082 | test_mmd=0.0010\n",
      "[CellOT] epoch=1550 f_loss=-0.1266 g_loss=22.8223 | train mmd=0.0074 | test_mmd=0.0012\n",
      "[CellOT] epoch=1600 f_loss=-0.0201 g_loss=23.0000 | train mmd=0.0065 | test_mmd=0.0010\n",
      "[CellOT] epoch=1650 f_loss=0.4692 g_loss=22.6172 | train mmd=0.0058 | test_mmd=0.0019\n",
      "[CellOT] epoch=1700 f_loss=-0.2225 g_loss=22.5577 | train mmd=0.0066 | test_mmd=0.0022\n",
      "[CellOT] epoch=1750 f_loss=0.2762 g_loss=22.3319 | train mmd=0.0079 | test_mmd=0.0019\n",
      "[CellOT] epoch=1800 f_loss=-0.2191 g_loss=23.0082 | train mmd=0.0067 | test_mmd=0.0022\n",
      "[CellOT] epoch=1850 f_loss=0.0134 g_loss=22.7003 | train mmd=0.0049 | test_mmd=0.0014\n",
      "[CellOT] epoch=1900 f_loss=0.0681 g_loss=22.8442 | train mmd=0.0050 | test_mmd=0.0019\n",
      "[CellOT] epoch=1950 f_loss=-0.0570 g_loss=22.5759 | train mmd=0.0049 | test_mmd=0.0009\n",
      "[CellOT] epoch=2000 f_loss=-0.0010 g_loss=22.7123 | train mmd=0.0048 | test_mmd=0.0010\n",
      "[CellOT] Final CellOT MMD: 0.0024\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 5 metrics: {'mmd2_gamma_median': 0.0010316909743250946, 'mmd2_gamma_0.5': 0.005227923367917886, 'mmd2_gamma_1.0': 0.0065241482976665655, 'wasserstein_distance': 0.6053150336773766, 'R2_feature_means': 0.9989998840272862}\n",
      "**************** Run: 6 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=0.4910 g_loss=-4.0445 | train mmd=0.4120 | test_mmd=1.2311\n",
      "[CellOT] epoch=50 f_loss=-1.9229 g_loss=0.5200 | train mmd=0.5122 | test_mmd=0.2674\n",
      "[CellOT] epoch=100 f_loss=-2.8023 g_loss=2.5761 | train mmd=0.6584 | test_mmd=0.2230\n",
      "[CellOT] epoch=150 f_loss=-4.3752 g_loss=5.1142 | train mmd=0.6811 | test_mmd=0.2082\n",
      "[CellOT] epoch=200 f_loss=-6.1995 g_loss=6.4059 | train mmd=0.6452 | test_mmd=0.1721\n",
      "[CellOT] epoch=250 f_loss=-7.5959 g_loss=8.6431 | train mmd=0.6114 | test_mmd=0.1423\n",
      "[CellOT] epoch=300 f_loss=-9.0780 g_loss=9.9253 | train mmd=0.5733 | test_mmd=0.1149\n",
      "[CellOT] epoch=350 f_loss=-10.2099 g_loss=12.0555 | train mmd=0.5120 | test_mmd=0.0870\n",
      "[CellOT] epoch=400 f_loss=-10.1459 g_loss=13.7154 | train mmd=0.4905 | test_mmd=0.0712\n",
      "[CellOT] epoch=450 f_loss=-11.1079 g_loss=14.3968 | train mmd=0.4371 | test_mmd=0.0531\n",
      "[CellOT] epoch=500 f_loss=-9.9659 g_loss=15.5713 | train mmd=0.4053 | test_mmd=0.0408\n",
      "[CellOT] epoch=550 f_loss=-10.1939 g_loss=16.2699 | train mmd=0.3646 | test_mmd=0.0285\n",
      "[CellOT] epoch=600 f_loss=-9.2339 g_loss=15.6855 | train mmd=0.2961 | test_mmd=0.0163\n",
      "[CellOT] epoch=650 f_loss=-7.9630 g_loss=16.2657 | train mmd=0.2487 | test_mmd=0.0093\n",
      "[CellOT] epoch=700 f_loss=-4.5495 g_loss=13.8620 | train mmd=0.1741 | test_mmd=0.0046\n",
      "[CellOT] epoch=750 f_loss=-0.8534 g_loss=12.7669 | train mmd=0.1168 | test_mmd=0.0024\n",
      "[CellOT] epoch=800 f_loss=1.9250 g_loss=9.3683 | train mmd=0.0444 | test_mmd=0.0019\n",
      "[CellOT] epoch=850 f_loss=0.7879 g_loss=9.8183 | train mmd=0.0374 | test_mmd=0.0017\n",
      "[CellOT] epoch=900 f_loss=0.3297 g_loss=11.6191 | train mmd=0.0295 | test_mmd=0.0017\n",
      "[CellOT] epoch=950 f_loss=0.0758 g_loss=12.5307 | train mmd=0.0224 | test_mmd=0.0020\n",
      "[CellOT] epoch=1000 f_loss=0.2661 g_loss=13.0487 | train mmd=0.0197 | test_mmd=0.0014\n",
      "[CellOT] epoch=1050 f_loss=0.4055 g_loss=13.5657 | train mmd=0.0174 | test_mmd=0.0020\n",
      "[CellOT] epoch=1100 f_loss=0.3340 g_loss=13.9573 | train mmd=0.0157 | test_mmd=0.0014\n",
      "[CellOT] epoch=1150 f_loss=0.1172 g_loss=14.0226 | train mmd=0.0157 | test_mmd=0.0022\n",
      "[CellOT] epoch=1200 f_loss=0.1561 g_loss=14.4103 | train mmd=0.0111 | test_mmd=0.0024\n",
      "[CellOT] epoch=1250 f_loss=-0.1457 g_loss=14.4693 | train mmd=0.0077 | test_mmd=0.0012\n",
      "[CellOT] epoch=1300 f_loss=-0.0695 g_loss=14.5845 | train mmd=0.0081 | test_mmd=0.0024\n",
      "[CellOT] epoch=1350 f_loss=-0.1974 g_loss=15.4874 | train mmd=0.0076 | test_mmd=0.0012\n",
      "[CellOT] epoch=1400 f_loss=-0.1617 g_loss=14.8812 | train mmd=0.0053 | test_mmd=0.0012\n",
      "[CellOT] epoch=1450 f_loss=-0.1627 g_loss=15.3209 | train mmd=0.0063 | test_mmd=0.0024\n",
      "[CellOT] epoch=1500 f_loss=0.2095 g_loss=15.5183 | train mmd=0.0074 | test_mmd=0.0029\n",
      "[CellOT] epoch=1550 f_loss=-0.3512 g_loss=15.4496 | train mmd=0.0049 | test_mmd=0.0008\n",
      "[CellOT] epoch=1600 f_loss=0.5068 g_loss=15.7544 | train mmd=0.0043 | test_mmd=0.0010\n",
      "[CellOT] epoch=1650 f_loss=-0.2656 g_loss=16.0456 | train mmd=0.0058 | test_mmd=0.0029\n",
      "[CellOT] epoch=1700 f_loss=0.5622 g_loss=15.6932 | train mmd=0.0041 | test_mmd=0.0014\n",
      "[CellOT] epoch=1750 f_loss=0.0569 g_loss=16.5932 | train mmd=0.0043 | test_mmd=0.0011\n",
      "[CellOT] epoch=1800 f_loss=0.5684 g_loss=16.5221 | train mmd=0.0042 | test_mmd=0.0022\n",
      "[CellOT] epoch=1850 f_loss=0.2466 g_loss=16.1955 | train mmd=0.0036 | test_mmd=0.0012\n",
      "[CellOT] epoch=1900 f_loss=0.1547 g_loss=16.5419 | train mmd=0.0030 | test_mmd=0.0006\n",
      "[CellOT] epoch=1950 f_loss=-0.1632 g_loss=16.4334 | train mmd=0.0040 | test_mmd=0.0021\n",
      "[CellOT] epoch=2000 f_loss=0.2646 g_loss=16.9374 | train mmd=0.0031 | test_mmd=0.0007\n",
      "[CellOT] Final CellOT MMD: 0.0020\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 6 metrics: {'mmd2_gamma_median': 0.0007218063774721006, 'mmd2_gamma_0.5': 0.003849972260665191, 'mmd2_gamma_1.0': 0.005177639670217227, 'wasserstein_distance': 0.5966285151993378, 'R2_feature_means': 0.9992468670367688}\n",
      "**************** Run: 7 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=0.0218 g_loss=-2.8122 | train mmd=0.4468 | test_mmd=0.9887\n",
      "[CellOT] epoch=50 f_loss=-2.5749 g_loss=0.0827 | train mmd=0.5261 | test_mmd=0.2680\n",
      "[CellOT] epoch=100 f_loss=-2.6298 g_loss=2.6157 | train mmd=0.6747 | test_mmd=0.2371\n",
      "[CellOT] epoch=150 f_loss=-4.3967 g_loss=5.1058 | train mmd=0.6637 | test_mmd=0.2063\n",
      "[CellOT] epoch=200 f_loss=-5.6245 g_loss=6.0469 | train mmd=0.6294 | test_mmd=0.1683\n",
      "[CellOT] epoch=250 f_loss=-6.6234 g_loss=7.6768 | train mmd=0.6080 | test_mmd=0.1435\n",
      "[CellOT] epoch=300 f_loss=-7.6979 g_loss=8.6253 | train mmd=0.5556 | test_mmd=0.1125\n",
      "[CellOT] epoch=350 f_loss=-9.1724 g_loss=10.2012 | train mmd=0.4900 | test_mmd=0.0843\n",
      "[CellOT] epoch=400 f_loss=-9.2470 g_loss=12.5019 | train mmd=0.4573 | test_mmd=0.0640\n",
      "[CellOT] epoch=450 f_loss=-8.1788 g_loss=11.1014 | train mmd=0.4070 | test_mmd=0.0483\n",
      "[CellOT] epoch=500 f_loss=-8.5128 g_loss=12.9476 | train mmd=0.3577 | test_mmd=0.0338\n",
      "[CellOT] epoch=550 f_loss=-6.8245 g_loss=12.9085 | train mmd=0.2905 | test_mmd=0.0206\n",
      "[CellOT] epoch=600 f_loss=-5.6998 g_loss=13.6321 | train mmd=0.2412 | test_mmd=0.0128\n",
      "[CellOT] epoch=650 f_loss=-4.9626 g_loss=13.6611 | train mmd=0.2174 | test_mmd=0.0079\n",
      "[CellOT] epoch=700 f_loss=-2.9864 g_loss=12.2572 | train mmd=0.1429 | test_mmd=0.0036\n",
      "[CellOT] epoch=750 f_loss=-1.4977 g_loss=10.6389 | train mmd=0.1125 | test_mmd=0.0025\n",
      "[CellOT] epoch=800 f_loss=1.4409 g_loss=8.9300 | train mmd=0.0161 | test_mmd=0.0018\n",
      "[CellOT] epoch=850 f_loss=0.5205 g_loss=11.6693 | train mmd=0.0184 | test_mmd=0.0012\n",
      "[CellOT] epoch=900 f_loss=0.2166 g_loss=12.5632 | train mmd=0.0205 | test_mmd=0.0014\n",
      "[CellOT] epoch=950 f_loss=0.6915 g_loss=14.5697 | train mmd=0.0179 | test_mmd=0.0054\n",
      "[CellOT] epoch=1000 f_loss=0.0531 g_loss=15.5685 | train mmd=0.0158 | test_mmd=0.0015\n",
      "[CellOT] epoch=1050 f_loss=0.1331 g_loss=16.5218 | train mmd=0.0206 | test_mmd=0.0034\n",
      "[CellOT] epoch=1100 f_loss=-0.0525 g_loss=17.2439 | train mmd=0.0110 | test_mmd=0.0030\n",
      "[CellOT] epoch=1150 f_loss=0.0458 g_loss=16.9199 | train mmd=0.0092 | test_mmd=0.0014\n",
      "[CellOT] epoch=1200 f_loss=-0.0642 g_loss=18.1313 | train mmd=0.0103 | test_mmd=0.0030\n",
      "[CellOT] epoch=1250 f_loss=0.0773 g_loss=18.1469 | train mmd=0.0110 | test_mmd=0.0014\n",
      "[CellOT] epoch=1300 f_loss=0.3495 g_loss=18.3647 | train mmd=0.0135 | test_mmd=0.0028\n",
      "[CellOT] epoch=1350 f_loss=0.3124 g_loss=18.4145 | train mmd=0.0087 | test_mmd=0.0035\n",
      "[CellOT] epoch=1400 f_loss=0.2405 g_loss=18.6922 | train mmd=0.0108 | test_mmd=0.0026\n",
      "[CellOT] epoch=1450 f_loss=0.1809 g_loss=19.1362 | train mmd=0.0059 | test_mmd=0.0012\n",
      "[CellOT] epoch=1500 f_loss=0.1981 g_loss=19.2125 | train mmd=0.0066 | test_mmd=0.0015\n",
      "[CellOT] epoch=1550 f_loss=0.0820 g_loss=19.1312 | train mmd=0.0094 | test_mmd=0.0017\n",
      "[CellOT] epoch=1600 f_loss=-0.2362 g_loss=19.5791 | train mmd=0.0051 | test_mmd=0.0012\n",
      "[CellOT] epoch=1650 f_loss=-0.2223 g_loss=19.8434 | train mmd=0.0093 | test_mmd=0.0025\n",
      "[CellOT] epoch=1700 f_loss=0.1474 g_loss=19.4972 | train mmd=0.0051 | test_mmd=0.0007\n",
      "[CellOT] epoch=1750 f_loss=-0.1809 g_loss=20.2179 | train mmd=0.0059 | test_mmd=0.0023\n",
      "[CellOT] epoch=1800 f_loss=-0.1001 g_loss=20.2043 | train mmd=0.0067 | test_mmd=0.0018\n",
      "[CellOT] epoch=1850 f_loss=-0.3152 g_loss=20.0628 | train mmd=0.0078 | test_mmd=0.0042\n",
      "[CellOT] epoch=1900 f_loss=-0.1636 g_loss=20.5288 | train mmd=0.0063 | test_mmd=0.0026\n",
      "[CellOT] epoch=1950 f_loss=0.7695 g_loss=20.1913 | train mmd=0.0050 | test_mmd=0.0011\n",
      "[CellOT] epoch=2000 f_loss=0.0124 g_loss=20.0968 | train mmd=0.0054 | test_mmd=0.0007\n",
      "[CellOT] Final CellOT MMD: 0.0021\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 7 metrics: {'mmd2_gamma_median': 0.0007374098812413798, 'mmd2_gamma_0.5': 0.004558680958698824, 'mmd2_gamma_1.0': 0.005737863146797251, 'wasserstein_distance': 0.6160920339707451, 'R2_feature_means': 0.9991150202406964}\n",
      "**************** Run: 8 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=1.5957 g_loss=-4.4153 | train mmd=0.4190 | test_mmd=1.1334\n",
      "[CellOT] epoch=50 f_loss=-2.0146 g_loss=0.1903 | train mmd=0.4758 | test_mmd=0.2579\n",
      "[CellOT] epoch=100 f_loss=-2.9091 g_loss=2.1354 | train mmd=0.6733 | test_mmd=0.2311\n",
      "[CellOT] epoch=150 f_loss=-4.4628 g_loss=5.4215 | train mmd=0.7008 | test_mmd=0.2100\n",
      "[CellOT] epoch=200 f_loss=-6.4984 g_loss=7.0216 | train mmd=0.6532 | test_mmd=0.1708\n",
      "[CellOT] epoch=250 f_loss=-7.3479 g_loss=8.4152 | train mmd=0.6286 | test_mmd=0.1434\n",
      "[CellOT] epoch=300 f_loss=-8.3354 g_loss=11.2761 | train mmd=0.5880 | test_mmd=0.1161\n",
      "[CellOT] epoch=350 f_loss=-9.4553 g_loss=11.7949 | train mmd=0.5367 | test_mmd=0.0926\n",
      "[CellOT] epoch=400 f_loss=-10.1883 g_loss=12.9789 | train mmd=0.5073 | test_mmd=0.0754\n",
      "[CellOT] epoch=450 f_loss=-10.4643 g_loss=15.3211 | train mmd=0.4760 | test_mmd=0.0592\n",
      "[CellOT] epoch=500 f_loss=-9.4750 g_loss=15.6545 | train mmd=0.4064 | test_mmd=0.0407\n",
      "[CellOT] epoch=550 f_loss=-8.6773 g_loss=16.3277 | train mmd=0.3631 | test_mmd=0.0281\n",
      "[CellOT] epoch=600 f_loss=-9.4933 g_loss=18.0976 | train mmd=0.3322 | test_mmd=0.0206\n",
      "[CellOT] epoch=650 f_loss=-9.5836 g_loss=17.8804 | train mmd=0.2980 | test_mmd=0.0144\n",
      "[CellOT] epoch=700 f_loss=-7.6229 g_loss=18.8754 | train mmd=0.2548 | test_mmd=0.0084\n",
      "[CellOT] epoch=750 f_loss=-4.6827 g_loss=16.9509 | train mmd=0.2139 | test_mmd=0.0046\n",
      "[CellOT] epoch=800 f_loss=-0.2031 g_loss=12.1334 | train mmd=0.0824 | test_mmd=0.0024\n",
      "[CellOT] epoch=850 f_loss=1.0573 g_loss=11.3844 | train mmd=0.0256 | test_mmd=0.0019\n",
      "[CellOT] epoch=900 f_loss=0.4494 g_loss=12.4382 | train mmd=0.0291 | test_mmd=0.0017\n",
      "[CellOT] epoch=950 f_loss=0.2277 g_loss=13.7056 | train mmd=0.0186 | test_mmd=0.0016\n",
      "[CellOT] epoch=1000 f_loss=0.3319 g_loss=14.5104 | train mmd=0.0203 | test_mmd=0.0025\n",
      "[CellOT] epoch=1050 f_loss=0.5185 g_loss=15.1044 | train mmd=0.0171 | test_mmd=0.0027\n",
      "[CellOT] epoch=1100 f_loss=0.0082 g_loss=15.9551 | train mmd=0.0160 | test_mmd=0.0021\n",
      "[CellOT] epoch=1150 f_loss=0.2176 g_loss=15.9688 | train mmd=0.0111 | test_mmd=0.0024\n",
      "[CellOT] epoch=1200 f_loss=-0.1683 g_loss=16.2960 | train mmd=0.0078 | test_mmd=0.0004\n",
      "[CellOT] epoch=1250 f_loss=-0.1685 g_loss=16.7357 | train mmd=0.0142 | test_mmd=0.0033\n",
      "[CellOT] epoch=1300 f_loss=0.6550 g_loss=17.0718 | train mmd=0.0098 | test_mmd=0.0060\n",
      "[CellOT] epoch=1350 f_loss=0.1246 g_loss=17.1647 | train mmd=0.0104 | test_mmd=0.0043\n",
      "[CellOT] epoch=1400 f_loss=0.1504 g_loss=17.4762 | train mmd=0.0065 | test_mmd=0.0008\n",
      "[CellOT] epoch=1450 f_loss=-0.2180 g_loss=17.8222 | train mmd=0.0078 | test_mmd=0.0019\n",
      "[CellOT] epoch=1500 f_loss=-0.2380 g_loss=18.0112 | train mmd=0.0057 | test_mmd=0.0016\n",
      "[CellOT] epoch=1550 f_loss=-0.3376 g_loss=18.1293 | train mmd=0.0071 | test_mmd=0.0024\n",
      "[CellOT] epoch=1600 f_loss=0.2788 g_loss=17.9971 | train mmd=0.0079 | test_mmd=0.0027\n",
      "[CellOT] epoch=1650 f_loss=0.0734 g_loss=17.9706 | train mmd=0.0047 | test_mmd=0.0023\n",
      "[CellOT] epoch=1700 f_loss=0.6081 g_loss=18.0115 | train mmd=0.0053 | test_mmd=0.0013\n",
      "[CellOT] epoch=1750 f_loss=0.3866 g_loss=18.4535 | train mmd=0.0071 | test_mmd=0.0016\n",
      "[CellOT] epoch=1800 f_loss=-0.0395 g_loss=18.0266 | train mmd=0.0057 | test_mmd=0.0025\n",
      "[CellOT] epoch=1850 f_loss=-0.2953 g_loss=18.2926 | train mmd=0.0033 | test_mmd=0.0007\n",
      "[CellOT] epoch=1900 f_loss=0.1506 g_loss=17.7452 | train mmd=0.0033 | test_mmd=0.0010\n",
      "[CellOT] epoch=1950 f_loss=-0.3256 g_loss=18.5687 | train mmd=0.0061 | test_mmd=0.0020\n",
      "[CellOT] epoch=2000 f_loss=0.0241 g_loss=18.0954 | train mmd=0.0049 | test_mmd=0.0019\n",
      "[CellOT] Final CellOT MMD: 0.0028\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 8 metrics: {'mmd2_gamma_median': 0.0018728046209228744, 'mmd2_gamma_0.5': 0.005815766138460021, 'mmd2_gamma_1.0': 0.006482459634980231, 'wasserstein_distance': 0.609433181302826, 'R2_feature_means': 0.9975236818763463}\n",
      "**************** Run: 9 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=-0.7550 g_loss=-6.5383 | train mmd=0.3823 | test_mmd=1.4637\n",
      "[CellOT] epoch=50 f_loss=-1.3337 g_loss=0.3096 | train mmd=0.5685 | test_mmd=0.2876\n",
      "[CellOT] epoch=100 f_loss=-2.8117 g_loss=2.1314 | train mmd=0.6677 | test_mmd=0.2347\n",
      "[CellOT] epoch=150 f_loss=-4.6845 g_loss=4.5533 | train mmd=0.6743 | test_mmd=0.2071\n",
      "[CellOT] epoch=200 f_loss=-5.7380 g_loss=6.6481 | train mmd=0.6443 | test_mmd=0.1705\n",
      "[CellOT] epoch=250 f_loss=-7.6351 g_loss=8.0465 | train mmd=0.5957 | test_mmd=0.1389\n",
      "[CellOT] epoch=300 f_loss=-7.8210 g_loss=10.1745 | train mmd=0.5498 | test_mmd=0.1134\n",
      "[CellOT] epoch=350 f_loss=-9.1816 g_loss=11.5537 | train mmd=0.5206 | test_mmd=0.0929\n",
      "[CellOT] epoch=400 f_loss=-9.8632 g_loss=11.4419 | train mmd=0.4831 | test_mmd=0.0729\n",
      "[CellOT] epoch=450 f_loss=-10.2393 g_loss=12.9908 | train mmd=0.4373 | test_mmd=0.0538\n",
      "[CellOT] epoch=500 f_loss=-10.4462 g_loss=14.4519 | train mmd=0.4229 | test_mmd=0.0447\n",
      "[CellOT] epoch=550 f_loss=-10.0927 g_loss=14.9911 | train mmd=0.3469 | test_mmd=0.0294\n",
      "[CellOT] epoch=600 f_loss=-8.6379 g_loss=15.1081 | train mmd=0.3074 | test_mmd=0.0188\n",
      "[CellOT] epoch=650 f_loss=-7.7575 g_loss=15.4473 | train mmd=0.2481 | test_mmd=0.0110\n",
      "[CellOT] epoch=700 f_loss=-6.3514 g_loss=15.1159 | train mmd=0.2343 | test_mmd=0.0077\n",
      "[CellOT] epoch=750 f_loss=-4.0800 g_loss=13.5104 | train mmd=0.1786 | test_mmd=0.0041\n",
      "[CellOT] epoch=800 f_loss=-0.6085 g_loss=12.4024 | train mmd=0.1024 | test_mmd=0.0038\n",
      "[CellOT] epoch=850 f_loss=0.9931 g_loss=10.2812 | train mmd=0.0512 | test_mmd=0.0025\n",
      "[CellOT] epoch=900 f_loss=0.8380 g_loss=11.4249 | train mmd=0.0293 | test_mmd=0.0018\n",
      "[CellOT] epoch=950 f_loss=0.8220 g_loss=12.5235 | train mmd=0.0275 | test_mmd=0.0013\n",
      "[CellOT] epoch=1000 f_loss=0.2087 g_loss=13.3837 | train mmd=0.0307 | test_mmd=0.0018\n",
      "[CellOT] epoch=1050 f_loss=0.1740 g_loss=14.2815 | train mmd=0.0204 | test_mmd=0.0015\n",
      "[CellOT] epoch=1100 f_loss=-0.2553 g_loss=15.3285 | train mmd=0.0142 | test_mmd=0.0018\n",
      "[CellOT] epoch=1150 f_loss=-0.0284 g_loss=16.1750 | train mmd=0.0126 | test_mmd=0.0017\n",
      "[CellOT] epoch=1200 f_loss=0.4528 g_loss=16.6283 | train mmd=0.0098 | test_mmd=0.0025\n",
      "[CellOT] epoch=1250 f_loss=0.1606 g_loss=16.5288 | train mmd=0.0075 | test_mmd=0.0016\n",
      "[CellOT] epoch=1300 f_loss=-0.1429 g_loss=16.8656 | train mmd=0.0082 | test_mmd=0.0010\n",
      "[CellOT] epoch=1350 f_loss=-0.0212 g_loss=17.0815 | train mmd=0.0060 | test_mmd=0.0016\n",
      "[CellOT] epoch=1400 f_loss=-0.0700 g_loss=17.2961 | train mmd=0.0089 | test_mmd=0.0027\n",
      "[CellOT] epoch=1450 f_loss=0.5940 g_loss=17.3287 | train mmd=0.0065 | test_mmd=0.0016\n",
      "[CellOT] epoch=1500 f_loss=0.0522 g_loss=17.5278 | train mmd=0.0053 | test_mmd=0.0015\n",
      "[CellOT] epoch=1550 f_loss=-0.2928 g_loss=17.5524 | train mmd=0.0064 | test_mmd=0.0019\n",
      "[CellOT] epoch=1600 f_loss=-0.0150 g_loss=17.3493 | train mmd=0.0047 | test_mmd=0.0007\n",
      "[CellOT] epoch=1650 f_loss=-0.0347 g_loss=17.6352 | train mmd=0.0033 | test_mmd=0.0008\n",
      "[CellOT] epoch=1700 f_loss=0.6103 g_loss=18.0083 | train mmd=0.0046 | test_mmd=0.0012\n",
      "[CellOT] epoch=1750 f_loss=-0.0498 g_loss=16.9546 | train mmd=0.0060 | test_mmd=0.0051\n",
      "[CellOT] epoch=1800 f_loss=-0.2986 g_loss=18.1734 | train mmd=0.0066 | test_mmd=0.0018\n",
      "[CellOT] epoch=1850 f_loss=0.0534 g_loss=17.8749 | train mmd=0.0049 | test_mmd=0.0017\n",
      "[CellOT] epoch=1900 f_loss=-0.2008 g_loss=18.4581 | train mmd=0.0041 | test_mmd=0.0011\n",
      "[CellOT] epoch=1950 f_loss=0.0560 g_loss=17.7956 | train mmd=0.0041 | test_mmd=0.0010\n",
      "[CellOT] epoch=2000 f_loss=0.2298 g_loss=18.1747 | train mmd=0.0038 | test_mmd=0.0010\n",
      "[CellOT] Final CellOT MMD: 0.0022\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.\n",
      "  warnings.warn(\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.\n",
      "  warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 9 metrics: {'mmd2_gamma_median': 0.0009923370626707673, 'mmd2_gamma_0.5': 0.004976473842692064, 'mmd2_gamma_1.0': 0.00565701149007547, 'wasserstein_distance': 0.5983066609104358, 'R2_feature_means': 0.9995492797218528}\n",
      "                        mean     std\n",
      "mmd2_gamma_median     0.0015  0.0006\n",
      "mmd2_gamma_0.5        0.0052  0.0010\n",
      "mmd2_gamma_1.0        0.0061  0.0008\n",
      "wasserstein_distance  0.6090  0.0083\n",
      "R2_feature_means      0.9982  0.0010\n"
     ]
    },
    {
     "ename": "RuntimeError",
     "evalue": "can't start new thread",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mRuntimeError\u001b[0m                              Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 51\u001b[0m\n\u001b[1;32m     48\u001b[0m \u001b[38;5;66;03m# Instantiate UMAP\u001b[39;00m\n\u001b[1;32m     49\u001b[0m umap_model \u001b[38;5;241m=\u001b[39m UMAP(n_components\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m, random_state\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m42\u001b[39m)\n\u001b[0;32m---> 51\u001b[0m all_sample_umap \u001b[38;5;241m=\u001b[39m \u001b[43mumap_model\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_transform\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvstack\u001b[49m\u001b[43m(\u001b[49m\u001b[43m[\u001b[49m\u001b[43msource\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtarget\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     52\u001b[0m source_umap \u001b[38;5;241m=\u001b[39m umap_model\u001b[38;5;241m.\u001b[39mtransform(source)\n\u001b[1;32m     53\u001b[0m target_umap \u001b[38;5;241m=\u001b[39m umap_model\u001b[38;5;241m.\u001b[39mtransform(target)\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/umap/umap_.py:2928\u001b[0m, in \u001b[0;36mUMAP.fit_transform\u001b[0;34m(self, X, y, force_all_finite, **kwargs)\u001b[0m\n\u001b[1;32m   2890\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mfit_transform\u001b[39m(\u001b[38;5;28mself\u001b[39m, X, y\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, force_all_finite\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m   2891\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"Fit X into an embedded space and return that transformed\u001b[39;00m\n\u001b[1;32m   2892\u001b[0m \u001b[38;5;124;03m    output.\u001b[39;00m\n\u001b[1;32m   2893\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   2926\u001b[0m \u001b[38;5;124;03m        Local radii of data points in the embedding (log-transformed).\u001b[39;00m\n\u001b[1;32m   2927\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m-> 2928\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mforce_all_finite\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2929\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtransform_mode \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124membedding\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m   2930\u001b[0m         \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_dens:\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/umap/umap_.py:2635\u001b[0m, in \u001b[0;36mUMAP.fit\u001b[0;34m(self, X, y, force_all_finite, **kwargs)\u001b[0m\n\u001b[1;32m   2629\u001b[0m     nn_metric \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_input_distance_func\n\u001b[1;32m   2630\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mknn_dists \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   2631\u001b[0m     (\n\u001b[1;32m   2632\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_knn_indices,\n\u001b[1;32m   2633\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_knn_dists,\n\u001b[1;32m   2634\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_knn_search_index,\n\u001b[0;32m-> 2635\u001b[0m     ) \u001b[38;5;241m=\u001b[39m \u001b[43mnearest_neighbors\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   2636\u001b[0m \u001b[43m        \u001b[49m\u001b[43mX\u001b[49m\u001b[43m[\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2637\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_n_neighbors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2638\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnn_metric\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2639\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_metric_kwds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2640\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mangular_rp_forest\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2641\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrandom_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2642\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlow_memory\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2643\u001b[0m \u001b[43m        \u001b[49m\u001b[43muse_pynndescent\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m   2644\u001b[0m \u001b[43m        \u001b[49m\u001b[43mn_jobs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mn_jobs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2645\u001b[0m \u001b[43m        \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2646\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2647\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   2648\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_knn_indices \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mknn_indices\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/umap/umap_.py:330\u001b[0m, in \u001b[0;36mnearest_neighbors\u001b[0;34m(X, n_neighbors, metric, metric_kwds, angular, random_state, low_memory, use_pynndescent, n_jobs, verbose)\u001b[0m\n\u001b[1;32m    327\u001b[0m     n_trees \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmin\u001b[39m(\u001b[38;5;241m64\u001b[39m, \u001b[38;5;241m5\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mint\u001b[39m(\u001b[38;5;28mround\u001b[39m((X\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]) \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39m \u001b[38;5;241m0.5\u001b[39m \u001b[38;5;241m/\u001b[39m \u001b[38;5;241m20.0\u001b[39m)))\n\u001b[1;32m    328\u001b[0m     n_iters \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mmax\u001b[39m(\u001b[38;5;241m5\u001b[39m, \u001b[38;5;28mint\u001b[39m(\u001b[38;5;28mround\u001b[39m(np\u001b[38;5;241m.\u001b[39mlog2(X\u001b[38;5;241m.\u001b[39mshape[\u001b[38;5;241m0\u001b[39m]))))\n\u001b[0;32m--> 330\u001b[0m     knn_search_index \u001b[38;5;241m=\u001b[39m \u001b[43mNNDescent\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    331\u001b[0m \u001b[43m        \u001b[49m\u001b[43mX\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    332\u001b[0m \u001b[43m        \u001b[49m\u001b[43mn_neighbors\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_neighbors\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    333\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmetric\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetric\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    334\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmetric_kwds\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmetric_kwds\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    335\u001b[0m \u001b[43m        \u001b[49m\u001b[43mrandom_state\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrandom_state\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    336\u001b[0m \u001b[43m        \u001b[49m\u001b[43mn_trees\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_trees\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    337\u001b[0m \u001b[43m        \u001b[49m\u001b[43mn_iters\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_iters\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    338\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmax_candidates\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m60\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m    339\u001b[0m \u001b[43m        \u001b[49m\u001b[43mlow_memory\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mlow_memory\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    340\u001b[0m \u001b[43m        \u001b[49m\u001b[43mn_jobs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_jobs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    341\u001b[0m \u001b[43m        \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    342\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcompressed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[1;32m    343\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    344\u001b[0m     knn_indices, knn_dists \u001b[38;5;241m=\u001b[39m knn_search_index\u001b[38;5;241m.\u001b[39mneighbor_graph\n\u001b[1;32m    346\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m verbose:\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/pynndescent/pynndescent_.py:806\u001b[0m, in \u001b[0;36mNNDescent.__init__\u001b[0;34m(self, data, metric, metric_kwds, n_neighbors, n_trees, leaf_size, pruning_degree_multiplier, diversify_prob, n_search_trees, tree_init, init_graph, init_dist, random_state, low_memory, max_candidates, max_rptree_depth, n_iters, delta, n_jobs, compressed, parallel_batch_queries, verbose)\u001b[0m\n\u001b[1;32m    793\u001b[0m         \u001b[38;5;28mprint\u001b[39m(ts(), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mBuilding RP forest with\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28mstr\u001b[39m(n_trees), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrees\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m    794\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_rp_forest \u001b[38;5;241m=\u001b[39m make_forest(\n\u001b[1;32m    795\u001b[0m         data,\n\u001b[1;32m    796\u001b[0m         n_neighbors,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    804\u001b[0m         max_depth\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmax_rptree_depth,\n\u001b[1;32m    805\u001b[0m     )\n\u001b[0;32m--> 806\u001b[0m     leaf_array \u001b[38;5;241m=\u001b[39m \u001b[43mrptree_leaf_array\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_rp_forest\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    807\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    808\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_rp_forest \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/pynndescent/rp_trees.py:1436\u001b[0m, in \u001b[0;36mrptree_leaf_array\u001b[0;34m(rp_forest)\u001b[0m\n\u001b[1;32m   1434\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mrptree_leaf_array\u001b[39m(rp_forest):\n\u001b[1;32m   1435\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(rp_forest) \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m-> 1436\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39mvstack(\u001b[43mrptree_leaf_array_parallel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrp_forest\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m   1437\u001b[0m     \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   1438\u001b[0m         \u001b[38;5;28;01mreturn\u001b[39;00m np\u001b[38;5;241m.\u001b[39marray([[\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m]])\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/pynndescent/rp_trees.py:1428\u001b[0m, in \u001b[0;36mrptree_leaf_array_parallel\u001b[0;34m(rp_forest)\u001b[0m\n\u001b[1;32m   1426\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21mrptree_leaf_array_parallel\u001b[39m(rp_forest):\n\u001b[1;32m   1427\u001b[0m     max_leaf_size \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mmax([rp_tree\u001b[38;5;241m.\u001b[39mleaf_size \u001b[38;5;28;01mfor\u001b[39;00m rp_tree \u001b[38;5;129;01min\u001b[39;00m rp_forest])\n\u001b[0;32m-> 1428\u001b[0m     result \u001b[38;5;241m=\u001b[39m \u001b[43mjoblib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mParallel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mn_jobs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrequire\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43msharedmem\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1429\u001b[0m \u001b[43m        \u001b[49m\u001b[43mjoblib\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdelayed\u001b[49m\u001b[43m(\u001b[49m\u001b[43mget_leaves_from_tree\u001b[49m\u001b[43m)\u001b[49m\u001b[43m(\u001b[49m\u001b[43mrp_tree\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_leaf_size\u001b[49m\u001b[43m)\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mrp_tree\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mrp_forest\u001b[49m\n\u001b[1;32m   1430\u001b[0m \u001b[43m    \u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1431\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m result\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/joblib/parallel.py:2070\u001b[0m, in \u001b[0;36mParallel.__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   2064\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_ref \u001b[38;5;241m=\u001b[39m weakref\u001b[38;5;241m.\u001b[39mref(output)\n\u001b[1;32m   2066\u001b[0m \u001b[38;5;66;03m# The first item from the output is blank, but it makes the interpreter\u001b[39;00m\n\u001b[1;32m   2067\u001b[0m \u001b[38;5;66;03m# progress until it enters the Try/Except block of the generator and\u001b[39;00m\n\u001b[1;32m   2068\u001b[0m \u001b[38;5;66;03m# reaches the first `yield` statement. This starts the asynchronous\u001b[39;00m\n\u001b[1;32m   2069\u001b[0m \u001b[38;5;66;03m# dispatch of the tasks to the workers.\u001b[39;00m\n\u001b[0;32m-> 2070\u001b[0m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43moutput\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2072\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m output \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreturn_generator \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mlist\u001b[39m(output)\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/joblib/parallel.py:1675\u001b[0m, in \u001b[0;36mParallel._get_outputs\u001b[0;34m(self, iterator, pre_dispatch)\u001b[0m\n\u001b[1;32m   1673\u001b[0m detach_generator_exit \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m   1674\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m-> 1675\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_start\u001b[49m\u001b[43m(\u001b[49m\u001b[43miterator\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpre_dispatch\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1676\u001b[0m     \u001b[38;5;66;03m# first yield returns None, for internal use only. This ensures\u001b[39;00m\n\u001b[1;32m   1677\u001b[0m     \u001b[38;5;66;03m# that we enter the try/except block and start dispatching the\u001b[39;00m\n\u001b[1;32m   1678\u001b[0m     \u001b[38;5;66;03m# tasks.\u001b[39;00m\n\u001b[1;32m   1679\u001b[0m     \u001b[38;5;28;01myield\u001b[39;00m\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/joblib/parallel.py:1658\u001b[0m, in \u001b[0;36mParallel._start\u001b[0;34m(self, iterator, pre_dispatch)\u001b[0m\n\u001b[1;32m   1649\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_start\u001b[39m(\u001b[38;5;28mself\u001b[39m, iterator, pre_dispatch):\n\u001b[1;32m   1650\u001b[0m     \u001b[38;5;66;03m# Only set self._iterating to True if at least a batch\u001b[39;00m\n\u001b[1;32m   1651\u001b[0m     \u001b[38;5;66;03m# was dispatched. In particular this covers the edge\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1655\u001b[0m     \u001b[38;5;66;03m# was very quick and its callback already dispatched all the\u001b[39;00m\n\u001b[1;32m   1656\u001b[0m     \u001b[38;5;66;03m# remaining jobs.\u001b[39;00m\n\u001b[1;32m   1657\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_iterating \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[0;32m-> 1658\u001b[0m     \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdispatch_one_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43miterator\u001b[49m\u001b[43m)\u001b[49m:\n\u001b[1;32m   1659\u001b[0m         \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_iterating \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_original_iterator \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m   1661\u001b[0m     \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdispatch_one_batch(iterator):\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/joblib/parallel.py:1540\u001b[0m, in \u001b[0;36mParallel.dispatch_one_batch\u001b[0;34m(self, iterator)\u001b[0m\n\u001b[1;32m   1538\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m   1539\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1540\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dispatch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtasks\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1541\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/joblib/parallel.py:1437\u001b[0m, in \u001b[0;36mParallel._dispatch\u001b[0;34m(self, batch)\u001b[0m\n\u001b[1;32m   1430\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_register_new_job(batch_tracker)\n\u001b[1;32m   1432\u001b[0m \u001b[38;5;66;03m# If return_ordered is False, the batch_tracker is not stored in the\u001b[39;00m\n\u001b[1;32m   1433\u001b[0m \u001b[38;5;66;03m# jobs queue at the time of submission. Instead, it will be appended to\u001b[39;00m\n\u001b[1;32m   1434\u001b[0m \u001b[38;5;66;03m# the queue by itself as soon as the callback is triggered to be able\u001b[39;00m\n\u001b[1;32m   1435\u001b[0m \u001b[38;5;66;03m# to return the results in the order of completion.\u001b[39;00m\n\u001b[0;32m-> 1437\u001b[0m job \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_backend\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msubmit\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcallback\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mbatch_tracker\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1438\u001b[0m batch_tracker\u001b[38;5;241m.\u001b[39mregister_job(job)\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/joblib/_parallel_backends.py:339\u001b[0m, in \u001b[0;36mPoolManagerMixin.submit\u001b[0;34m(self, func, callback)\u001b[0m\n\u001b[1;32m    335\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Schedule a func to be run\"\"\"\u001b[39;00m\n\u001b[1;32m    336\u001b[0m \u001b[38;5;66;03m# Here, we need a wrapper to avoid crashes on KeyboardInterruptErrors.\u001b[39;00m\n\u001b[1;32m    337\u001b[0m \u001b[38;5;66;03m# We also call the callback on error, to make sure the pool does not\u001b[39;00m\n\u001b[1;32m    338\u001b[0m \u001b[38;5;66;03m# wait on crashed jobs.\u001b[39;00m\n\u001b[0;32m--> 339\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_get_pool\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mapply_async(\n\u001b[1;32m    340\u001b[0m     _TracebackCapturingWrapper(func),\n\u001b[1;32m    341\u001b[0m     (),\n\u001b[1;32m    342\u001b[0m     callback\u001b[38;5;241m=\u001b[39mcallback,\n\u001b[1;32m    343\u001b[0m     error_callback\u001b[38;5;241m=\u001b[39mcallback,\n\u001b[1;32m    344\u001b[0m )\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/joblib/_parallel_backends.py:507\u001b[0m, in \u001b[0;36mThreadingBackend._get_pool\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    501\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Lazily initialize the thread pool\u001b[39;00m\n\u001b[1;32m    502\u001b[0m \n\u001b[1;32m    503\u001b[0m \u001b[38;5;124;03mThe actual pool of worker threads is only initialized at the first\u001b[39;00m\n\u001b[1;32m    504\u001b[0m \u001b[38;5;124;03mcall to apply_async.\u001b[39;00m\n\u001b[1;32m    505\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    506\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pool \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 507\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pool \u001b[38;5;241m=\u001b[39m \u001b[43mThreadPool\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_n_jobs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    508\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pool\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/multiprocessing/pool.py:927\u001b[0m, in \u001b[0;36mThreadPool.__init__\u001b[0;34m(self, processes, initializer, initargs)\u001b[0m\n\u001b[1;32m    926\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, processes\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, initializer\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, initargs\u001b[38;5;241m=\u001b[39m()):\n\u001b[0;32m--> 927\u001b[0m     \u001b[43mPool\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mprocesses\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minitializer\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minitargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/multiprocessing/pool.py:212\u001b[0m, in \u001b[0;36mPool.__init__\u001b[0;34m(self, processes, initializer, initargs, maxtasksperchild, context)\u001b[0m\n\u001b[1;32m    210\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_processes \u001b[38;5;241m=\u001b[39m processes\n\u001b[1;32m    211\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 212\u001b[0m     \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_repopulate_pool\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    213\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[1;32m    214\u001b[0m     \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pool:\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/multiprocessing/pool.py:303\u001b[0m, in \u001b[0;36mPool._repopulate_pool\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    302\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21m_repopulate_pool\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[0;32m--> 303\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_repopulate_pool_static\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_ctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mProcess\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    304\u001b[0m \u001b[43m                                        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_processes\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    305\u001b[0m \u001b[43m                                        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_pool\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_inqueue\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    306\u001b[0m \u001b[43m                                        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_outqueue\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_initializer\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    307\u001b[0m \u001b[43m                                        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_initargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    308\u001b[0m \u001b[43m                                        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_maxtasksperchild\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    309\u001b[0m \u001b[43m                                        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_wrap_exception\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/multiprocessing/pool.py:326\u001b[0m, in \u001b[0;36mPool._repopulate_pool_static\u001b[0;34m(ctx, Process, processes, pool, inqueue, outqueue, initializer, initargs, maxtasksperchild, wrap_exception)\u001b[0m\n\u001b[1;32m    324\u001b[0m w\u001b[38;5;241m.\u001b[39mname \u001b[38;5;241m=\u001b[39m w\u001b[38;5;241m.\u001b[39mname\u001b[38;5;241m.\u001b[39mreplace(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mProcess\u001b[39m\u001b[38;5;124m'\u001b[39m, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mPoolWorker\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m    325\u001b[0m w\u001b[38;5;241m.\u001b[39mdaemon \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[0;32m--> 326\u001b[0m \u001b[43mw\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstart\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    327\u001b[0m pool\u001b[38;5;241m.\u001b[39mappend(w)\n\u001b[1;32m    328\u001b[0m util\u001b[38;5;241m.\u001b[39mdebug(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124madded worker\u001b[39m\u001b[38;5;124m'\u001b[39m)\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/multiprocessing/dummy/__init__.py:51\u001b[0m, in \u001b[0;36mDummyProcess.start\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m     49\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mhasattr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_parent, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m_children\u001b[39m\u001b[38;5;124m'\u001b[39m):\n\u001b[1;32m     50\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_parent\u001b[38;5;241m.\u001b[39m_children[\u001b[38;5;28mself\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m---> 51\u001b[0m \u001b[43mthreading\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mThread\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mstart\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/here/miniconda3/envs/scgen-env/lib/python3.9/threading.py:899\u001b[0m, in \u001b[0;36mThread.start\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m    897\u001b[0m     _limbo[\u001b[38;5;28mself\u001b[39m] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\n\u001b[1;32m    898\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m--> 899\u001b[0m     \u001b[43m_start_new_thread\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_bootstrap\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    900\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m:\n\u001b[1;32m    901\u001b[0m     \u001b[38;5;28;01mwith\u001b[39;00m _active_limbo_lock:\n",
      "\u001b[0;31mRuntimeError\u001b[0m: can't start new thread"
     ]
    }
   ],
   "source": [
    "drug = \"Vem\"\n",
    "X_pre, X_post = prepare_pair_from_mat('WM902B', 'DMSO','24h', drug, '72h')\n",
    "jfe_indices = [1, 6, 0, 5, 4, 7, 8, 2, 3, 19]  \n",
    "\n",
    "print(\"X_pre cells:\", X_pre.shape)\n",
    "print(\"X_post cells:\", X_post.shape)\n",
    "\n",
    "X_tr_pre, X_te_pre, Y_tr_post, Y_te_post = split_train_test(X_pre, X_post, 0.8)\n",
    "\n",
    "print(X_tr_pre.shape)\n",
    "print(X_te_pre.shape)\n",
    "print(Y_tr_post.shape)\n",
    "print(Y_te_post.shape)\n",
    "\n",
    "# Compute median heuristic gamma on training data\n",
    "median_gamma = median_heuristic_gamma(X_tr_pre, Y_tr_post)\n",
    "print(\"Median heuristic gamma:\", median_gamma)\n",
    "\n",
    "\n",
    "all_metrics = []\n",
    "for run in range(10):\n",
    "    print(f\"**************** Run: {run} ****************\")\n",
    "    seed = 1234 + run\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "    out = run_cellot_pair(X_tr_pre[:, jfe_indices], Y_tr_post[:, jfe_indices], X_te_pre[:, jfe_indices], Y_te_post[:, jfe_indices], n_epochs=2000)\n",
    "    metrics = summarize_metrics(out[\"y_pred\"], Y_te_post[:, jfe_indices], median_gamma)\n",
    "    print(f\"Run {run} metrics: {metrics}\")\n",
    "    all_metrics.append(metrics)\n",
    "\n",
    "# Results summary\n",
    "df = pd.DataFrame(all_metrics)\n",
    "print(df.describe().T[['mean', 'std']].round(4))\n",
    "\n",
    "\n",
    "from umap import UMAP\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "source = Y_tr_post[:, jfe_indices]\n",
    "target = Y_te_post[:, jfe_indices]\n",
    "predicted = out.get('y_pred') \n",
    "\n",
    "# Instantiate UMAP\n",
    "umap_model = UMAP(n_components=2, random_state=42)\n",
    "\n",
    "all_sample_umap = umap_model.fit_transform(np.vstack([source, target]))\n",
    "source_umap = umap_model.transform(source)\n",
    "target_umap = umap_model.transform(target)\n",
    "y_pred_umap = umap_model.transform(predicted)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(4, 4))\n",
    "# ax.scatter(source_umap[:, 0], source_umap[:, 1], s=10, alpha=0.7, label='train_post', color='C2')\n",
    "ax.scatter(target_umap[:, 0], target_umap[:, 1], s=10, alpha=0.7, label='observed treated cells', color=\"#C88131\")\n",
    "ax.scatter(y_pred_umap[:, 0], y_pred_umap[:, 1], s=10, alpha=0.7, label='predicted cells', color=\"#1F4D8D\")\n",
    "\n",
    "ax.set_title(f'{drug}')\n",
    "# ax.set_xlabel('UMAP 1')\n",
    "# ax.set_ylabel('UMAP 2')\n",
    "ax.set_aspect('equal', 'box')\n",
    "# Add a legend to distinguish the points\n",
    "ax.legend()\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "# Display the plot\n",
    "plt.savefig(f\"./plots/cellot_on_4i_drug_{drug}.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9747cc38",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "610e4869",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cell line:  SKMEL19\n",
      "['DMSO' 'Vem' 'Vem+Tram']\n",
      "X_pre cells: (2677, 20)\n",
      "X_post cells: (2677, 20)\n",
      "(2141, 20)\n",
      "(536, 20)\n",
      "(2141, 20)\n",
      "(536, 20)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Median heuristic gamma: 0.10487158680337198\n",
      "**************** Run: 0 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=-0.3313 g_loss=-5.6712 | train mmd=0.3145 | test_mmd=0.8853\n",
      "[CellOT] epoch=50 f_loss=-2.3354 g_loss=-1.9843 | train mmd=0.1017 | test_mmd=0.0865\n",
      "[CellOT] epoch=100 f_loss=-0.4245 g_loss=-0.0805 | train mmd=0.0851 | test_mmd=0.0218\n",
      "[CellOT] epoch=150 f_loss=-0.0761 g_loss=0.6526 | train mmd=0.0387 | test_mmd=0.0036\n",
      "[CellOT] epoch=200 f_loss=-0.1050 g_loss=0.9193 | train mmd=0.0165 | test_mmd=0.0034\n",
      "[CellOT] epoch=250 f_loss=-0.5806 g_loss=1.3326 | train mmd=0.0056 | test_mmd=0.0013\n",
      "[CellOT] epoch=300 f_loss=0.3165 g_loss=1.2259 | train mmd=0.0040 | test_mmd=0.0016\n",
      "[CellOT] epoch=350 f_loss=-0.3509 g_loss=1.4420 | train mmd=0.0044 | test_mmd=0.0019\n",
      "[CellOT] epoch=400 f_loss=0.2738 g_loss=1.4090 | train mmd=0.0035 | test_mmd=0.0026\n",
      "[CellOT] epoch=450 f_loss=-0.1944 g_loss=1.7903 | train mmd=0.0042 | test_mmd=0.0023\n",
      "[CellOT] epoch=500 f_loss=0.0196 g_loss=1.7564 | train mmd=0.0026 | test_mmd=0.0008\n",
      "[CellOT] epoch=550 f_loss=0.0408 g_loss=1.5857 | train mmd=0.0026 | test_mmd=0.0023\n",
      "[CellOT] epoch=600 f_loss=0.0691 g_loss=1.3915 | train mmd=0.0021 | test_mmd=0.0009\n",
      "[CellOT] epoch=650 f_loss=0.3175 g_loss=1.6289 | train mmd=0.0044 | test_mmd=0.0033\n",
      "[CellOT] epoch=700 f_loss=-0.3790 g_loss=1.4184 | train mmd=0.0024 | test_mmd=0.0014\n",
      "[CellOT] epoch=750 f_loss=0.7727 g_loss=1.6018 | train mmd=0.0039 | test_mmd=0.0023\n",
      "[CellOT] epoch=800 f_loss=-0.0190 g_loss=1.8223 | train mmd=0.0027 | test_mmd=0.0037\n",
      "[CellOT] epoch=850 f_loss=-0.5901 g_loss=1.9130 | train mmd=0.0031 | test_mmd=0.0009\n",
      "[CellOT] epoch=900 f_loss=0.4496 g_loss=1.4667 | train mmd=0.0021 | test_mmd=0.0026\n",
      "[CellOT] epoch=950 f_loss=-0.2520 g_loss=1.7326 | train mmd=0.0034 | test_mmd=0.0016\n",
      "[CellOT] epoch=1000 f_loss=0.2575 g_loss=1.5896 | train mmd=0.0037 | test_mmd=0.0025\n",
      "[CellOT] epoch=1050 f_loss=0.4684 g_loss=1.8230 | train mmd=0.0018 | test_mmd=0.0013\n",
      "[CellOT] epoch=1100 f_loss=0.4880 g_loss=1.7685 | train mmd=0.0019 | test_mmd=0.0009\n",
      "[CellOT] epoch=1150 f_loss=0.1690 g_loss=1.2272 | train mmd=0.0016 | test_mmd=0.0016\n",
      "[CellOT] epoch=1200 f_loss=0.1828 g_loss=1.8395 | train mmd=0.0029 | test_mmd=0.0028\n",
      "[CellOT] epoch=1250 f_loss=-0.1304 g_loss=1.4751 | train mmd=0.0022 | test_mmd=0.0027\n",
      "[CellOT] epoch=1300 f_loss=0.2721 g_loss=1.5337 | train mmd=0.0021 | test_mmd=0.0013\n",
      "[CellOT] epoch=1350 f_loss=-0.1744 g_loss=1.4405 | train mmd=0.0019 | test_mmd=0.0024\n",
      "[CellOT] epoch=1400 f_loss=-0.1741 g_loss=1.4688 | train mmd=0.0024 | test_mmd=0.0040\n",
      "[CellOT] epoch=1450 f_loss=0.4312 g_loss=1.9335 | train mmd=0.0020 | test_mmd=0.0019\n",
      "[CellOT] epoch=1500 f_loss=0.3726 g_loss=1.6302 | train mmd=0.0020 | test_mmd=0.0018\n",
      "[CellOT] epoch=1550 f_loss=0.3141 g_loss=1.6062 | train mmd=0.0017 | test_mmd=0.0025\n",
      "[CellOT] epoch=1600 f_loss=0.2774 g_loss=1.2188 | train mmd=0.0016 | test_mmd=0.0014\n",
      "[CellOT] epoch=1650 f_loss=-0.0511 g_loss=1.1213 | train mmd=0.0027 | test_mmd=0.0047\n",
      "[CellOT] epoch=1700 f_loss=-0.8142 g_loss=2.2030 | train mmd=0.0025 | test_mmd=0.0016\n",
      "[CellOT] epoch=1750 f_loss=-0.3480 g_loss=1.2494 | train mmd=0.0023 | test_mmd=0.0038\n",
      "[CellOT] epoch=1800 f_loss=0.1297 g_loss=1.6637 | train mmd=0.0019 | test_mmd=0.0015\n",
      "[CellOT] epoch=1850 f_loss=0.2746 g_loss=1.4580 | train mmd=0.0017 | test_mmd=0.0009\n",
      "[CellOT] epoch=1900 f_loss=0.1872 g_loss=1.8351 | train mmd=0.0045 | test_mmd=0.0019\n",
      "[CellOT] epoch=1950 f_loss=0.0326 g_loss=1.2800 | train mmd=0.0022 | test_mmd=0.0010\n",
      "[CellOT] epoch=2000 f_loss=0.4437 g_loss=1.6541 | train mmd=0.0016 | test_mmd=0.0013\n",
      "[CellOT] Final CellOT MMD: 0.0022\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=0.5839 g_loss=-3.8953 | train mmd=0.3514 | test_mmd=0.7713\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 0 metrics: {'mmd2_gamma_median': 0.001320746864281297, 'mmd2_gamma_0.5': 0.004396997576266215, 'mmd2_gamma_1.0': 0.005598257084086622, 'wasserstein_distance': 0.6342140424279016, 'R2_feature_means': 0.998732808192168}\n",
      "**************** Run: 1 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-2.2613 g_loss=-1.6739 | train mmd=0.1287 | test_mmd=0.0949\n",
      "[CellOT] epoch=100 f_loss=-1.4777 g_loss=0.9020 | train mmd=0.2163 | test_mmd=0.0515\n",
      "[CellOT] epoch=150 f_loss=-1.5734 g_loss=1.5894 | train mmd=0.1679 | test_mmd=0.0287\n",
      "[CellOT] epoch=200 f_loss=-1.8450 g_loss=2.6636 | train mmd=0.1404 | test_mmd=0.0163\n",
      "[CellOT] epoch=250 f_loss=-1.9271 g_loss=2.5767 | train mmd=0.1135 | test_mmd=0.0086\n",
      "[CellOT] epoch=300 f_loss=-1.0887 g_loss=3.1216 | train mmd=0.0800 | test_mmd=0.0037\n",
      "[CellOT] epoch=350 f_loss=-0.2178 g_loss=3.1520 | train mmd=0.0360 | test_mmd=0.0014\n",
      "[CellOT] epoch=400 f_loss=-0.0364 g_loss=2.8555 | train mmd=0.0234 | test_mmd=0.0026\n",
      "[CellOT] epoch=450 f_loss=0.1987 g_loss=3.1782 | train mmd=0.0125 | test_mmd=0.0021\n",
      "[CellOT] epoch=500 f_loss=0.3447 g_loss=2.8525 | train mmd=0.0114 | test_mmd=0.0025\n",
      "[CellOT] epoch=550 f_loss=0.4078 g_loss=3.3672 | train mmd=0.0102 | test_mmd=0.0027\n",
      "[CellOT] epoch=600 f_loss=0.6246 g_loss=4.2524 | train mmd=0.0101 | test_mmd=0.0024\n",
      "[CellOT] epoch=650 f_loss=0.1179 g_loss=4.7396 | train mmd=0.0060 | test_mmd=0.0016\n",
      "[CellOT] epoch=700 f_loss=0.0786 g_loss=4.7561 | train mmd=0.0039 | test_mmd=0.0006\n",
      "[CellOT] epoch=750 f_loss=0.4577 g_loss=4.8541 | train mmd=0.0045 | test_mmd=0.0033\n",
      "[CellOT] epoch=800 f_loss=0.0039 g_loss=5.0243 | train mmd=0.0037 | test_mmd=0.0015\n",
      "[CellOT] epoch=850 f_loss=0.4719 g_loss=5.8298 | train mmd=0.0047 | test_mmd=0.0015\n",
      "[CellOT] epoch=900 f_loss=0.3845 g_loss=5.1632 | train mmd=0.0033 | test_mmd=0.0010\n",
      "[CellOT] epoch=950 f_loss=0.0697 g_loss=5.1036 | train mmd=0.0021 | test_mmd=0.0022\n",
      "[CellOT] epoch=1000 f_loss=-0.3038 g_loss=4.4689 | train mmd=0.0034 | test_mmd=0.0009\n",
      "[CellOT] epoch=1050 f_loss=0.4605 g_loss=4.9996 | train mmd=0.0030 | test_mmd=0.0014\n",
      "[CellOT] epoch=1100 f_loss=0.3780 g_loss=5.3758 | train mmd=0.0030 | test_mmd=0.0006\n",
      "[CellOT] epoch=1150 f_loss=-0.4935 g_loss=5.7250 | train mmd=0.0034 | test_mmd=0.0026\n",
      "[CellOT] epoch=1200 f_loss=0.6407 g_loss=5.3296 | train mmd=0.0033 | test_mmd=0.0039\n",
      "[CellOT] epoch=1250 f_loss=0.2522 g_loss=5.4851 | train mmd=0.0028 | test_mmd=0.0021\n",
      "[CellOT] epoch=1300 f_loss=0.2320 g_loss=5.8239 | train mmd=0.0037 | test_mmd=0.0029\n",
      "[CellOT] epoch=1350 f_loss=0.1121 g_loss=5.3877 | train mmd=0.0023 | test_mmd=0.0014\n",
      "[CellOT] epoch=1400 f_loss=-0.2786 g_loss=5.7575 | train mmd=0.0032 | test_mmd=0.0015\n",
      "[CellOT] epoch=1450 f_loss=-0.3177 g_loss=6.0006 | train mmd=0.0032 | test_mmd=0.0008\n",
      "[CellOT] epoch=1500 f_loss=-0.1457 g_loss=5.7370 | train mmd=0.0025 | test_mmd=0.0010\n",
      "[CellOT] epoch=1550 f_loss=-0.3294 g_loss=5.4620 | train mmd=0.0021 | test_mmd=0.0015\n",
      "[CellOT] epoch=1600 f_loss=-0.3515 g_loss=6.0432 | train mmd=0.0020 | test_mmd=0.0013\n",
      "[CellOT] epoch=1650 f_loss=-0.6294 g_loss=5.6184 | train mmd=0.0021 | test_mmd=0.0013\n",
      "[CellOT] epoch=1700 f_loss=-0.3370 g_loss=6.0218 | train mmd=0.0029 | test_mmd=0.0012\n",
      "[CellOT] epoch=1750 f_loss=0.1537 g_loss=5.8404 | train mmd=0.0022 | test_mmd=0.0019\n",
      "[CellOT] epoch=1800 f_loss=-0.5098 g_loss=5.6588 | train mmd=0.0023 | test_mmd=0.0021\n",
      "[CellOT] epoch=1850 f_loss=-0.0356 g_loss=5.5961 | train mmd=0.0025 | test_mmd=0.0019\n",
      "[CellOT] epoch=1900 f_loss=0.3533 g_loss=5.9865 | train mmd=0.0025 | test_mmd=0.0012\n",
      "[CellOT] epoch=1950 f_loss=-0.1402 g_loss=5.8121 | train mmd=0.0037 | test_mmd=0.0028\n",
      "[CellOT] epoch=2000 f_loss=-0.1933 g_loss=6.2931 | train mmd=0.0026 | test_mmd=0.0017\n",
      "[CellOT] Final CellOT MMD: 0.0029\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=-0.9197 g_loss=-4.6062 | train mmd=0.2679 | test_mmd=0.6760\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 1 metrics: {'mmd2_gamma_median': 0.0017450689078086778, 'mmd2_gamma_0.5': 0.005877143402917118, 'mmd2_gamma_1.0': 0.007637182953584953, 'wasserstein_distance': 0.6454637080038895, 'R2_feature_means': 0.9966071671262233}\n",
      "**************** Run: 2 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-1.3475 g_loss=-1.1268 | train mmd=0.1011 | test_mmd=0.0705\n",
      "[CellOT] epoch=100 f_loss=-0.9793 g_loss=-0.3184 | train mmd=0.1457 | test_mmd=0.0285\n",
      "[CellOT] epoch=150 f_loss=-0.8279 g_loss=1.5286 | train mmd=0.1321 | test_mmd=0.0179\n",
      "[CellOT] epoch=200 f_loss=-1.1608 g_loss=1.6650 | train mmd=0.1226 | test_mmd=0.0142\n",
      "[CellOT] epoch=250 f_loss=-1.8775 g_loss=2.3353 | train mmd=0.1064 | test_mmd=0.0084\n",
      "[CellOT] epoch=300 f_loss=-1.6252 g_loss=2.7505 | train mmd=0.0953 | test_mmd=0.0053\n",
      "[CellOT] epoch=350 f_loss=-1.5499 g_loss=3.7497 | train mmd=0.0771 | test_mmd=0.0031\n",
      "[CellOT] epoch=400 f_loss=-1.3819 g_loss=3.7694 | train mmd=0.0603 | test_mmd=0.0024\n",
      "[CellOT] epoch=450 f_loss=0.2433 g_loss=3.9798 | train mmd=0.0471 | test_mmd=0.0017\n",
      "[CellOT] epoch=500 f_loss=0.5864 g_loss=3.9707 | train mmd=0.0370 | test_mmd=0.0012\n",
      "[CellOT] epoch=550 f_loss=0.2424 g_loss=4.6669 | train mmd=0.0263 | test_mmd=0.0009\n",
      "[CellOT] epoch=600 f_loss=-0.1831 g_loss=5.5939 | train mmd=0.0175 | test_mmd=0.0011\n",
      "[CellOT] epoch=650 f_loss=-0.0909 g_loss=6.6424 | train mmd=0.0184 | test_mmd=0.0010\n",
      "[CellOT] epoch=700 f_loss=0.4232 g_loss=7.5161 | train mmd=0.0121 | test_mmd=0.0026\n",
      "[CellOT] epoch=750 f_loss=-0.1227 g_loss=8.1600 | train mmd=0.0112 | test_mmd=0.0014\n",
      "[CellOT] epoch=800 f_loss=0.2063 g_loss=8.3941 | train mmd=0.0100 | test_mmd=0.0033\n",
      "[CellOT] epoch=850 f_loss=0.9353 g_loss=8.3772 | train mmd=0.0053 | test_mmd=0.0020\n",
      "[CellOT] epoch=900 f_loss=0.5135 g_loss=8.1652 | train mmd=0.0037 | test_mmd=0.0016\n",
      "[CellOT] epoch=950 f_loss=0.0881 g_loss=8.5548 | train mmd=0.0047 | test_mmd=0.0017\n",
      "[CellOT] epoch=1000 f_loss=0.1743 g_loss=8.0255 | train mmd=0.0041 | test_mmd=0.0018\n",
      "[CellOT] epoch=1050 f_loss=0.4968 g_loss=8.3776 | train mmd=0.0037 | test_mmd=0.0014\n",
      "[CellOT] epoch=1100 f_loss=-0.1924 g_loss=8.4237 | train mmd=0.0033 | test_mmd=0.0037\n",
      "[CellOT] epoch=1150 f_loss=-0.0597 g_loss=8.0747 | train mmd=0.0055 | test_mmd=0.0013\n",
      "[CellOT] epoch=1200 f_loss=0.2666 g_loss=8.6721 | train mmd=0.0024 | test_mmd=0.0020\n",
      "[CellOT] epoch=1250 f_loss=0.4580 g_loss=8.3094 | train mmd=0.0028 | test_mmd=0.0014\n",
      "[CellOT] epoch=1300 f_loss=-0.0581 g_loss=8.4764 | train mmd=0.0022 | test_mmd=0.0010\n",
      "[CellOT] epoch=1350 f_loss=0.2277 g_loss=8.8970 | train mmd=0.0025 | test_mmd=0.0025\n",
      "[CellOT] epoch=1400 f_loss=-0.2884 g_loss=8.7800 | train mmd=0.0030 | test_mmd=0.0027\n",
      "[CellOT] epoch=1450 f_loss=0.0106 g_loss=9.0369 | train mmd=0.0023 | test_mmd=0.0008\n",
      "[CellOT] epoch=1500 f_loss=0.8974 g_loss=8.8395 | train mmd=0.0031 | test_mmd=0.0025\n",
      "[CellOT] epoch=1550 f_loss=0.3006 g_loss=8.7222 | train mmd=0.0036 | test_mmd=0.0051\n",
      "[CellOT] epoch=1600 f_loss=0.9575 g_loss=8.6817 | train mmd=0.0022 | test_mmd=0.0042\n",
      "[CellOT] epoch=1650 f_loss=-0.6880 g_loss=8.9314 | train mmd=0.0023 | test_mmd=0.0019\n",
      "[CellOT] epoch=1700 f_loss=-0.3796 g_loss=8.7377 | train mmd=0.0021 | test_mmd=0.0028\n",
      "[CellOT] epoch=1750 f_loss=-0.1480 g_loss=8.9829 | train mmd=0.0022 | test_mmd=0.0012\n",
      "[CellOT] epoch=1800 f_loss=0.5353 g_loss=9.2922 | train mmd=0.0029 | test_mmd=0.0014\n",
      "[CellOT] epoch=1850 f_loss=-0.2280 g_loss=9.1017 | train mmd=0.0022 | test_mmd=0.0021\n",
      "[CellOT] epoch=1900 f_loss=-0.2821 g_loss=9.2111 | train mmd=0.0017 | test_mmd=0.0007\n",
      "[CellOT] epoch=1950 f_loss=-0.4370 g_loss=9.2070 | train mmd=0.0018 | test_mmd=0.0012\n",
      "[CellOT] epoch=2000 f_loss=-0.1911 g_loss=9.0361 | train mmd=0.0025 | test_mmd=0.0010\n",
      "[CellOT] Final CellOT MMD: 0.0022\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=0.9513 g_loss=-4.0548 | train mmd=0.3376 | test_mmd=0.8150\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 2 metrics: {'mmd2_gamma_median': 0.0010143183094528663, 'mmd2_gamma_0.5': 0.0042329371646173675, 'mmd2_gamma_1.0': 0.005819936692502137, 'wasserstein_distance': 0.6343809590579009, 'R2_feature_means': 0.9988270674273856}\n",
      "**************** Run: 3 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-1.0186 g_loss=-1.2537 | train mmd=0.1177 | test_mmd=0.0832\n",
      "[CellOT] epoch=100 f_loss=-0.5646 g_loss=-0.1842 | train mmd=0.1451 | test_mmd=0.0344\n",
      "[CellOT] epoch=150 f_loss=-1.0791 g_loss=0.9057 | train mmd=0.1409 | test_mmd=0.0223\n",
      "[CellOT] epoch=200 f_loss=-0.7902 g_loss=1.4919 | train mmd=0.1055 | test_mmd=0.0108\n",
      "[CellOT] epoch=250 f_loss=-1.0328 g_loss=2.4114 | train mmd=0.0868 | test_mmd=0.0068\n",
      "[CellOT] epoch=300 f_loss=-0.7879 g_loss=2.2073 | train mmd=0.0706 | test_mmd=0.0039\n",
      "[CellOT] epoch=350 f_loss=-1.6230 g_loss=2.8117 | train mmd=0.0524 | test_mmd=0.0020\n",
      "[CellOT] epoch=400 f_loss=-0.8263 g_loss=2.6666 | train mmd=0.0356 | test_mmd=0.0010\n",
      "[CellOT] epoch=450 f_loss=-0.4343 g_loss=2.8543 | train mmd=0.0345 | test_mmd=0.0015\n",
      "[CellOT] epoch=500 f_loss=0.4362 g_loss=3.4379 | train mmd=0.0190 | test_mmd=0.0013\n",
      "[CellOT] epoch=550 f_loss=-0.1842 g_loss=4.2690 | train mmd=0.0113 | test_mmd=0.0013\n",
      "[CellOT] epoch=600 f_loss=0.0488 g_loss=3.7984 | train mmd=0.0066 | test_mmd=0.0012\n",
      "[CellOT] epoch=650 f_loss=-0.2063 g_loss=4.8561 | train mmd=0.0059 | test_mmd=0.0019\n",
      "[CellOT] epoch=700 f_loss=0.3990 g_loss=5.3096 | train mmd=0.0043 | test_mmd=0.0007\n",
      "[CellOT] epoch=750 f_loss=-0.0954 g_loss=5.3113 | train mmd=0.0050 | test_mmd=0.0025\n",
      "[CellOT] epoch=800 f_loss=0.5879 g_loss=5.4385 | train mmd=0.0060 | test_mmd=0.0015\n",
      "[CellOT] epoch=850 f_loss=0.0998 g_loss=5.7848 | train mmd=0.0044 | test_mmd=0.0018\n",
      "[CellOT] epoch=900 f_loss=-0.3172 g_loss=5.9117 | train mmd=0.0027 | test_mmd=0.0014\n",
      "[CellOT] epoch=950 f_loss=-0.1850 g_loss=6.1317 | train mmd=0.0024 | test_mmd=0.0012\n",
      "[CellOT] epoch=1000 f_loss=-0.1169 g_loss=6.2319 | train mmd=0.0031 | test_mmd=0.0019\n",
      "[CellOT] epoch=1050 f_loss=0.1911 g_loss=6.1281 | train mmd=0.0028 | test_mmd=0.0013\n",
      "[CellOT] epoch=1100 f_loss=-0.1949 g_loss=6.2184 | train mmd=0.0032 | test_mmd=0.0019\n",
      "[CellOT] epoch=1150 f_loss=0.4027 g_loss=6.1778 | train mmd=0.0020 | test_mmd=0.0019\n",
      "[CellOT] epoch=1200 f_loss=0.2168 g_loss=6.4473 | train mmd=0.0036 | test_mmd=0.0012\n",
      "[CellOT] epoch=1250 f_loss=-0.2756 g_loss=6.7361 | train mmd=0.0029 | test_mmd=0.0013\n",
      "[CellOT] epoch=1300 f_loss=1.1837 g_loss=6.6569 | train mmd=0.0029 | test_mmd=0.0041\n",
      "[CellOT] epoch=1350 f_loss=0.5410 g_loss=6.6853 | train mmd=0.0021 | test_mmd=0.0015\n",
      "[CellOT] epoch=1400 f_loss=0.0001 g_loss=6.7543 | train mmd=0.0020 | test_mmd=0.0018\n",
      "[CellOT] epoch=1450 f_loss=-0.0853 g_loss=7.1953 | train mmd=0.0020 | test_mmd=0.0027\n",
      "[CellOT] epoch=1500 f_loss=-0.5435 g_loss=6.8583 | train mmd=0.0035 | test_mmd=0.0019\n",
      "[CellOT] epoch=1550 f_loss=0.0656 g_loss=6.9332 | train mmd=0.0033 | test_mmd=0.0012\n",
      "[CellOT] epoch=1600 f_loss=0.0634 g_loss=6.9642 | train mmd=0.0017 | test_mmd=0.0020\n",
      "[CellOT] epoch=1650 f_loss=-0.3210 g_loss=7.6007 | train mmd=0.0046 | test_mmd=0.0037\n",
      "[CellOT] epoch=1700 f_loss=0.5408 g_loss=6.9457 | train mmd=0.0024 | test_mmd=0.0036\n",
      "[CellOT] epoch=1750 f_loss=-0.0340 g_loss=7.7613 | train mmd=0.0019 | test_mmd=0.0012\n",
      "[CellOT] epoch=1800 f_loss=-0.5273 g_loss=7.4304 | train mmd=0.0028 | test_mmd=0.0032\n",
      "[CellOT] epoch=1850 f_loss=-0.5682 g_loss=7.4744 | train mmd=0.0029 | test_mmd=0.0016\n",
      "[CellOT] epoch=1900 f_loss=0.1405 g_loss=7.7961 | train mmd=0.0036 | test_mmd=0.0018\n",
      "[CellOT] epoch=1950 f_loss=-0.3144 g_loss=7.2938 | train mmd=0.0028 | test_mmd=0.0032\n",
      "[CellOT] epoch=2000 f_loss=-0.1489 g_loss=7.6906 | train mmd=0.0018 | test_mmd=0.0017\n",
      "[CellOT] Final CellOT MMD: 0.0027\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=1.3431 g_loss=-7.4367 | train mmd=0.2573 | test_mmd=0.9504\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 3 metrics: {'mmd2_gamma_median': 0.0016791060081855491, 'mmd2_gamma_0.5': 0.005862947959947018, 'mmd2_gamma_1.0': 0.007088786873825803, 'wasserstein_distance': 0.645679458192905, 'R2_feature_means': 0.9986158645128312}\n",
      "**************** Run: 4 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-2.3033 g_loss=-1.3413 | train mmd=0.1523 | test_mmd=0.1448\n",
      "[CellOT] epoch=100 f_loss=-0.6692 g_loss=-0.0095 | train mmd=0.1991 | test_mmd=0.0462\n",
      "[CellOT] epoch=150 f_loss=-1.3915 g_loss=1.2715 | train mmd=0.1388 | test_mmd=0.0251\n",
      "[CellOT] epoch=200 f_loss=-2.3117 g_loss=1.3629 | train mmd=0.1164 | test_mmd=0.0134\n",
      "[CellOT] epoch=250 f_loss=-0.6531 g_loss=2.3214 | train mmd=0.0968 | test_mmd=0.0074\n",
      "[CellOT] epoch=300 f_loss=-0.9556 g_loss=2.5955 | train mmd=0.0945 | test_mmd=0.0063\n",
      "[CellOT] epoch=350 f_loss=-2.6603 g_loss=3.0563 | train mmd=0.0664 | test_mmd=0.0028\n",
      "[CellOT] epoch=400 f_loss=-1.1025 g_loss=3.8551 | train mmd=0.0496 | test_mmd=0.0018\n",
      "[CellOT] epoch=450 f_loss=-0.6014 g_loss=3.7548 | train mmd=0.0546 | test_mmd=0.0015\n",
      "[CellOT] epoch=500 f_loss=0.1192 g_loss=4.2094 | train mmd=0.0321 | test_mmd=0.0031\n",
      "[CellOT] epoch=550 f_loss=-0.0386 g_loss=4.4543 | train mmd=0.0224 | test_mmd=0.0035\n",
      "[CellOT] epoch=600 f_loss=-0.0672 g_loss=6.1349 | train mmd=0.0175 | test_mmd=0.0012\n",
      "[CellOT] epoch=650 f_loss=0.2765 g_loss=7.1872 | train mmd=0.0177 | test_mmd=0.0011\n",
      "[CellOT] epoch=700 f_loss=0.3842 g_loss=8.0678 | train mmd=0.0133 | test_mmd=0.0011\n",
      "[CellOT] epoch=750 f_loss=-0.8846 g_loss=9.3600 | train mmd=0.0111 | test_mmd=0.0011\n",
      "[CellOT] epoch=800 f_loss=0.9602 g_loss=9.7202 | train mmd=0.0111 | test_mmd=0.0010\n",
      "[CellOT] epoch=850 f_loss=0.2473 g_loss=10.6383 | train mmd=0.0080 | test_mmd=0.0012\n",
      "[CellOT] epoch=900 f_loss=0.0044 g_loss=10.5210 | train mmd=0.0056 | test_mmd=0.0023\n",
      "[CellOT] epoch=950 f_loss=0.0604 g_loss=10.7028 | train mmd=0.0052 | test_mmd=0.0010\n",
      "[CellOT] epoch=1000 f_loss=-0.4486 g_loss=10.3847 | train mmd=0.0036 | test_mmd=0.0021\n",
      "[CellOT] epoch=1050 f_loss=0.5076 g_loss=10.8886 | train mmd=0.0045 | test_mmd=0.0034\n",
      "[CellOT] epoch=1100 f_loss=-0.1156 g_loss=10.6587 | train mmd=0.0052 | test_mmd=0.0038\n",
      "[CellOT] epoch=1150 f_loss=-0.6430 g_loss=10.0169 | train mmd=0.0059 | test_mmd=0.0022\n",
      "[CellOT] epoch=1200 f_loss=-0.2638 g_loss=10.7353 | train mmd=0.0030 | test_mmd=0.0017\n",
      "[CellOT] epoch=1250 f_loss=0.0698 g_loss=10.6261 | train mmd=0.0034 | test_mmd=0.0012\n",
      "[CellOT] epoch=1300 f_loss=0.4430 g_loss=10.0592 | train mmd=0.0023 | test_mmd=0.0030\n",
      "[CellOT] epoch=1350 f_loss=-0.2150 g_loss=10.4411 | train mmd=0.0036 | test_mmd=0.0018\n",
      "[CellOT] epoch=1400 f_loss=0.4162 g_loss=10.4325 | train mmd=0.0030 | test_mmd=0.0019\n",
      "[CellOT] epoch=1450 f_loss=0.4985 g_loss=10.5002 | train mmd=0.0020 | test_mmd=0.0013\n",
      "[CellOT] epoch=1500 f_loss=0.5263 g_loss=10.2876 | train mmd=0.0019 | test_mmd=0.0012\n",
      "[CellOT] epoch=1550 f_loss=-0.0788 g_loss=10.5882 | train mmd=0.0035 | test_mmd=0.0024\n",
      "[CellOT] epoch=1600 f_loss=0.2466 g_loss=10.6939 | train mmd=0.0023 | test_mmd=0.0016\n",
      "[CellOT] epoch=1650 f_loss=-0.4731 g_loss=10.8956 | train mmd=0.0026 | test_mmd=0.0012\n",
      "[CellOT] epoch=1700 f_loss=-0.2580 g_loss=11.0652 | train mmd=0.0038 | test_mmd=0.0021\n",
      "[CellOT] epoch=1750 f_loss=0.2536 g_loss=10.5187 | train mmd=0.0025 | test_mmd=0.0041\n",
      "[CellOT] epoch=1800 f_loss=-0.1257 g_loss=10.8645 | train mmd=0.0031 | test_mmd=0.0015\n",
      "[CellOT] epoch=1850 f_loss=0.5506 g_loss=11.2908 | train mmd=0.0017 | test_mmd=0.0019\n",
      "[CellOT] epoch=1900 f_loss=0.0399 g_loss=11.0313 | train mmd=0.0041 | test_mmd=0.0027\n",
      "[CellOT] epoch=1950 f_loss=-0.8405 g_loss=10.8722 | train mmd=0.0023 | test_mmd=0.0017\n",
      "[CellOT] epoch=2000 f_loss=0.1911 g_loss=10.9006 | train mmd=0.0016 | test_mmd=0.0017\n",
      "[CellOT] Final CellOT MMD: 0.0028\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=1.5583 g_loss=-5.3132 | train mmd=0.3488 | test_mmd=0.8863\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 4 metrics: {'mmd2_gamma_median': 0.0016790039060707862, 'mmd2_gamma_0.5': 0.005877711035723543, 'mmd2_gamma_1.0': 0.0074182325506369495, 'wasserstein_distance': 0.6398828653533303, 'R2_feature_means': 0.9984428458104981}\n",
      "**************** Run: 5 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-2.1664 g_loss=-1.0712 | train mmd=0.1302 | test_mmd=0.1191\n",
      "[CellOT] epoch=100 f_loss=-0.9585 g_loss=0.4238 | train mmd=0.2034 | test_mmd=0.0426\n",
      "[CellOT] epoch=150 f_loss=-2.3740 g_loss=0.9971 | train mmd=0.1591 | test_mmd=0.0274\n",
      "[CellOT] epoch=200 f_loss=-1.1765 g_loss=1.8407 | train mmd=0.1451 | test_mmd=0.0159\n",
      "[CellOT] epoch=250 f_loss=-1.3751 g_loss=2.8120 | train mmd=0.1133 | test_mmd=0.0076\n",
      "[CellOT] epoch=300 f_loss=-1.0039 g_loss=3.5649 | train mmd=0.0724 | test_mmd=0.0034\n",
      "[CellOT] epoch=350 f_loss=-1.9434 g_loss=4.4924 | train mmd=0.0791 | test_mmd=0.0025\n",
      "[CellOT] epoch=400 f_loss=-0.6618 g_loss=3.5953 | train mmd=0.0501 | test_mmd=0.0014\n",
      "[CellOT] epoch=450 f_loss=-1.0454 g_loss=3.8402 | train mmd=0.0469 | test_mmd=0.0016\n",
      "[CellOT] epoch=500 f_loss=-0.7820 g_loss=4.0290 | train mmd=0.0373 | test_mmd=0.0012\n",
      "[CellOT] epoch=550 f_loss=0.4348 g_loss=4.9206 | train mmd=0.0302 | test_mmd=0.0018\n",
      "[CellOT] epoch=600 f_loss=-0.3236 g_loss=6.0151 | train mmd=0.0283 | test_mmd=0.0018\n",
      "[CellOT] epoch=650 f_loss=-0.3990 g_loss=6.8968 | train mmd=0.0165 | test_mmd=0.0021\n",
      "[CellOT] epoch=700 f_loss=0.1976 g_loss=7.1805 | train mmd=0.0167 | test_mmd=0.0010\n",
      "[CellOT] epoch=750 f_loss=0.4437 g_loss=8.4766 | train mmd=0.0190 | test_mmd=0.0012\n",
      "[CellOT] epoch=800 f_loss=-0.1672 g_loss=8.9733 | train mmd=0.0100 | test_mmd=0.0015\n",
      "[CellOT] epoch=850 f_loss=-0.1245 g_loss=9.5900 | train mmd=0.0089 | test_mmd=0.0022\n",
      "[CellOT] epoch=900 f_loss=0.2800 g_loss=9.7815 | train mmd=0.0047 | test_mmd=0.0012\n",
      "[CellOT] epoch=950 f_loss=0.1049 g_loss=9.3651 | train mmd=0.0040 | test_mmd=0.0029\n",
      "[CellOT] epoch=1000 f_loss=-0.0050 g_loss=9.1258 | train mmd=0.0054 | test_mmd=0.0011\n",
      "[CellOT] epoch=1050 f_loss=-0.0899 g_loss=9.6161 | train mmd=0.0035 | test_mmd=0.0014\n",
      "[CellOT] epoch=1100 f_loss=-0.4423 g_loss=9.3071 | train mmd=0.0024 | test_mmd=0.0016\n",
      "[CellOT] epoch=1150 f_loss=0.0523 g_loss=9.6079 | train mmd=0.0043 | test_mmd=0.0030\n",
      "[CellOT] epoch=1200 f_loss=-0.2578 g_loss=9.2556 | train mmd=0.0028 | test_mmd=0.0031\n",
      "[CellOT] epoch=1250 f_loss=-0.2903 g_loss=9.6326 | train mmd=0.0038 | test_mmd=0.0012\n",
      "[CellOT] epoch=1300 f_loss=0.3105 g_loss=9.6987 | train mmd=0.0038 | test_mmd=0.0031\n",
      "[CellOT] epoch=1350 f_loss=-0.4913 g_loss=10.0823 | train mmd=0.0029 | test_mmd=0.0018\n",
      "[CellOT] epoch=1400 f_loss=0.0329 g_loss=9.9311 | train mmd=0.0022 | test_mmd=0.0014\n",
      "[CellOT] epoch=1450 f_loss=0.2201 g_loss=9.8860 | train mmd=0.0041 | test_mmd=0.0027\n",
      "[CellOT] epoch=1500 f_loss=0.4038 g_loss=9.9033 | train mmd=0.0033 | test_mmd=0.0032\n",
      "[CellOT] epoch=1550 f_loss=0.4051 g_loss=10.3938 | train mmd=0.0024 | test_mmd=0.0024\n",
      "[CellOT] epoch=1600 f_loss=0.5328 g_loss=10.2179 | train mmd=0.0024 | test_mmd=0.0013\n",
      "[CellOT] epoch=1650 f_loss=-0.2282 g_loss=10.3636 | train mmd=0.0031 | test_mmd=0.0022\n",
      "[CellOT] epoch=1700 f_loss=-0.5176 g_loss=10.6254 | train mmd=0.0042 | test_mmd=0.0042\n",
      "[CellOT] epoch=1750 f_loss=-0.2670 g_loss=10.4101 | train mmd=0.0034 | test_mmd=0.0026\n",
      "[CellOT] epoch=1800 f_loss=0.6323 g_loss=10.6584 | train mmd=0.0024 | test_mmd=0.0014\n",
      "[CellOT] epoch=1850 f_loss=-0.0556 g_loss=10.5433 | train mmd=0.0032 | test_mmd=0.0019\n",
      "[CellOT] epoch=1900 f_loss=-0.6025 g_loss=10.3944 | train mmd=0.0029 | test_mmd=0.0039\n",
      "[CellOT] epoch=1950 f_loss=-0.0518 g_loss=10.9607 | train mmd=0.0029 | test_mmd=0.0027\n",
      "[CellOT] epoch=2000 f_loss=-0.0616 g_loss=10.9607 | train mmd=0.0032 | test_mmd=0.0021\n",
      "[CellOT] Final CellOT MMD: 0.0034\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=0.3858 g_loss=-4.1507 | train mmd=0.2776 | test_mmd=0.8577\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 5 metrics: {'mmd2_gamma_median': 0.0021052109270522923, 'mmd2_gamma_0.5': 0.007383934631729816, 'mmd2_gamma_1.0': 0.009410351842466635, 'wasserstein_distance': 0.6354149952682455, 'R2_feature_means': 0.9973374831951498}\n",
      "**************** Run: 6 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-2.0629 g_loss=-1.0556 | train mmd=0.1254 | test_mmd=0.0850\n",
      "[CellOT] epoch=100 f_loss=-1.1823 g_loss=-0.0068 | train mmd=0.1115 | test_mmd=0.0272\n",
      "[CellOT] epoch=150 f_loss=-0.9333 g_loss=0.7456 | train mmd=0.1220 | test_mmd=0.0162\n",
      "[CellOT] epoch=200 f_loss=-1.5165 g_loss=1.8771 | train mmd=0.0786 | test_mmd=0.0077\n",
      "[CellOT] epoch=250 f_loss=-0.4563 g_loss=2.2651 | train mmd=0.0643 | test_mmd=0.0040\n",
      "[CellOT] epoch=300 f_loss=-0.6762 g_loss=2.6582 | train mmd=0.0301 | test_mmd=0.0015\n",
      "[CellOT] epoch=350 f_loss=-0.1127 g_loss=2.4653 | train mmd=0.0303 | test_mmd=0.0009\n",
      "[CellOT] epoch=400 f_loss=0.7116 g_loss=2.5002 | train mmd=0.0153 | test_mmd=0.0012\n",
      "[CellOT] epoch=450 f_loss=-0.3335 g_loss=2.9122 | train mmd=0.0082 | test_mmd=0.0014\n",
      "[CellOT] epoch=500 f_loss=0.1166 g_loss=3.3280 | train mmd=0.0069 | test_mmd=0.0020\n",
      "[CellOT] epoch=550 f_loss=0.5857 g_loss=3.8643 | train mmd=0.0066 | test_mmd=0.0024\n",
      "[CellOT] epoch=600 f_loss=-0.0234 g_loss=4.1522 | train mmd=0.0049 | test_mmd=0.0024\n",
      "[CellOT] epoch=650 f_loss=-0.3382 g_loss=4.5439 | train mmd=0.0034 | test_mmd=0.0015\n",
      "[CellOT] epoch=700 f_loss=-0.2202 g_loss=4.7934 | train mmd=0.0045 | test_mmd=0.0019\n",
      "[CellOT] epoch=750 f_loss=-0.2189 g_loss=4.6812 | train mmd=0.0042 | test_mmd=0.0010\n",
      "[CellOT] epoch=800 f_loss=0.8008 g_loss=5.0219 | train mmd=0.0033 | test_mmd=0.0025\n",
      "[CellOT] epoch=850 f_loss=0.0436 g_loss=4.5914 | train mmd=0.0029 | test_mmd=0.0012\n",
      "[CellOT] epoch=900 f_loss=0.1365 g_loss=4.5812 | train mmd=0.0028 | test_mmd=0.0012\n",
      "[CellOT] epoch=950 f_loss=0.0734 g_loss=4.9418 | train mmd=0.0035 | test_mmd=0.0018\n",
      "[CellOT] epoch=1000 f_loss=-0.5259 g_loss=5.6201 | train mmd=0.0041 | test_mmd=0.0018\n",
      "[CellOT] epoch=1050 f_loss=0.3898 g_loss=5.3500 | train mmd=0.0035 | test_mmd=0.0014\n",
      "[CellOT] epoch=1100 f_loss=-0.0861 g_loss=5.7124 | train mmd=0.0040 | test_mmd=0.0017\n",
      "[CellOT] epoch=1150 f_loss=0.3465 g_loss=5.3646 | train mmd=0.0047 | test_mmd=0.0047\n",
      "[CellOT] epoch=1200 f_loss=-0.2598 g_loss=5.7559 | train mmd=0.0035 | test_mmd=0.0048\n",
      "[CellOT] epoch=1250 f_loss=-0.3939 g_loss=5.8471 | train mmd=0.0024 | test_mmd=0.0015\n",
      "[CellOT] epoch=1300 f_loss=-0.2387 g_loss=5.7068 | train mmd=0.0023 | test_mmd=0.0018\n",
      "[CellOT] epoch=1350 f_loss=0.0311 g_loss=5.8987 | train mmd=0.0044 | test_mmd=0.0026\n",
      "[CellOT] epoch=1400 f_loss=-0.1420 g_loss=6.5225 | train mmd=0.0026 | test_mmd=0.0021\n",
      "[CellOT] epoch=1450 f_loss=0.2794 g_loss=6.5148 | train mmd=0.0036 | test_mmd=0.0045\n",
      "[CellOT] epoch=1500 f_loss=-0.3919 g_loss=6.1803 | train mmd=0.0023 | test_mmd=0.0020\n",
      "[CellOT] epoch=1550 f_loss=-0.8711 g_loss=6.4374 | train mmd=0.0018 | test_mmd=0.0009\n",
      "[CellOT] epoch=1600 f_loss=-0.2740 g_loss=5.8757 | train mmd=0.0024 | test_mmd=0.0021\n",
      "[CellOT] epoch=1650 f_loss=-0.2136 g_loss=6.0114 | train mmd=0.0040 | test_mmd=0.0013\n",
      "[CellOT] epoch=1700 f_loss=0.4968 g_loss=5.7045 | train mmd=0.0037 | test_mmd=0.0012\n",
      "[CellOT] epoch=1750 f_loss=0.0208 g_loss=6.4701 | train mmd=0.0023 | test_mmd=0.0022\n",
      "[CellOT] epoch=1800 f_loss=-0.1275 g_loss=6.6187 | train mmd=0.0034 | test_mmd=0.0012\n",
      "[CellOT] epoch=1850 f_loss=-0.0887 g_loss=6.6738 | train mmd=0.0023 | test_mmd=0.0015\n",
      "[CellOT] epoch=1900 f_loss=-0.6047 g_loss=6.6443 | train mmd=0.0019 | test_mmd=0.0015\n",
      "[CellOT] epoch=1950 f_loss=-0.2076 g_loss=6.6430 | train mmd=0.0028 | test_mmd=0.0007\n",
      "[CellOT] epoch=2000 f_loss=0.5489 g_loss=6.4492 | train mmd=0.0026 | test_mmd=0.0023\n",
      "[CellOT] Final CellOT MMD: 0.0032\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=0.4383 g_loss=-3.0190 | train mmd=0.3355 | test_mmd=0.7172\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 6 metrics: {'mmd2_gamma_median': 0.0022623411285300765, 'mmd2_gamma_0.5': 0.006969220549900124, 'mmd2_gamma_1.0': 0.008522749495494975, 'wasserstein_distance': 0.65705446195221, 'R2_feature_means': 0.9976724488732909}\n",
      "**************** Run: 7 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-1.9487 g_loss=-1.1556 | train mmd=0.1133 | test_mmd=0.0854\n",
      "[CellOT] epoch=100 f_loss=-1.2310 g_loss=-0.4217 | train mmd=0.0867 | test_mmd=0.0218\n",
      "[CellOT] epoch=150 f_loss=-0.5443 g_loss=1.0482 | train mmd=0.0659 | test_mmd=0.0087\n",
      "[CellOT] epoch=200 f_loss=-1.1897 g_loss=1.6315 | train mmd=0.0773 | test_mmd=0.0080\n",
      "[CellOT] epoch=250 f_loss=-0.9949 g_loss=2.2535 | train mmd=0.0912 | test_mmd=0.0065\n",
      "[CellOT] epoch=300 f_loss=-1.1520 g_loss=2.6643 | train mmd=0.0661 | test_mmd=0.0035\n",
      "[CellOT] epoch=350 f_loss=-0.9265 g_loss=2.8519 | train mmd=0.0542 | test_mmd=0.0017\n",
      "[CellOT] epoch=400 f_loss=0.0343 g_loss=3.0998 | train mmd=0.0336 | test_mmd=0.0013\n",
      "[CellOT] epoch=450 f_loss=-0.1945 g_loss=3.4727 | train mmd=0.0212 | test_mmd=0.0016\n",
      "[CellOT] epoch=500 f_loss=-0.7593 g_loss=3.2378 | train mmd=0.0134 | test_mmd=0.0011\n",
      "[CellOT] epoch=550 f_loss=-0.2398 g_loss=4.1266 | train mmd=0.0153 | test_mmd=0.0019\n",
      "[CellOT] epoch=600 f_loss=-0.1376 g_loss=4.8263 | train mmd=0.0105 | test_mmd=0.0021\n",
      "[CellOT] epoch=650 f_loss=0.4604 g_loss=4.7574 | train mmd=0.0078 | test_mmd=0.0013\n",
      "[CellOT] epoch=700 f_loss=-0.1995 g_loss=5.2774 | train mmd=0.0071 | test_mmd=0.0022\n",
      "[CellOT] epoch=750 f_loss=-0.5050 g_loss=5.3073 | train mmd=0.0035 | test_mmd=0.0009\n",
      "[CellOT] epoch=800 f_loss=0.1875 g_loss=5.6790 | train mmd=0.0038 | test_mmd=0.0026\n",
      "[CellOT] epoch=850 f_loss=-0.1172 g_loss=5.4245 | train mmd=0.0036 | test_mmd=0.0011\n",
      "[CellOT] epoch=900 f_loss=-0.1211 g_loss=5.7061 | train mmd=0.0032 | test_mmd=0.0013\n",
      "[CellOT] epoch=950 f_loss=0.8872 g_loss=5.6730 | train mmd=0.0024 | test_mmd=0.0017\n",
      "[CellOT] epoch=1000 f_loss=0.0789 g_loss=5.5162 | train mmd=0.0031 | test_mmd=0.0015\n",
      "[CellOT] epoch=1050 f_loss=0.4939 g_loss=5.8363 | train mmd=0.0032 | test_mmd=0.0015\n",
      "[CellOT] epoch=1100 f_loss=0.2941 g_loss=6.0058 | train mmd=0.0026 | test_mmd=0.0019\n",
      "[CellOT] epoch=1150 f_loss=-0.6564 g_loss=6.1066 | train mmd=0.0015 | test_mmd=0.0013\n",
      "[CellOT] epoch=1200 f_loss=0.7691 g_loss=6.9265 | train mmd=0.0023 | test_mmd=0.0013\n",
      "[CellOT] epoch=1250 f_loss=-0.7177 g_loss=6.6795 | train mmd=0.0027 | test_mmd=0.0008\n",
      "[CellOT] epoch=1300 f_loss=1.0991 g_loss=6.9267 | train mmd=0.0022 | test_mmd=0.0013\n",
      "[CellOT] epoch=1350 f_loss=0.4293 g_loss=6.3792 | train mmd=0.0020 | test_mmd=0.0016\n",
      "[CellOT] epoch=1400 f_loss=-0.3751 g_loss=6.4398 | train mmd=0.0020 | test_mmd=0.0020\n",
      "[CellOT] epoch=1450 f_loss=-0.0955 g_loss=6.4595 | train mmd=0.0023 | test_mmd=0.0026\n",
      "[CellOT] epoch=1500 f_loss=-0.0005 g_loss=7.1522 | train mmd=0.0025 | test_mmd=0.0015\n",
      "[CellOT] epoch=1550 f_loss=-0.0694 g_loss=6.9636 | train mmd=0.0030 | test_mmd=0.0016\n",
      "[CellOT] epoch=1600 f_loss=-0.2475 g_loss=6.6369 | train mmd=0.0024 | test_mmd=0.0016\n",
      "[CellOT] epoch=1650 f_loss=0.3394 g_loss=6.3624 | train mmd=0.0023 | test_mmd=0.0009\n",
      "[CellOT] epoch=1700 f_loss=-0.2502 g_loss=6.3199 | train mmd=0.0045 | test_mmd=0.0032\n",
      "[CellOT] epoch=1750 f_loss=-0.4575 g_loss=6.6014 | train mmd=0.0023 | test_mmd=0.0016\n",
      "[CellOT] epoch=1800 f_loss=0.2923 g_loss=6.7711 | train mmd=0.0024 | test_mmd=0.0015\n",
      "[CellOT] epoch=1850 f_loss=0.0269 g_loss=6.6408 | train mmd=0.0028 | test_mmd=0.0031\n",
      "[CellOT] epoch=1900 f_loss=0.6635 g_loss=6.6432 | train mmd=0.0023 | test_mmd=0.0039\n",
      "[CellOT] epoch=1950 f_loss=0.2889 g_loss=6.8464 | train mmd=0.0033 | test_mmd=0.0037\n",
      "[CellOT] epoch=2000 f_loss=-0.3819 g_loss=6.9992 | train mmd=0.0024 | test_mmd=0.0011\n",
      "[CellOT] Final CellOT MMD: 0.0025\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n",
      "[CellOT] epoch=0 f_loss=0.5500 g_loss=-4.0519 | train mmd=0.3488 | test_mmd=0.9153\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 7 metrics: {'mmd2_gamma_median': 0.0010820052109572487, 'mmd2_gamma_0.5': 0.004844942436565636, 'mmd2_gamma_1.0': 0.006785768545482962, 'wasserstein_distance': 0.6390393432556715, 'R2_feature_means': 0.9979419462615559}\n",
      "**************** Run: 8 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=50 f_loss=-1.4129 g_loss=-1.3289 | train mmd=0.1173 | test_mmd=0.0829\n",
      "[CellOT] epoch=100 f_loss=-1.3042 g_loss=0.0087 | train mmd=0.2079 | test_mmd=0.0446\n",
      "[CellOT] epoch=150 f_loss=-0.9382 g_loss=1.4222 | train mmd=0.1638 | test_mmd=0.0305\n",
      "[CellOT] epoch=200 f_loss=-1.4448 g_loss=2.6212 | train mmd=0.1441 | test_mmd=0.0195\n",
      "[CellOT] epoch=250 f_loss=-1.7412 g_loss=3.0482 | train mmd=0.1152 | test_mmd=0.0105\n",
      "[CellOT] epoch=300 f_loss=-0.9326 g_loss=2.7844 | train mmd=0.0891 | test_mmd=0.0058\n",
      "[CellOT] epoch=350 f_loss=-1.7178 g_loss=3.8652 | train mmd=0.0702 | test_mmd=0.0039\n",
      "[CellOT] epoch=400 f_loss=-1.3702 g_loss=4.3389 | train mmd=0.0770 | test_mmd=0.0027\n",
      "[CellOT] epoch=450 f_loss=0.5533 g_loss=4.8365 | train mmd=0.0829 | test_mmd=0.0016\n",
      "[CellOT] epoch=500 f_loss=-0.2493 g_loss=4.2586 | train mmd=0.0578 | test_mmd=0.0012\n",
      "[CellOT] epoch=550 f_loss=0.5643 g_loss=5.2537 | train mmd=0.0383 | test_mmd=0.0019\n",
      "[CellOT] epoch=600 f_loss=0.6415 g_loss=7.1521 | train mmd=0.0265 | test_mmd=0.0026\n",
      "[CellOT] epoch=650 f_loss=1.0207 g_loss=8.2924 | train mmd=0.0219 | test_mmd=0.0015\n",
      "[CellOT] epoch=700 f_loss=0.2825 g_loss=8.7733 | train mmd=0.0145 | test_mmd=0.0011\n",
      "[CellOT] epoch=750 f_loss=0.3902 g_loss=9.6852 | train mmd=0.0147 | test_mmd=0.0011\n",
      "[CellOT] epoch=800 f_loss=-0.2384 g_loss=9.6408 | train mmd=0.0111 | test_mmd=0.0023\n",
      "[CellOT] epoch=850 f_loss=0.7634 g_loss=9.7353 | train mmd=0.0076 | test_mmd=0.0040\n",
      "[CellOT] epoch=900 f_loss=0.5439 g_loss=9.1790 | train mmd=0.0046 | test_mmd=0.0032\n",
      "[CellOT] epoch=950 f_loss=-0.0460 g_loss=10.2231 | train mmd=0.0054 | test_mmd=0.0015\n",
      "[CellOT] epoch=1000 f_loss=0.8075 g_loss=10.2502 | train mmd=0.0059 | test_mmd=0.0027\n",
      "[CellOT] epoch=1050 f_loss=-0.8318 g_loss=10.2246 | train mmd=0.0028 | test_mmd=0.0015\n",
      "[CellOT] epoch=1100 f_loss=0.2824 g_loss=9.8683 | train mmd=0.0038 | test_mmd=0.0020\n",
      "[CellOT] epoch=1150 f_loss=0.1648 g_loss=10.1049 | train mmd=0.0035 | test_mmd=0.0021\n",
      "[CellOT] epoch=1200 f_loss=-0.2743 g_loss=10.0564 | train mmd=0.0029 | test_mmd=0.0018\n",
      "[CellOT] epoch=1250 f_loss=-0.1585 g_loss=10.3025 | train mmd=0.0038 | test_mmd=0.0017\n",
      "[CellOT] epoch=1300 f_loss=-0.0991 g_loss=9.9479 | train mmd=0.0024 | test_mmd=0.0014\n",
      "[CellOT] epoch=1350 f_loss=0.0565 g_loss=10.2072 | train mmd=0.0038 | test_mmd=0.0033\n",
      "[CellOT] epoch=1400 f_loss=0.8395 g_loss=10.5756 | train mmd=0.0034 | test_mmd=0.0025\n",
      "[CellOT] epoch=1450 f_loss=-0.7370 g_loss=10.6675 | train mmd=0.0026 | test_mmd=0.0027\n",
      "[CellOT] epoch=1500 f_loss=-0.3975 g_loss=10.5963 | train mmd=0.0028 | test_mmd=0.0008\n",
      "[CellOT] epoch=1550 f_loss=0.9482 g_loss=10.1944 | train mmd=0.0037 | test_mmd=0.0040\n",
      "[CellOT] epoch=1600 f_loss=0.1503 g_loss=10.0949 | train mmd=0.0026 | test_mmd=0.0023\n",
      "[CellOT] epoch=1650 f_loss=0.3298 g_loss=10.6262 | train mmd=0.0034 | test_mmd=0.0016\n",
      "[CellOT] epoch=1700 f_loss=-0.1498 g_loss=10.7149 | train mmd=0.0048 | test_mmd=0.0023\n",
      "[CellOT] epoch=1750 f_loss=-0.0711 g_loss=10.7433 | train mmd=0.0030 | test_mmd=0.0011\n",
      "[CellOT] epoch=1800 f_loss=0.1516 g_loss=10.7025 | train mmd=0.0016 | test_mmd=0.0022\n",
      "[CellOT] epoch=1850 f_loss=0.1404 g_loss=10.5138 | train mmd=0.0020 | test_mmd=0.0048\n",
      "[CellOT] epoch=1900 f_loss=-0.8559 g_loss=10.6711 | train mmd=0.0024 | test_mmd=0.0012\n",
      "[CellOT] epoch=1950 f_loss=-0.5569 g_loss=10.9757 | train mmd=0.0021 | test_mmd=0.0012\n",
      "[CellOT] epoch=2000 f_loss=0.0546 g_loss=10.6780 | train mmd=0.0017 | test_mmd=0.0013\n",
      "[CellOT] Final CellOT MMD: 0.0025\n",
      "VERS torch=1.13.1+cu117 (CellOT), device=cuda\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 8 metrics: {'mmd2_gamma_median': 0.0012576819660510274, 'mmd2_gamma_0.5': 0.00492906910710611, 'mmd2_gamma_1.0': 0.00659249380527982, 'wasserstein_distance': 0.6438661963048388, 'R2_feature_means': 0.9979908670946189}\n",
      "**************** Run: 9 ****************\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[CellOT] epoch=0 f_loss=-0.5349 g_loss=-4.7464 | train mmd=0.2847 | test_mmd=0.9832\n",
      "[CellOT] epoch=50 f_loss=-1.4623 g_loss=-1.7422 | train mmd=0.1334 | test_mmd=0.1042\n",
      "[CellOT] epoch=100 f_loss=-0.7991 g_loss=0.3912 | train mmd=0.0897 | test_mmd=0.0223\n",
      "[CellOT] epoch=150 f_loss=-0.7542 g_loss=1.0103 | train mmd=0.1015 | test_mmd=0.0152\n",
      "[CellOT] epoch=200 f_loss=-0.6188 g_loss=1.3966 | train mmd=0.1034 | test_mmd=0.0105\n",
      "[CellOT] epoch=250 f_loss=-0.8771 g_loss=2.6092 | train mmd=0.0765 | test_mmd=0.0057\n",
      "[CellOT] epoch=300 f_loss=-0.4810 g_loss=2.4090 | train mmd=0.0541 | test_mmd=0.0033\n",
      "[CellOT] epoch=350 f_loss=-0.1741 g_loss=2.2137 | train mmd=0.0481 | test_mmd=0.0015\n",
      "[CellOT] epoch=400 f_loss=-0.1135 g_loss=3.1462 | train mmd=0.0295 | test_mmd=0.0016\n",
      "[CellOT] epoch=450 f_loss=0.5215 g_loss=3.2123 | train mmd=0.0294 | test_mmd=0.0011\n",
      "[CellOT] epoch=500 f_loss=0.8677 g_loss=3.1728 | train mmd=0.0263 | test_mmd=0.0025\n",
      "[CellOT] epoch=550 f_loss=-0.5122 g_loss=4.2341 | train mmd=0.0069 | test_mmd=0.0010\n",
      "[CellOT] epoch=600 f_loss=-0.1126 g_loss=5.7079 | train mmd=0.0076 | test_mmd=0.0018\n",
      "[CellOT] epoch=650 f_loss=0.1914 g_loss=5.7711 | train mmd=0.0063 | test_mmd=0.0017\n",
      "[CellOT] epoch=700 f_loss=0.7566 g_loss=6.1448 | train mmd=0.0049 | test_mmd=0.0033\n",
      "[CellOT] epoch=750 f_loss=0.2150 g_loss=5.7565 | train mmd=0.0060 | test_mmd=0.0023\n",
      "[CellOT] epoch=800 f_loss=0.3622 g_loss=6.0462 | train mmd=0.0045 | test_mmd=0.0012\n",
      "[CellOT] epoch=850 f_loss=0.0416 g_loss=6.1418 | train mmd=0.0045 | test_mmd=0.0020\n",
      "[CellOT] epoch=900 f_loss=-0.9441 g_loss=7.0804 | train mmd=0.0042 | test_mmd=0.0016\n",
      "[CellOT] epoch=950 f_loss=0.0558 g_loss=6.6675 | train mmd=0.0034 | test_mmd=0.0013\n",
      "[CellOT] epoch=1000 f_loss=-0.5263 g_loss=6.6731 | train mmd=0.0041 | test_mmd=0.0012\n",
      "[CellOT] epoch=1050 f_loss=-0.0191 g_loss=6.7995 | train mmd=0.0029 | test_mmd=0.0011\n",
      "[CellOT] epoch=1100 f_loss=0.1748 g_loss=7.0882 | train mmd=0.0036 | test_mmd=0.0008\n",
      "[CellOT] epoch=1150 f_loss=-0.5772 g_loss=6.9841 | train mmd=0.0031 | test_mmd=0.0021\n",
      "[CellOT] epoch=1200 f_loss=0.1224 g_loss=7.0336 | train mmd=0.0029 | test_mmd=0.0021\n",
      "[CellOT] epoch=1250 f_loss=-0.2673 g_loss=7.4395 | train mmd=0.0027 | test_mmd=0.0016\n",
      "[CellOT] epoch=1300 f_loss=-0.2876 g_loss=7.7285 | train mmd=0.0023 | test_mmd=0.0011\n",
      "[CellOT] epoch=1350 f_loss=-0.3328 g_loss=7.9536 | train mmd=0.0032 | test_mmd=0.0016\n",
      "[CellOT] epoch=1400 f_loss=-0.0855 g_loss=7.7202 | train mmd=0.0025 | test_mmd=0.0017\n",
      "[CellOT] epoch=1450 f_loss=0.1619 g_loss=7.0096 | train mmd=0.0030 | test_mmd=0.0027\n",
      "[CellOT] epoch=1500 f_loss=0.3617 g_loss=7.4779 | train mmd=0.0028 | test_mmd=0.0020\n",
      "[CellOT] epoch=1550 f_loss=-0.4646 g_loss=7.9670 | train mmd=0.0023 | test_mmd=0.0021\n",
      "[CellOT] epoch=1600 f_loss=-0.4398 g_loss=8.2482 | train mmd=0.0020 | test_mmd=0.0021\n",
      "[CellOT] epoch=1650 f_loss=0.2044 g_loss=7.4889 | train mmd=0.0020 | test_mmd=0.0012\n",
      "[CellOT] epoch=1700 f_loss=-0.3590 g_loss=7.9764 | train mmd=0.0018 | test_mmd=0.0011\n",
      "[CellOT] epoch=1750 f_loss=-0.4874 g_loss=7.6600 | train mmd=0.0015 | test_mmd=0.0011\n",
      "[CellOT] epoch=1800 f_loss=-1.0007 g_loss=8.0776 | train mmd=0.0020 | test_mmd=0.0019\n",
      "[CellOT] epoch=1850 f_loss=0.1680 g_loss=7.9873 | train mmd=0.0022 | test_mmd=0.0015\n",
      "[CellOT] epoch=1900 f_loss=0.2864 g_loss=8.1106 | train mmd=0.0016 | test_mmd=0.0016\n",
      "[CellOT] epoch=1950 f_loss=0.0343 g_loss=8.1125 | train mmd=0.0013 | test_mmd=0.0017\n",
      "[CellOT] epoch=2000 f_loss=0.4062 g_loss=8.2389 | train mmd=0.0016 | test_mmd=0.0021\n",
      "[CellOT] Final CellOT MMD: 0.0029\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.\n",
      "  warnings.warn(\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/umap/umap_.py:1952: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.\n",
      "  warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Run 9 metrics: {'mmd2_gamma_median': 0.002116835808132045, 'mmd2_gamma_0.5': 0.006504946011012813, 'mmd2_gamma_1.0': 0.0076680169659653075, 'wasserstein_distance': 0.6421072001703861, 'R2_feature_means': 0.9982041719924265}\n",
      "                        mean     std\n",
      "mmd2_gamma_median     0.0016  0.0004\n",
      "mmd2_gamma_0.5        0.0057  0.0011\n",
      "mmd2_gamma_1.0        0.0073  0.0012\n",
      "wasserstein_distance  0.6417  0.0069\n",
      "R2_feature_means      0.9980  0.0007\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.\n",
      "  warnings.warn(\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.\n",
      "  warnings.warn(\n",
      "/u/jrp5td/here/miniconda3/envs/scgen-env/lib/python3.9/site-packages/sklearn/utils/deprecation.py:151: FutureWarning: 'force_all_finite' was renamed to 'ensure_all_finite' in 1.6 and will be removed in 1.8.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "drug = \"Vem\"\n",
    "X_pre, X_post = prepare_pair_from_mat('SKMEL19', 'DMSO','24h', drug, '72h')\n",
    "jfe_indices = [1, 6, 0, 5, 4, 7, 8, 2, 3, 19]  \n",
    "\n",
    "print(\"X_pre cells:\", X_pre.shape)\n",
    "print(\"X_post cells:\", X_post.shape)\n",
    "\n",
    "X_tr_pre, X_te_pre, Y_tr_post, Y_te_post = split_train_test(X_pre, X_post, 0.8)\n",
    "\n",
    "print(X_tr_pre.shape)\n",
    "print(X_te_pre.shape)\n",
    "print(Y_tr_post.shape)\n",
    "print(Y_te_post.shape)\n",
    "\n",
    "# Compute median heuristic gamma on training data\n",
    "median_gamma = median_heuristic_gamma(X_tr_pre, Y_tr_post)\n",
    "print(\"Median heuristic gamma:\", median_gamma)\n",
    "\n",
    "\n",
    "all_metrics = []\n",
    "for run in range(10):\n",
    "    print(f\"**************** Run: {run} ****************\")\n",
    "    seed = 1234 + run\n",
    "    torch.manual_seed(seed)\n",
    "    torch.cuda.manual_seed_all(seed)\n",
    "    random.seed(seed)\n",
    "    np.random.seed(seed)\n",
    "    torch.backends.cudnn.deterministic = True\n",
    "    torch.backends.cudnn.benchmark = False\n",
    "\n",
    "    out = run_cellot_pair(X_tr_pre[:, jfe_indices], Y_tr_post[:, jfe_indices], X_te_pre[:, jfe_indices], Y_te_post[:, jfe_indices], n_epochs=2000)\n",
    "    metrics = summarize_metrics(out[\"y_pred\"], Y_te_post[:, jfe_indices], median_gamma)\n",
    "    print(f\"Run {run} metrics: {metrics}\")\n",
    "    all_metrics.append(metrics)\n",
    "\n",
    "# Results summary\n",
    "df = pd.DataFrame(all_metrics)\n",
    "print(df.describe().T[['mean', 'std']].round(4))\n",
    "\n",
    "\n",
    "from umap import UMAP\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "source = Y_tr_post[:, jfe_indices]\n",
    "target = Y_te_post[:, jfe_indices]\n",
    "predicted = out.get('y_pred') \n",
    "\n",
    "# Instantiate UMAP\n",
    "umap_model = UMAP(n_components=2, random_state=42)\n",
    "\n",
    "all_sample_umap = umap_model.fit_transform(np.vstack([source, target]))\n",
    "source_umap = umap_model.transform(source)\n",
    "target_umap = umap_model.transform(target)\n",
    "y_pred_umap = umap_model.transform(predicted)\n",
    "\n",
    "fig, ax = plt.subplots(figsize=(4, 4))\n",
    "# ax.scatter(source_umap[:, 0], source_umap[:, 1], s=10, alpha=0.7, label='train_post', color='C2')\n",
    "ax.scatter(target_umap[:, 0], target_umap[:, 1], s=10, alpha=0.7, label='observed treated cells', color=\"#C88131\")\n",
    "ax.scatter(y_pred_umap[:, 0], y_pred_umap[:, 1], s=10, alpha=0.7, label='predicted cells', color=\"#1F4D8D\")\n",
    "\n",
    "ax.set_title(f'{drug}')\n",
    "# ax.set_xlabel('UMAP 1')\n",
    "# ax.set_ylabel('UMAP 2')\n",
    "ax.set_aspect('equal', 'box')\n",
    "# Add a legend to distinguish the points\n",
    "ax.legend()\n",
    "# Adjust layout\n",
    "plt.tight_layout()\n",
    "# Display the plot\n",
    "plt.savefig(f\"./plots/cellot_on_4i_drug_{drug}.png\", dpi=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cc05c34",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6856ea1b",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2061ee67",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "711c8633",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3847dd9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "258f2758",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d698f206",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
