{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4ce6f328",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "sys.path.append(\"../src\")\n",
    "import feature.scrna_dataset as scrna_dataset\n",
    "import model.sdes as sdes\n",
    "import model.generate as generate\n",
    "import model.scrna_ae as scrna_ae\n",
    "import model.util as model_util\n",
    "import analysis.fid as fid\n",
    "import torch\n",
    "import numpy as np\n",
    "import scipy.stats\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import h5py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "df22d86f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define device\n",
    "if torch.cuda.is_available():\n",
    "    DEVICE = \"cuda\"\n",
    "else:\n",
    "    DEVICE = \"cpu\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dfe528ba",
   "metadata": {},
   "source": [
    "### Define the branches and create the data loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b3268176",
   "metadata": {},
   "outputs": [],
   "source": [
    "latent_space = False\n",
    "latent_dim = 200"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "951177e3",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_file = \"/data/anon/branched_diffusion/data/scrna/covid_flu/processed/covid_flu_processed_reduced_genes.h5\"\n",
    "autoencoder_path = \"/data/anon/branched_diffusion/models/trained_models/scrna_vaes/covid_flu/covid_flu_processed_reduced_genes_ldvae_d%d/\" % latent_dim\n",
    "\n",
    "models_base_path = \"/home/anon/branched_diffusion/models/trained_models/scrna_covid_flu_continuous_class_extension\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5b6d3f84",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: this is currently rather inefficient; a decision-tree-style structure\n",
    "# would be better\n",
    "\n",
    "def class_time_to_branch(c, t, branch_defs):\n",
    "    \"\"\"\n",
    "    Given a class and a time (both scalars), return the\n",
    "    corresponding branch index.\n",
    "    \"\"\"\n",
    "    for i, branch_def in enumerate(branch_defs):\n",
    "        if c in branch_def[0] and t >= branch_def[1] and t <= branch_def[2]:\n",
    "            return i\n",
    "    raise ValueError(\"Undefined class and time\")\n",
    "        \n",
    "def class_time_to_branch_tensor(c, t, branch_defs):\n",
    "    \"\"\"\n",
    "    Given tensors of classes and a times, return the\n",
    "    corresponding branch indices as a tensor.\n",
    "    \"\"\"\n",
    "    return torch.tensor([\n",
    "        class_time_to_branch(c_i, t_i, branch_defs) for c_i, t_i in zip(c, t)\n",
    "    ], device=DEVICE)\n",
    "\n",
    "def class_to_class_index_tensor(c, classes):\n",
    "    \"\"\"\n",
    "    Given a tensor of classes, return the corresponding class indices\n",
    "    as a tensor.\n",
    "    \"\"\"\n",
    "    return torch.argmax(\n",
    "        (c[:, None] == torch.tensor(classes, device=c.device)).int(), dim=1\n",
    "    ).to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "426efda4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the branches\n",
    "classes_01 = [0, 1]\n",
    "branch_defs_01 = [((0, 1), 0.5795795795795796, 1), ((0,), 0, 0.5795795795795796), ((1,), 0, 0.5795795795795796)]\n",
    "\n",
    "classes_012 = [0, 1, 5]\n",
    "branch_defs_012 = [((0, 1, 5), 6.786786786786787e-01, 1), ((0, 1), 0.5795795795795796, 0.6786786786786787), ((5,), 0, 0.6786786786786787), ((0,), 0, 0.5795795795795796), ((1,), 0, 0.5795795795795796)]\n",
    "\n",
    "classes_2 = [5]\n",
    "branch_defs_2 = [((5,), 0, 0.6786786786786787)]\n",
    "\n",
    "dataset_01 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=(autoencoder_path if latent_space else None))\n",
    "dataset_012 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=(autoencoder_path if latent_space else None))\n",
    "dataset_2 = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=(autoencoder_path if latent_space else None))\n",
    "\n",
    "# Limit classes\n",
    "inds_01 = np.isin(dataset_01.cell_cluster, classes_01)\n",
    "dataset_01.data = dataset_01.data[inds_01]\n",
    "dataset_01.cell_cluster = dataset_01.cell_cluster[inds_01]\n",
    "inds_012 = np.isin(dataset_012.cell_cluster, classes_012)\n",
    "dataset_012.data = dataset_012.data[inds_012]\n",
    "dataset_012.cell_cluster = dataset_012.cell_cluster[inds_012]\n",
    "inds_2 = np.isin(dataset_2.cell_cluster, classes_2)\n",
    "dataset_2.data = dataset_2.data[inds_2]\n",
    "dataset_2.cell_cluster = dataset_2.cell_cluster[inds_2]\n",
    "\n",
    "data_loader_01 = torch.utils.data.DataLoader(dataset_01, batch_size=128, shuffle=True, num_workers=0)\n",
    "data_loader_012 = torch.utils.data.DataLoader(dataset_012, batch_size=128, shuffle=True, num_workers=0)\n",
    "data_loader_2 = torch.utils.data.DataLoader(dataset_2, batch_size=128, shuffle=True, num_workers=0)\n",
    "input_shape = next(iter(data_loader_01))[0].shape[1:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ab1a36e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the SDE\n",
    "sde = sdes.VariancePreservingSDE(0.1, 5, input_shape)\n",
    "\n",
    "t_limit = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "387aa61e",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"MODEL_DIR\"] = os.path.join(models_base_path, \"extension\")\n",
    "\n",
    "import model.train_continuous_model as train_continuous_model  # Import this AFTER setting environment"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f19ff6b",
   "metadata": {},
   "source": [
    "#### Train extra branch on branched model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "d8e4d145",
   "metadata": {},
   "outputs": [],
   "source": [
    "def map_branch_def(branch_def, target_branch_defs):\n",
    "    \"\"\"\n",
    "    Given a particular branch definition (i.e. a triplet), and a\n",
    "    list of branch definitions, attempts to match that branch\n",
    "    definition to the corresponding entry in the list. This\n",
    "    mapping is based on whether or not the branch would need to be\n",
    "    retrained. The query `branch_def` is matched to a target within\n",
    "    `branch_defs` if the target's class indices are all present in\n",
    "    the query, and the query time is a sub-interval of the target\n",
    "    time.\n",
    "    Arguments:\n",
    "        `branch_def`: a branch definition (i.e. triplet of class index\n",
    "            tuple, start time, and end time)\n",
    "        `target_branch_defs`: a list of branch definitions\n",
    "    Returns the index of the matched branch definition in `branch_defs`,\n",
    "    or -1 if there is no suitable match found.\n",
    "    \"\"\"\n",
    "    for i, target_branch_def in enumerate(target_branch_defs):\n",
    "        if set(branch_def[0]).issuperset(set(target_branch_def[0])) \\\n",
    "            and branch_def[1] >= target_branch_def[1] \\\n",
    "            and branch_def[2] <= target_branch_def[2]:\n",
    "            return i\n",
    "    return -1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "701982d4",
   "metadata": {},
   "outputs": [],
   "source": [
    "branched_model_1 = model_util.load_model(\n",
    "    scrna_ae.MultitaskResNet,\n",
    "    os.path.join(models_base_path, \"scrna_covid_flu_continuous_branched_2classes/1/last_ckpt.pth\")\n",
    ").to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "a99ea13e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:58<00:00,  8.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.62it/s]\n"
     ]
    }
   ],
   "source": [
    "# Generate the samples\n",
    "branched_samples_before = {}\n",
    "for class_to_sample in classes_01:\n",
    "    print(\"Sampling class: %s\" % class_to_sample)\n",
    "    sample = generate.generate_continuous_branched_samples(\n",
    "        branched_model_1, sde, class_to_sample,\n",
    "        lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_01),\n",
    "        sampler=\"pc\", t_limit=t_limit, num_samples=1000, verbose=True\n",
    "    )\n",
    "    branched_samples_before[class_to_sample] = sample.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "88cc8341",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create new model and copy over parameters\n",
    "branched_model_2 = scrna_ae.MultitaskResNet(\n",
    "    len(branch_defs_012), input_shape[0], t_limit=t_limit\n",
    ").to(DEVICE)\n",
    "\n",
    "# Figure out which branches should be copied over to which ones\n",
    "branch_map_inds = [\n",
    "    map_branch_def(bd, branch_defs_01) for bd in branch_defs_012\n",
    "]\n",
    "\n",
    "# For each submodule, copy over the weights\n",
    "# Careful: this assumes a particular kind of architecture!\n",
    "modules_1 = dict(branched_model_1.named_children())\n",
    "modules_2 = dict(branched_model_2.named_children())\n",
    "\n",
    "for module_name in [\"layers\", \"time_embedders\"]:\n",
    "    for submodule_i, submodule in enumerate(modules_1[module_name]):\n",
    "        if len(submodule) == 1:\n",
    "            branched_model_2.get_submodule(module_name)[submodule_i].load_state_dict(\n",
    "                submodule.state_dict()\n",
    "            )\n",
    "        elif len(submodule) == len(branch_defs_01):\n",
    "            target_submodule_list = branched_model_2.get_submodule(module_name)[submodule_i]\n",
    "            for target_i, source_i in enumerate(branch_map_inds):\n",
    "                if source_i != -1:\n",
    "                    target_submodule_list[target_i].load_state_dict(\n",
    "                        submodule[source_i].state_dict()\n",
    "                    )\n",
    "                else:\n",
    "                    # Copy over some other branch for a warm start\n",
    "                    # We'll manually set it for now (TODO)\n",
    "                    source_i = -1  # Last branch\n",
    "                    target_submodule_list[target_i].load_state_dict(\n",
    "                        submodule[source_i].state_dict()\n",
    "                    )\n",
    "        else:\n",
    "            raise ValueError(\"Found module list of length %d\" % len(module_list))\n",
    "\n",
    "submodule = branched_model_1.get_submodule(\"last_linears\")\n",
    "target_submodule_list = branched_model_2.get_submodule(\"last_linears\")\n",
    "for target_i, source_i in enumerate(branch_map_inds):\n",
    "    if source_i != -1:\n",
    "        target_submodule_list[target_i].load_state_dict(\n",
    "            submodule[source_i].state_dict()\n",
    "        )\n",
    "    else:\n",
    "        # Copy over some other branch for a warm start\n",
    "        # We'll manually set it for now (TODO)\n",
    "        source_i = -1  # Last branch\n",
    "        target_submodule_list[target_i].load_state_dict(\n",
    "            submodule[source_i].state_dict()\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f6840fa6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:58<00:00,  8.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:58<00:00,  8.61it/s]\n"
     ]
    }
   ],
   "source": [
    "# Generate the samples again to make sure match-up was done correctly\n",
    "branched_samples_before_2 = {}\n",
    "for class_to_sample in classes_01:\n",
    "    print(\"Sampling class: %s\" % class_to_sample)\n",
    "    sample = generate.generate_continuous_branched_samples(\n",
    "        branched_model_2, sde, class_to_sample,\n",
    "        lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_012),\n",
    "        sampler=\"pc\", t_limit=t_limit, num_samples=1000, verbose=True\n",
    "    )\n",
    "    branched_samples_before_2[class_to_sample] = sample.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "7992ca21",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING - root - Added new config entry: \"class_time_to_branch_index\"\n",
      "WARNING - root - Added new config entry: \"data_loader\"\n",
      "WARNING - root - Added new config entry: \"loss_weighting_type\"\n",
      "WARNING - root - Added new config entry: \"model\"\n",
      "WARNING - root - Added new config entry: \"sde\"\n",
      "WARNING - root - Added new config entry: \"t_limit\"\n",
      "INFO - train - Running command 'train_branched_model'\n",
      "INFO - train - Started run with ID \"1\"\n",
      "Loss: 541.89: 100%|██████████████████████████████████████████| 24/24 [00:06<00:00,  3.57it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 average Loss: 934.59\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 322.41: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2 average Loss: 394.38\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 262.48: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3 average Loss: 290.47\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 236.19: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4 average Loss: 248.73\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 223.13: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5 average Loss: 229.84\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 214.03: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6 average Loss: 214.70\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 207.60: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7 average Loss: 209.96\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 207.48: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8 average Loss: 208.24\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 199.02: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9 average Loss: 204.64\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 219.66: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10 average Loss: 205.58\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 201.84: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 11 average Loss: 204.01\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 200.55: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 12 average Loss: 202.39\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 201.48: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 13 average Loss: 200.10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 208.25: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 14 average Loss: 202.08\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 206.25: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15 average Loss: 201.40\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 193.77: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 16 average Loss: 199.07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 194.77: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 17 average Loss: 200.06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 195.51: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 18 average Loss: 200.14\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 197.59: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19 average Loss: 199.59\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 214.07: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 20 average Loss: 199.34\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 203.86: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 21 average Loss: 200.05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 203.20: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 22 average Loss: 199.25\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 204.09: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 23 average Loss: 198.07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 197.67: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 24 average Loss: 197.75\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 197.21: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 25 average Loss: 196.30\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 193.27: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 26 average Loss: 195.22\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 194.59: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 27 average Loss: 195.88\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 194.37: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 28 average Loss: 193.41\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 191.63: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 29 average Loss: 194.82\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 194.82: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 30 average Loss: 193.57\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 197.86: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 31 average Loss: 192.92\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 193.93: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 32 average Loss: 193.21\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 195.69: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 33 average Loss: 193.24\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 188.00: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.93it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 34 average Loss: 190.22\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 191.57: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 35 average Loss: 189.02\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 180.67: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 36 average Loss: 189.45\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 187.40: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 37 average Loss: 187.70\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 192.07: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 38 average Loss: 188.94\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 179.81: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 39 average Loss: 186.43\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 186.99: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 40 average Loss: 188.21\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 180.00: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 41 average Loss: 184.47\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 180.65: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 42 average Loss: 184.01\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 178.03: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 43 average Loss: 182.47\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 176.69: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 44 average Loss: 182.72\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 182.24: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 45 average Loss: 183.89\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 179.27: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 46 average Loss: 181.67\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 186.57: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 47 average Loss: 181.42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 181.07: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.98it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 48 average Loss: 181.16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 176.92: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 49 average Loss: 180.07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 184.67: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 50 average Loss: 179.34\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 176.39: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 51 average Loss: 178.82\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 177.84: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 52 average Loss: 180.16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 170.51: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 53 average Loss: 177.33\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 185.37: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 54 average Loss: 178.77\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 186.09: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 55 average Loss: 178.24\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 176.18: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 56 average Loss: 177.40\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 170.73: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 57 average Loss: 178.27\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 177.48: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 58 average Loss: 176.12\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 179.95: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 59 average Loss: 177.06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 183.54: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 60 average Loss: 175.83\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 172.75: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 61 average Loss: 174.99\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 176.54: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 62 average Loss: 173.56\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 167.04: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 63 average Loss: 175.57\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 171.84: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 64 average Loss: 171.87\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 180.66: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 65 average Loss: 175.43\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 165.39: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 66 average Loss: 172.45\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 180.04: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 67 average Loss: 171.16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 190.52: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 68 average Loss: 171.96\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 180.63: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 69 average Loss: 173.21\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 169.40: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 70 average Loss: 172.09\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 172.24: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 71 average Loss: 171.75\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 163.40: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 72 average Loss: 169.62\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 176.77: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 73 average Loss: 171.32\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 179.21: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 74 average Loss: 169.29\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 173.94: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  5.00it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 75 average Loss: 167.70\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 167.65: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 76 average Loss: 169.27\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 173.67: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 77 average Loss: 168.66\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 161.21: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 78 average Loss: 168.91\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 169.10: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 79 average Loss: 170.50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 162.25: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 80 average Loss: 168.97\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 177.53: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 81 average Loss: 169.51\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 167.69: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 82 average Loss: 167.77\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 164.09: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 83 average Loss: 166.51\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 173.80: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 84 average Loss: 168.70\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 158.83: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 85 average Loss: 167.01\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 170.86: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 86 average Loss: 166.36\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 165.91: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 87 average Loss: 167.18\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 171.36: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 88 average Loss: 166.90\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 165.73: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.94it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 89 average Loss: 166.48\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 167.68: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 90 average Loss: 164.47\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 157.61: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 91 average Loss: 165.10\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 165.24: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.97it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 92 average Loss: 165.88\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 166.33: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.95it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 93 average Loss: 164.01\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 165.82: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 94 average Loss: 166.55\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 159.49: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 95 average Loss: 164.26\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 160.49: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 96 average Loss: 164.98\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 163.95: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 97 average Loss: 164.67\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 160.04: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 98 average Loss: 163.01\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 170.42: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 99 average Loss: 163.36\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 173.83: 100%|██████████████████████████████████████████| 24/24 [00:04<00:00,  4.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 100 average Loss: 162.93\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO - train - Completed after 0:13:26\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<sacred.run.Run at 0x2aab88d76eb0>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train the model, for the specific branches only\n",
    "\n",
    "# Freeze all shared layers of the model, and freeze all task-specific\n",
    "# layers other than the ones we want to train\n",
    "for module_name in [\"layers\", \"time_embedders\"]:\n",
    "    for submodule in branched_model_2.get_submodule(module_name):\n",
    "        if len(submodule) == 1:\n",
    "            for p in submodule.parameters():\n",
    "                p.requires_grad = False\n",
    "        elif len(submodule) == len(branch_defs_012):\n",
    "            for i in range(len(submodule)):\n",
    "                if branch_map_inds[i] != -1:\n",
    "                    for p in submodule[i].parameters():\n",
    "                        p.requires_grad = False\n",
    "                else:\n",
    "                    for p in submodule[i].parameters():\n",
    "                        p.requires_grad = True\n",
    "        else:\n",
    "            raise ValueError(\"Found module list of length %d\" % len(submodule))\n",
    "submodule = branched_model_2.get_submodule(\"last_linears\")\n",
    "for i in range(len(submodule)):\n",
    "    if branch_map_inds[i] != -1:\n",
    "        for p in submodule[i].parameters():\n",
    "            p.requires_grad = False\n",
    "    else:\n",
    "        for p in submodule[i].parameters():\n",
    "            p.requires_grad = True\n",
    "\n",
    "# Train\n",
    "train_continuous_model.train_ex.run(\n",
    "    \"train_branched_model\",\n",
    "    config_updates={\n",
    "        \"model\": branched_model_2,\n",
    "        \"sde\": sde,\n",
    "        \"data_loader\": data_loader_2,\n",
    "        \"class_time_to_branch_index\": lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_012),\n",
    "        \"num_epochs\": 100,\n",
    "        \"learning_rate\": 0.001,\n",
    "        \"t_limit\": branch_defs_2[0][2],\n",
    "        \"loss_weighting_type\": \"empirical_norm\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "441c00a3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.66it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:58<00:00,  8.62it/s]\n"
     ]
    }
   ],
   "source": [
    "# Generate the samples\n",
    "branched_samples_after = {}\n",
    "for class_to_sample in classes_012:\n",
    "    print(\"Sampling class: %s\" % class_to_sample)\n",
    "    sample = generate.generate_continuous_branched_samples(\n",
    "        branched_model_2, sde, class_to_sample,\n",
    "        lambda c, t: class_time_to_branch_tensor(c, t, branch_defs_012),\n",
    "        sampler=\"pc\", t_limit=t_limit, num_samples=1000, verbose=True\n",
    "    )\n",
    "    branched_samples_after[class_to_sample] = sample.cpu().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6a78ad98",
   "metadata": {},
   "source": [
    "#### Train label-guided model with only new label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "32c18892",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the label-guided model\n",
    "label_guided_model_1 = model_util.load_model(\n",
    "    scrna_ae.LabelGuidedResNet,\n",
    "    os.path.join(models_base_path, \"scrna_covid_flu_continuous_labelguided_2classes/1/last_ckpt.pth\")\n",
    ").to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "b7ddc6f4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:56<00:00,  8.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.76it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.77it/s]\n"
     ]
    }
   ],
   "source": [
    "linear_samples_before = {}\n",
    "for class_to_sample in classes_012:\n",
    "    print(\"Sampling class: %s\" % class_to_sample)\n",
    "    sample = generate.generate_continuous_label_guided_samples(\n",
    "        label_guided_model_1, sde, class_to_sample,\n",
    "        lambda c: class_to_class_index_tensor(c, classes_012),\n",
    "        sampler=\"pc\", t_limit=t_limit, num_samples=1000, verbose=True\n",
    "    )\n",
    "    linear_samples_before[class_to_sample] = sample.cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "4b9a3956",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING - root - Added new config entry: \"class_to_class_index\"\n",
      "WARNING - root - Added new config entry: \"data_loader\"\n",
      "WARNING - root - Added new config entry: \"loss_weighting_type\"\n",
      "WARNING - root - Added new config entry: \"model\"\n",
      "WARNING - root - Added new config entry: \"sde\"\n",
      "WARNING - root - Added new config entry: \"t_limit\"\n",
      "INFO - train - Running command 'train_label_guided_model'\n",
      "INFO - train - Started run with ID \"2\"\n",
      "Loss: 188.88: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.52it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 average Loss: 228.78\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 164.17: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.58it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2 average Loss: 171.67\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 159.49: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3 average Loss: 161.44\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 156.18: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.59it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4 average Loss: 152.68\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 150.89: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5 average Loss: 148.06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 140.66: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6 average Loss: 146.70\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 145.87: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7 average Loss: 142.78\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 140.37: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.60it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8 average Loss: 140.87\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 143.83: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9 average Loss: 137.89\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 137.81: 100%|██████████████████████████████████████████| 24/24 [00:02<00:00,  9.61it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10 average Loss: 136.16\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO - train - Completed after 0:00:52\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<sacred.run.Run at 0x2aab88dda910>"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train on only new label\n",
    "train_continuous_model.train_ex.run(\n",
    "    \"train_label_guided_model\",\n",
    "    config_updates={\n",
    "        \"model\": label_guided_model_1,\n",
    "        \"sde\": sde,\n",
    "        \"data_loader\": data_loader_2,\n",
    "        \"class_to_class_index\": lambda c: class_to_class_index_tensor(c, classes_012),\n",
    "        \"num_epochs\": 10,\n",
    "        \"learning_rate\": 0.001,\n",
    "        \"t_limit\": t_limit,\n",
    "        \"loss_weighting_type\": \"empirical_norm\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "3af73f38",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:56<00:00,  8.80it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:56<00:00,  8.79it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:56<00:00,  8.78it/s]\n"
     ]
    }
   ],
   "source": [
    "linear_samples_after_newonly = {}\n",
    "for class_to_sample in classes_012:\n",
    "    print(\"Sampling class: %s\" % class_to_sample)\n",
    "    sample = generate.generate_continuous_label_guided_samples(\n",
    "        label_guided_model_1, sde, class_to_sample,\n",
    "        lambda c: class_to_class_index_tensor(c, classes_012),\n",
    "        sampler=\"pc\", t_limit=t_limit, num_samples=1000, verbose=True\n",
    "    )\n",
    "    linear_samples_after_newonly[class_to_sample] = sample.cpu().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e7e003f",
   "metadata": {},
   "source": [
    "#### Train label-guided model with all data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "ae80f566",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the label-guided model\n",
    "label_guided_model_2 = model_util.load_model(\n",
    "    scrna_ae.LabelGuidedResNet,\n",
    "    os.path.join(models_base_path, \"scrna_covid_flu_continuous_labelguided_2classes/1/last_ckpt.pth\")\n",
    ").to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "0fb65d2d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING - root - Added new config entry: \"class_to_class_index\"\n",
      "WARNING - root - Added new config entry: \"data_loader\"\n",
      "WARNING - root - Added new config entry: \"loss_weighting_type\"\n",
      "WARNING - root - Added new config entry: \"model\"\n",
      "WARNING - root - Added new config entry: \"sde\"\n",
      "WARNING - root - Added new config entry: \"t_limit\"\n",
      "INFO - train - Running command 'train_label_guided_model'\n",
      "INFO - train - Started run with ID \"5\"\n",
      "Loss: 94.82: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 average Loss: 84.54\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 72.43: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2 average Loss: 79.13\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 81.77: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3 average Loss: 78.28\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 82.70: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4 average Loss: 78.02\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 84.01: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5 average Loss: 77.01\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 91.03: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6 average Loss: 76.50\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 74.29: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7 average Loss: 75.05\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 76.08: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8 average Loss: 76.26\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 71.18: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9 average Loss: 74.69\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 68.64: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10 average Loss: 75.23\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 79.51: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 11 average Loss: 74.38\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 87.25: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 12 average Loss: 74.67\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 75.56: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 13 average Loss: 74.07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 84.54: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 14 average Loss: 73.26\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 91.69: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15 average Loss: 73.42\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 75.05: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 16 average Loss: 74.01\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 70.85: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 17 average Loss: 73.73\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 75.24: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 18 average Loss: 72.69\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 71.01: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19 average Loss: 72.76\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 71.26: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 20 average Loss: 71.95\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 86.02: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 21 average Loss: 71.86\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 75.04: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 22 average Loss: 72.07\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 67.58: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 23 average Loss: 71.33\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 75.39: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.41it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 24 average Loss: 71.99\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 76.64: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 25 average Loss: 71.06\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 72.59: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 26 average Loss: 71.66\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 70.42: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 27 average Loss: 70.71\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 76.20: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.43it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 28 average Loss: 70.91\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 83.29: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 29 average Loss: 71.17\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loss: 70.21: 100%|█████████████████████████████████████████| 132/132 [00:14<00:00,  9.42it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 30 average Loss: 70.25\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO - train - Completed after 0:07:55\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<sacred.run.Run at 0x2aac7a367280>"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Train on all data\n",
    "train_continuous_model.train_ex.run(\n",
    "    \"train_label_guided_model\",\n",
    "    config_updates={\n",
    "        \"model\": label_guided_model_2,\n",
    "        \"sde\": sde,\n",
    "        \"data_loader\": data_loader_012,\n",
    "        \"class_to_class_index\": lambda c: class_to_class_index_tensor(c, classes_012),\n",
    "        \"num_epochs\": 30,\n",
    "        \"learning_rate\": 0.001,\n",
    "        \"t_limit\": t_limit,\n",
    "        \"loss_weighting_type\": \"empirical_norm\"\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "c22f8480",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:56<00:00,  8.81it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████| 500/500 [00:57<00:00,  8.73it/s]\n"
     ]
    }
   ],
   "source": [
    "linear_samples_after_all = {}\n",
    "for class_to_sample in classes_012:\n",
    "    print(\"Sampling class: %s\" % class_to_sample)\n",
    "    sample = generate.generate_continuous_label_guided_samples(\n",
    "        label_guided_model_2, sde, class_to_sample,\n",
    "        lambda c: class_to_class_index_tensor(c, classes_012),\n",
    "        sampler=\"pc\", t_limit=t_limit, num_samples=1000, verbose=True\n",
    "    )\n",
    "    linear_samples_after_all[class_to_sample] = sample.cpu().numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "77ff25fc",
   "metadata": {},
   "source": [
    "#### Compute FIDs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "2b853d4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: 0\n",
      "Sampling class: 1\n",
      "Sampling class: 5\n"
     ]
    }
   ],
   "source": [
    "# Sample objects from the original dataset\n",
    "true_samples = {}\n",
    "for class_to_sample in classes_012:\n",
    "    print(\"Sampling class: %s\" % class_to_sample)\n",
    "    inds = np.where(dataset_012.cell_cluster == class_to_sample)[0]\n",
    "    sample_inds = np.random.choice(inds, size=1000, replace=False)\n",
    "    true_samples[class_to_sample] = dataset_012.data[sample_inds]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "03d78307",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[34mINFO    \u001b[0m File                                                                                                      \n",
      "         \u001b[35m/data/anon/branched_diffusion/models/trained_models/scrna_vaes/covid_flu/covid_flu_proc\u001b[0m\n",
      "         \u001b[35messed_reduced_genes_ldvae_d200/\u001b[0m\u001b[95mmodel.pt\u001b[0m already downloaded                                                \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/anon/miniconda3/envs/scanpy/lib/python3.9/site-packages/scvi/data/fields/_layer_field.py:91: UserWarning: adata.X does not contain unnormalized count data. Are you sure this is what you want?\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "if not latent_space:\n",
    "    dataset_with_ae = scrna_dataset.SingleCellDataset(data_file, autoencoder_path=autoencoder_path)\n",
    "\n",
    "def compute_fid(gen_samples, true_samples, latent=True):\n",
    "    if latent_space:\n",
    "        if latent:\n",
    "            return fid.compute_fid(\n",
    "                gen_samples,\n",
    "                dataset_01.encode_batch(torch.tensor(true_samples, device=DEVICE)).cpu().numpy()\n",
    "            )\n",
    "        else:\n",
    "            return fid.compute_fid(\n",
    "                dataset_01.decode_batch(torch.tensor(gen_samples, device=DEVICE)).cpu().numpy(),\n",
    "                true_samples\n",
    "            )\n",
    "    else:\n",
    "        gen_samples[gen_samples < 0] = 0  # Generated values should never be above 0\n",
    "        if latent:\n",
    "            return fid.compute_fid(\n",
    "                dataset_with_ae.encode_batch(torch.tensor(gen_samples, device=DEVICE)).cpu().numpy(),\n",
    "                dataset_with_ae.encode_batch(torch.tensor(true_samples, device=DEVICE)).cpu().numpy()\n",
    "            )\n",
    "        else:\n",
    "            return fid.compute_fid(\n",
    "                dataset_with_ae.decode_batch(dataset_with_ae.encode_batch(torch.tensor(gen_samples, device=DEVICE))).cpu().numpy(),\n",
    "                dataset_with_ae.decode_batch(dataset_with_ae.encode_batch(torch.tensor(true_samples, device=DEVICE))).cpu().numpy()\n",
    "            )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "c3e69366",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "branched_before_fids = {}\n",
    "branched_before_2_fids = {}\n",
    "branched_after_fids = {}\n",
    "linear_before_fids = {}\n",
    "linear_after_newonly_fids = {}\n",
    "linear_after_all_fids = {}\n",
    "\n",
    "latent = True\n",
    "\n",
    "for c in branched_samples_before.keys():\n",
    "    branched_before_fids[c] = compute_fid(branched_samples_before[c], true_samples[c], latent)\n",
    "for c in branched_samples_before_2.keys():\n",
    "    branched_before_2_fids[c] = compute_fid(branched_samples_before_2[c], true_samples[c], latent)\n",
    "for c in branched_samples_after.keys():\n",
    "    branched_after_fids[c] = compute_fid(branched_samples_after[c], true_samples[c], latent)\n",
    "for c in linear_samples_before.keys():\n",
    "    linear_before_fids[c] = compute_fid(linear_samples_before[c], true_samples[c], latent)\n",
    "for c in linear_samples_after_newonly.keys():\n",
    "    linear_after_newonly_fids[c] = compute_fid(linear_samples_after_newonly[c], true_samples[c], latent)\n",
    "for c in linear_samples_after_all.keys():\n",
    "    linear_after_all_fids[c] = compute_fid(linear_samples_after_all[c], true_samples[c], latent)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "add8ed5c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "B-before {0: 17.91029206726665, 1: 20.768634314998813}\n",
      "B-before2 {0: 18.08620742509962, 1: 21.22780248367854}\n",
      "B-after {0: 18.08532807016608, 1: 20.537572760461792, 5: 23.560019574154172}\n",
      "L-before {0: 20.431972237481226, 1: 22.235712766699923, 5: 27.880261054677923}\n",
      "L-afterone {0: 25.624702214950076, 1: 26.794907896162005, 5: 32.10889066394775}\n",
      "L-afterall {0: 16.95639804145786, 1: 19.29993126245404, 5: 28.023053225371605}\n"
     ]
    }
   ],
   "source": [
    "print(\"B-before\", branched_before_fids)\n",
    "print(\"B-before2\", branched_before_2_fids)\n",
    "print(\"B-after\", branched_after_fids)\n",
    "print(\"L-before\", linear_before_fids)\n",
    "print(\"L-afterone\", linear_after_newonly_fids)\n",
    "print(\"L-afterall\", linear_after_all_fids)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
