{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "1222c627",
   "metadata": {
    "id": "1222c627"
   },
   "source": [
    "# Exploring Diverse Solutions for Underdetermined Problems\n",
    "\n",
    "This notebook accompanies the workshop paper \"Exploring Diverse Solutions for Underdetermined Problems\".\n",
    "\n",
    "The goal of this notebook is to illustrate the nearest-neighbor diversity loss on finite vector and function spaces.\n",
    "The notebook consists of 3 parts.\n",
    "1) **Horseshoe**: We start with a finite vector space examples and look at some variants and properties of the nearest-neighbor and Leinster diversity loss.\n",
    "2) **Flat parametric curve**: We then show how the nearest neighbor diversity loss acts on simple parametric curves in the flat plane.\n",
    "3) **Parametric curve on manifold**: Lastly, we show the nearest neighbor diversity on a parametric"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f597c855",
   "metadata": {
    "id": "f597c855"
   },
   "outputs": [],
   "source": [
    "## GLOBAL IMPORTS\n",
    "import math\n",
    "from functools import partial\n",
    "from tqdm import trange\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import relu\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# utilities import\n",
    "from util import tensor_product_xz\n",
    "from model_defs import Net, ConditionalNet\n",
    "from sampling_primitives import sample_bbox, get_meshgrid_in_domain2d, get_meshgrid_in_domain3d\n",
    "from horse_shoe import horse_shoe_sdf, get_horse_shoe_bounds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "66d82b0d",
   "metadata": {
    "id": "66d82b0d"
   },
   "outputs": [],
   "source": [
    "## SET SEEDS\n",
    "np.random.seed(42)\n",
    "\n",
    "torch.manual_seed(42)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(42)\n",
    "torch.backends.cudnn.deterministic = True\n",
    "torch.backends.cudnn.benchmark = False"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d022685",
   "metadata": {
    "id": "2d022685",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## 1. Horseshoe\n",
    "\n",
    "In this example we show the nearest-neighbor diversity loss on finite vector spaces, concretely $\\mathbb{R}^2$.\n",
    "\n",
    "We introduce a design region or envelope where points shall be in.\n",
    "For this example, the design region is a horse shoe and it is implicitly defined by a signed distance function (SDF).\n",
    "An SDF defines a shape by assigning each point in the domain a scalar value. This value is equal to the distance to the boundary of the shape. To distinguish inside and outside of the shape, inside regions have a negative sign.\n",
    "\n",
    "More formally, let $\\phi(x)$ be a signed distance function. Then the domain $\\Omega$ is defined as:\n",
    "\n",
    "$\n",
    "\\Omega = \\{ x \\in \\mathbb{R}^n \\mid \\phi(x) < 0 \\}\n",
    "$\n",
    "\n",
    "The boundary of the shape is given by the levelset 0, as the distance to the boundary is 0:\n",
    "\n",
    "$\n",
    "\\partial \\Omega = \\{ x \\in \\mathbb{R}^n \\mid \\phi(x) = 0 \\}\n",
    "$\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a1a617f3",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 424
    },
    "id": "a1a617f3",
    "outputId": "87f61d70-4b91-427e-85ab-7ba2ea5396ed"
   },
   "outputs": [],
   "source": [
    "horse_shoe_bounds = get_horse_shoe_bounds()\n",
    "X0_horse_shoe, X1_horse_shoe, pts = get_meshgrid_in_domain2d(horse_shoe_bounds)\n",
    "sdf_horse_shoe = horse_shoe_sdf(pts)\n",
    "sdf_horse_shoe = sdf_horse_shoe.reshape(X0_horse_shoe.shape).detach()\n",
    "\n",
    "im = plt.contourf(X0_horse_shoe, X1_horse_shoe, sdf_horse_shoe, levels=50)\n",
    "plt.contour(X0_horse_shoe, X1_horse_shoe, sdf_horse_shoe, levels=[0], colors='k')\n",
    "plt.axis('scaled')\n",
    "plt.colorbar(im)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5eb077df",
   "metadata": {
    "id": "5eb077df"
   },
   "source": [
    "Next, we define the nearest neighbor diversity loss as Berzins et al.\n",
    "This loss pushes each element of a set away from its nearest neighbor.\n",
    "\n",
    "To highlight some important aspects of this loss, we put it into contrast with another diversity loss, the repell_all_diversity.\n",
    "As the name implies, each element of a set is pushed away from all others.\n",
    "We will see later that this fails to establish diversity among points in the horseshoe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "118737c5",
   "metadata": {
    "id": "118737c5"
   },
   "outputs": [],
   "source": [
    "def nearest_neighbor_diversity(D, p=.5):\n",
    "    \"\"\"\n",
    "    Takes the pairwise dissimilarity matrix D and computes a diversity loss by repelling the nearest neighbor.\n",
    "    The power p should be <= 1 to make the loss concave.\n",
    "    \"\"\"\n",
    "    # Create a mask to exclude diagonal elements\n",
    "    D = D.masked_fill(torch.eye(D.size(0), dtype=torch.bool), float('inf'))\n",
    "    nearest_neighbor_d, _ = D.min(dim=1)\n",
    "    return -(nearest_neighbor_d.pow(p)).mean().pow(1/p)\n",
    "\n",
    "def leinster_diversity(D, q=.5):\n",
    "    \"\"\"\n",
    "    Takes the pairwise dissimilarity matrix D and computes the Leinster diversity loss.\n",
    "    We follow Equation 6.5 in Leinster's book (https://arxiv.org/pdf/2012.02113).\n",
    "    The sampling probability p is assumed to be 1/n * one-vector.\n",
    "    \"\"\"\n",
    "    assert D.ndim == 2 and D.shape[0] == D.shape[1], \"Assure quadratic matrix\"\n",
    "    assert q > 0 and q != 1 and math.isfinite(q), \"valid range violated, see Leinster, p. 175\"\n",
    "\n",
    "    # convert to similarity matrix using formula on bottom of p. 172\n",
    "    Z = torch.exp(-D)\n",
    "\n",
    "    # convert matrix\n",
    "    Zp = Z.mean(dim=-1)\n",
    "    term_sum = torch.pow(Zp, q-1).mean()\n",
    "    div = term_sum ** (1 / (1-q))\n",
    "    return -div"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "986c9637",
   "metadata": {
    "id": "986c9637"
   },
   "source": [
    "To keep points with the design region, we introduce a design region loss, as Berzins et al.\n",
    "This loss enacts a counterforce to the diversity loss and constrains points to state within the horseshoe."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d17f454b",
   "metadata": {
    "id": "d17f454b"
   },
   "outputs": [],
   "source": [
    "def design_region_loss(y, sdf_func=horse_shoe_sdf):\n",
    "    \"\"\"Transform a SDF of a design region into an objective: 0 inside, distance squared outside.\"\"\"\n",
    "    loss_per_sample = sdf_func(y).relu().square()\n",
    "    return loss_per_sample.mean()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68916075",
   "metadata": {
    "id": "68916075",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Finite point set"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a5a21a0e",
   "metadata": {
    "id": "a5a21a0e"
   },
   "source": [
    "Next, we train our model with different variants of the nearest neighbor and Leinster's diversity loss.\n",
    "Importantly, we ablate different exponents."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f12ae6bc",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "f12ae6bc",
    "outputId": "af66c1e3-2ed7-4d57-ab33-500fee87e27c"
   },
   "outputs": [],
   "source": [
    "def train_horse_shoe(div_loss_fn, lambda_div, n_points=1000, n_iter=1000):\n",
    "\n",
    "    x = sample_bbox(horse_shoe_bounds, N=n_points).detach()\n",
    "    x.requires_grad_(True)\n",
    "    opt = torch.optim.Adam([x], lr=1e-2)\n",
    "\n",
    "    for i in (pbar := trange(n_iter)):\n",
    "        opt.zero_grad()\n",
    "        loss_obj = design_region_loss(x)\n",
    "        pairwise_dist_mat = torch.cdist(x, x)\n",
    "        loss_div = lambda_div*div_loss_fn(pairwise_dist_mat)\n",
    "        loss = loss_obj + loss_div\n",
    "        loss.backward()\n",
    "        opt.step()\n",
    "\n",
    "        pbar.set_description(f\"{loss_obj.item():.2e}, {loss_div.item():.2e}\")\n",
    "\n",
    "    return x.detach()\n",
    "\n",
    "lambda_div=1e-2\n",
    "n_points = 1000\n",
    "x_nn_05 = train_horse_shoe(partial(nearest_neighbor_diversity, p=0.5),lambda_div=1e-2)\n",
    "x_nn_1 = train_horse_shoe(partial(nearest_neighbor_diversity, p=1),lambda_div=1e-2)\n",
    "x_nn_2 = train_horse_shoe(partial(nearest_neighbor_diversity, p=2),lambda_div=1e-2)\n",
    "\n",
    "x_hill_05 = train_horse_shoe(partial(leinster_diversity, q=0.5),lambda_div=lambda_div, n_points=n_points)\n",
    "x_hill_1 = train_horse_shoe(partial(leinster_diversity, q=0.99),lambda_div=lambda_div, n_points=n_points)\n",
    "x_hill_2 = train_horse_shoe(partial(leinster_diversity, q=2),lambda_div=lambda_div, n_points=n_points)\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5e3580b8",
   "metadata": {
    "id": "5e3580b8"
   },
   "source": [
    "The following cell visualizes the results. There are two things to learn.\n",
    "First, using the Leinster loss with similarity repells from all other points. While the repelling happens with smaller weight for farer away points, the loss is dominated by the biggest distances. This pushes the points to the outer boundary of the domain.\n",
    "Second, the nearest neighbor loss should be concave, which is achieved by using a power $p <= 1$. This way smaller distances have a larger gradient than greater distances.\n",
    "\n",
    "The points under the nearest neighbor loss behave similarly to molecules in a room, which are repelled only by their neighbors and can therefore more evenly spread."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3683b55d",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 679
    },
    "id": "3683b55d",
    "outputId": "9ba18ae2-cad1-42f2-d18f-b17614835c21"
   },
   "outputs": [],
   "source": [
    "\n",
    "fig, axs = plt.subplots(2, 3, figsize=(12, 8))  # 2 rows, 3 columns\n",
    "axs = axs.ravel()\n",
    "\n",
    "point_sets = [x_nn_05, x_nn_1, x_nn_2, x_hill_05, x_hill_1, x_hill_2]  # Assuming you have x_nn5 and x_nn6\n",
    "titles = [\n",
    "    r'$\\delta_p$, p=0.5',\n",
    "    r'$\\delta_p$, p=1',\n",
    "    r'$\\delta_p$, p=2',\n",
    "    r'$D_q^Z$, q=0.5',\n",
    "    r'$D_q^Z$, q=0.99',\n",
    "    r'$D_q^Z$, q=2',\n",
    "    ]\n",
    "\n",
    "for i, (ax, pts) in enumerate(zip(axs, point_sets)):\n",
    "    ax.contour(X0_horse_shoe, X1_horse_shoe, sdf_horse_shoe, levels=[0], colors='k')\n",
    "    ax.contourf(X0_horse_shoe, X1_horse_shoe, sdf_horse_shoe, levels=[sdf_horse_shoe.min(), 0, sdf_horse_shoe.max()], colors=['#0000ff', 'white'], alpha=0.5)\n",
    "    ax.scatter(*pts.T.detach(), c='r', label='y', marker='.', s=2)\n",
    "    ax.axis('off')\n",
    "    ax.axis('scaled')\n",
    "    ax.set_title(titles[i], fontsize=16)  # Set the title for each subplot\n",
    "\n",
    "plt.subplots_adjust(hspace=0.1, wspace=-0.1)\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4e8bd5c4",
   "metadata": {
    "id": "4e8bd5c4",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Neural Curve"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b21ace58",
   "metadata": {
    "id": "b21ace58"
   },
   "source": [
    "Next, we define a simple MLP which takes a modulation vector $z \\in [0, 1]$ as input and outputs 2d points. We optimize it to map points into the horseshoe and to increase diversity.\n",
    "\n",
    "To run the experiment with diversity simply comment out the loss term in the training loop."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aad2a092",
   "metadata": {
    "id": "aad2a092"
   },
   "outputs": [],
   "source": [
    "ny = 2\n",
    "nx = 1\n",
    "x_range = [-1, 1]\n",
    "n_samples = 100\n",
    "n_iter = 10000\n",
    "\n",
    "lambda_design = 10\n",
    "lambda_div = 1\n",
    "\n",
    "model2 = Net(layer_widths=[nx, 40, 40, ny])\n",
    "opt = torch.optim.Adam(model2.parameters(), lr=3e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6547e10f",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "6547e10f",
    "outputId": "1896d0ac-c804-40fd-c704-bfeb8025e1ad"
   },
   "outputs": [],
   "source": [
    "## Logging\n",
    "loss_keys = ['design region', 'diversity']\n",
    "loss_over_iters = {key: {} for key in loss_keys}\n",
    "\n",
    "\n",
    "for i in (pbar := trange(n_iter)):\n",
    "    opt.zero_grad()\n",
    "\n",
    "    ## SPHERICALITY ##\n",
    "    x = torch.rand(size=(n_samples, 1)) * (x_range[1] - x_range[0]) + x_range[0]\n",
    "    x = torch.cat([x, torch.tensor([[0.0], [1.0]])])\n",
    "    y = model2(x)\n",
    "    loss_design = design_region_loss(y, sdf_func=horse_shoe_sdf)\n",
    "\n",
    "    ## DIVERSITY ##\n",
    "    D = torch.cdist(y, y)\n",
    "    diversity = nearest_neighbor_diversity(D, p=0.5)\n",
    "\n",
    "    loss_over_iters['design region'][i] = loss_design.item()\n",
    "    loss_over_iters['diversity'][i] = -diversity.item()\n",
    "\n",
    "    loss = lambda_design * loss_design + diversity * lambda_div\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "\n",
    "    #pbar.set_description(f\"design region: {loss_design.item():.2e}\")  # , diversity: {diversity.item():.2f}\")\n",
    "    pbar.set_description(f\"design region: {loss_design.item():.2e}, diversity: {diversity.item():.2f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec99d795",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 430
    },
    "id": "ec99d795",
    "outputId": "05ee58aa-c021-4a0a-9219-35f688b85e18"
   },
   "outputs": [],
   "source": [
    "plt.plot(list(loss_over_iters['design region'].keys()), list(loss_over_iters['design region'].values()), label='design region')\n",
    "plt.plot(list(loss_over_iters['diversity'].keys()), list(loss_over_iters['diversity'].values()), label='diversity')\n",
    "plt.semilogy()\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b41ed27",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 554
    },
    "id": "1b41ed27",
    "outputId": "cb6cf03e-725a-4963-d2b8-e5a65503f958"
   },
   "outputs": [],
   "source": [
    "xs = [torch.linspace(x_range[0], x_range[1], 10*n_samples) for i in range(nx)]\n",
    "Xs = torch.meshgrid(*xs)\n",
    "x = torch.vstack([X.flatten() for X in Xs]).T\n",
    "y = model2(x)\n",
    "points = y.detach().cpu().numpy()\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "\n",
    "ax = plt.gca()\n",
    "ax.contour(X0_horse_shoe, X1_horse_shoe, sdf_horse_shoe, levels=[0], colors='k', linewidths=2)\n",
    "ax.contourf(X0_horse_shoe, X1_horse_shoe, sdf_horse_shoe, levels=[sdf_horse_shoe.min(), 0, sdf_horse_shoe.max()], colors=['#0000ff', 'white'], alpha=0.5)\n",
    "ax.plot(points[:, 0], points[:, 1], c='r', linewidth=3)\n",
    "ax.axis('off')\n",
    "ax.axis('scaled')\n",
    "\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f803250e",
   "metadata": {
    "id": "f803250e",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## 2. Parametric function in 1D\n",
    "\n",
    "In this example, we will use the nearest_neighbor_diversity loss on elements of function space.\n",
    "The goal is to find a diverse set of curves that connect the points (-0.8, 0) with (0.8, 0).\n",
    "Again all the curves should stay within a design region (envelope), which in this case is a circle of radius 1.\n",
    "\n",
    "\n",
    "A neural network takes as input 2 variables x and z.\n",
    "- x denotes the position of the parameterized curve (also often denoted t in mathematics). Usually $t \\in [0, 1]$\n",
    "- z distinguishes different curves. This way we can parameterize an infinite number of curves\n",
    "\n",
    "\n",
    "We start by defining a pairwise distance on the curves."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d1d487dd",
   "metadata": {
    "id": "d1d487dd"
   },
   "outputs": [],
   "source": [
    "def pairwise_dist_curves(y):\n",
    "    \"\"\"\n",
    "    y: [num_curves, num_points, 2] tensor\n",
    "    pairwise_dist[i, j] contains the distance between curves i and j\n",
    "    \"\"\"\n",
    "    pairwise_dist = (y.unsqueeze(0) - y.unsqueeze(1)).norm(dim=-1).mean(-1)\n",
    "    return pairwise_dist"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cc93c36c",
   "metadata": {
    "id": "cc93c36c"
   },
   "source": [
    "Next, we train the curves, sampling different points x and different curve parameters z."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "577bbc4c",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 447
    },
    "id": "577bbc4c",
    "outputId": "4c4a246e-17d9-489b-e87b-f9a3e90aa0ed"
   },
   "outputs": [],
   "source": [
    "nx = 1\n",
    "ny = 2\n",
    "nz = 5\n",
    "\n",
    "n_iter = 3000\n",
    "lambda_div = 2\n",
    "lambda_envelope = 1\n",
    "lambda_interface = 10\n",
    "envelope_radius = 1\n",
    "n_latents_per_iter = 100\n",
    "n_points_per_iter = 50\n",
    "layer_widths = [nx+nz, 40, 40, ny]\n",
    "\n",
    "## End points\n",
    "x_endp = torch.tensor([[0.0], [1.0]])\n",
    "y_endp = torch.tensor([[-.8, 0],[.8,0]])\n",
    "\n",
    "# define model\n",
    "model2 = ConditionalNet(layer_widths=layer_widths)\n",
    "\n",
    "## Optimizer\n",
    "opt = torch.optim.Adam(model2.parameters(), lr=1e-2)\n",
    "\n",
    "# logging\n",
    "loss_over_iters = {}\n",
    "loss_over_iters['interface'] = {}\n",
    "loss_over_iters['envelope'] = {}\n",
    "loss_over_iters['diversity'] = {}\n",
    "\n",
    "\n",
    "## Train\n",
    "for i in (pbar := trange(n_iter)):\n",
    "    opt.zero_grad()\n",
    "\n",
    "    z = torch.rand(n_latents_per_iter, nz) ## number of codes, bz\n",
    "\n",
    "    ## INTERFACE\n",
    "    x_tp, z_tp = tensor_product_xz(x_endp, z)  # repeats and interleaves x and z\n",
    "    y = model2(x_tp, z_tp).reshape(len(z), len(x_endp), ny) ## [bz, bx, ny]\n",
    "    loss_interface = (y - y_endp[None,:,:]).square().sum() / len(y)\n",
    "\n",
    "    # FORWARD for envelope and diversity\n",
    "    x = torch.rand(n_points_per_iter, nx)\n",
    "    x_tp, z_tp = tensor_product_xz(x, z)\n",
    "    y = model2(x_tp, z_tp).reshape(len(z), len(x), ny) ## [bz, bx, ny]\n",
    "\n",
    "    ## ENVELOPE\n",
    "    r = y.norm(dim=2)\n",
    "    loss_envelope = relu(r-envelope_radius).square().sum()/nz\n",
    "\n",
    "    ## DIVERSITY\n",
    "    distances = pairwise_dist_curves(y)\n",
    "    loss_diversity =  nearest_neighbor_diversity(distances)\n",
    "\n",
    "    # logging\n",
    "    loss_over_iters['interface'][i] = loss_interface.item()\n",
    "    loss_over_iters['envelope' ][i] = loss_envelope.item()\n",
    "    loss_over_iters['diversity'][i] = -loss_diversity.item()\n",
    "\n",
    "    # total\n",
    "    loss = lambda_interface*loss_interface + lambda_envelope*loss_envelope + lambda_div*loss_diversity\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "\n",
    "    pbar.set_description(f\"interface: {loss_interface.item():.2e}, envelope: {loss_envelope.item():.2e}, diversity: {loss_diversity.item():.2e}\")\n",
    "\n",
    "\n",
    "plt.plot(list(loss_over_iters['interface'].keys()), list(loss_over_iters['interface'].values()), label='interface')\n",
    "plt.plot(list(loss_over_iters['envelope' ].keys()), list(loss_over_iters['envelope'].values()),  label='envelope')\n",
    "plt.plot(list(loss_over_iters['diversity'].keys()), list(loss_over_iters['diversity'].values()), label='diversity')\n",
    "plt.semilogy()\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4acd365d",
   "metadata": {
    "id": "4acd365d"
   },
   "source": [
    "The following cell plots a single parametric curve with a randomly sampled z."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4349e08",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 653
    },
    "id": "d4349e08",
    "outputId": "8f0487cc-459d-4e58-9130-67aa644e8d43"
   },
   "outputs": [],
   "source": [
    "n_curves_to_plot = 1\n",
    "x_resolution = 400\n",
    "display_envelope = False\n",
    "\n",
    "x = torch.linspace(0, 1, x_resolution)[:,None]\n",
    "z = torch.rand(n_curves_to_plot, nz) ## number of codes, bz\n",
    "\n",
    "x_tp, z_tp = tensor_product_xz(x, z)\n",
    "y = model2(x_tp, z_tp).detach().reshape(len(z), len(x), ny)\n",
    "y = y.detach().numpy()\n",
    "\n",
    "import matplotlib.patches as patches\n",
    "fig, ax = plt.subplots(figsize=(8,8))\n",
    "ax.plot(*y.T, c='k', alpha=0.8)\n",
    "ax.scatter(*y_endp.T, c='k')\n",
    "ax.axis('equal')\n",
    "ax.axis('off')\n",
    "if display_envelope:\n",
    "    circle = patches.Circle((0, 0), radius=envelope_radius, fill=False, color='lightblue', linewidth=2)\n",
    "    ax.add_patch(circle)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8da77c59",
   "metadata": {
    "id": "8da77c59"
   },
   "source": [
    "Now we plot multiple samples on a linear trajectory in the z space.\n",
    "The diversity loss successfully pushes the curves apart."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ce25ef7a",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 725
    },
    "id": "ce25ef7a",
    "outputId": "e0e372d2-9292-4388-cb48-c3d37a12967b"
   },
   "outputs": [],
   "source": [
    "steps = 10\n",
    "x_resolution = 400\n",
    "interpolate_random_endpoints = False\n",
    "display_envelope = False\n",
    "plot_use_alpha = True\n",
    "\n",
    "fig, ax = plt.subplots(1, 1, figsize=(8, 8))\n",
    "\n",
    "if interpolate_random_endpoints:\n",
    "    z0 = torch.rand(1,nz) ## latent code start\n",
    "    z1 = torch.rand(1,nz) ## latent code stop\n",
    "    print(z0-z1)\n",
    "else:\n",
    "    z0 = torch.tensor([[1.0,]*nz]) ## latent code start\n",
    "    z1 = torch.tensor([[0.0,]*nz]) ## latent code stop\n",
    "\n",
    "dz = (z1-z0)/steps\n",
    "x = torch.linspace(0, 1, x_resolution)[:,None]\n",
    "if plot_use_alpha:\n",
    "    alphas = np.hstack([np.linspace(.2, 1.0, steps//2), np.linspace(1.0, .2, steps//2)])\n",
    "else:\n",
    "    cmap = 'winter'\n",
    "    colormap = plt.get_cmap(cmap)\n",
    "    colors = [colormap(i) for i in np.linspace(0, 1, steps)]\n",
    "\n",
    "for i in trange(steps):\n",
    "    z = z0 + dz*i\n",
    "    x_tp, z_tp = tensor_product_xz(x, z)\n",
    "    y = model2(x_tp, z_tp).detach().reshape(len(z), len(x), ny)\n",
    "    if plot_use_alpha:\n",
    "        ax.plot(*y.T, c='k', alpha=alphas[i])\n",
    "    else:\n",
    "        ax.plot(*y.T, c=colors[i])\n",
    "\n",
    "\n",
    "ax.scatter(*y_endp.T, c='k')\n",
    "ax.axis('equal')\n",
    "ax.axis('off')\n",
    "if display_envelope:\n",
    "    circle = patches.Circle((0, 0), radius=envelope_radius, fill=False, color='lightblue', linewidth=2)\n",
    "    ax.add_patch(circle)\n",
    "fig.patch.set_facecolor('white')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b2db6a62",
   "metadata": {
    "id": "b2db6a62",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "## 3. Curves on a sphere\n",
    "\n",
    "In this example we want to find diverse points on sphere. First, we solve this task with a finite set of points similar to the horseshoe example.\n",
    "\n",
    "Next, we train a NN which parametrizes a curve to spread over the manifold.\n",
    "Concretely we learn a neural network with parameters $\\theta$ that represents a function $f_\\theta: \\mathbb{R}^1 \\mapsto \\mathbb{R}^3$ on the sphere.\n",
    "\n",
    "\n",
    "Again we start by introducing a pairwise distance function on curves on the sphere.\n",
    "As we will use it for the nearest neighbor and assume it is close, we use the euclidean distance as a local approximation of the geodesic."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bafd1e62",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "bafd1e62",
    "outputId": "cb7a4ca8-8d29-41e5-993e-46e686a92624",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "### Needed to enable k3d in colab. In a local env the jupyter nbextension commands and google.colab imports are not needed\n",
    "!pip install k3d\n",
    "!jupyter nbextension install --py --user k3d\n",
    "!jupyter nbextension enable --py --user k3d\n",
    "\n",
    "import k3d"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cf3096b1-308d-42f8-b639-f9964293510a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from google.colab import output\n",
    "\n",
    "output.enable_custom_widget_manager()\n",
    "\n",
    "k3d.switch_to_text_protocol()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "42bdbd84",
   "metadata": {
    "id": "42bdbd84"
   },
   "outputs": [],
   "source": [
    "def pairwise_dist_points_on_sphere(points):\n",
    "    \"\"\"\n",
    "    y: [num_points, 3] tensor\n",
    "    pairwise_dist[i, j] contains the distance between points i and j\n",
    "    \"\"\"\n",
    "    pairwise_dist = torch.norm(points[:, None] - points, dim=2, p=2)\n",
    "    return pairwise_dist\n",
    "\n",
    "\n",
    "def pairwise_dist_curves_on_sphere(y):\n",
    "    \"\"\"\n",
    "    y: [num_curves, num_points, 2] tensor\n",
    "    pairwise_dist[i, j] contains the distance between curves i and j\n",
    "    \"\"\"\n",
    "    pairwise_dist = torch.norm(y[:, None] - y, dim=2, p=2)\n",
    "    return pairwise_dist"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a8798d69",
   "metadata": {
    "id": "a8798d69"
   },
   "source": [
    "We define the spherical loss, as the distance of a point to the sphere.\n",
    "Note that this is a soft-constraint and only leads to approximate solutions."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d0e67c8",
   "metadata": {
    "id": "9d0e67c8",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Finite Set of Points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2757c48f",
   "metadata": {
    "id": "2757c48f"
   },
   "outputs": [],
   "source": [
    "ny = 3\n",
    "nx = 1\n",
    "x_range = [-1., 1.]\n",
    "n_samples = 500\n",
    "n_iter = 1000\n",
    "sphere_radius = 1\n",
    "\n",
    "lambda_spherical = 100\n",
    "lambda_div = 1\n",
    "\n",
    "box_bounds = torch.tensor([[x_range[0], x_range[1]], [x_range[0], x_range[1]], [x_range[0], x_range[1]]])  # [3, 2] tensor for 3D box bounds\n",
    "pts = sample_bbox(box_bounds, N=n_samples)\n",
    "pts.requires_grad_(True)\n",
    "opt = torch.optim.Adam([pts], lr=3e-3)\n",
    "\n",
    "## Logging\n",
    "loss_keys = ['sphericality', 'diversity']\n",
    "loss_over_iters = {key: {} for key in loss_keys}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "452f7246",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 447
    },
    "id": "452f7246",
    "outputId": "83276f70-90ab-400f-c8d1-e4acfa53e8ff"
   },
   "outputs": [],
   "source": [
    "for i in (pbar := trange(n_iter)):\n",
    "    opt.zero_grad()\n",
    "\n",
    "    ## SPHERICALITY ##\n",
    "    r = pts.norm(dim=-1)\n",
    "    loss_sphericality = (r-sphere_radius).square().mean()\n",
    "\n",
    "    ## DIVERSITY ##\n",
    "    D = pairwise_dist_points_on_sphere(pts)\n",
    "    diversity = nearest_neighbor_diversity(D)\n",
    "\n",
    "    loss_over_iters['sphericality'][i] = loss_sphericality.item()\n",
    "    loss_over_iters['diversity'][i] = - diversity.item()\n",
    "\n",
    "    loss = lambda_spherical * loss_sphericality + diversity * lambda_div\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "\n",
    "\n",
    "    pbar.set_description(f\"sphericality: {loss_sphericality.item():.2e}, diversity: {diversity.item():.2f}\")\n",
    "\n",
    "\n",
    "plt.plot(list(loss_over_iters['sphericality'].keys()), list(loss_over_iters['sphericality'].values()), label='sphericality')\n",
    "plt.plot(list(loss_over_iters['diversity'].keys()), list(loss_over_iters['diversity'].values()), label='diversity')\n",
    "plt.semilogy()\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "40c49117",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000,
     "referenced_widgets": [
      "d8f747934bbb47949e63523cfb6170dd",
      "67ec8b04b6bc4e4dac35595aec5080e9",
      "a1fc19349ea54f3eb9ec1aa6a360a2a3",
      "95ff438c045c47d8a802db3feb2b032d"
     ]
    },
    "id": "40c49117",
    "outputId": "2dcb8ca8-26e6-4dc2-9aa7-a9810405ff80"
   },
   "outputs": [],
   "source": [
    "###Care: Colab is a bit buggy with k3d, sometimes as it only shows empty plots\n",
    "import k3d\n",
    "plot = k3d.plot(height=1000)\n",
    "plot += k3d.points(pts.detach().cpu().numpy(), point_size=0.03, color=0xff0000)\n",
    "plot += k3d.points([0,0,0], point_size=2, shader=\"mesh\", mesh_detail=10, color=0x0000ff, opacity=0.5)\n",
    "\n",
    "plot.display()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cbf7703f",
   "metadata": {
    "id": "cbf7703f",
    "jp-MarkdownHeadingCollapsed": true
   },
   "source": [
    "### Neural curve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc2ee378",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 454
    },
    "id": "dc2ee378",
    "outputId": "5c0484c8-6b7c-4213-a272-4972e104227f"
   },
   "outputs": [],
   "source": [
    "ny = 3\n",
    "nx = 1\n",
    "x_range = [-1, 1]\n",
    "n_samples = 100\n",
    "n_iter = 2500\n",
    "# div_scale = 1e-3\n",
    "sphere_radius = 1\n",
    "\n",
    "lambda_spherical = 1\n",
    "lambda_div = 1\n",
    "torch.manual_seed(0)\n",
    "\n",
    "model3 = Net(layer_widths=[nx, 40, 40, ny])\n",
    "opt = torch.optim.Adam(model3.parameters(), lr=1e-3)\n",
    "\n",
    "## Logging\n",
    "loss_keys = ['sphericality', 'diversity']\n",
    "loss_over_iters = {key: {} for key in loss_keys}\n",
    "\n",
    "\n",
    "for i in (pbar := trange(n_iter)):\n",
    "    opt.zero_grad()\n",
    "\n",
    "    ## SPHERICALITY ##\n",
    "    x = torch.rand(size=(n_samples, 1)) * (x_range[1] - x_range[0]) + x_range[0]\n",
    "    y = model3(x)\n",
    "    r = y.norm(dim=-1)\n",
    "    loss_sphericality = (r-sphere_radius).square().mean()\n",
    "\n",
    "    ## DIVERSITY ##\n",
    "    D = pairwise_dist_curves_on_sphere(y)\n",
    "    diversity = nearest_neighbor_diversity(D)\n",
    "\n",
    "    loss_over_iters['sphericality'][i] = loss_sphericality.item()\n",
    "    loss_over_iters['diversity'][i] = - diversity.item()\n",
    "\n",
    "    loss = lambda_spherical * loss_sphericality + diversity * lambda_div\n",
    "    loss.backward()\n",
    "    opt.step()\n",
    "\n",
    "\n",
    "    pbar.set_description(f\"sphericality: {loss_sphericality.item():.2e}, diversity: {diversity.item():.2f}\")\n",
    "\n",
    "\n",
    "plt.plot(list(loss_over_iters['sphericality'].keys()), list(loss_over_iters['sphericality'].values()), label='sphericality')\n",
    "plt.plot(list(loss_over_iters['diversity'].keys()), list(loss_over_iters['diversity'].values()), label='diversity')\n",
    "plt.semilogy()\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dfc2438",
   "metadata": {
    "id": "8dfc2438"
   },
   "source": [
    "A 3D visualization domonstrates that the curve spreads over the sphere.\n",
    "Feel free to decrease the lambda_div variable and compare the result."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65394560",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000,
     "referenced_widgets": [
      "993091014b36457a89035ed1964813aa",
      "3eb443527a3a4223b65fad1560beb3d2",
      "587894c1862f433a985d35f47e3e2795",
      "3ad1a84dd0eb4fd384f4c01f674ba52f"
     ]
    },
    "id": "65394560",
    "outputId": "64f535be-7a08-4024-dc53-a421f2391fe9"
   },
   "outputs": [],
   "source": [
    "###Care: Colab is a bit buggy with k3d, sometimes as it only shows empty plots\n",
    "n_samples = 10000\n",
    "\n",
    "xs = [torch.linspace(x_range[0], x_range[1], n_samples) for i in range(nx)]\n",
    "Xs = torch.meshgrid(*xs)\n",
    "x = torch.vstack([X.flatten() for X in Xs]).T\n",
    "y = model3(x).detach().cpu().numpy()\n",
    "\n",
    "plot2 = k3d.plot(height=1000)\n",
    "plot2 += k3d.points(y, point_size=0.03, color=0xff0000)\n",
    "plot2 += k3d.points([0,0,0], point_size=2, shader=\"mesh\", mesh_detail=10, color=0x0000ff, opacity=0.5)\n",
    "plot2.display()"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.13.2"
  },
  "widgets": {
   "application/vnd.jupyter.widget-state+json": {
    "3ad1a84dd0eb4fd384f4c01f674ba52f": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "3eb443527a3a4223b65fad1560beb3d2": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "587894c1862f433a985d35f47e3e2795": {
     "model_module": "k3d",
     "model_module_version": "2.16.1",
     "model_name": "PlotModel",
     "state": {
      "_backend_version": "2.16.1",
      "_dom_classes": [],
      "_model_module": "k3d",
      "_model_module_version": "2.16.1",
      "_model_name": "PlotModel",
      "_view_count": null,
      "_view_module": "k3d",
      "_view_module_version": "2.16.1",
      "_view_name": "PlotView",
      "antialias": 3,
      "auto_rendering": true,
      "axes": [
       "x",
       "y",
       "z"
      ],
      "axes_helper": 1,
      "axes_helper_colors": [
       16711680,
       65280,
       255
      ],
      "background_color": 16777215,
      "camera": [],
      "camera_animation": [],
      "camera_auto_fit": true,
      "camera_damping_factor": 0,
      "camera_fov": 60,
      "camera_mode": "trackball",
      "camera_no_pan": false,
      "camera_no_rotate": false,
      "camera_no_zoom": false,
      "camera_pan_speed": 0.3,
      "camera_rotate_speed": 1,
      "camera_zoom_speed": 1.2,
      "clipping_planes": [],
      "colorbar_object_id": -1,
      "colorbar_scientific": false,
      "custom_data": null,
      "fps": 25,
      "fps_meter": false,
      "grid": [
       -1,
       -1,
       -1,
       1,
       1,
       1
      ],
      "grid_auto_fit": true,
      "grid_color": 15132390,
      "grid_visible": true,
      "height": 1000,
      "label_color": 4473924,
      "layout": "IPY_MODEL_3ad1a84dd0eb4fd384f4c01f674ba52f",
      "lighting": 1.5,
      "logarithmic_depth_buffer": true,
      "manipulate_mode": "translate",
      "menu_visibility": true,
      "minimum_fps": -1,
      "mode": "view",
      "name": null,
      "object_ids": [
       132030268811408,
       132030272750736
      ],
      "rendering_steps": 1,
      "screenshot": "",
      "screenshot_scale": 2,
      "snapshot": "",
      "snapshot_type": "full",
      "time": 0,
      "voxel_paint_color": 0
     }
    },
    "67ec8b04b6bc4e4dac35595aec5080e9": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "95ff438c045c47d8a802db3feb2b032d": {
     "model_module": "@jupyter-widgets/base",
     "model_module_version": "1.2.0",
     "model_name": "LayoutModel",
     "state": {
      "_model_module": "@jupyter-widgets/base",
      "_model_module_version": "1.2.0",
      "_model_name": "LayoutModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/base",
      "_view_module_version": "1.2.0",
      "_view_name": "LayoutView",
      "align_content": null,
      "align_items": null,
      "align_self": null,
      "border": null,
      "bottom": null,
      "display": null,
      "flex": null,
      "flex_flow": null,
      "grid_area": null,
      "grid_auto_columns": null,
      "grid_auto_flow": null,
      "grid_auto_rows": null,
      "grid_column": null,
      "grid_gap": null,
      "grid_row": null,
      "grid_template_areas": null,
      "grid_template_columns": null,
      "grid_template_rows": null,
      "height": null,
      "justify_content": null,
      "justify_items": null,
      "left": null,
      "margin": null,
      "max_height": null,
      "max_width": null,
      "min_height": null,
      "min_width": null,
      "object_fit": null,
      "object_position": null,
      "order": null,
      "overflow": null,
      "overflow_x": null,
      "overflow_y": null,
      "padding": null,
      "right": null,
      "top": null,
      "visibility": null,
      "width": null
     }
    },
    "993091014b36457a89035ed1964813aa": {
     "model_module": "@jupyter-widgets/output",
     "model_module_version": "1.0.0",
     "model_name": "OutputModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/output",
      "_model_module_version": "1.0.0",
      "_model_name": "OutputModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/output",
      "_view_module_version": "1.0.0",
      "_view_name": "OutputView",
      "layout": "IPY_MODEL_3eb443527a3a4223b65fad1560beb3d2",
      "msg_id": "",
      "outputs": [
       {
        "data": {
         "application/vnd.jupyter.widget-view+json": {
          "model_id": "587894c1862f433a985d35f47e3e2795",
          "version_major": 2,
          "version_minor": 0
         },
         "text/plain": "Plot(antialias=3, axes=['x', 'y', 'z'], axes_helper=1.0, axes_helper_colors=[16711680, 65280, 255], background…"
        },
        "metadata": {
         "application/vnd.jupyter.widget-view+json": {
          "colab": {
           "custom_widget_manager": {
            "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/2b70e893a8ba7c0f/manager.min.js"
           }
          }
         }
        },
        "output_type": "display_data"
       }
      ]
     }
    },
    "a1fc19349ea54f3eb9ec1aa6a360a2a3": {
     "model_module": "k3d",
     "model_module_version": "2.16.1",
     "model_name": "PlotModel",
     "state": {
      "_backend_version": "2.16.1",
      "_dom_classes": [],
      "_model_module": "k3d",
      "_model_module_version": "2.16.1",
      "_model_name": "PlotModel",
      "_view_count": null,
      "_view_module": "k3d",
      "_view_module_version": "2.16.1",
      "_view_name": "PlotView",
      "antialias": 3,
      "auto_rendering": true,
      "axes": [
       "x",
       "y",
       "z"
      ],
      "axes_helper": 1,
      "axes_helper_colors": [
       16711680,
       65280,
       255
      ],
      "background_color": 16777215,
      "camera": [],
      "camera_animation": [],
      "camera_auto_fit": true,
      "camera_damping_factor": 0,
      "camera_fov": 60,
      "camera_mode": "trackball",
      "camera_no_pan": false,
      "camera_no_rotate": false,
      "camera_no_zoom": false,
      "camera_pan_speed": 0.3,
      "camera_rotate_speed": 1,
      "camera_zoom_speed": 1.2,
      "clipping_planes": [],
      "colorbar_object_id": -1,
      "colorbar_scientific": false,
      "custom_data": null,
      "fps": 25,
      "fps_meter": false,
      "grid": [
       -1,
       -1,
       -1,
       1,
       1,
       1
      ],
      "grid_auto_fit": true,
      "grid_color": 15132390,
      "grid_visible": true,
      "height": 1000,
      "label_color": 4473924,
      "layout": "IPY_MODEL_95ff438c045c47d8a802db3feb2b032d",
      "lighting": 1.5,
      "logarithmic_depth_buffer": true,
      "manipulate_mode": "translate",
      "menu_visibility": true,
      "minimum_fps": -1,
      "mode": "view",
      "name": null,
      "object_ids": [
       132030270969680,
       132030272576400
      ],
      "rendering_steps": 1,
      "screenshot": "",
      "screenshot_scale": 2,
      "snapshot": "",
      "snapshot_type": "full",
      "time": 0,
      "voxel_paint_color": 0
     }
    },
    "d8f747934bbb47949e63523cfb6170dd": {
     "model_module": "@jupyter-widgets/output",
     "model_module_version": "1.0.0",
     "model_name": "OutputModel",
     "state": {
      "_dom_classes": [],
      "_model_module": "@jupyter-widgets/output",
      "_model_module_version": "1.0.0",
      "_model_name": "OutputModel",
      "_view_count": null,
      "_view_module": "@jupyter-widgets/output",
      "_view_module_version": "1.0.0",
      "_view_name": "OutputView",
      "layout": "IPY_MODEL_67ec8b04b6bc4e4dac35595aec5080e9",
      "msg_id": "",
      "outputs": [
       {
        "data": {
         "application/vnd.jupyter.widget-view+json": {
          "model_id": "a1fc19349ea54f3eb9ec1aa6a360a2a3",
          "version_major": 2,
          "version_minor": 0
         },
         "text/plain": "Plot(antialias=3, axes=['x', 'y', 'z'], axes_helper=1.0, axes_helper_colors=[16711680, 65280, 255], background…"
        },
        "metadata": {
         "application/vnd.jupyter.widget-view+json": {
          "colab": {
           "custom_widget_manager": {
            "url": "https://ssl.gstatic.com/colaboratory-static/widgets/colab-cdn-widget-manager/2b70e893a8ba7c0f/manager.min.js"
           }
          }
         }
        },
        "output_type": "display_data"
       }
      ]
     }
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
