{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4ce6f328",
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import sys\n",
    "sys.path.append(\"../src\")\n",
    "import model.sdes as sdes\n",
    "import model.generate as generate\n",
    "import model.table_dnn as table_dnn\n",
    "import model.util as model_util\n",
    "import torch\n",
    "import numpy as np\n",
    "import scipy.stats\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "df22d86f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define device\n",
    "if torch.cuda.is_available():\n",
    "    DEVICE = \"cuda\"\n",
    "else:\n",
    "    DEVICE = \"cpu\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "51fbc829",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TODO: this is currently rather inefficient; a decision-tree-style structure\n",
    "# would be better\n",
    "\n",
    "def class_time_to_branch(c, t):\n",
    "    \"\"\"\n",
    "    Given a class and a time (both scalars), return the\n",
    "    corresponding branch index.\n",
    "    \"\"\"\n",
    "    for i, branch_def in enumerate(branch_defs):\n",
    "        if c in branch_def[0] and t >= branch_def[1] and t <= branch_def[2]:\n",
    "            return i\n",
    "    raise ValueError(\"Undefined class and time\")\n",
    "        \n",
    "def class_time_to_branch_tensor(c, t):\n",
    "    \"\"\"\n",
    "    Given tensors of classes and a times, return the\n",
    "    corresponding branch indices as a tensor.\n",
    "    \"\"\"\n",
    "    return torch.tensor([\n",
    "        class_time_to_branch(c_i, t_i) for c_i, t_i in zip(c, t)\n",
    "    ], device=DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "767236f1",
   "metadata": {},
   "outputs": [],
   "source": [
    "models_base_path = \"/data/anon/branched_diffusion/models/trained_models/\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5f55c8b",
   "metadata": {},
   "source": [
    "### Generating multiple digit classes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f1c6de95",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the branches\n",
    "letters = list(\"ABCDEFGHIJKLMNOPQRSTUVWXYZ\")\n",
    "class_to_letter = dict(enumerate(letters))\n",
    "letter_to_class = {v : k for k, v in class_to_letter.items()}\n",
    "\n",
    "classes = [letter_to_class[l] for l in letters]\n",
    "branch_defs = [(('A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'), 0.5235235235235235, 1), (('A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'X', 'Y', 'Z'), 0.5165165165165165, 0.5235235235235235), (('B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'X', 'Y', 'Z'), 0.5115115115115115, 0.5165165165165165), (('B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'X', 'Y', 'Z'), 0.4944944944944945, 0.5115115115115115), (('I', 'J'), 0.4794794794794795, 0.4944944944944945), (('B', 'C', 'D', 'E', 'F', 'G', 'H', 'K', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'X', 'Y', 'Z'), 0.4724724724724725, 0.4944944944944945), (('B', 'C', 'D', 'E', 'G', 'H', 'K', 'M', 'N', 'O', 'Q', 'R', 'S', 'U', 'X', 'Z'), 0.45645645645645644, 0.4724724724724725), (('F', 'P', 'T', 'V', 'Y'), 0.4364364364364364, 0.4724724724724725), (('B', 'C', 'D', 'E', 'G', 'H', 'K', 'N', 'O', 'Q', 'R', 'S', 'U', 'X', 'Z'), 0.4174174174174174, 0.45645645645645644), (('B', 'C', 'D', 'E', 'G', 'H', 'K', 'N', 'O', 'Q', 'R', 'S', 'X', 'Z'), 0.4134134134134134, 0.4174174174174174), (('B', 'D', 'G', 'H', 'K', 'N', 'O', 'Q', 'R', 'S', 'X', 'Z'), 0.4094094094094094, 0.4134134134134134), (('F', 'T', 'V', 'Y'), 0.4024024024024024, 0.4364364364364364), (('B', 'D', 'G', 'H', 'K', 'O', 'Q', 'R', 'S', 'X', 'Z'), 0.3863863863863864, 0.4094094094094094), (('B', 'G', 'H', 'K', 'O', 'Q', 'R', 'S', 'X', 'Z'), 0.3813813813813814, 0.3863863863863864), (('B', 'G', 'H', 'O', 'Q', 'R', 'S', 'X', 'Z'), 0.3733733733733734, 0.3813813813813814), (('F', 'T', 'Y'), 0.36036036036036034, 0.4024024024024024), (('C', 'E'), 0.3563563563563564, 0.4134134134134134), (('T', 'Y'), 0.3533533533533533, 0.36036036036036034), (('B', 'R', 'S', 'X', 'Z'), 0.35135135135135137, 0.3733733733733734), (('G', 'H', 'O', 'Q'), 0.34134134134134136, 0.3733733733733734), (('B', 'S', 'X', 'Z'), 0.32232232232232233, 0.35135135135135137), (('B', 'S', 'X'), 0.27627627627627627, 0.32232232232232233), (('G', 'H', 'O'), 0.26426426426426425, 0.34134134134134136), (('G', 'O'), 0.25725725725725723, 0.26426426426426425), (('S', 'X'), 0.15615615615615616, 0.27627627627627627), (('W',), 0, 0.5235235235235235), (('A',), 0, 0.5165165165165165), (('L',), 0, 0.5115115115115115), (('J',), 0, 0.4794794794794795), (('I',), 0, 0.4794794794794795), (('M',), 0, 0.45645645645645644), (('P',), 0, 0.4364364364364364), (('U',), 0, 0.4174174174174174), (('N',), 0, 0.4094094094094094), (('V',), 0, 0.4024024024024024), (('D',), 0, 0.3863863863863864), (('K',), 0, 0.3813813813813814), (('F',), 0, 0.36036036036036034), (('E',), 0, 0.3563563563563564), (('C',), 0, 0.3563563563563564), (('Y',), 0, 0.3533533533533533), (('T',), 0, 0.3533533533533533), (('R',), 0, 0.35135135135135137), (('Q',), 0, 0.34134134134134136), (('Z',), 0, 0.32232232232232233), (('B',), 0, 0.27627627627627627), (('H',), 0, 0.26426426426426425), (('G',), 0, 0.25725725725725723), (('O',), 0, 0.25725725725725723), (('S',), 0, 0.15615615615615616), (('X',), 0, 0.15615615615615616)]\n",
    "\n",
    "input_shape = (16,)\n",
    "\n",
    "branch_defs = [\n",
    "    (tuple(map(lambda l: letter_to_class[l], trip[0])), trip[1], trip[2])\n",
    "    for trip in branch_defs\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "efedd4e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the SDE and import the model\n",
    "sde = sdes.VariancePreservingSDE(0.1, 20, input_shape)\n",
    "\n",
    "t_limit = 1\n",
    "model = model_util.load_model(\n",
    "    table_dnn.MultitaskTabularNet,\n",
    "    os.path.join(models_base_path, \"letters_continuous_allletters/1/epoch_100_ckpt.pth\")\n",
    ").to(DEVICE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c1e50cd6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling class: A\n",
      "Sampling class: B\n",
      "Sampling class: C\n",
      "Sampling class: D\n",
      "Sampling class: E\n",
      "Sampling class: F\n",
      "Sampling class: G\n",
      "Sampling class: H\n",
      "Sampling class: I\n",
      "Sampling class: J\n",
      "Sampling class: K\n",
      "Sampling class: L\n",
      "Sampling class: M\n",
      "Sampling class: N\n",
      "Sampling class: O\n",
      "Sampling class: P\n",
      "Sampling class: Q\n",
      "Sampling class: R\n",
      "Sampling class: S\n",
      "Sampling class: T\n",
      "Sampling class: U\n",
      "Sampling class: V\n",
      "Sampling class: W\n",
      "Sampling class: X\n",
      "Sampling class: Y\n",
      "Sampling class: Z\n",
      "Total time taken: 115s\n"
     ]
    }
   ],
   "source": [
    "# Sample each class individually without taking advantage of branches\n",
    "letter_samples = {}\n",
    "time_a = time.time()\n",
    "for class_to_sample in classes:\n",
    "    print(\"Sampling class: %s\" % class_to_letter[class_to_sample])\n",
    "    letter_samples[class_to_sample] = generate.generate_continuous_branched_samples(\n",
    "        model, sde, class_to_sample, class_time_to_branch_tensor,\n",
    "        sampler=\"pc\", t_limit=t_limit, num_steps=1000\n",
    "    )\n",
    "time_b = time.time()\n",
    "linear_time = time_b - time_a\n",
    "print(\"Total time taken: %ds\" % linear_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "0de8f9d2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sampling branch 1/51\n",
      "Sampling branch 2/51\n",
      "Sampling branch 3/51\n",
      "Sampling branch 4/51\n",
      "Sampling branch 5/51\n",
      "Sampling branch 6/51\n",
      "Sampling branch 7/51\n",
      "Sampling branch 8/51\n",
      "Sampling branch 9/51\n",
      "Sampling branch 10/51\n",
      "Sampling branch 11/51\n",
      "Sampling branch 12/51\n",
      "Sampling branch 13/51\n",
      "Sampling branch 14/51\n",
      "Sampling branch 15/51\n",
      "Sampling branch 16/51\n",
      "Sampling branch 17/51\n",
      "Sampling branch 18/51\n",
      "Sampling branch 19/51\n",
      "Sampling branch 20/51\n",
      "Sampling branch 21/51\n",
      "Sampling branch 22/51\n",
      "Sampling branch 23/51\n",
      "Sampling branch 24/51\n",
      "Sampling branch 25/51\n",
      "Sampling branch 26/51\n",
      "Sampling branch 27/51\n",
      "Sampling branch 28/51\n",
      "Sampling branch 29/51\n",
      "Sampling branch 30/51\n",
      "Sampling branch 31/51\n",
      "Sampling branch 32/51\n",
      "Sampling branch 33/51\n",
      "Sampling branch 34/51\n",
      "Sampling branch 35/51\n",
      "Sampling branch 36/51\n",
      "Sampling branch 37/51\n",
      "Sampling branch 38/51\n",
      "Sampling branch 39/51\n",
      "Sampling branch 40/51\n",
      "Sampling branch 41/51\n",
      "Sampling branch 42/51\n",
      "Sampling branch 43/51\n",
      "Sampling branch 44/51\n",
      "Sampling branch 45/51\n",
      "Sampling branch 46/51\n",
      "Sampling branch 47/51\n",
      "Sampling branch 48/51\n",
      "Sampling branch 49/51\n",
      "Sampling branch 50/51\n",
      "Sampling branch 51/51\n",
      "Total time taken: 69s\n"
     ]
    }
   ],
   "source": [
    "# Sample each digit by taking advantage of branches\n",
    "\n",
    "# Sort the branches by starting time point (in reverse order), and generate along those\n",
    "# branches, caching results; this guarantees that we will always find a cached batch\n",
    "# (other than the first one)\n",
    "cache = {}\n",
    "sorted_branch_defs = sorted(branch_defs, key=(lambda t: -t[1]))\n",
    "\n",
    "time_a = time.time()\n",
    "# First branch\n",
    "print(\"Sampling branch 1/%d\" % len(sorted_branch_defs))\n",
    "branch_def = sorted_branch_defs[0]\n",
    "samples = generate.generate_continuous_branched_samples(\n",
    "    # Specify arbitrary class\n",
    "    model, sde, branch_def[0][0], class_time_to_branch_tensor,\n",
    "    sampler=\"pc\", t_limit=branch_def[2], t_start=branch_def[1],\n",
    "    num_steps=int(1000 * (branch_def[2] - branch_def[1]))\n",
    ")\n",
    "for class_i in branch_def[0]:\n",
    "    cache[class_i] = (branch_def[1], samples)\n",
    "\n",
    "for i, branch_def in enumerate(sorted_branch_defs[1:]):\n",
    "    print(\"Sampling branch %d/%d\" % (i + 2, len(sorted_branch_defs)))\n",
    "    cached_time, cached_samples = cache[branch_def[0][0]]\n",
    "    assert cached_time == branch_def[2]\n",
    "    samples = generate.generate_continuous_branched_samples(\n",
    "        model, sde, branch_def[0][0], class_time_to_branch_tensor,\n",
    "        sampler=\"pc\", t_limit=branch_def[2], t_start=branch_def[1],\n",
    "        num_steps=int(1000 * (branch_def[2] - branch_def[1])),\n",
    "        initial_samples=cached_samples\n",
    "    )\n",
    "    for class_i in branch_def[0]:\n",
    "        cache[class_i] = (branch_def[1], samples)\n",
    "    \n",
    "time_b = time.time()\n",
    "branched_time = time_b - time_a\n",
    "print(\"Total time taken: %ds\" % branched_time)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "26111439",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeYAAAHgCAYAAABwys7SAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAZ0klEQVR4nO3df7xkd13f8ffHbEAgAolZ8ggJbZAGaOABAdZUQBEMvyxCQpUSH6KrpKZaBNGiBq0PqBQJoFSLgkZElopg+KGkQYGwEFAwwIaEhBAQKhFWIlmQICgFEz79Y866w3rv7s2de3e/uff5fDzuY86cOTPne2d39rXnzMw51d0BAMbwDYd7AADAPsIMAAMRZgAYiDADwECEGQAGIswAMJAtB1ugql6e5HuSXN/d957mvTDJY5N8Ncn/TfIj3X3DdNszk5yd5KYkT+vutxxsHccee2yfdNJJq/wVAOCW57LLLvtsd2/df34d7HvMVfWQJF9K8sq5MD8yydu7+8aqen6SdPfPVdUpSV6d5LQkd07ytiR37+6bDrSObdu29a5du1bxawHALVNVXdbd2/aff9Bd2d39riR/t9+8t3b3jdPVS5OcOE2fkeQ13f2V7v5Eko9nFmkAYAXW4j3mJyf502n6hCSfmrtt9zQPAFiBhcJcVb+Q5MYkr9o7a4nFltxXXlXnVNWuqtq1Z8+eRYYBABvGqsNcVdsz+1DYD/S+N6p3J7nL3GInJvn0Uvfv7vO7e1t3b9u69V+89w0Am9KqwlxVj07yc0ke193/OHfThUnOqqpbV9Vdk5yc5H2LDxMANoeVfF3q1UkemuTYqtqd5FlJnpnk1kkurqokubS7f6y7r66qC5J8OLNd3E852CeyAYB9Dvp1qUPB16UA2GxW/XUpAODQEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGctCzS91SnXTumw73EGDNXHveYw73EIBDxBYzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIAcNc1W9vKqur6oPzc07pqourqqPTZdHz932zKr6eFV9tKoetV4DB4CNaCVbzK9I8uj95p2bZGd3n5xk53Q9VXVKkrOS3Gu6z0uq6og1Gy0AbHAHDXN3vyvJ3+03+4wkO6bpHUnOnJv/mu7+Snd/IsnHk5y2NkMFgI1vte8xH9fd1yXJdHmnaf4JST41t9zuaR4AsAJr/eGvWmJeL7lg1TlVtauqdu3Zs2eNhwEAt0yrDfNnqur4JJkur5/m705yl7nlTkzy6aUeoLvP7+5t3b1t69atqxwGAGwsqw3zhUm2T9Pbk7xxbv5ZVXXrqrprkpOTvG+xIQLA5rHlYAtU1auTPDTJsVW1O8mzkpyX5IKqOjvJJ5M8IUm6++qquiDJh5PcmOQp3X3TOo0dADacg4a5u79/mZtOX2b55yZ57iKDAoDNypG/AGAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgSwU5qr6qaq6uqo+VFWvrqpvrKpjquriqvrYdHn0Wg0WADa6VYe5qk5I8rQk27r73kmOSHJWknOT7Ozuk5PsnK4DACuw6K7sLUluU1Vbktw2yaeTnJFkx3T7jiRnLrgOANg0Vh3m7v6bJL+S5JNJrkvyhe5+a5Ljuvu6aZnrktxpLQYKAJvBIruyj85s6/iuSe6c5HZV9aSbcf9zqmpXVe3as2fPaocBABvKIruyH57kE929p7v/KckbkjwoyWeq6vgkmS6vX+rO3X1+d2/r7m1bt25dYBgAsHEsEuZPJvm2qrptVVWS05Nck+TCJNunZbYneeNiQwSAzWPLau/Y3e+tqtcl+UCSG5NcnuT8JEcluaCqzs4s3k9Yi4ECwGaw6jAnSXc/K8mz9pv9lcy2ngGAm8mRvwBgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADGTL4R4AsDGddO6bDvcQYM1ce95jDtm6bDEDwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgC4W5qu5YVa+rqo9U1TVV9cCqOqaqLq6qj02XR6/VYAFgo1t0i/nXk7y5u++Z5L5JrklybpKd3X1ykp3TdQBgBVYd5qq6fZKHJPndJOnur3b3DUnOSLJjWmxHkjMXGyIAbB6LbDF/S5I9SX6vqi6vqpdV1e2SHNfd1yXJdHmnpe5cVedU1a6q2rVnz54FhgEAG8ciYd6S5P5JXtrd90vyD7kZu627+/zu3tbd27Zu3brAMABg41gkzLuT7O7u907XX5dZqD9TVccnyXR5/WJDBIDNY9Vh7u6/TfKpqrrHNOv0JB9OcmGS7dO87UneuNAIAWAT2bLg/Z+a5FVVdaskf5XkRzKL/QVVdXaSTyZ5woLrAIBNY6Ewd/cVSbYtcdPpizwuAGxWjvwFAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEsHOaqOqKqLq+qi6brx1TVxVX1seny6MWHCQCbw1psMf9kkmvmrp+bZGd3n5xk53QdAFiBhcJcVScmeUySl83NPiPJjml6R5IzF1kHAGwmi24x/1qSn03ytbl5x3X3dUkyXd5pwXUAwKax6jBX1fckub67L1vl/c+pql1VtWvPnj2rHQYAbCiLbDE/OMnjquraJK9J8l1V9ftJPlNVxyfJdHn9Unfu7vO7e1t3b9u6desCwwCAjWPVYe7uZ3b3id19UpKzkry9u5+U5MIk26fFtid548KjBIBNYj2+x3xekkdU1ceSPGK6DgCswJa1eJDuviTJJdP055KcvhaPCwCbjSN/AcBAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAA1l1mKvqLlX1jqq6pqqurqqfnOYfU1UXV9XHpsuj1264ALCxLbLFfGOS/9rd/zbJtyV5SlWdkuTcJDu7++QkO6frAMAKrDrM3X1dd39gmv5ikmuSnJDkjCQ7psV2JDlzwTECwKaxJu8xV9VJSe6X5L1Jjuvu65JZvJPcaS3WAQCbwcJhrqqjkrw+ydO7++9vxv3OqapdVbVrz549iw4DADaEhcJcVUdmFuVXdfcbptmfqarjp9uPT3L9Uvft7vO7e1t3b9u6desiwwCADWORT2VXkt9Nck13v2jupguTbJ+mtyd54+qHBwCby5YF7vvgJD+Y5KqqumKa9/NJzktyQVWdneSTSZ6w0AgBYBNZdZi7+8+T1DI3n77axwWAzcyRvwBgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIEIMwAMRJgBYCDCDAADEWYAGIgwA8BAhBkABiLMADAQYQaAgQgzAAxEmAFgIMIMAAMRZgAYiDADwECEGQAGIswAMBBhBoCBCDMADESYAWAgwgwAAxFmABiIMAPAQIQZAAYizAAwEGEGgIGsW5ir6tFV9dGq+nhVnbte6wGAjWRdwlxVRyT5zSTfneSUJN9fVaesx7oAYCNZry3m05J8vLv/qru/muQ1Sc5Yp3UBwIaxXmE+Icmn5q7vnuYBAAewZZ0et5aY11+3QNU5Sc6Zrn6pqj66TmNhfR2b5LOHexAbXT3/cI+AgXkNHgLr9Br810vNXK8w705yl7nrJyb59PwC3X1+kvPXaf0cIlW1q7u3He5xwGblNbjxrNeu7PcnObmq7lpVt0pyVpIL12ldALBhrMsWc3ffWFU/keQtSY5I8vLuvno91gUAG8l67cpOd/9Jkj9Zr8dnGN6OgMPLa3CDqe4++FIAwCHhkJwAMBBhJlX1pSXm/VhV/dDhGA9sNFV1U1VdUVUfrKoPVNWD1nl9J1XVhxZ8jH/x7wKHxrq9x8wtW3f/1no+flVVZm+lfG091wOD+HJ3n5okVfWoJM9L8p3zC1TVEd1902EYG4OxxcySqurZVfWMafqSqnp+Vb2vqv6yqr5jmn9EVb2wqt5fVVdW1X+e5h9VVTunLYOrquqMaf5JVXVNVb0kyQfy9d91h83i9kk+nyRV9dCqekdV/UGSq6Z5f1xVl1XV1dOBmDLN/1JVPXfa6r60qo6b5h9XVX80zf/g3Nb4EVX1O9PjvLWqbjMtf7eqevO0jj+rqntO8+9aVX8xvZ6fcwifD/YjzKzUlu4+LcnTkzxrmnd2ki9097cm+dYkP1pVd03y/5I8vrvvn+RhSX512kJOknskeWV336+7//qQ/gZw+Nxm2pX9kSQvSzIfvtOS/EJ37z3Rz5O7+wFJtiV5WlV98zT/dkku7e77JnlXkh+d5v+vJO+c5t8/yd6vpp6c5De7+15JbkjyvdP885M8dVrHM5K8ZJr/60leOr2e/3aNfm9Wwa5sVuoN0+VlSU6aph+Z5D5V9X3T9Ttk9o/B7iS/XFUPSfK1zI6Tfty0zF9396WHZMQwjvld2Q9M8sqquvd02/u6+xNzyz6tqh4/Td8ls9fU55J8NclF0/zLkjximv6uJD+UJNOu8C9U1dFJPtHdV8wtf1JVHZXkQUleu+//yrn1dPng7Iv3/07iQLCHiTCzUl+ZLm/Kvr83ldn/vN8yv2BV/XCSrUke0N3/VFXXJvnG6eZ/WP+hwri6+y+q6tjMXiPJ3Guiqh6a5OFJHtjd/1hVl2Tfa+efet/3W+dfh8v5ytz0TUluk9le0hv2/idhqeGt7LdgPdmVzSLekuTHq+rIJKmqu1fV7TLbcr5+ivLDssyB2mEzmt7TPSKzreD93SHJ56co3zPJt63gIXcm+fHpsY+oqtsvt2B3/32ST1TVE6blq6ruO9387swOn5wkP7CiX4Z1IcwkyW2ravfcz0+v8H4vS/LhJB+Yvprx25n9L/5VSbZV1a7MXuAfWZdRwy3H3veYr0jyh0m2L/MJ7Dcn2VJVV2b2PvRK3vb5ySQPq6qrMttlfa+DLP8DSc6uqg9m9n70GXOP85Sqen9m/0HgMHHkLwAYiC1mABiIMAPAQIQZAAYizAAwEGEGgIEIM+umqr5571dEqupvq+pvpukvTcfLXuv1nVlVpxx8yTFV1alV9e9Xcb83V9UNVXXRAZa5pKq2rfVYFjkDUVXduqreNv2deOJqH+cg61hkfNdOBwJZc1V1v6p62QFu31pVb16PdTM+YWbddPfnuvvU6ShDv5Xkf07Xj+ru/7IOqzwzydBhrqoDHa3p1CQ3O8xJXpjkB1c1oOWdmtWN5ea4X5Ijp78Tf7jO6xrNzyd58XI3dveeJNdV1YMP3ZAYhTBzyE1n1Llomn52Ve2Yzn5zbVX9h6p6Qc3OSvXmuaOKPaCq3jmdEectVXX8fo/5oCSPS/LCaQvsbtNW36U1O/PVH03HD95/LHeblnl/Vf3S/BZWVf1M7Ttz1n+f5u09Q9bNOWvPK6rqRVX1jiTPr6rTquo9VXX5dHmPqrpVkl9K8sS9W5BVdbuqevk0hstrOkvX/rp7Z5IvruCpf9K0vg9V1WnT2FY6lqOq6vemP5crq2rvMZVTS5zxaL/n+JianTHpymmZ+1TVnZL8fpJT9/55LfHnstRz+diqeu803rfVvjMsLTK+Ze87t8y/OONTzY6y9Yrp+byqqn5qmv+0qvrw9FivWeKxvinJfbr7g9P176x9e5Yun25Pkj+OI3BtTt3tx8+6/yR5dpJnTNMPTXLR3Pw/T3Jkkvsm+cck3z3d9keZbQUfmeQ9SbZO85+Y5OVLrOMVSb5v7vqVSb5zmv6lJL+2xH0uSvL90/SPJfnSNP3IzM7CU5n9B/aiJA/J7AQeNyY5dVrugiRPmqZ3Jjl5mv53Sd4+N66LkhwxXb99ZmfrSmbHRX79NP3DSX5jbmy/PPfYd0zyl0lut8zz+8/P6TK3X5Lkd6bphyT50M0cy/Pnn78kR0+XneSx0/QLkvy3Jdb94iTPmqa/K8kVBxvzAZ7Lo7PvwEj/KcmvrsH4lrvvtUmOnaaPmS5vk+RDSb45yQOSXDx3vztOl59Ocuv5efut72F7n+fp+v9J8uBp+qi5P48Tklx1uF+7fg79j5NYMII/7dlxta/K7BjCe99buyqzEN4jyb2TXFyzM+IckeS6Az1gVd0hs38U3znN2pHktUss+sDM4p8kf5DkV6bpR04/l0/Xj8rsLD+fzM0/a0+SvLb3HYLxDkl2VNXJmYXjyGV+jUcmeVxN58XO7GQG/yrJNcssfzCvTpLufldV3b6q7pjkm1Y4lodn33GU092fnyaXO+PRvG/PdNai7n57zT57sOwhHw/yXJ6Y5A9rtsfkVkn2npVpkfEtd995S53x6aNJvqWqXpzkTUneOt1+ZZJXVdUfZ7bVu7/jk+yZu/7uJC+qqlcleUN3757mX5/kzkvcnw1OmBnBV5Kku79WVfNn0PlaZn9HK8nV3f3AQzimSvK87v7tr5tZdVJWd9ae+bNqPSfJO7r78dPjXXKAMXxvd3/05g5+Gfsff7dv5liWOn7vSs54VEvMO9CxgA/0XL44yYu6+8KanYnp2Ws0vmXHU8uc8am7P1+zE0A8KslTkvzHJE9O8pjM9ko8LskvVtW9uvvGuYf8cvadMSrdfV5VvSmz9/QvraqHd/dHpmW+vNy42Li8x8wtwUeTbK3ZeWxTVUdW1VIH6v9iZluA6e4vJPl8VX3HdNsPJnnnEve5NPvOQXvW3Py3JHnytPWWqjphel90SX3gs/bs7w5J/maa/uGlxj83hqfWtNlYVfdbbv0r9MTpcb49yRem52ilY3lrkp/Ye6WWeL/+AN6V6b3SKXKfnZ6vJR3kuZwf7/Y1Gt/B7rvkGZ9q9ontb+ju1yf5xST3r6pvSHKX7n5Hkp/N7C2Io/Z7vGuS/Ju59d2tu6/q7ucn2ZXkntNNd89stzmbjDAzvO7+apLvy+yDUx9MckVmuzr395okPzN9gOZumf3D/cKanann1MzeZ97f05P8dFW9L7NdjF+Y1vnWzHZt/8W0i/11+fpQLWW5s/bs7wVJnldV785st/xe70hySu37+tBzMtu1fGXNzt71nKUerKr+LLPd9KfX7Oxgj1pmvZ+vqvdk9gn5s2/mWP5HkqOnDzp9MLP3SVfq2ZmdbezKJOfl64O6nOWey2dntov7z5J8dm75RcZ3sPsud8anE5JcUrMzRr0iyTMzew5/f/o7c3lm30S4Yf7Bpq3hO8x9yOvpc+v+cpI/neY/LLNd5Gwyzi7FplZVt03y5e7uqjorsw+CLRdUWBPTJ7i/2N0H+i7zu5Kcscx73mxg3mNms3tAkt+YdhffkNl7hLDeXprkCcvdWFVbM3svXZQ3IVvMADAQ7zEDwECEGQAGIswAMBBhBoCBCDMADESYAWAg/x/CrdXT5tSG2AAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 576x576 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "fig, ax = plt.subplots(figsize=(8, 8))\n",
    "labels = [\"Linear\", \"Branched\"]\n",
    "times = [linear_time, branched_time]\n",
    "ax.bar(labels, times)\n",
    "ax.set_xlabel(\"Time to generate 1 batch of each class (s)\")\n",
    "plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
