{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just a temporary notebook to run the k=15-19 scRNA experiments with the new A."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys, os\n",
    "sys.path.append(os.path.join(sys.path[0], \"code\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<module 'sinkhorn_cnf' from '/mnt/hdd/scarv/riemannian-metric-learning-ot/code/sinkhorn_cnf.py'>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import pytorch_models as models\n",
    "import pytorch_samplers as samplers\n",
    "import pytorch_losses as losses\n",
    "import pytorch_training as training\n",
    "import pytorch_utils as utils\n",
    "import sinkhorn_cnf as cnf\n",
    "import scrna_exper as scrna\n",
    "from geomloss import SamplesLoss\n",
    "from torchdiffeq import odeint as odeint\n",
    "\n",
    "import importlib\n",
    "import time\n",
    "import json\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm, colors\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "importlib.reload(models)\n",
    "importlib.reload(losses)\n",
    "importlib.reload(samplers)\n",
    "importlib.reload(training)\n",
    "importlib.reload(utils)\n",
    "importlib.reload(scrna)\n",
    "importlib.reload(cnf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.set_default_dtype(torch.float64)\n",
    "device = torch.device(\"cuda:1\" if torch.cuda.is_available() else \"cpu\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import WOT data\n",
    "\n",
    "data_file = \"data/scrna/schiebinger.npz\"\n",
    "data_dict = np.load(data_file, allow_pickle=True)\n",
    "\n",
    "all_data = data_dict[\"original_embedding\"]\n",
    "# Rescale the data by factor of 1e-3\n",
    "all_data *= 1e-3\n",
    "\n",
    "timestamps = data_dict[\"sample_labels\"]\n",
    "\n",
    "times = np.unique(timestamps)\n",
    "\n",
    "subsampled_data_list = []\n",
    "data_list = []\n",
    "# Generate training data by randomly selecting 500 observations per time point\n",
    "for t in list(times):\n",
    "    data_t = all_data[timestamps == t]\n",
    "    data_list.append(data_t)\n",
    "    np.random.shuffle(data_t) # shuffle rows of data_t\n",
    "    subsampled_data_t = data_t[:500] # keep only first 500 samples\n",
    "    subsampled_data_list.append(subsampled_data_t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Load pretrained model\n",
    "\n",
    "space_dims = 2\n",
    "matrix_hidden_dims = 2048\n",
    "A_fname = \"trained_models/scrna_learned_params_v3.pt\" # was 'scrna_pretrained_params.pt'\n",
    "scrna_learnedA = models.PSDMatrix(space_dims, matrix_hidden_dims).to(device)\n",
    "scrna_learnedA.load_state_dict(torch.load(A_fname))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate list of samplers for all time points\n",
    "\n",
    "rho_list = samplers.generate_all_eb_samplers(data_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter 0 loss: tensor(102.6561, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 100 loss: tensor(70.7911, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 200 loss: tensor(67.1092, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 300 loss: tensor(61.8781, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 400 loss: tensor(61.3156, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 500 loss: tensor(58.5950, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 600 loss: tensor(48.6366, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 700 loss: tensor(2.1899, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 800 loss: tensor(2.1129, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 900 loss: tensor(2.0990, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1000 loss: tensor(2.1875, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1100 loss: tensor(1.6906, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1200 loss: tensor(1.9296, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1300 loss: tensor(1.9720, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1400 loss: tensor(1.8340, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1500 loss: tensor(2.0182, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1600 loss: tensor(2.1870, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1700 loss: tensor(2.1062, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1800 loss: tensor(1.9725, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1900 loss: tensor(2.1687, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2000 loss: tensor(2.1145, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2100 loss: tensor(1.9566, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2200 loss: tensor(2.0611, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2300 loss: tensor(1.9436, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2400 loss: tensor(1.8600, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2500 loss: tensor(1.8929, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2600 loss: tensor(1.7056, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2700 loss: tensor(1.8383, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2800 loss: tensor(1.9378, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2900 loss: tensor(1.6029, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3000 loss: tensor(1.5508, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3100 loss: tensor(1.2077, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3200 loss: tensor(0.6697, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3300 loss: tensor(0.4747, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3400 loss: tensor(0.4798, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3500 loss: tensor(0.3145, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3600 loss: tensor(0.3315, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3700 loss: tensor(0.3441, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3800 loss: tensor(0.3091, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3900 loss: tensor(0.3107, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4000 loss: tensor(0.6377, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4100 loss: tensor(0.2884, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4200 loss: tensor(0.7719, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4300 loss: tensor(0.2451, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4400 loss: tensor(0.3446, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4500 loss: tensor(0.4316, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4600 loss: tensor(0.2039, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4700 loss: tensor(0.2581, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4800 loss: tensor(0.2537, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4900 loss: tensor(0.2917, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5000 loss: tensor(0.3122, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5100 loss: tensor(0.2130, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5200 loss: tensor(0.2974, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5300 loss: tensor(0.4166, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5400 loss: tensor(0.4740, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5500 loss: tensor(0.4850, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5600 loss: tensor(0.1813, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5700 loss: tensor(0.3428, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5800 loss: tensor(0.2917, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5900 loss: tensor(0.3554, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6000 loss: tensor(0.2651, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6100 loss: tensor(0.3035, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6200 loss: tensor(0.5310, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6300 loss: tensor(0.2482, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6400 loss: tensor(0.2892, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6500 loss: tensor(0.2760, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6600 loss: tensor(0.3284, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6700 loss: tensor(0.2562, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6800 loss: tensor(0.2645, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6900 loss: tensor(0.1299, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7000 loss: tensor(0.3043, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7100 loss: tensor(0.2075, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7200 loss: tensor(0.1979, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7300 loss: tensor(0.2988, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7400 loss: tensor(0.2135, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7500 loss: tensor(0.3360, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7600 loss: tensor(0.3257, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7700 loss: tensor(0.1475, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7800 loss: tensor(0.1243, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7900 loss: tensor(0.3347, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8000 loss: tensor(0.1607, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8100 loss: tensor(0.2638, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8200 loss: tensor(0.3579, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8300 loss: tensor(0.2548, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8400 loss: tensor(0.3667, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8500 loss: tensor(0.1475, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8600 loss: tensor(0.1211, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8700 loss: tensor(0.4776, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8800 loss: tensor(0.2091, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8900 loss: tensor(0.4904, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9000 loss: tensor(0.2477, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9100 loss: tensor(0.1247, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9200 loss: tensor(0.5713, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9300 loss: tensor(0.2338, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9400 loss: tensor(0.2636, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9500 loss: tensor(0.2100, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9600 loss: tensor(0.2835, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9700 loss: tensor(0.3558, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9800 loss: tensor(0.1858, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9900 loss: tensor(0.2224, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "{'tp 1': 1.1473785076664536, 'tp 2': 3.8730091913058953, 'tp 3': 4.149327863039666, 'tp 4': 5.0327894078875115, 'tp 5': 5.814365434886701, 'tp 6': 5.925518582256864, 'tp 7': 5.500217709254779, 'tp 8': 5.460187788168083, 'tp 9': 5.005898518894156, 'tp 10': 4.692356356077964, 'tp 11': 4.074981424620035, 'tp 12': 3.804852158812894, 'tp 13': 3.2183184186712896, 'tp 14': 2.2969485896511106, 'tp 15': 1.6968415548476745, 'tp 16': 2.126467417107925, 'tp 17': 1.6041756711649167, 'tp 18': 1.0701169573600964}\n",
      "iter 0 loss: tensor(60.1610, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 100 loss: tensor(4.0209, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 200 loss: tensor(2.7100, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 300 loss: tensor(2.1100, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 400 loss: tensor(1.7491, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 500 loss: tensor(1.6728, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 600 loss: tensor(1.3637, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 700 loss: tensor(1.4829, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 800 loss: tensor(1.4828, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 900 loss: tensor(1.3142, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1000 loss: tensor(1.5358, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1100 loss: tensor(1.4238, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1200 loss: tensor(1.0352, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1300 loss: tensor(1.2509, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1400 loss: tensor(1.4357, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1500 loss: tensor(1.4231, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1600 loss: tensor(1.2089, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1700 loss: tensor(1.0550, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1800 loss: tensor(1.4125, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1900 loss: tensor(1.2065, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2000 loss: tensor(1.2768, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2100 loss: tensor(1.1336, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2200 loss: tensor(1.1002, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2300 loss: tensor(1.2515, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2400 loss: tensor(1.2896, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2500 loss: tensor(1.2880, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2600 loss: tensor(1.3170, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2700 loss: tensor(1.7545, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2800 loss: tensor(1.1454, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2900 loss: tensor(1.2456, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3000 loss: tensor(1.0589, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3100 loss: tensor(1.0044, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3200 loss: tensor(1.1058, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3300 loss: tensor(1.2736, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3400 loss: tensor(1.2902, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3500 loss: tensor(1.2169, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3600 loss: tensor(1.0813, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3700 loss: tensor(1.4332, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3800 loss: tensor(1.1058, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3900 loss: tensor(1.2206, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4000 loss: tensor(1.0084, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4100 loss: tensor(1.1223, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4200 loss: tensor(1.1578, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4300 loss: tensor(1.0622, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4400 loss: tensor(1.0028, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4500 loss: tensor(1.2200, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4600 loss: tensor(1.0633, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4700 loss: tensor(1.3877, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4800 loss: tensor(1.0693, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4900 loss: tensor(1.2117, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5000 loss: tensor(1.0148, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5100 loss: tensor(1.1795, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5200 loss: tensor(1.5944, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5300 loss: tensor(1.3063, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5400 loss: tensor(1.3838, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5500 loss: tensor(1.1774, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5600 loss: tensor(1.0402, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5700 loss: tensor(1.0856, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5800 loss: tensor(1.1025, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5900 loss: tensor(1.1286, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6000 loss: tensor(1.0997, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6100 loss: tensor(1.0343, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6200 loss: tensor(0.9633, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6300 loss: tensor(1.0580, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6400 loss: tensor(1.4022, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6500 loss: tensor(0.9617, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6600 loss: tensor(1.0981, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6700 loss: tensor(0.9382, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6800 loss: tensor(0.8652, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6900 loss: tensor(0.8481, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7000 loss: tensor(1.1902, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7100 loss: tensor(1.4436, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7200 loss: tensor(1.0099, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7300 loss: tensor(0.9826, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7400 loss: tensor(1.2412, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7500 loss: tensor(0.9258, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7600 loss: tensor(0.8140, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7700 loss: tensor(0.8467, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7800 loss: tensor(0.8116, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7900 loss: tensor(0.7547, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8000 loss: tensor(0.8056, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8100 loss: tensor(0.7422, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8200 loss: tensor(0.7144, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8300 loss: tensor(1.0162, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8400 loss: tensor(0.8424, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8500 loss: tensor(0.6720, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8600 loss: tensor(0.8094, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8700 loss: tensor(0.8983, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8800 loss: tensor(0.6925, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 8900 loss: tensor(0.6372, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9000 loss: tensor(0.7306, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9100 loss: tensor(1.3399, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9200 loss: tensor(0.8016, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9300 loss: tensor(0.6439, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9400 loss: tensor(0.6491, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9500 loss: tensor(0.7072, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9600 loss: tensor(0.7542, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9700 loss: tensor(0.6977, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9800 loss: tensor(0.6192, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 9900 loss: tensor(0.8467, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "{'tp 1': 1.1473785076664536, 'tp 2': 3.8730091913058953, 'tp 3': 4.149327863039666, 'tp 4': 5.0327894078875115, 'tp 5': 5.814365434886701, 'tp 6': 5.925518582256864, 'tp 7': 5.500217709254779, 'tp 8': 5.460187788168083, 'tp 9': 5.005898518894156, 'tp 10': 4.692356356077964, 'tp 11': 4.074981424620035, 'tp 12': 3.804852158812894, 'tp 13': 3.2183184186712896, 'tp 14': 2.2969485896511106, 'tp 15': 1.6968415548476745, 'tp 16': 2.126467417107925, 'tp 17': 1.6041756711649167, 'tp 18': 1.0701169573600964, 'tp 20': 0.5947019018691064, 'tp 21': 1.3710584782408946, 'tp 22': 2.0998348558145215, 'tp 23': 2.768578650490901, 'tp 24': 2.987297618407023, 'tp 25': 3.1107736159069264, 'tp 26': 2.783688890411322, 'tp 27': 2.6405068272444314, 'tp 28': 2.477821226629289, 'tp 29': 2.754321438634957, 'tp 30': 3.064175885455234, 'tp 31': 3.3413190327386033, 'tp 32': 2.823991169449449, 'tp 33': 3.4091618397346743, 'tp 34': 3.342275468446287, 'tp 35': 2.300712323326003, 'tp 36': 1.5291749425663554, 'tp 37': 0.9922174034782485}\n",
      "iter 0 loss: tensor(0.2777, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 100 loss: tensor(0.0918, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 200 loss: tensor(0.0805, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 300 loss: tensor(0.2332, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 400 loss: tensor(0.1759, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 500 loss: tensor(0.0900, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 600 loss: tensor(0.1049, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 700 loss: tensor(0.1385, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 800 loss: tensor(0.1146, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 900 loss: tensor(0.0724, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1000 loss: tensor(0.0678, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1100 loss: tensor(0.1389, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1200 loss: tensor(0.0949, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1300 loss: tensor(0.0614, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1400 loss: tensor(0.1117, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1500 loss: tensor(0.2536, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1600 loss: tensor(0.1002, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1700 loss: tensor(0.1468, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1800 loss: tensor(0.0961, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 1900 loss: tensor(0.1335, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2000 loss: tensor(0.1114, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2100 loss: tensor(0.2172, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2200 loss: tensor(0.2059, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2300 loss: tensor(0.2721, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2400 loss: tensor(0.0912, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2500 loss: tensor(0.1300, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2600 loss: tensor(0.0551, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2700 loss: tensor(0.0795, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2800 loss: tensor(0.1549, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 2900 loss: tensor(0.0487, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3000 loss: tensor(0.0743, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3100 loss: tensor(0.0589, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3200 loss: tensor(0.0606, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3300 loss: tensor(0.0878, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3400 loss: tensor(0.0664, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3500 loss: tensor(0.1045, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3600 loss: tensor(0.1122, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3700 loss: tensor(0.2248, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3800 loss: tensor(0.1488, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 3900 loss: tensor(0.0312, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4000 loss: tensor(0.1458, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4100 loss: tensor(0.0678, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4200 loss: tensor(0.0515, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4300 loss: tensor(0.0780, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4400 loss: tensor(0.0813, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4500 loss: tensor(0.0371, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4600 loss: tensor(0.1441, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4700 loss: tensor(0.1203, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4800 loss: tensor(0.0614, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 4900 loss: tensor(0.1931, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5000 loss: tensor(0.0911, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5100 loss: tensor(0.0482, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5200 loss: tensor(0.0699, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5300 loss: tensor(0.0612, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5400 loss: tensor(0.0671, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5500 loss: tensor(0.1670, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5600 loss: tensor(0.0705, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5700 loss: tensor(0.0529, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5800 loss: tensor(0.3366, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 5900 loss: tensor(0.0471, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6000 loss: tensor(0.0626, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6100 loss: tensor(0.0967, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6200 loss: tensor(0.1635, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6300 loss: tensor(0.4592, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6400 loss: tensor(0.0777, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6500 loss: tensor(0.0499, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6600 loss: tensor(0.0537, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6700 loss: tensor(0.1044, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6800 loss: tensor(0.1678, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 6900 loss: tensor(0.1556, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7000 loss: tensor(0.0737, device='cuda:1', grad_fn=<AddBackward0>)\n",
      "iter 7100 loss: tensor(0.0434, device='cuda:1', grad_fn=<AddBackward0>)\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_35641/930926338.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      8\u001b[0m     \u001b[0mW1_vals\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscrna\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_experiment\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrho_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjson_fname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0muse_A\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      9\u001b[0m     \u001b[0mjson_fname_final_tps\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"results/scrna_experiments/k\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"_withAv2_lambd_1em1_final_tps.json\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m     \u001b[0mW1_vals_final_tps\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mscrna\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrun_final_tp_experiment\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrho_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mjson_fname_final_tps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mk\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mA_model\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0muse_A\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
      "\u001b[0;32m/mnt/hdd/scarv/riemannian-metric-learning-ot/code/scrna_exper.py\u001b[0m in \u001b[0;36mrun_final_tp_experiment\u001b[0;34m(rho_list, json_fname, k, A_model, use_A)\u001b[0m\n\u001b[1;32m    115\u001b[0m     \u001b[0mbase_sampler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mold_targets\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m     \u001b[0mtarget_sampler_list\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mrho_list\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m     advected_samples, trained_v, losses = cnf.sinkhorn_cnf(base_sampler, # should have a .sample method\n\u001b[0m\u001b[1;32m    118\u001b[0m              \u001b[0mtarget_sampler_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# each element should have a .sample method\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    119\u001b[0m              \u001b[0mn_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;31m# number of samples drawn from each sampler in each epoch of training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/mnt/hdd/scarv/riemannian-metric-learning-ot/code/sinkhorn_cnf.py\u001b[0m in \u001b[0;36msinkhorn_cnf\u001b[0;34m(base_sampler, target_sampler_list, n_samples, step_size, space_dims, lambd, lr, weight_decay, n_epochs, A, time_varying, hidden_dims, v_model)\u001b[0m\n\u001b[1;32m    148\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    149\u001b[0m         \u001b[0;31m# Advect base_samples through velocity field v\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 150\u001b[0;31m         \u001b[0mx_t\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0modeint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbase_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimes\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmethod\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# advected samples\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    151\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    152\u001b[0m         \u001b[0;31m# Compute kinetic energy loss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/labb/lib/python3.9/site-packages/torchdiffeq/_impl/odeint.py\u001b[0m in \u001b[0;36modeint\u001b[0;34m(func, y0, t, rtol, atol, method, options, event_fn)\u001b[0m\n\u001b[1;32m     75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mevent_fn\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 77\u001b[0;31m         \u001b[0msolution\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msolver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mintegrate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     78\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     79\u001b[0m         \u001b[0mevent_t\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msolution\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msolver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mintegrate_until_event\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mevent_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/labb/lib/python3.9/site-packages/torchdiffeq/_impl/solvers.py\u001b[0m in \u001b[0;36mintegrate\u001b[0;34m(self, t)\u001b[0m\n\u001b[1;32m    103\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mt0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt1\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtime_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtime_grid\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    104\u001b[0m             \u001b[0mdt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mt1\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mt0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 105\u001b[0;31m             \u001b[0mdy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_step_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    106\u001b[0m             \u001b[0my1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my0\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mdy\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/labb/lib/python3.9/site-packages/torchdiffeq/_impl/fixed_grid.py\u001b[0m in \u001b[0;36m_step_func\u001b[0;34m(self, func, t0, dt, t1, y0)\u001b[0m\n\u001b[1;32m     19\u001b[0m         \u001b[0mf0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mperturb\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mPerturb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNEXT\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mperturb\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mPerturb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mNONE\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m         \u001b[0my_mid\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0my0\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mf0\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mhalf_dt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mdt\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt0\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mhalf_dt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_mid\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     22\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/labb/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1103\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/labb/lib/python3.9/site-packages/torchdiffeq/_impl/misc.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, t, y, perturb)\u001b[0m\n\u001b[1;32m    187\u001b[0m             \u001b[0;31m# Do nothing.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    188\u001b[0m             \u001b[0;32mpass\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 189\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbase_func\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    190\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    191\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/labb/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1103\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/mnt/hdd/scarv/riemannian-metric-learning-ot/code/sinkhorn_cnf.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, t, x)\u001b[0m\n\u001b[1;32m     56\u001b[0m                     \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munsqueeze\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     57\u001b[0m                     \u001b[0mhidden\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhidden_layer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mout_layer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     59\u001b[0m             \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# forward pass for kinetic energy computation\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     60\u001b[0m                 \u001b[0mB\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m/\u001b[0m\u001b[0mt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# assumes that x_t has already been reshaped to (T*B, space_dims)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/labb/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m   1100\u001b[0m         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m   1101\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1103\u001b[0m         \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1104\u001b[0m         \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/labb/lib/python3.9/site-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    101\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    102\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    105\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m~/anaconda3/envs/labb/lib/python3.9/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mlinear\u001b[0;34m(input, weight, bias)\u001b[0m\n\u001b[1;32m   1846\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0mhas_torch_function_variadic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1847\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mhandle_torch_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1848\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_C\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1849\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1850\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Run experiments with A\n",
    "\n",
    "k_vals = [19]\n",
    "A_model = scrna_learnedA\n",
    "for k in k_vals:\n",
    "    json_fname = \"results/scrna_experiments/k\" + str(k) + \"_withAv2_lambd_1em1.json\"\n",
    "    use_A = True\n",
    "    W1_vals = scrna.run_experiment(rho_list, json_fname, k, A_model, use_A)\n",
    "    json_fname_final_tps = \"results/scrna_experiments/k\" + str(k) + \"_withAv2_lambd_1em1_final_tps.json\"\n",
    "    W1_vals_final_tps = scrna.run_final_tp_experiment(rho_list, json_fname_final_tps, k, A_model, use_A)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:labb]",
   "language": "python",
   "name": "conda-env-labb-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
