{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2961053-64be-4bbd-90d4-d2862967ffdb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import autograd\n",
    "import autograd.numpy as np\n",
    "import torchdiffeq\n",
    "from torchdiffeq import odeint\n",
    "import geomloss\n",
    "from tqdm import tqdm\n",
    "import importlib\n",
    "import math\n",
    "import torchsde\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "torch.set_default_dtype(torch.float32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a74a4a1f-939a-42bd-8055-791301f0de71",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "N = 500\n",
    "T = 5\n",
    "dim = 2\n",
    "betamax = 5.0\n",
    "seed = 2\n",
    "reg = 'vf'\n",
    "data = torch.load(f\"sim_twowell_N_{N}_T_{T}_dim_{dim}_D_0.25_beta_{betamax}.pkl\", weights_only = False)\n",
    "data_largeT = torch.load(f\"sim_twowell_N_{N}_T_25_dim_{dim}_D_0.25_beta_0.0.pkl\", weights_only = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b407cb7-0bbc-427e-979a-c65a30f199f0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dim = 1\n",
    "data['x'] = data['x'][:, range(1)]\n",
    "data_largeT['x'] = data_largeT['x'][:, range(1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9c316b2a-65ec-4be9-bf64-1c3b50a70134",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "T = len(np.unique(data['t_idx']))\n",
    "dim = data[\"x\"].shape[1]\n",
    "ts = torch.linspace(0, data['t_final'], T)\n",
    "\n",
    "plt.figure(figsize = (2.5, 2.5))\n",
    "plt.scatter(ts[data['t_idx']] + np.random.randn(len(data['t_idx']))*0.01, data['x'][:, 0], alpha = 0.05, s = 1, color = \"red\", rasterized = True)\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x_0$\")\n",
    "plt.title(\"Data\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f8ccc741-b076-43f6-986f-4d84fdd71078",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "dim"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41bec35d-942a-4c8b-b045-cdda023fad02",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../../src\")\n",
    "import models, utils, train\n",
    "from torch import optim\n",
    "importlib.reload(models)\n",
    "# Score fitting \n",
    "X = [torch.tensor(data[\"x\"][data[\"t_idx\"] == i, None], device = device, dtype = torch.float32) for i in np.sort(np.unique(data[\"t_idx\"]))]\n",
    "s = models.NCScoreFunc(d = dim, hidden_sizes = [64, 64, 64], activation = torch.nn.ReLU, time_dependent = True).to(device)\n",
    "s_opt = optim.AdamW(s.parameters(), lr = 1e-2)\n",
    "sigmas = torch.linspace(0, -2, 5).exp().to(device)\n",
    "s_trace = train.train_denoising_score(s, s_opt, sigmas, {'X' : X, 't' : ts.to(device)}, \n",
    "                                      256,\n",
    "                                      options = {'iters' : 10_000,\n",
    "                                                 'print_iter' : 1000,\n",
    "                                                 'checkpoint_iter' : None,\n",
    "                                                 'checkpoint_file' : f\"params_NCScoreFunc_\",\n",
    "                                                 'save_final' : False,\n",
    "                                                 'save_file' : f\"params_NCScoreFunc_\",\n",
    "                                                 'outdir' : None},\n",
    "                                      sample_batch_options = {'replacement' : True, })"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4177dd81-beb9-45af-824e-ff4d64f2a4ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "_ts = torch.linspace(0, 1, 15)\n",
    "_xs = torch.linspace(-1.5, 1.5, 15)\n",
    "X, Y = np.meshgrid(_ts, _xs)\n",
    "s.cpu()\n",
    "plt.figure(figsize = (3, 3))\n",
    "with torch.no_grad():\n",
    "    _s = torch.hstack([s(_t, _xs[:, None], sigmas[-1].cpu()) for _t in _ts])\n",
    "_s = torch.clamp(_s, -15, 15)\n",
    "plt.quiver(X, Y, torch.zeros_like(_s), _s, _s, cmap = \"RdBu_r\", scale_units = 'y')\n",
    "plt.scatter(ts[data['t_idx']] + np.random.randn(len(data['t_idx']))*0.025, data['x'][:, 0], alpha = 0.1, s = 1, c = 'k', rasterized = True)\n",
    "plt.xlabel(\"$t$\"); plt.ylabel(\"$x$\"); plt.axis(\"off\")\n",
    "plt.savefig(\"../../figures/concept_score.pdf\", dpi = 600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7703d072-8cd1-43b5-b74b-275bcfe6ce5b",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "X = [torch.tensor(data[\"x\"][data[\"t_idx\"] == i, None], device = device, dtype = torch.float32) for i in np.sort(np.unique(data[\"t_idx\"]))]\n",
    "m_ratios = torch.tensor([x.shape[0] / X[0].shape[0] for x in X]).float()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dc02a8de-06ec-4d14-8170-d5fdd036071e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from sklearn.mixture import GaussianMixture\n",
    "from matplotlib.collections import LineCollection\n",
    "from mpl_toolkits.mplot3d.art3d import Line3DCollection\n",
    "from matplotlib.colors import ListedColormap\n",
    "import matplotlib.cm  as cm\n",
    "\n",
    "fig = plt.figure(figsize=(5, 5))\n",
    "ax = fig.add_subplot(111, projection='3d')\n",
    "x = np.linspace(-1.3, 1.5, 100)\n",
    "\n",
    "gmms = [GaussianMixture(n_components=2, random_state=42) for _ in range(T)]\n",
    "cmap = cm.get_cmap('Blues', T) \n",
    "for i in range(T):\n",
    "    gmms[i].fit(data['x'][data['t_idx'] == i, 0][:, None])\n",
    "    y = np.full_like(x, i / (T-1))  # The curve is in the x-z plane (y=0)\n",
    "    z = np.exp(gmms[i].score_samples(x[:, None]))\n",
    "    ax.plot(x, y, z, 'k-', linewidth=2, label='Curve')\n",
    "    z_min = 0\n",
    "    xx = np.vstack([x, x])  # x-coordinates\n",
    "    yy = np.vstack([y, y])  # y-coordinates (all zeros in this case)\n",
    "    zz = np.vstack([z, np.zeros_like(x)])  # z-coordinates: curve values and zeros\n",
    "    surf = ax.plot_surface(xx, yy, zz, color = cmap(i), alpha=0.7, \n",
    "                           linewidth=0, antialiased=True, shade = False)\n",
    "\n",
    "for i in np.random.permutation(data['x_paths'].shape[0])[:100]:\n",
    "    ax.plot(data_largeT['x_paths'][i, :, 0], torch.linspace(0, 1, data_largeT[\"x_paths\"].shape[1]), alpha = 0.1, c = 'k')\n",
    "\n",
    "norm = plt.Normalize(-1, 1)\n",
    "x = x\n",
    "y = np.zeros_like(x) - 0.1\n",
    "z = np.zeros_like(x)\n",
    "points = np.array([x, y, z]).T.reshape(-1, 1, 3)\n",
    "segments = np.concatenate([points[:-1], points[1:]], axis=1)\n",
    "lc = Line3DCollection(segments, cmap='Reds', norm = norm)\n",
    "lc.set_array(np.tanh(2*x))  # Color the line based on position in the array\n",
    "lc.set_linewidth(5)\n",
    "\n",
    "# Add the collection to the plot\n",
    "line = ax.add_collection3d(lc)\n",
    "    \n",
    "ax.xaxis.pane.fill = False\n",
    "ax.yaxis.pane.fill = False\n",
    "ax.zaxis.pane.fill = False\n",
    "\n",
    "ax.xaxis.pane.set_edgecolor('none')\n",
    "ax.yaxis.pane.set_edgecolor('none')\n",
    "ax.zaxis.pane.set_edgecolor('none')\n",
    "\n",
    "ax.bar3d(np.zeros_like(m_ratios)-1.5, ts - 0.075, np.zeros_like(m_ratios), 0, 0.15, m_ratios * 0.1, shade = False, color = 'lightgrey')\n",
    "\n",
    "ax.set_axis_off()\n",
    "\n",
    "ax.set_xlabel('$x$')\n",
    "ax.set_ylabel('$t$')\n",
    "ax.view_init(elev=30, azim=-35)\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/concept_distributions_paths.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a623fa1-d31d-41f6-8dbe-dd3363305e0f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "X = [torch.tensor(data[\"x\"][data[\"t_idx\"] == i, :], device = device, dtype = torch.float32) for i in np.sort(np.unique(data[\"t_idx\"]))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cbfb749d-6a13-4843-948b-96242df2fc98",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "importlib.reload(models)\n",
    "importlib.reload(train)\n",
    "D = 0.5**2\n",
    "num_iter = 1000\n",
    "s.to(device)\n",
    "v_pfi = models.VectorField(dim, hidden_sizes = [64, 64, 64], time_dependent=True).to(device)\n",
    "opt_pfi = optim.Adam(v_pfi.parameters(), lr = 1e-2)\n",
    "trace_pfi = train.train_pfi(v_pfi, opt_pfi, s, sigmas,\n",
    "                            {'X' : X, 't' : ts.to(device)},\n",
    "                            {'D' : D, 'reg' : 0.001},\n",
    "                            256,\n",
    "                            options = {'iters' : num_iter, 'print_iter' : 100, 'reg_kind' : 'vf', 'checkpoint_iter' : None, 'checkpoint_file' : \"v_pfi\",\n",
    "                                       'save_final' : False, 'save_file' : 'v_pfi',\n",
    "                                       'anneal_sigma_iters' : None, 'outdir' : './', \n",
    "                                      'teacher_forcing_iter' : num_iter}, \n",
    "                            sample_batch_options = {'replacement' : True,  'add_noise' : False,}, \n",
    "                            samplesloss_options={'loss' : 'sinkhorn'}, \n",
    "                            odeint_options={'method' : 'euler', 'options' : {'step_size' : data['t_final'] / (2*T)}})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b70f579b-3a37-41a7-b19f-70f5e1aad97a",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "v_upfi = models.ODEFlowGrowth(dim, kwargs_v = {'hidden_sizes' : [64, 64, 64]}, kwargs_g = {'hidden_sizes' : [64, ]}).to(device);\n",
    "opt_upfi = optim.Adam(v_upfi.parameters(), lr = 1e-2)\n",
    "alpha_wfr = 0.1\n",
    "reg_wfr = 0.001\n",
    "trace_upfi = train.train_upfi(v_upfi, opt_upfi, s, sigmas,\n",
    "                              {'X' : X, 't' : ts.to(device), 'm_ratios' : torch.tensor(m_ratios).to(device),},\n",
    "                              {'D' : D, 'alpha_wfr' : alpha_wfr, 'reg_wfr' : reg_wfr, },\n",
    "                              256,\n",
    "                              options = {'iters' : num_iter, 'print_iter' : 100, 'reg_kind' : 'vf', 'checkpoint_iter' : None, 'checkpoint_file' : \"v_upfi\",\n",
    "                                       'save_final' : False, 'save_file' : 'v_upfi',\n",
    "                                       'anneal_sigma_iters' : None, 'outdir' : './', \n",
    "                                      'teacher_forcing_iter' : num_iter},\n",
    "                              sample_batch_options = {'replacement' : True,  'add_noise' : False}, \n",
    "                              samplesloss_options={'loss' : 'sinkhorn', 'reach' : 5.0}, \n",
    "                              odeint_options={'method' : 'euler', 'options' : {'step_size' : data['t_final'] / (2*T)}})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fd52a36-0c0a-4e09-b6d2-1a7c9bd96500",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(trace_pfi)\n",
    "plt.yscale('log')\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(trace_upfi)\n",
    "plt.yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8607b687-0054-417f-b621-1788eb5b2147",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (2, 2))\n",
    "plt.bar(range(T), m_ratios)\n",
    "plt.xticks(range(T), range(T))\n",
    "plt.yticks([1, 5, 10, 15], [1, 5, 10, 15])\n",
    "plt.xlabel(\"Timepoint\"); plt.ylabel(\"Relative mass\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/concept_mass_v_time.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3293df48-d283-4211-9cbb-02df35a0c574",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "X = [torch.tensor(data[\"x\"][data[\"t_idx\"] == i, :], device = device, dtype = torch.float32) for i in np.sort(np.unique(data[\"t_idx\"]))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "351241d1-4311-422b-b543-2bb5146e5c0d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "odeint_options = {'method' : 'euler', 'options' : {'step_size' : data['t_final'] / (2*T)}}\n",
    "x0_mass = utils.sample_batch_upfi(X, m_ratios.to(device), 64)[0]\n",
    "x0 = x0_mass[:, 1:]\n",
    "F_ode_upfi = lambda t, x: v_upfi(t, x) - (D/2)*torch.hstack([torch.zeros_like(x[:, :1]), s(t, x[:, 1:], sigmas[-1]), ])\n",
    "F_ode_pfi = lambda t, x: v_pfi(t, x) - (D/2)*s(t, x, sigmas[-1])\n",
    "_ts = torch.linspace(0, 1, 100)\n",
    "with torch.no_grad():\n",
    "    xs_t_pfi = odeint(F_ode_pfi, x0, _ts, **odeint_options).cpu()\n",
    "    xs_t_upfi = odeint(F_ode_upfi, x0_mass.to(device), _ts, **odeint_options).cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fe499231-ed29-4616-947c-b7fdd34b288d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "sde_pfi = models.SDE(lambda t, x: v_pfi(t, x), sigma = D**0.5)\n",
    "sde_upfi = models.SDE(lambda t, x: v_upfi(t, x), sigma = torch.cat([torch.tensor([0, ]), torch.full((dim, ), D**0.5)]).to(device))\n",
    "with torch.no_grad():\n",
    "    xs_t_pfi_ = torchsde.sdeint(sde_pfi, x0, _ts, method = \"euler\").cpu()\n",
    "    xs_t_upfi_ = torchsde.sdeint(sde_upfi, x0_mass, _ts, method = \"euler\").cpu()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "138e3380-1e9e-4f2c-8640-42c3a6e7eb49",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from matplotlib.collections import LineCollection\n",
    "\n",
    "# def plot_line_variable_width(x, y, w, ax, **kwargs):\n",
    "#     points = np.array([x, y]).T.reshape(-1, 1, 2)\n",
    "#     segments = np.concatenate([points[:-1], points[1:]], axis=1)\n",
    "#     lc = LineCollection(segments, linewidths=w, **kwargs)\n",
    "#     ax.add_collection(lc)\n",
    "\n",
    "def plot_line_variable_width(xs, ys, widths, ax=None, color='b', xlim=None, ylim=None,\n",
    "                **kwargs):\n",
    "    if not (len(xs) == len(ys) == len(widths)):\n",
    "        raise ValueError('xs, ys, and widths must have identical lengths')\n",
    "    segmentx, segmenty = [xs[0]], [ys[0]]\n",
    "    current_width = widths[0]\n",
    "    for ii, (x, y, width) in enumerate(zip(xs, ys, widths)):\n",
    "        segmentx.append(x)\n",
    "        segmenty.append(y)\n",
    "        if (width != current_width) or (ii == (len(xs) - 1)):\n",
    "            ax.plot(segmentx, segmenty, linewidth=current_width, color=color[ii] if isinstance(color, list) else color,\n",
    "                    **kwargs)\n",
    "            segmentx, segmenty = [x], [y]\n",
    "            current_width = width"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "314be463-f771-4075-88f0-92af243c9e04",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (5, 2.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(_ts, xs_t_upfi[..., 1], c = 'k', alpha = 0.15);\n",
    "plt.ylim(-1.5, 1.5)\n",
    "# plt.title(\"PF-ODE\")\n",
    "plt.axis(\"off\")\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(_ts, xs_t_upfi_[..., 1], c = 'k', alpha = 0.15);\n",
    "plt.ylim(-1.5, 1.5)\n",
    "# plt.title(\"SDE\")\n",
    "plt.axis(\"off\")\n",
    "plt.tight_layout()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f98402f-0779-42ed-af24-ccdf4f1144f0",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (5, 2.5))\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(_ts, xs_t_pfi[..., 0], c = 'k', alpha = 0.15);\n",
    "plt.ylim(-1.5, 1.5)\n",
    "plt.axis(\"off\")\n",
    "# plt.title(\"PF-ODE\")\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(_ts, xs_t_pfi_[..., 0], c = 'k', alpha = 0.15);\n",
    "plt.ylim(-1.5, 1.5)\n",
    "plt.axis(\"off\")\n",
    "# plt.title(\"SDE\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/concept_PFI_paths.pdf\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fcbb8421-7c92-46f8-8c68-a93590a79b48",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "plt.figure(figsize = (5, 2.5))\n",
    "ax = plt.subplot(1, 2, 1)\n",
    "for i in range(64):\n",
    "    plot_line_variable_width(_ts, xs_t_upfi[:, i, 1], 25*torch.clamp_max(xs_t_upfi[:, i, 0].exp(), 0.5), ax = ax, alpha = 0.15, rasterized = True, color = 'k')\n",
    "plt.ylim(-1.5, 1.5)\n",
    "plt.axis(\"off\")\n",
    "# plt.title(\"PF-ODE\")\n",
    "ax = plt.subplot(1, 2, 2)\n",
    "for i in range(64):\n",
    "    plot_line_variable_width(_ts, xs_t_upfi_[:, i, 1], 25*torch.clamp_max(xs_t_upfi_[:, i, 0].exp(), 0.5), ax = ax, alpha = 0.15, rasterized = True, color = 'k')\n",
    "plt.ylim(-1.5, 1.5)\n",
    "# plt.title(\"SDE\")\n",
    "plt.axis(\"off\")\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/concept_UPFI_paths.pdf\", dpi = 600)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "24daaac5-a8f5-488d-a7f5-80dd627affe1",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import matplotlib\n",
    "cmap = matplotlib.colormaps['Blues']\n",
    "_T = 10\n",
    "colors = [x for x in cmap(np.linspace(0, 1, _T))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "56c6a422-b844-4ea6-8451-a7ab6779dd64",
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize = (2.5, 2.5))\n",
    "t_offset = torch.cumsum(torch.full_like(ts, 0.1), dim = 0)\n",
    "t_offset = t_offset - t_offset[0]\n",
    "ax = plt.gca()\n",
    "for i in range(T):\n",
    "    y = X[i][:, 0].cpu()[:500]\n",
    "    plt.scatter((ts + t_offset)[i].expand(len(y)), y, alpha = 0.3, c = 'lightgrey', edgecolor = 'k', rasterized = True, s = 5)\n",
    "    idx = torch.randperm(len(y))[:32]\n",
    "    plt.scatter((ts + t_offset)[i].expand(len(idx)), y[idx], alpha = 0.5, c = 'r', s = 50, zorder = 100, edgecolor = 'k')\n",
    "    if i < T-1:\n",
    "        _t = torch.linspace((ts + t_offset)[i], ts[i+1] + t_offset[i], _T)\n",
    "        _x = odeint(F_ode_upfi, utils.pad_zeros_upfi(y[idx, None]).to(device), _t, **odeint_options).detach().cpu()\n",
    "        for j in range(_x.shape[1]):\n",
    "            # plt.plot(_t, _x[:, j, 1], c = 'blue')\n",
    "            plot_line_variable_width(_t, _x[:, j, 1], (2*_x[:, j, 0]).exp(), ax = ax, alpha = 1, rasterized = True, color = colors)\n",
    "        plt.scatter((ts[i+1] + t_offset[i]).expand(len(idx)), _x[-1, :, 1], s = 10*(2*_x[-1, :, 0]).exp(), alpha = 1, c = colors[-1], zorder = 100, edgecolor = 'k')\n",
    "plt.axis('off')\n",
    "plt.tight_layout()\n",
    "plt.savefig(\"../../figures/concept_ode_integration.pdf\", dpi = 300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2e6fe2d-8399-4535-b94b-e7396cc6ee98",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "169219cc-e562-46fc-952e-b76ff6264f2c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "upfi",
   "language": "python",
   "name": "upfi"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.21"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
