{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VFidSWlB6CIX"
   },
   "outputs": [],
   "source": [
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Lb4FT9_yTz6Z"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "# from colabtools import adhoc_import\n",
    "import importlib\n",
    "from userdiffusion import ode_datasets, unet, samplers\n",
    "importlib.reload(ode_datasets)\n",
    "importlib.reload(unet)\n",
    "importlib.reload(samplers)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rc\n",
    "rc('animation', html='jshtml')\n",
    "import jax.numpy as jnp\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "Joy4WVQ5Bqrw"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from jax import devices,device_count\n",
    "device_count()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "iyArKlOsUVs5"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tf.executing_eagerly()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "i4P7SzmOsEvF"
   },
   "source": [
    "# Generate the Trajectories"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "eDIphxAwsKUE"
   },
   "source": [
    "## N-Link Pendulum"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "jgkNTEf-fhAR"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "E1217 17:18:03.310604   34593 cuda_dnn.cc:534] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR                                                                                                                                            | 0/25 [00:00<?, ?it/s]\n",
      "E1217 17:18:03.310698   34593 cuda_dnn.cc:538] Memory usage: 25155010560 bytes free, 25438126080 bytes total.\n",
      "E1217 17:18:03.311269   34593 cuda_dnn.cc:534] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR\n",
      "E1217 17:18:03.311314   34593 cuda_dnn.cc:538] Memory usage: 25155010560 bytes free, 25438126080 bytes total.\n",
      "  0%|                                                                                                                                                                                                                                                | 0/25 [00:00<?, ?it/s]\n"
     ]
    },
    {
     "ename": "XlaRuntimeError",
     "evalue": "FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mXlaRuntimeError\u001b[0m                           Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[4], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m dt \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m.1\u001b[39m\n\u001b[0;32m----> 2\u001b[0m ds \u001b[38;5;241m=\u001b[39m \u001b[43mode_datasets\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mNPendulum\u001b[49m\u001b[43m(\u001b[49m\u001b[43mN\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m2000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mn\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mdt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m      3\u001b[0m thetas,vs \u001b[38;5;241m=\u001b[39m ode_datasets\u001b[38;5;241m.\u001b[39munpack(ds\u001b[38;5;241m.\u001b[39mZs)\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/src/userdiffusion/ode_datasets.py:341\u001b[0m, in \u001b[0;36mNPendulum.__init__\u001b[0;34m(self, n, dt, *args, **kwargs)\u001b[0m\n\u001b[1;32m    330\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"NPendulum constructor.\u001b[39;00m\n\u001b[1;32m    331\u001b[0m \n\u001b[1;32m    332\u001b[0m \u001b[38;5;124;03mUses additional arguments over base class.\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    338\u001b[0m \u001b[38;5;124;03m  **kwargs: ODEDataset kwargs\u001b[39;00m\n\u001b[1;32m    339\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m    340\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn \u001b[38;5;241m=\u001b[39m n\n\u001b[0;32m--> 341\u001b[0m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/src/userdiffusion/ode_datasets.py:271\u001b[0m, in \u001b[0;36mHamiltonianDataset.__init__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    270\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m--> 271\u001b[0m   \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__init__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    272\u001b[0m   \u001b[38;5;66;03m# convert the momentum into velocity\u001b[39;00m\n\u001b[1;32m    273\u001b[0m   qs, ps \u001b[38;5;241m=\u001b[39m unpack(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mZs)\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/src/userdiffusion/ode_datasets.py:77\u001b[0m, in \u001b[0;36mODEDataset.__init__\u001b[0;34m(self, N, chunk_len, dt, integration_time)\u001b[0m\n\u001b[1;32m     65\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Constructor for the ODE dataset.\u001b[39;00m\n\u001b[1;32m     66\u001b[0m \n\u001b[1;32m     67\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     74\u001b[0m \u001b[38;5;124;03m      randomly sampled\u001b[39;00m\n\u001b[1;32m     75\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m     76\u001b[0m \u001b[38;5;28msuper\u001b[39m()\u001b[38;5;241m.\u001b[39m\u001b[38;5;21m__init__\u001b[39m()\n\u001b[0;32m---> 77\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mZs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mgenerate_trajectory_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mN\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mintegration_time\u001b[49m\u001b[43m)\u001b[49m  \u001b[38;5;66;03m# pylint: disable=invalid-name\u001b[39;00m\n\u001b[1;32m     78\u001b[0m T \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39masarray(jnp\u001b[38;5;241m.\u001b[39marange(\u001b[38;5;241m0\u001b[39m, integration_time, dt))  \u001b[38;5;66;03m# pylint: disable=invalid-name\u001b[39;00m\n\u001b[1;32m     79\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mT \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mT_long \u001b[38;5;241m=\u001b[39m T[T \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mburnin_time]  \u001b[38;5;66;03m# pylint: disable=invalid-name\u001b[39;00m\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/src/userdiffusion/ode_datasets.py:117\u001b[0m, in \u001b[0;36mODEDataset.generate_trajectory_data\u001b[0;34m(self, trajectories, dt, integration_time, bs)\u001b[0m\n\u001b[1;32m    115\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m _ \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mrange\u001b[39m(\u001b[38;5;241m0\u001b[39m, trajectories, bs \u001b[38;5;241m*\u001b[39m k)):\n\u001b[1;32m    116\u001b[0m   z0s \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msample_initial_conditions(bs \u001b[38;5;241m*\u001b[39m k)\n\u001b[0;32m--> 117\u001b[0m   ts \u001b[38;5;241m=\u001b[39m \u001b[43mjnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marange\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mintegration_time\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    118\u001b[0m   new_zs \u001b[38;5;241m=\u001b[39m jintegrate(z0s, ts)\n\u001b[1;32m    119\u001b[0m   new_zs \u001b[38;5;241m=\u001b[39m new_zs[:, ts \u001b[38;5;241m>\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mburnin_time]\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:6518\u001b[0m, in \u001b[0;36marange\u001b[0;34m(start, stop, step, dtype, device)\u001b[0m\n\u001b[1;32m   6448\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Create an array of evenly-spaced values.\u001b[39;00m\n\u001b[1;32m   6449\u001b[0m \n\u001b[1;32m   6450\u001b[0m \u001b[38;5;124;03mJAX implementation of :func:`numpy.arange`, implemented in terms of\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   6514\u001b[0m \u001b[38;5;124;03m  - :func:`jax.lax.iota`: directly generate integer sequences in XLA.\u001b[39;00m\n\u001b[1;32m   6515\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m   6516\u001b[0m \u001b[38;5;66;03m# TODO(vfdev-5): optimize putting the array directly on the device specified\u001b[39;00m\n\u001b[1;32m   6517\u001b[0m \u001b[38;5;66;03m# instead of putting it on default device and then on the specific device\u001b[39;00m\n\u001b[0;32m-> 6518\u001b[0m output \u001b[38;5;241m=\u001b[39m \u001b[43m_arange\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstart\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstop\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstop\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstep\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   6519\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m device \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   6520\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m jax\u001b[38;5;241m.\u001b[39mdevice_put(output, device\u001b[38;5;241m=\u001b[39mdevice)\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:6562\u001b[0m, in \u001b[0;36m_arange\u001b[0;34m(start, stop, step, dtype)\u001b[0m\n\u001b[1;32m   6560\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m step \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m start \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mand\u001b[39;00m stop \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   6561\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m lax\u001b[38;5;241m.\u001b[39miota(dtype, np\u001b[38;5;241m.\u001b[39mceil(stop)\u001b[38;5;241m.\u001b[39mastype(\u001b[38;5;28mint\u001b[39m))  \u001b[38;5;66;03m# type: ignore[arg-type]\u001b[39;00m\n\u001b[0;32m-> 6562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43marray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43marange\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstart\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstop\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstop\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstep\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mstep\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdtype\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py:5426\u001b[0m, in \u001b[0;36marray\u001b[0;34m(object, dtype, copy, order, ndmin, device)\u001b[0m\n\u001b[1;32m   5424\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m   5425\u001b[0m   \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mUnexpected input type for array: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(\u001b[38;5;28mobject\u001b[39m)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m-> 5426\u001b[0m out_array: Array \u001b[38;5;241m=\u001b[39m \u001b[43mlax_internal\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_convert_element_type\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   5427\u001b[0m \u001b[43m    \u001b[49m\u001b[43mout\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweak_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweak_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msharding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msharding\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   5428\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m ndmin \u001b[38;5;241m>\u001b[39m ndim(out_array):\n\u001b[1;32m   5429\u001b[0m   out_array \u001b[38;5;241m=\u001b[39m lax\u001b[38;5;241m.\u001b[39mexpand_dims(out_array, \u001b[38;5;28mrange\u001b[39m(ndmin \u001b[38;5;241m-\u001b[39m ndim(out_array)))\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/jax/_src/lax/lax.py:587\u001b[0m, in \u001b[0;36m_convert_element_type\u001b[0;34m(operand, new_dtype, weak_type, sharding)\u001b[0m\n\u001b[1;32m    585\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m operand\n\u001b[1;32m    586\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 587\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mconvert_element_type_p\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    588\u001b[0m \u001b[43m      \u001b[49m\u001b[43moperand\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnew_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnew_dtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweak_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mbool\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mweak_type\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    589\u001b[0m \u001b[43m      \u001b[49m\u001b[43msharding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msharding\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/jax/_src/lax/lax.py:2981\u001b[0m, in \u001b[0;36m_convert_element_type_bind\u001b[0;34m(operand, new_dtype, weak_type, sharding)\u001b[0m\n\u001b[1;32m   2980\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_convert_element_type_bind\u001b[39m(operand, \u001b[38;5;241m*\u001b[39m, new_dtype, weak_type, sharding):\n\u001b[0;32m-> 2981\u001b[0m   operand \u001b[38;5;241m=\u001b[39m \u001b[43mcore\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mPrimitive\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind\u001b[49m\u001b[43m(\u001b[49m\u001b[43mconvert_element_type_p\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43moperand\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2982\u001b[0m \u001b[43m                                \u001b[49m\u001b[43mnew_dtype\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mnew_dtype\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mweak_type\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mweak_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   2983\u001b[0m \u001b[43m                                \u001b[49m\u001b[43msharding\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msharding\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   2984\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m sharding \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m   2985\u001b[0m     operand \u001b[38;5;241m=\u001b[39m pjit\u001b[38;5;241m.\u001b[39mwith_sharding_constraint(operand, sharding)\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/jax/_src/core.py:438\u001b[0m, in \u001b[0;36mPrimitive.bind\u001b[0;34m(self, *args, **params)\u001b[0m\n\u001b[1;32m    435\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbind\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams):\n\u001b[1;32m    436\u001b[0m   \u001b[38;5;28;01massert\u001b[39;00m (\u001b[38;5;129;01mnot\u001b[39;00m config\u001b[38;5;241m.\u001b[39menable_checks\u001b[38;5;241m.\u001b[39mvalue \u001b[38;5;129;01mor\u001b[39;00m\n\u001b[1;32m    437\u001b[0m           \u001b[38;5;28mall\u001b[39m(\u001b[38;5;28misinstance\u001b[39m(arg, Tracer) \u001b[38;5;129;01mor\u001b[39;00m valid_jaxtype(arg) \u001b[38;5;28;01mfor\u001b[39;00m arg \u001b[38;5;129;01min\u001b[39;00m args)), args\n\u001b[0;32m--> 438\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbind_with_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfind_top_trace\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/jax/_src/core.py:442\u001b[0m, in \u001b[0;36mPrimitive.bind_with_trace\u001b[0;34m(self, trace, args, params)\u001b[0m\n\u001b[1;32m    440\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mbind_with_trace\u001b[39m(\u001b[38;5;28mself\u001b[39m, trace, args, params):\n\u001b[1;32m    441\u001b[0m   \u001b[38;5;28;01mwith\u001b[39;00m pop_level(trace\u001b[38;5;241m.\u001b[39mlevel):\n\u001b[0;32m--> 442\u001b[0m     out \u001b[38;5;241m=\u001b[39m \u001b[43mtrace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprocess_primitive\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mmap\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtrace\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfull_raise\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    443\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mmap\u001b[39m(full_lower, out) \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmultiple_results \u001b[38;5;28;01melse\u001b[39;00m full_lower(out)\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/jax/_src/core.py:955\u001b[0m, in \u001b[0;36mEvalTrace.process_primitive\u001b[0;34m(self, primitive, tracers, params)\u001b[0m\n\u001b[1;32m    953\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m call_impl_with_key_reuse_checks(primitive, primitive\u001b[38;5;241m.\u001b[39mimpl, \u001b[38;5;241m*\u001b[39mtracers, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mparams)\n\u001b[1;32m    954\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m--> 955\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mprimitive\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mimpl\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mtracers\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mparams\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/jax/_src/dispatch.py:91\u001b[0m, in \u001b[0;36mapply_primitive\u001b[0;34m(prim, *args, **params)\u001b[0m\n\u001b[1;32m     89\u001b[0m prev \u001b[38;5;241m=\u001b[39m lib\u001b[38;5;241m.\u001b[39mjax_jit\u001b[38;5;241m.\u001b[39mswap_thread_local_state_disable_jit(\u001b[38;5;28;01mFalse\u001b[39;00m)\n\u001b[1;32m     90\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m---> 91\u001b[0m   outs \u001b[38;5;241m=\u001b[39m \u001b[43mfun\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     92\u001b[0m \u001b[38;5;28;01mfinally\u001b[39;00m:\n\u001b[1;32m     93\u001b[0m   lib\u001b[38;5;241m.\u001b[39mjax_jit\u001b[38;5;241m.\u001b[39mswap_thread_local_state_disable_jit(prev)\n",
      "    \u001b[0;31m[... skipping hidden 16 frame]\u001b[0m\n",
      "File \u001b[0;32m~/workspace/GitHub/pmlr-v202-finzi23a/.venv/lib/python3.10/site-packages/jax/_src/compiler.py:267\u001b[0m, in \u001b[0;36mbackend_compile\u001b[0;34m(backend, module, options, host_callbacks)\u001b[0m\n\u001b[1;32m    261\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m backend\u001b[38;5;241m.\u001b[39mcompile(\n\u001b[1;32m    262\u001b[0m         built_c, compile_options\u001b[38;5;241m=\u001b[39moptions, host_callbacks\u001b[38;5;241m=\u001b[39mhost_callbacks\n\u001b[1;32m    263\u001b[0m     )\n\u001b[1;32m    264\u001b[0m   \u001b[38;5;66;03m# Some backends don't have `host_callbacks` option yet\u001b[39;00m\n\u001b[1;32m    265\u001b[0m   \u001b[38;5;66;03m# TODO(sharadmv): remove this fallback when all backends allow `compile`\u001b[39;00m\n\u001b[1;32m    266\u001b[0m   \u001b[38;5;66;03m# to take in `host_callbacks`\u001b[39;00m\n\u001b[0;32m--> 267\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mbackend\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompile\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbuilt_c\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcompile_options\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43moptions\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    268\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m xc\u001b[38;5;241m.\u001b[39mXlaRuntimeError \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m    269\u001b[0m   \u001b[38;5;28;01mfor\u001b[39;00m error_handler \u001b[38;5;129;01min\u001b[39;00m _XLA_RUNTIME_ERROR_HANDLERS:\n",
      "\u001b[0;31mXlaRuntimeError\u001b[0m: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details."
     ]
    }
   ],
   "source": [
    "dt = .1\n",
    "ds = ode_datasets.NPendulum(N=2000,n=1,dt=dt)\n",
    "thetas,vs = ode_datasets.unpack(ds.Zs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "N8HsUM03yuk8"
   },
   "outputs": [],
   "source": [
    "# for i in range(20):\n",
    "#   fig = plt.figure()\n",
    "#   ax = fig.add_subplot(1, 1, 1)\n",
    "#   line1, = ax.plot(ds.T_long,thetas[i,:,0])\n",
    "#   line2, = ax.plot(ds.T_long,thetas[i,:,1])#\n",
    "#   #line2, = ax.plot(ds.T_long,jnp.cos(thetas[i,:,1])+jnp.cos(thetas[i,:,0]))\n",
    "#   plt.xlabel('Time t')\n",
    "#   plt.ylabel(r'State')\n",
    "#   plt.legend([r'$\\theta_0$',r'$\\theta_1$'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tD3PUWBVQLX-"
   },
   "outputs": [],
   "source": [
    "dataset = tf.data.Dataset.from_tensor_slices(thetas)\n",
    "data_std = thetas.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "D9LK0AFrd8-J"
   },
   "outputs": [],
   "source": [
    "jnp.sqrt(((thetas[None,:400]-thetas[:400,None])**2).sum((-1,-2))).max()/jnp.sqrt(np.prod(thetas.shape[1:]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "brLISQLhQVXz"
   },
   "outputs": [],
   "source": [
    "bs = 400\n",
    "dataiter = dataset.shuffle(len(dataset)).batch(bs).as_numpy_iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8SJ3stKFsOR1"
   },
   "outputs": [],
   "source": [
    "from matplotlib import rc\n",
    "rc('animation', html='jshtml')\n",
    "#ds.animate()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Y1hHzbmWZrvS"
   },
   "source": [
    "##Diffusion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vXD6VL3jTfpO"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import jax.numpy as jnp\n",
    "from jax import random\n",
    "import jax\n",
    "import flax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SMN8fY2UthvT"
   },
   "outputs": [],
   "source": [
    "x = next(dataiter())\n",
    "t = np.random.rand(x.shape[0])\n",
    "model = unet.UNet(unet.unet_64_config(out_dim=x.shape[-1],base_channels=24))\n",
    "params = model.init(random.PRNGKey(42), x=x,t=t,train=False)\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pajn72Q-uOps"
   },
   "outputs": [],
   "source": [
    "def count_params(params):\n",
    "  if isinstance(params, jax.numpy.ndarray):\n",
    "    return np.prod(params.shape)\n",
    "  elif isinstance(params,(dict,flax.core.frozen_dict.FrozenDict)):\n",
    "    return sum([count_params(v) for v in params.values()])\n",
    "  else:\n",
    "    assert False, type(params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pkRzOnhCuQE7"
   },
   "outputs": [],
   "source": [
    "count_params(params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3HyQzRHSr_qh"
   },
   "source": [
    "Initialize the UNet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZpPYuYEluyJo"
   },
   "outputs": [],
   "source": [
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "U0zOA0KXvdbJ"
   },
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "import optax\n",
    "from jax import jit\n",
    "import pandas as pd\n",
    "importlib.reload(samplers)\n",
    "\n",
    "#sigma_min = 1e-3#2e-4#2e-3\n",
    "#sigma_max = 1#100\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZFhSBePQuzjs"
   },
   "outputs": [],
   "source": [
    "key = random.PRNGKey(38)\n",
    "with Mesh(mesh_utils.create_device_mesh((device_count(),)), ('data',)):\n",
    "  for epoch in tqdm(range(601)):\n",
    "    for data in dataiter():\n",
    "      params,ema_params,opt_state,key,loss_val = update_fn(params,ema_params,opt_state,key,data)\n",
    "    if epoch % 5 == 0:\n",
    "      ema_loss = jloss(ema_params,data,key)\n",
    "      message = f'Loss epoch {epoch}: {loss_val:.3f} Ema {ema_loss:.3f}'\n",
    "      # if not epoch % 30:\n",
    "      #   val = pmetric(samplers.sde_sampler(denoiser,params,key,(512,)+data.shape[1:],500)[0])[0]\n",
    "      #   #message += f'     Precision: {}'\n",
    "      print(message)\n",
    "    if epoch %200 ==0:\n",
    "      print(eval_metrics(dataiter,ema_params,key))\n",
    "\n",
    "params=ema_params"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gCU1NmjKqLM-"
   },
   "outputs": [],
   "source": [
    "mb = data[:30]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "z4v7WbBbQsPL"
   },
   "outputs": [],
   "source": [
    "importlib.reload(samplers)\n",
    "denoiser = jit(lambda params,x,sigma: denoised(params,x,jnp.ones(x.shape[0])*sigma,train=False))  \n",
    "def conditioning_scores(observed_values,s=.2):\n",
    "  b,n1,c = observed_values.shape\n",
    "  return jax.grad(lambda x: -jnp.sum((x.reshape(b,-1,c)[:,:n1]-observed_values)**2)/(2*s**2))\n",
    "#conditioning_scores(mb[:,:20]),\n",
    "\n",
    "  \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "iqgGfIaO-POu"
   },
   "outputs": [],
   "source": [
    "importlib.reload(samplers)\n",
    "t=.001\n",
    "z = samplers.ode_sample(denoiser,params,key,mb.shape,t,t_max)#,conditioning_scores(mb[:,:50]))\n",
    "noised_x = mb*samplers.s(t)+np.random.randn(*mb.shape)*(samplers.s(t)*samplers.sigma(t))\n",
    "import matplotlib.pyplot as plt\n",
    "i=2\n",
    "plt.plot(ds.T_long,mb[i,:,0])\n",
    "plt.plot(ds.T_long,noised_x[i,:,0])\n",
    "plt.plot(ds.T_long,z[i,:,0])\n",
    "\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "plt.legend([r'GT','GT noised xt',r'Model xt'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OUsObb_pFBZh"
   },
   "outputs": [],
   "source": [
    "importlib.reload(samplers)\n",
    "nll = samplers.compute_nll(denoiser,params,key,data[:400])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mDbkYQtcFGwa"
   },
   "outputs": [],
   "source": [
    "nll.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Kktv27hE2jZN"
   },
   "outputs": [],
   "source": [
    "importlib.reload(samplers)\n",
    "from jax import grad\n",
    "def score(x,t):\n",
    "  return (denoiser(params,x.reshape(mb.shape)/samplers.s(t),samplers.sigma(t)).reshape(-1)-x/samplers.s(t))/(samplers.s(t)*samplers.sigma(t)**2)\n",
    "dynamics = lambda t,x: grad(samplers.s)(t)*x/samplers.s(t)-(samplers.s(t)**2)*grad(samplers.sigma)(t)*score(x,t).reshape(-1)*samplers.sigma(t)\n",
    "dynamics2 = lambda t,x: (grad(samplers.s)(t)/samplers.s(t)+grad(samplers.sigma)(t)/samplers.sigma(t))*x - (grad(samplers.sigma)(t)/samplers.sigma(t))*samplers.s(t)*denoiser(params,x.reshape(mb.shape)/samplers.s(t),samplers.sigma(t)).reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "rsdaFI5jDm36"
   },
   "outputs": [],
   "source": [
    "dynamics(.99,xt.reshape(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Pvgo3RXfDrw0"
   },
   "outputs": [],
   "source": [
    "dynamics2(.99,xt.reshape(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "CYvRKZEq2uHn"
   },
   "outputs": [],
   "source": [
    "xt = np.random.randn(*mb.shape)*samplers.s(t_max)*samplers.sigma(t_max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Hwq8vbg93nBw"
   },
   "outputs": [],
   "source": [
    "xt.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tt-V5k-991-h"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "A8649rV_40u5"
   },
   "outputs": [],
   "source": [
    "jnp.max(jnp.abs(samplers.score(denoiser,params,mb.shape)(mb.reshape(-1),1.)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WlBiNZIC5EP6"
   },
   "outputs": [],
   "source": [
    "t=1.\n",
    "denoiser(params,mb/samplers.s(t),samplers.sigma(t)).reshape(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "E-XlJWv15M8-"
   },
   "outputs": [],
   "source": [
    "denoiser(params,mb/samplers.s(t),t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JVVhWgkO5bBD"
   },
   "outputs": [],
   "source": [
    "denoised(params,mb,jnp.ones(mb.shape[0])*samplers.sigma(t),train=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "sbReawQH3Gty"
   },
   "outputs": [],
   "source": [
    "dynamics(1.,mb.reshape(-1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "UnGckHT03PtR"
   },
   "outputs": [],
   "source": [
    "1/samplers.s(t_max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lyXySA2M_0g6"
   },
   "outputs": [],
   "source": [
    "# z = jax.random.normal(key,(64,)+input_data.shape[1:])\n",
    "# y = denoiser(z,.1)\n",
    "# import numpy as np\n",
    "# perm = np.random.permutation(z.shape[0])\n",
    "# y2 = denoiser(z[perm],.1)[np.argsort(perm)]\n",
    "# print(jnp.linalg.norm(y-y2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ioj3N1vbRHdR"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "i=5\n",
    "plt.plot(ds.T_long,mb[i,:,0])\n",
    "plt.plot(ds.T_long,z[i,:,0])\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "plt.legend([r'GT',r'Model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LhlxPaa5ZijM"
   },
   "outputs": [],
   "source": [
    "data = next(dataiter())\n",
    "key = random.PRNGKey(26)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "EE5N_rBEUDiM"
   },
   "outputs": [],
   "source": [
    "nll = samplers.compute_nll(denoiser,params,key,data[:400],smin=sigma_min,smax=sigma_max,num_probes=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9LAyp0c2678J"
   },
   "outputs": [],
   "source": [
    "nll.mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "q_0MJE3qdlo7"
   },
   "outputs": [],
   "source": [
    "nll.std(0)/jnp.sqrt(len(nll))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gR_lt_xWI0AW"
   },
   "outputs": [],
   "source": [
    "\n",
    "noised_data = samplers.forward_process2(denoiser,params,key,data,smin=sigma_min,smax=sigma_max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gzv1gVO7NWOD"
   },
   "outputs": [],
   "source": [
    "noised_data.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bUaxGULjTe2I"
   },
   "outputs": [],
   "source": [
    "noised_data.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "DblhkquiJK5n"
   },
   "outputs": [],
   "source": [
    "noised_data.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qO0uH4NmJ4zt"
   },
   "outputs": [],
   "source": [
    "T = samplers.timesteps(30,sigma_min,sigma_max)\n",
    "print(np.sum(T[1:]-T[:-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2mtArA_7vtyS"
   },
   "outputs": [],
   "source": [
    "\n",
    "key = random.PRNGKey(45)\n",
    "\n",
    "s,history = samplers.sde_sampler(denoiser,params,key,(32,)+data.shape[1:],nsteps=1000,smin=sigma_min,smax=sigma_max)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aIModqvIW-xg"
   },
   "outputs": [],
   "source": [
    "#stochastic_sampler(params,key,(128,)+input_data.shape[1:],N=2000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fEU1jVUmDPlb"
   },
   "outputs": [],
   "source": [
    "s = samplers.ode_sample(denoiser,params,random.split(key)[0],(64,)+data.shape[1:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GJX9cnJ3lcGJ"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aE452iDGuwxc"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "plt.plot(ds.T_long,thetas[2,:,0])\n",
    "plt.plot(ds.T_long,thetas[2,:,-1])\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "plt.legend([r'$\\theta_0$',r'$\\theta_1$'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yOm-f1H4uxok"
   },
   "outputs": [],
   "source": [
    "for i,h in enumerate(history[::200]):\n",
    "  plt.plot(ds.T_long,h[1,:,-1],label=str(i),alpha=1/3)\n",
    "plt.plot(ds.T_long,s[1,:,-1],label=str(i))\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "plt.legend()\n",
    "plt.ylim((-3,3))\n",
    "#plt.legend([r'$\\theta_0$',r'$\\theta_1$'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PJ0ojxLm2VuD"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from ipywidgets import interact\n",
    "\n",
    "\n",
    "\n",
    "# @interact(i=(0,s.shape[0]-1))\n",
    "# def plot(i=1):\n",
    "#   fig = plt.figure()\n",
    "#   ax = fig.add_subplot(1, 1, 1)\n",
    "#   line1, = ax.plot(ds.T_long,s[i,:,0])\n",
    "#   line2, = ax.plot(ds.T_long,s[i,:,1])\n",
    "#   plt.xlabel('Time t')\n",
    "#   plt.ylabel(r'State')\n",
    "#   plt.legend([r'$\\theta_0$',r'$\\theta_1$'])\n",
    "  #plt.ylim(-2,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uWsm4HAMFccX"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K7PSODHy8abI"
   },
   "outputs": [],
   "source": [
    "for i in range(2):\n",
    "  fig = plt.figure()\n",
    "  ax = fig.add_subplot(1, 1, 1)\n",
    "  line1, = ax.plot(ds.T_long,s[i,:,0])\n",
    "  line2, = ax.plot(ds.T_long,s[i,:,-1])\n",
    "  plt.xlabel('Time t')\n",
    "  plt.ylabel(r'State')\n",
    "  plt.legend([r'$\\theta_0$',r'$\\theta_1$'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HJ2dVfECLvj1"
   },
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "  fig = plt.figure()\n",
    "  ax = fig.add_subplot(1, 1, 1)\n",
    "  line1, = ax.plot(ds.T_long,s[i,:,0])\n",
    "  line2, = ax.plot(ds.T_long,s[i,:,-1])\n",
    "  plt.xlabel('Time t')\n",
    "  plt.ylabel(r'State')\n",
    "  plt.legend([r'$\\theta_0$',r'$\\theta_1$'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "S9cmzf_ZHMrR"
   },
   "outputs": [],
   "source": [
    "key = random.PRNGKey(45)\n",
    "#s=s2#,history = samplers.sde_sampler(denoiser,params,key,(32,)+data.shape[1:],nsteps=500,smin=sigma_min,smax=sigma_max)\n",
    "\n",
    "\n",
    "k = 5\n",
    "q = s[:,k:]\n",
    "v = -(q[:,:-2]-q[:,2:])/(2*(ds.T[1]-ds.T[0]))\n",
    "z = ode_datasets.pack(q[:,1:-1],(vmap(vmap(ds.mass))(q[:,1:-1])@v[...,None]).squeeze(-1))\n",
    "T = ds.T_long[k+1:-1]\n",
    "z0 = z[:,0]\n",
    "z_gts = vmap(ds.integrate,(0,None),0)(z0,T)\n",
    "z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)\n",
    "z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JqbHZCCgoilB"
   },
   "outputs": [],
   "source": [
    "q.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GjgmAsigXkC_"
   },
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "  fig = plt.figure()\n",
    "  ax = fig.add_subplot(1, 1, 1)\n",
    "  line1, = ax.plot(T,z_gts[i,:,0])\n",
    "  line2, = ax.plot(T,z[i,:,0])\n",
    "  line3, = ax.plot(T,z_pert[i,:,0])\n",
    "  plt.xlabel('Time t')\n",
    "  plt.ylabel(r'State')\n",
    "  plt.legend(['gt','model','pert'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ujIbtIHyxMaw"
   },
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "  fig = plt.figure()\n",
    "  ax = fig.add_subplot(1, 1, 1)\n",
    "  line1, = ax.plot(T,z_gts[i,:,0])\n",
    "  line2, = ax.plot(T,z[i,:,0])\n",
    "  line3, = ax.plot(T,z_gts[i,:,-1])\n",
    "  line5, = ax.plot(T,z[i,:,-1])\n",
    "  plt.xlabel('Time t')\n",
    "  plt.ylabel(r'State')\n",
    "  plt.legend([r'$\\theta_0$ gt',r'$\\theta_0$ model',r'v gt', r'v model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "HjnKwVZdORrg"
   },
   "outputs": [],
   "source": [
    "pmetric(s)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "divghYKhaDhw"
   },
   "outputs": [],
   "source": [
    "for pred in [z,z_pert,z_random]:\n",
    "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)\n",
    "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
    "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
    "  plt.plot(T,rel_errs)\n",
    "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
    "\n",
    "plt.plot()\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Prediction Error')\n",
    "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WDm9CtIbTEIs"
   },
   "outputs": [],
   "source": [
    "for pred in [z,z_pert,z_random]:\n",
    "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)\n",
    "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
    "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
    "  plt.plot(T,rel_errs)\n",
    "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
    "\n",
    "plt.plot()\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Prediction Error')\n",
    "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "UM_G0V6QMnjL"
   },
   "outputs": [],
   "source": [
    "H_gts = vmap(vmap(ds.hamiltonian))(z_gts)\n",
    "for pred in [z,z_pert,z_random]:\n",
    "  Hs = vmap(vmap(ds.hamiltonian))(pred)\n",
    "  clamped_errs = jax.lax.clamp(1e-3,jnp.abs(Hs-H_gts)/jnp.abs(Hs*H_gts),np.inf)\n",
    "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
    "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
    "  plt.plot(T,rel_errs)\n",
    "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
    "\n",
    "plt.plot()\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Energy Error')\n",
    "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Am0N9YRAFlxS"
   },
   "outputs": [],
   "source": [
    "\n",
    "for H in Hs:\n",
    "  plt.plot(ds.T_long[1:-1],jnp.abs(H-H[0]))\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Energy Error')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xjtnSnyqqDGD"
   },
   "outputs": [],
   "source": [
    "metric_vals =[]\n",
    "metric_stds = []\n",
    "Ns = [25,50,100,200,500,1000,2000]\n",
    "for N in Ns:\n",
    "  s,_ = samplers.sde_sampler(denoiser,params,key,(256,)+data.shape[1:],nsteps=N,smin=sigma_min,smax=sigma_max)\n",
    "  mean,std = pmetric(s)\n",
    "  metric_vals.append(mean)\n",
    "  metric_stds.append(std)\n",
    "metric_vals = np.array(metric_vals)\n",
    "metric_stds = np.array(metric_stds)\n",
    "plt.plot(Ns,metric_vals)\n",
    "plt.fill_between(Ns, metric_vals/metric_stds, metric_vals*metric_stds,alpha=.3)\n",
    "plt.xlabel('Sampler steps')\n",
    "plt.ylabel('Pmetric value')\n",
    "plt.xscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vPNbx_4GZOI4"
   },
   "outputs": [],
   "source": [
    "plt.plot(Ns,metric_vals)\n",
    "plt.fill_between(Ns, metric_vals/metric_stds, metric_vals*metric_stds,alpha=.3)\n",
    "plt.xlabel('Sampler steps')\n",
    "plt.ylabel('Pmetric value')\n",
    "plt.xscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XKmrvn60ex2J"
   },
   "outputs": [],
   "source": [
    "data = next(dataiter())\n",
    "key = random.PRNGKey(26)\n",
    "noised_x,sigma = noise_input(data,key)\n",
    "weighting = (sigma**2+data_std**2)/(sigma*data_std)**2\n",
    "losses = jnp.mean(((denoised(ema_params,noised_x,sigma)-data)**2)*weighting[:,None,None],axis=(-1,-2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FBpfHonffl_H"
   },
   "outputs": [],
   "source": [
    "\n",
    "plt.scatter(sigma,losses)\n",
    "#plt.plot(np.sort(sigma),jax.scipy.stats.norm.pdf(np.log(np.sort(sigma)),mu,std),color='y')\n",
    "#plt.hline(1e-1)\n",
    "#plt.scatter(sigma,weighting)\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "plt.ylabel('weighted loss')\n",
    "plt.xlabel(r'$\\sigma$')\n",
    "plt.legend(['loss values','sigma sample pdf'][:1:-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hRLWfpQ_29Ad"
   },
   "outputs": [],
   "source": [
    "x = np.random.randn(256)\n",
    "\n",
    "binomial = [np.array([1., 1.])/2]\n",
    "for _ in range(int(np.floor(np.log2(len(x))))):\n",
    "  sqr = np.convolve(binomial[-1],binomial[-1])\n",
    "  #binomial[-1] /= sqr[sqr.shape[0]//2+1]\n",
    "  #binomial.append(sqr/sqr.sum())\n",
    "  binomial.append(sqr/sqr[sqr.shape[0]//2+1])\n",
    "  \n",
    "binomial = [np.array([1.])]+binomial[:-1]\n",
    "def blur(z):\n",
    "  return jnp.convolve(binomial[-1],z,mode='same')\n",
    "\n",
    "#vblur = vmap(vmap(blur,0,0),2,2)\n",
    "\n",
    "def vblur(z):\n",
    "  s = jnp.fft.rfft(z,axis=1)\n",
    "  f = 1+jnp.abs(jnp.fft.fftfreq(z.shape[1])[:s.shape[1]])*s.shape[1]\n",
    "  scaled = s/f[None,:,None]**.5\n",
    "  scaled = scaled/jnp.mean(jnp.abs(scaled),axis=1,keepdims=True)\n",
    "  noise = jnp.fft.irfft(scaled,axis=1)\n",
    "  return noise\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Qrka-iN7h5Ke"
   },
   "outputs": [],
   "source": [
    "x = np.random.randn(300)\n",
    "\n",
    "binomial = [np.array([1., 1.])/2]\n",
    "for _ in range(int(np.floor(np.log2(len(x))))):\n",
    "  sqr = np.convolve(binomial[-1],binomial[-1])\n",
    "  #binomial[-1] /= sqr[sqr.shape[0]//2+1]\n",
    "  binomial.append(sqr/sqr.sum())\n",
    "  #binomial.append(sqr/sqr[sqr.shape[0]//2+1])\n",
    "  \n",
    "binomial = [np.array([1.])]+binomial[:-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZLaW91Crq9l8"
   },
   "outputs": [],
   "source": [
    "blurred = [jax.scipy.signal.convolve(x,bin,mode='same') for bin in binomial]\n",
    "blurred.append(jnp.cumsum(x)/np.sqrt(len(x)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pmeV_A3dsrIv"
   },
   "outputs": [],
   "source": [
    "for i,bx in enumerate(blurred):\n",
    "  plt.plot(bx,label=str(2**i))\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Z5C-esHKsuDN"
   },
   "outputs": [],
   "source": [
    "freq = np.fft.fftfreq(x.shape[0])[:x.shape[0]//2]*x.shape[0]\n",
    "for i,bx in enumerate(blurred):\n",
    "  plt.plot(freq, jnp.abs(np.fft.fft(bx)[:x.shape[0]//2]),label=str(2**i))\n",
    "\n",
    "plt.plot(freq,1/freq**2,label='brown')\n",
    "plt.plot(freq,1/freq,label='pink')\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "d9o2ahEjyG55"
   },
   "outputs": [],
   "source": [
    "#plt.plot(freq, jnp.abs(np.fft.fft(vblur(x[None,:,None])[0,:,0])[:x.shape[0]//2]),label=str(2**i))\n",
    "plt.plot(freq,1/freq**2,label='brown')\n",
    "plt.plot(freq,1/freq,label='pink')\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Hbj5vWgi9veU"
   },
   "outputs": [],
   "source": [
    "plt.plot(vblur(x[None,:,None])[0,:,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FJttK8iqs8Td"
   },
   "outputs": [],
   "source": [
    "[(bx**2).mean() for x in blurred]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Sj1izTwsxjw5"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "last_runtime": {
    "build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu",
    "kind": "private"
   },
   "name": "double_pendulum_diffusion.ipynb",
   "private_outputs": true,
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
