{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qj5AP2Jc7gXL"
      },
      "outputs": [],
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gLeUAXP5j8RJ"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "tf.config.experimental.set_visible_devices([], \"GPU\")\n",
        "\n",
        "# from colabtools import adhoc_import\n",
        "import importlib\n",
        "from userdiffusion import ode_datasets, unet, samplers, train\n",
        "importlib.reload(ode_datasets)\n",
        "importlib.reload(unet)\n",
        "importlib.reload(samplers)\n",
        "importlib.reload(train)\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from matplotlib import rc\n",
        "rc('animation', html='jshtml')\n",
        "import jax.numpy as jnp\n",
        "import numpy as np\n",
        "import jax"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RQwCeAGHkB3i"
      },
      "outputs": [],
      "source": [
        "dt = .5\n",
        "bs = 400\n",
        "ds = ode_datasets.NPendulum(N=4000+bs,n=2,dt=dt)\n",
        "\n",
        "\n",
        "thetas,vs = ode_datasets.unpack(ds.Zs[bs:])\n",
        "test_x = ode_datasets.unpack(ds.Zs[:bs])[0]\n",
        "#thetas /=thetas.std()\n",
        "#thetas = jax.random.normal(jax.random.PRNGKey(38),thetas.shape)\n",
        "dataset = tf.data.Dataset.from_tensor_slices(thetas)\n",
        "\n",
        "dataiter = dataset.shuffle(len(dataset)).batch(bs).as_numpy_iterator"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "22s4EuwL0j8Y"
      },
      "outputs": [],
      "source": [
        "\n",
        "## Check that the computed velocities from $v=M(q)^{-1}p$ match those from finite differences.\n",
        "\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "vfd = -(thetas[:,:-2]-thetas[:,2:])/(2*dt)\n",
        "plt.plot(ds.T_long[1:-1],vs[0,1:-1,0])\n",
        "plt.plot(ds.T_long[1:-1],vfd[0,:,0])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'$\\dot\\theta$')\n",
        "plt.legend(['v state','v finite diff'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8jIaqgzj1APl"
      },
      "outputs": [],
      "source": [
        "## Examine the state over time\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "plt.plot(ds.T_long,thetas[0,:,0])\n",
        "plt.plot(ds.T_long,thetas[0,:,1])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend([r'$\\theta_0$',r'$\\theta_1$'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JbFiGz6W8Qgc"
      },
      "outputs": [],
      "source": [
        "x = test_x#next(dataiter())\n",
        "t = np.random.rand(x.shape[0])\n",
        "model = unet.UNet(unet.unet_64_config(out_dim=x.shape[-1],base_channels=24))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nUUCduoEkaYh"
      },
      "outputs": [],
      "source": [
        "#%env XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false\n",
        "difftype='VE'#@param ['VP','VE','SubVP','Test']\n",
        "diff = {'VP':train.VariancePreserving,'VE':train.VarianceExploding,\n",
        "        'SubVP':train.SubVariancePreserving,'Test':train.Test}[difftype]\n",
        "epochs = 400#@param {'type':'integer'}\n",
        "score_fn = train.train_diffusion(model,dataiter,epochs,diffusion=diff,lr=1e-3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GmHkzHw-XjW5"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "key= jax.random.PRNGKey(38)\n",
        "sample_traj = samplers.sde_sample(diff,score_fn,key,x[:30].shape,nsteps=1000,traj=True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PoyYkj2uX65n"
      },
      "outputs": [],
      "source": [
        "sample_traj.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JLhO8G2MXqpf"
      },
      "outputs": [],
      "source": [
        "xt = sample_traj[:,:1]\n",
        "N=1000\n",
        "timesteps = (.5+np.arange(N)[::-1])/N\n",
        "ts = (timesteps[:-1])\n",
        "J= jax.jacfwd(score_fn)(xt[0].reshape(-1),ts[0])\n",
        "tt = ts[::200]\n",
        "Js = jax.vmap(jax.jacfwd(score_fn))(xt.reshape(xt.shape[0],-1)[::200],tt)*(diff.sigma(tt)**2)[:,None,None]\n",
        "w,v = jnp.linalg.eigh(Js)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A7jh5WwUZPuR"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R29-bF17b-Fz"
      },
      "outputs": [],
      "source": [
        "for i,t in enumerate(ts[::200]):\n",
        "  plt.plot(w[i],label=str(t))\n",
        "plt.legend()\n",
        "plt.ylabel(r'$\\lambda_i$')\n",
        "plt.xlabel(r'$i$')\n",
        "#plt.ylim(1e-3,2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WJVdBGgiaLsL"
      },
      "outputs": [],
      "source": [
        "for i,t in enumerate(ts[::200]):\n",
        "  plt.plot(1+w[i],label=str(t))\n",
        "plt.yscale('log')\n",
        "plt.legend()\n",
        "plt.ylabel(r'$(1+\\lambda_i)$')\n",
        "plt.xlabel(r'$i$')\n",
        "#plt.ylim(1e-3,2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SpPCSaX_eHEa"
      },
      "outputs": [],
      "source": [
        "from jax import grad,jit\n",
        "condition_amount = 13# @param {type:\"slider\", min:0, max:50, step:1}\n",
        "mb = x[:30,:]\n",
        "\n",
        "\n",
        "def inpainting_scores(diffusion,scorefn,observed_values,slc):\n",
        "  b,n,c = observed_values.shape\n",
        "  def conditioned_scores(xt,t):\n",
        "    unflat_xt = xt.reshape(b,-1,c)\n",
        "\n",
        "    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)\n",
        "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
        "    combined_score = unobserved_score.at[:,slc].set(observed_score)\n",
        "    return combined_score.reshape(-1)\n",
        "  return conditioned_scores\n",
        "\n",
        "def inpainting_scores2(diffusion,scorefn,observed_values,slc):\n",
        "  b,n,c = observed_values.shape\n",
        "  def conditioned_scores(xt,t):\n",
        "    unflat_xt = xt.reshape(b,-1,c)\n",
        "\n",
        "    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)\n",
        "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
        "    def constraint(xt):\n",
        "      one_step_xhat = (xt+diffusion.sigma(t)**2*scorefn(xt,t))/diffusion.scale(t)\n",
        "      return jnp.sum((one_step_xhat.reshape(b,-1,c)[:,slc]-observed_values)**2)\n",
        "    #unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*10/(diff.g2(t)/2)\n",
        "    unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*10*diff.scale(t)**2/diff.sigma(t)**2\n",
        "    combined_score = unobserved_score.at[:,slc].set(observed_score)\n",
        "    return combined_score.reshape(-1)\n",
        "  return jit(conditioned_scores)\n",
        "\n",
        "slc = slice(condition_amount)\n",
        "conditioned_samples = samplers.sde_sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape,nsteps=1000,traj=True)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "prXGkjZxeJ7F"
      },
      "outputs": [],
      "source": [
        "xt = conditioned_samples[:,:1]\n",
        "N=1000\n",
        "timesteps = (.5+np.arange(N)[::-1])/N\n",
        "ts = (timesteps[:-1])\n",
        "J= jax.jacfwd(score_fn)(xt[0].reshape(-1),ts[0])\n",
        "tt = ts[::25]\n",
        "Js = jax.vmap(jax.jacfwd(score_fn))(xt.reshape(xt.shape[0],-1)[::25],tt)*(diff.sigma(tt)**2)[:,None,None]\n",
        "Js = Js[:,slc,slc]\n",
        "w,v = jnp.linalg.eigh(Js)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_M4u_1CuyWns"
      },
      "outputs": [],
      "source": [
        "jnp.linalg.norm(Js-jnp.swapaxes(Js,-1,-2),axis=(-1,-2))/jnp.linalg.norm(Js,axis=(-1,-2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fni9438tej54"
      },
      "outputs": [],
      "source": [
        "for i,t in enumerate(tt):\n",
        "  plt.plot(w[i],label=str(t))\n",
        "plt.legend()\n",
        "plt.ylabel(r'$\\lambda_i$')\n",
        "plt.xlabel(r'$i$')\n",
        "#plt.ylim(1e-3,2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7nQQHy1vzKc_"
      },
      "outputs": [],
      "source": [
        "plt.plot(tt,1+w[:,-1],label=r'$1+\\lambda_0$')\n",
        "plt.plot(tt,1-1/(1+1e0/diff.sigma(tt)**2),label=r'$1/\\sqrt{\\sigma_t}$')\n",
        "plt.yscale('log')\n",
        "plt.legend()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ylU8czRQey80"
      },
      "outputs": [],
      "source": [
        "for i,t in enumerate(tt):\n",
        "  plt.plot(1/(1.05+w[i]),label=str(t))\n",
        "plt.yscale('log')\n",
        "plt.legend()\n",
        "plt.ylabel(r'$(1+\\lambda_i)$')\n",
        "plt.xlabel(r'$i$')\n",
        "#plt.ylim(1e-3,2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fRn9pR8Ewzey"
      },
      "outputs": [],
      "source": [
        "importlib.reload(samplers)\n",
        "importlib.reload(train)\n",
        "#samplers.probability_flow(diff,score_fn,x,1e-4,1.).std()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Un8fU4D9ZyuI"
      },
      "outputs": [],
      "source": [
        "import jax\n",
        "key= jax.random.PRNGKey(38)\n",
        "samplers.compute_nll(diff,score_fn,key,x).mean()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8nc8LmPLL_z4"
      },
      "source": [
        "Sample generation"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Kboc2IIeteh7"
      },
      "outputs": [],
      "source": [
        "stoch_samples = samplers.sde_sample(diff,score_fn,key,x[:30].shape,nsteps=1000,traj=False)\n",
        "sample_traj = samplers.sde_sample(diff,score_fn,key,x[:30].shape,nsteps=1000,traj=True)\n",
        "det_samples = samplers.ode_sample(diff,score_fn,key,x[:30].shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5rrpbT2MAWRn"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "i=13 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "plt.plot(ds.T_long,sample_traj[400::100,i,:,-1].T,alpha=1/2)\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.ylim(-2,2)\n",
        "#plt.legend([r'GT',r'Model'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-LerwGMP9Bla"
      },
      "outputs": [],
      "source": [
        "|from jax import vmap\n",
        "n=sample_traj.shape[0]+1\n",
        "ts = (.5+jnp.arange(n)[::-1])[:-1]/n\n",
        "scores = vmap(score_fn)(sample_traj.reshape(len(ts),-1),ts).reshape(sample_traj.shape)\n",
        "best_reconstructions = (sample_traj+diff.sigma(ts)[:,None,None,None]**2*scores)/diff.scale(ts)[:,None,None,None]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xr1eZyO-1IWB"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import matplotlib as mpl\n",
        "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
        "\n",
        "i=7 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "\n",
        "cmap='inferno'\n",
        "\n",
        "\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = best_reconstructions[100::25,i,:,-1].T\n",
        "ax1.plot(ds.T_long,data[:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for i,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[i])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.ylim(-2,2)\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2ohSLTVT9f2M"
      },
      "outputs": [],
      "source": [
        "from scipy.ndimage import correlate1d\n",
        "i=7 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "vs = -correlate1d(best_reconstructions,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=2)\n",
        "print(vs.shape)\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = vs[100::25,i,:,-1].T\n",
        "ax1.plot(ds.T_long,data[:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for i,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[i])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'$\\dot \\theta$')\n",
        "#plt.ylim(-2,2)\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5UZzL2APAxpy"
      },
      "outputs": [],
      "source": [
        "i=14 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "vs = -correlate1d(best_reconstructions,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=2)\n",
        "z = ode_datasets.pack(best_reconstructions,(vmap(vmap(vmap(ds.mass)))(best_reconstructions)@vs[...,None]).squeeze(-1))\n",
        "Hs = vmap(vmap(vmap(ds.hamiltonian)))(z)\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = Hs[100::25,i,:].T\n",
        "ax1.plot(ds.T_long,data[:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for k,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[k])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'Energy')\n",
        "minn,maxx = data[:,-1].min(),data[:,-1].max()\n",
        "plt.ylim(minn-2*(maxx-minn),maxx+2*(maxx-minn))\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)\n",
        "z0 = z[-1,:,0]\n",
        "z_gts = vmap(ds.integrate,(0,None),0)(z0,ds.T_long)\n",
        "ax1.plot(ds.T_long,vmap(vmap(ds.hamiltonian))(z_gts)[i],color='g',lw=3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fiq7MGQm-1x-"
      },
      "outputs": [],
      "source": [
        "i=15 # @param {type:\"slider\", min:0, max:30, step:1}\n",
        "nn = sample_traj.shape[2]\n",
        "fft = jnp.abs(np.fft.rfft(sample_traj,axis=2))#[:,:,:nn//2]\n",
        "freq = np.fft.rfftfreq(sample_traj.shape[2],d=(ds.T[1]-ds.T[0]))#[:nn//2]\n",
        "\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = fft[100::25,i,:,-1].T\n",
        "ax1.plot(freq,data[:,:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for i,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[i])\n",
        "plt.xlabel('Frequency f')\n",
        "plt.ylabel(r'Fourier spectrum')\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "#plt.ylim(-2,2)\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)\n",
        "ax1.plot(freq,jnp.abs(np.fft.rfft(x,axis=1))[::10,:,-1].T,color='blue',alpha=.1);"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GSCNVbsV-zlI"
      },
      "outputs": [],
      "source": [
        "i=8 # @param {type:\"slider\", min:0, max:30, step:1}\n",
        "nn = best_reconstructions.shape[2]\n",
        "fft = jnp.abs(np.fft.rfft(best_reconstructions,axis=2))#[:,:,:nn//2]\n",
        "freq = np.fft.rfftfreq(best_reconstructions.shape[2],d=(ds.T[1]-ds.T[0]))#[:nn//2]\n",
        "\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = fft[100::25,i,:,-1].T\n",
        "ax1.plot(freq,data[:,:],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for i,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[i])\n",
        "plt.xlabel('Frequency f')\n",
        "plt.ylabel(r'Fourier spectrum')\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "#plt.ylim(-2,2)\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmax=ts[100], vmin=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)\n",
        "ax1.plot(freq,jnp.abs(np.fft.rfft(x,axis=1))[::10,:,-1].T,color='blue',alpha=.1);"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fa5_fj-mU_zw"
      },
      "outputs": [],
      "source": [
        "\n",
        "import matplotlib.pyplot as plt\n",
        "i=24 # @param {type:\"slider\", min:0, max:30, step:1}\n",
        "plt.plot(ds.T_long,x[i,:,-1])\n",
        "plt.plot(ds.T_long,det_samples[i,:,-1])\n",
        "plt.plot(ds.T_long,stoch_samples[i,:,-1])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend([r'GT',r'Model (ODE)', r'Model (SDE)'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QFHO0tDBL42D"
      },
      "source": [
        "Test ability to condition model on previous timesteps"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cH8P02jcfKRH"
      },
      "outputs": [],
      "source": [
        "from jax import grad,jit\n",
        "condition_amount = 13# @param {type:\"slider\", min:0, max:50, step:1}\n",
        "mb = x[:30,:]\n",
        "\n",
        "\n",
        "def inpainting_scores(diffusion,scorefn,observed_values,slc):\n",
        "  b,n,c = observed_values.shape\n",
        "  def conditioned_scores(xt,t):\n",
        "    unflat_xt = xt.reshape(b,-1,c)\n",
        "\n",
        "    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)\n",
        "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
        "    combined_score = unobserved_score.at[:,slc].set(observed_score)\n",
        "    return combined_score.reshape(-1)\n",
        "  return conditioned_scores\n",
        "\n",
        "def inpainting_scores2(diffusion,scorefn,observed_values,slc):\n",
        "  b,n,c = observed_values.shape\n",
        "  def conditioned_scores(xt,t):\n",
        "    unflat_xt = xt.reshape(b,-1,c)\n",
        "\n",
        "    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)\n",
        "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
        "    def constraint(xt):\n",
        "      one_step_xhat = (xt+diffusion.sigma(t)**2*scorefn(xt,t))/diffusion.scale(t)\n",
        "      return jnp.sum((one_step_xhat.reshape(b,-1,c)[:,slc]-observed_values)**2)\n",
        "    #unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*10/(diff.g2(t)/2)\n",
        "    unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*10*diff.scale(t)**2/diff.sigma(t)**2\n",
        "    combined_score = unobserved_score.at[:,slc].set(observed_score)\n",
        "    return combined_score.reshape(-1)\n",
        "  return jit(conditioned_scores)\n",
        "\n",
        "slc = slice(condition_amount)\n",
        "conditioned_samples = samplers.sde_sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape,nsteps=1000,traj=True)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IUMkKWJuAy2y"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "i=28 # @param {type:\"slider\", min:0, max:30, step:1}\n",
        "plt.plot(ds.T_long,conditioned_samples[-600::100,i,:,-1].T,zorder=0,alpha=.2)\n",
        "plt.plot(ds.T_long,conditioned_samples[-1,i,  :,-1].T,zorder=2)\n",
        "plt.plot(ds.T_long,x[i,:,-1],label='gt',alpha=1,zorder=99)\n",
        "plt.plot(ds.T_long[slc],x[i,slc,-1],label='cond',alpha=1,zorder=100,lw=3)\n",
        "\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.ylim(-3,3)\n",
        "plt.legend()\n",
        "#plt.legend([r'GT',r'Model'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "21fgx9UbckBK"
      },
      "outputs": [],
      "source": [
        "conditioned_sample = samplers.ode_sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xKR36dO6FgwA"
      },
      "outputs": [],
      "source": [
        "from jax import jit,vmap,random\n",
        "\n",
        "@jit\n",
        "def rel_err(z1,z2):\n",
        "  return jnp.abs((jnp.abs(z1-z2)).sum(-1)/(jnp.abs(z1).sum(-1)*jnp.abs(z2).sum(-1)))\n",
        "\n",
        "gt = x[:30]\n",
        "for pred in [conditioned_samples[-1],conditioned_sample]:\n",
        "  clamped_errs = jax.lax.clamp(1e-5,rel_err(pred,gt),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(ds.T_long,rel_errs)\n",
        "  plt.fill_between(ds.T_long, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Prediction Error')\n",
        "plt.legend(['SDE completion','ODE completion'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3Xfy7NxAfQ8S"
      },
      "outputs": [],
      "source": [
        "i=2 # @param {type:\"slider\", min:0, max:29, step:1}\n",
        "plt.plot(ds.T_long,x[i,:,-1])\n",
        "plt.plot(ds.T_long[slc],x[i,slc,-1],lw=3)\n",
        "plt.plot(ds.T_long,conditioned_sample[i,:,-1])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'State')\n",
        "plt.legend([r'GT','Conditioning',r'Model'])\n",
        "plt.ylim(-3,3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mMvSdxUUG9Op"
      },
      "outputs": [],
      "source": [
        "import jax \n",
        "jax.config.update('jax_default_matmul_precision', 'float32')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ovmzQCeoG_go"
      },
      "source": [
        "Energy Conditioning"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zYF_N85SHDZU"
      },
      "outputs": [],
      "source": [
        "\n",
        "from jax.scipy.stats.multivariate_normal import logpdf as normal_logpdf\n",
        "from jax import grad,jit,vmap\n",
        "key = jax.random.PRNGKey(38)\n",
        "mb = x[:30,:]\n",
        "b,_,c = mb.shape\n",
        "\n",
        "def H(q):\n",
        "  # takes as input q shaped (n,c)\n",
        "  #v = -correlate1d(one_step_xhat,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=1)\n",
        "  v = -(q[:-2]-q[2:])/(2*(ds.T[1]-ds.T[0]))\n",
        "  z = ode_datasets.pack(q[1:-1],(vmap(ds.mass)(q[1:-1])@v[...,None]).squeeze(-1))\n",
        "  Hs = vmap(ds.hamiltonian)(z)\n",
        "  return Hs\n",
        "\n",
        "def C(q):\n",
        "  # takes as input q shaped (n,c)\n",
        "  Hs = H(q)\n",
        "  return Hs-Hs.mean(axis=0,keepdims=True)\n",
        "\n",
        "data_std = thetas.std()\n",
        "\n",
        "def energy_conditioned_scores(diffusion,scorefn):\n",
        "  def conditioned_scores(xt,t):\n",
        "    # if not t.shape:\n",
        "    #   t=jnp.ones(b)*t\n",
        "    unflat_xt = xt.reshape(b,-1,c)\n",
        "    n = unflat_xt.shape[1]\n",
        "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
        "    # def constraint(xt):\n",
        "    #   qhat = (xt+diffusion.sigma(t)**2*scorefn(xt,t))/diffusion.scale(t)\n",
        "    #   qhat = qhat.reshape(b,-1,c)\n",
        "    #   return jnp.sum(jnp.square(C(qhat)))\n",
        "    def xhat(xt):\n",
        "      score_xhat = (xt+diffusion.sigma(t)**2*scorefn(xt,t))/diffusion.scale(t)\n",
        "      #print('score',score_xhat)\n",
        "      limiting_xhat = (xt/(1+diffusion.sigma(t)**2/data_std**2))/diffusion.scale(t)\n",
        "      #print('limiting',limiting_xhat)\n",
        "      tau = jnp.maximum(2*(t-1)+1.,0.)\n",
        "      #tau = jnp.minimum(2*t,1.0)\n",
        "      #ratio = (limiting_xhat/score_xhat)\n",
        "\n",
        "      blended = score_xhat#*jnp.sign(ratio)*jnp.abs(ratio)**tau\n",
        "      #print('blended',blended)\n",
        "      return blended\n",
        "\n",
        "    # def Sigma(xt):\n",
        "    #   return (diff.sigma(t)/diff.scale(t))**2/(1+diffusion.sigma(t)**2/data_std**2)\n",
        "      # #b,n,c = xt.shape\n",
        "      # xt = xt.reshape(b,-1)\n",
        "      # xh = vmap(xhat)(xt).reshape(b,n,c)\n",
        "      # DC = vmap(jax.jacrev(C))(xh).reshape(b,-1,n*c)\n",
        "      # #DC2 = vmap(jax.jacrev(lambda x: C(xhat(x[None])[0].reshape(n,c))))(xt)*diffusion.scale(t)\n",
        "      # #sig = (DC@jnp.swapaxes(DC,-1,-2))/(1+diffusion.sigma(t)**2/data_std**2)\n",
        "      # scale = jnp.trace(DC@jnp.swapaxes(DC,-1,-2),axis1=-1,axis2=-2)/DC.shape[1]/(1+diffusion.sigma(t)**2/data_std**2)\n",
        "      # return scale*(diff.sigma(t)/diff.scale(t))**2\n",
        "    sigma = ((diff.sigma(t)/diff.scale(t))**2)#/(1+diffusion.sigma(t)**2/data_std**2) \n",
        "    def log_likelihood(xt):\n",
        "      qhat = xhat(xt).reshape(b,-1,c)\n",
        "      value = vmap(C)(qhat)\n",
        "      return -(value*value).sum()/(2.*sigma)\n",
        "      #return normal_logpdf(value,jnp.zeros_like(value),Sigma(xt)).sum()\n",
        "    unobserved_score += .5*grad(log_likelihood)(xt).reshape(unflat_xt.shape)\n",
        "    return unobserved_score.reshape(-1)\n",
        "  return jit(conditioned_scores)\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SUy5Nglww80v"
      },
      "outputs": [],
      "source": [
        "energy_samples_det = samplers.ode_sample(diff,energy_conditioned_scores(diff,score_fn),key,mb.shape)\n",
        "print(f'ODE performance {pmetric(energy_samples_det)[0]}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-LIvv-dShrUG"
      },
      "outputs": [],
      "source": [
        "\n",
        "energy_samples_stoch = samplers.sde_sample(diff,energy_conditioned_scores(diff,score_fn),key,mb.shape,nsteps=100,traj=False)\n",
        "print(f'SDE performance {pmetric(energy_samples_stoch)[0]}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fMn0DvH2EOxT"
      },
      "outputs": [],
      "source": [
        "print(f'SDE performance {pmetric(energy_samples_stoch)[0]}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G1xhyREgxxgM"
      },
      "outputs": [],
      "source": [
        "from scipy.ndimage import correlate1d\n",
        "energy_samples_traj = samplers.sde_sample(diff,energy_conditioned_scores(diff,score_fn),key,mb.shape,nsteps=100,traj=True)\n",
        "n=energy_samples_traj.shape[0]+1\n",
        "ts = (.5+jnp.arange(n)[::-1])[:-1]/n\n",
        "scores = vmap(score_fn)(energy_samples_traj.reshape(len(ts),-1),ts).reshape(energy_samples_traj.shape)\n",
        "best_reconstructions = (energy_samples_traj+diff.sigma(ts)[:,None,None,None]**2*scores)/diff.scale(ts)[:,None,None,None]\n",
        "vs = -correlate1d(best_reconstructions,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=2)\n",
        "z = ode_datasets.pack(best_reconstructions,(vmap(vmap(vmap(ds.mass)))(best_reconstructions)@vs[...,None]).squeeze(-1))\n",
        "Hs = vmap(vmap(vmap(ds.hamiltonian)))(z)\n",
        "kstart=10\n",
        "z0 = z[-1,:,kstart]\n",
        "#print(vmap(ds.hamiltonian)(z0)[i])\n",
        "z_gts = vmap(ds.integrate,(0,None),0)(z0,ds.T_long[kstart:])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xOBoadZ8qakU"
      },
      "outputs": [],
      "source": [
        "jnp.any(jnp.isnan(energy_samples_traj),(-2,-1)).mean(-1)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XdY_nzlKx20p"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import matplotlib as mpl\n",
        "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
        "cmap='inferno'\n",
        "i=1 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = Hs[10::5,i,:].T\n",
        "#print(data[:,-1])\n",
        "ax1.plot(ds.T_long,data[::-1],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for k,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[k])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'Energy')\n",
        "minn,maxx = data[:,-1].min(),data[:,-1].max()\n",
        "plt.ylim(minn-2*(maxx-minn),maxx+2*(maxx-minn))\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)\n",
        "\n",
        "ax1.plot(ds.T_long[kstart:],vmap(vmap(ds.hamiltonian))(z_gts)[i],color='g',lw=3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KEZ9EH8yFXtV"
      },
      "outputs": [],
      "source": [
        "data_std"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qsN7xoFpkcL4"
      },
      "outputs": [],
      "source": [
        "diffusion=diff\n",
        "scorefn=score_fn\n",
        "N=1000\n",
        "timesteps = (.5+np.arange(N)[::-1])/N\n",
        "\n",
        "\n",
        "# def xhat(xt,t):\n",
        "#   return (xt+diffusion.sigma(t)**2*scorefn(xt,t))/diffusion.scale(t)\n",
        "def xhat(xt,t):\n",
        "    score_xhat = (xt+diffusion.sigma(t)**2*scorefn(xt,t))/diffusion.scale(t)\n",
        "    #print('score',score_xhat)\n",
        "    limiting_xhat = (xt/(1+diffusion.sigma(t)**2/data_std**2))/diffusion.scale(t)\n",
        "    #print('limiting',limiting_xhat)\n",
        "    tau = jnp.minimum(3*t,1.0)\n",
        "    ratio = (limiting_xhat/score_xhat)\n",
        "\n",
        "    blended = score_xhat*jnp.sign(ratio)*jnp.abs(ratio)**tau\n",
        "    #print('blended',blended)\n",
        "    return blended\n",
        "def Sigma(xt,t):\n",
        "  b,n,c = xt.shape\n",
        "  xt = xt.reshape(b,-1)\n",
        "  xh = vmap(xhat)(xt,t).reshape(b,n,c)\n",
        "  DC = vmap(jax.jacrev(C))(xh).reshape(b,-1,n*c)\n",
        "  DC2 = vmap(jax.jacrev(lambda x,t: C(xhat(x[None],t)[0].reshape(n,c))))(xt,t)\n",
        "  sig = (DC@jnp.swapaxes(DC2,-1,-2))*diffusion.scale(t)[:,None,None]\n",
        "  reg = 1e-5*jnp.eye(sig.shape[-1])[None]/(1+diffusion.sigma(t)**2/data_std**2)[:,None,None]\n",
        "  #eigs,V = jnp\n",
        "  #rint(jnp.linalg.eigh(sig)[0])\n",
        "  return ((sig + jnp.swapaxes(sig,-1,-2))/2. + reg)#*(1+diffusion.sigma(t)**2/data_std**2)[:,None,None]#*(diff.scale(t)**2/diff.sigma(t)**2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rhw1eF9QkqOR"
      },
      "outputs": [],
      "source": [
        "ts = (timesteps[:-1])[::100]\n",
        "S = Sigma(sample_traj[::100,0],ts)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dk832CJHFsQQ"
      },
      "outputs": [],
      "source": [
        "sample_traj[0,0].shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eqx3dp_7Fd7D"
      },
      "outputs": [],
      "source": [
        "xhat(sample_traj[0,0].reshape(-1),ts[0]);"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eqf-ZF3JIX25"
      },
      "outputs": [],
      "source": [
        "t  =timesteps[::200]\n",
        "xt = sample_traj[::200,0]\n",
        "b,n,c = xt.shape\n",
        "xt = xt.reshape(b,-1)\n",
        "xh = vmap(xhat)(xt,t)\n",
        "DC = vmap(jax.jacfwd(lambda x: C(x.reshape(n,c))))(xh)#.reshape(b,-1,n*c)\n",
        "DC2 = vmap(jax.jacfwd(lambda x,t: C(xhat(x,t).reshape(n,c))))(xt,t)\n",
        "SS = vmap(jax.jacfwd(xhat))(xt,t)\n",
        "sig = (DC@jnp.swapaxes(DC2,-1,-2))*diffusion.scale(t)[:,None,None]\n",
        "sig = (sig+jnp.swapaxes(sig,-1,-2))/2\n",
        "eigs,V = jnp.linalg.eigh(sig)\n",
        "cliplevel = 1e-4*eigs[:,-1]\n",
        "eigs = jnp.where(eigs\u003ccliplevel[:,None],cliplevel[:,None],eigs) \n",
        "sig = V@vmap(jnp.diag)(eigs)@jnp.swapaxes(V,-1,-2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eUzENeVjuBwP"
      },
      "outputs": [],
      "source": [
        "eigs[0,0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KXOfOrhEtiz5"
      },
      "outputs": [],
      "source": [
        "rel_err(V@vmap(jnp.diag)(eigs)@jnp.swapaxes(V,-1,-2),sig)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aSHFkY4JE4kz"
      },
      "outputs": [],
      "source": [
        "\n",
        "w = jnp.linalg.eigh(S)[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OsuEZeMhFBDm"
      },
      "outputs": [],
      "source": [
        "w[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VDxgUbzMn0SB"
      },
      "outputs": [],
      "source": [
        "#print(w)\n",
        "for i,ti in enumerate(ts):\n",
        "  plt.plot(w[i],label=str(ti))\n",
        "plt.legend()\n",
        "plt.ylabel(r'$\\lambda_i$')\n",
        "plt.xlabel(r'$i$')\n",
        "#plt.ylim(1e-3,20)\n",
        "#plt.ylim(-2,10)\n",
        "plt.yscale('log')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "a7iPVwaCLZv8"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3pJHwRbVuvba"
      },
      "outputs": [],
      "source": [
        "\n",
        "w = jnp.linalg.eigh(sig)[0]\n",
        "#print(w)\n",
        "for i,ti in enumerate(t):\n",
        "  plt.plot(w[i],label=str(ti))\n",
        "plt.legend()\n",
        "plt.ylabel(r'$\\lambda_i$')\n",
        "plt.xlabel(r'$i$')\n",
        "# plt.ylim(1e-3,20)\n",
        "plt.yscale('log')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NE_UuZoro78H"
      },
      "outputs": [],
      "source": [
        "xh[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8ctTpbbYo4CR"
      },
      "outputs": [],
      "source": [
        "w[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qt_rYIQoKiXX"
      },
      "outputs": [],
      "source": [
        "from functools import partial\n",
        "vnorm = vmap(partial(jnp.linalg.norm,ord=2))\n",
        "def rel_err(A,B):\n",
        "  return vnorm(A-B)/(vnorm(A)+vnorm(B))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3SVjZDQwIh15"
      },
      "outputs": [],
      "source": [
        "sig.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U_A6iqinKeiJ"
      },
      "outputs": [],
      "source": [
        "#jnp.abs(SS@jnp.swapaxes(DC,-1,-2)-jnp.swapaxes(DC2,-1,-2)).mean()\n",
        "rel_err(SS@jnp.swapaxes(DC,-1,-2),jnp.swapaxes(DC2,-1,-2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QCbNd_e5K496"
      },
      "outputs": [],
      "source": [
        "#jnp.abs(DC@SS-DC2).mean()/jnp.abs(DC2).mean()\n",
        "rel_err(DC@SS,DC2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Om1c5arDKxE9"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jpmvjX2oIarJ"
      },
      "outputs": [],
      "source": [
        "#jnp.abs(SS-jnp.swapaxes(SS,-1,-2)).mean()/jnp.abs(SS).mean()\n",
        "rel_err(SS,jnp.swapaxes(SS,-1,-2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qc2IOWyNJHYw"
      },
      "outputs": [],
      "source": [
        "rel_err(sig,jnp.swapaxes(sig,-1,-2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yS3Mnk0YIppq"
      },
      "outputs": [],
      "source": [
        "sig2 = (DC@SS)@jnp.swapaxes(DC,-1,-2)*diffusion.scale(t)[:,None,None]\n",
        "#jnp.abs(sig2-sig).mean()/jnp.abs(sig).mean()\n",
        "rel_err(sig2,sig)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U7KOW0LAJczB"
      },
      "outputs": [],
      "source": [
        "jnp.abs(sig2-jnp.swapaxes(sig,-1,-2)).mean()/jnp.abs(sig2).mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2it5zyu-KHAP"
      },
      "outputs": [],
      "source": [
        "\n",
        "jnp.abs(sig-jnp.swapaxes(sig,-1,-2)).mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BjJ3juGeKGAI"
      },
      "outputs": [],
      "source": [
        "jnp.linalg.norm(jnp.eye(20),ord=2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3KecaCKiJu-V"
      },
      "outputs": [],
      "source": [
        "\n",
        "rel_err(sig2,jnp.swapaxes(sig,-1,-2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zgqw5AoUNZY9"
      },
      "outputs": [],
      "source": [
        "from jax.config import config\n",
        "config.update(\"jax_enable_x64\", True)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KC5tS4hvalQz"
      },
      "outputs": [],
      "source": [
        "import jax \n",
        "jax.config.update('jax_default_matmul_precision', 'float32')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0caV5rxlJUh7"
      },
      "outputs": [],
      "source": [
        "f = lambda x,t: C(xhat(x,t).reshape(n,c))\n",
        "DC2 = vmap(jax.jacfwd(f))(xt,t)\n",
        "DC2r = vmap(jax.jacrev(f))(xt,t)\n",
        "jnp.abs(DC2-DC2r).mean()/jnp.abs(DC2).mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lfw-__QnI47K"
      },
      "outputs": [],
      "source": [
        "print('FWD',DC2[1,:3,:3])\n",
        "print('BWD',DC2r[1,:3,:3])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gsV36c6bIpFS"
      },
      "outputs": [],
      "source": [
        "DC.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "93O5MjJ4HSJP"
      },
      "outputs": [],
      "source": [
        "vmap(xhat)(xt,t).shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fImNrLOoJs-G"
      },
      "outputs": [],
      "source": [
        "xt.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Af7YFp08HJW0"
      },
      "outputs": [],
      "source": [
        "jnp.abs(S[3]-jnp.swapaxes(S[3],-1,-2)).mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uEos5GhLGkdH"
      },
      "outputs": [],
      "source": [
        "eigs,V = jnp.linalg.eigh(S)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qGYRmT8mGr_q"
      },
      "outputs": [],
      "source": [
        "jnp.abs(V[1]@jnp.diag(eigs[1])@V[1].T -(S[1]+S[1].T)/2).mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EGFtrxufm-hm"
      },
      "outputs": [],
      "source": [
        "w = jnp.linalg.eigh(S)[0]\n",
        "for i,t in enumerate(ts):\n",
        "  plt.plot(w[i],label=str(t))\n",
        "plt.legend()\n",
        "plt.ylabel(r'$\\lambda_i$')\n",
        "plt.xlabel(r'$i$')\n",
        "#plt.ylim(-2,10)\n",
        "plt.yscale('log')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7aZQTyIGmt16"
      },
      "outputs": [],
      "source": [
        "%debug"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XKRowUEwJ7ea"
      },
      "outputs": [],
      "source": [
        "energy_samples_stoch = samplers.sde_sample(diff,energy_conditioned_scores(diff,score_fn),key,mb.shape,nsteps=1000,traj=False)\n",
        "print(f'SDE performance {pmetric(energy_samples_stoch)[0]}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jNE5bxk3NBr-"
      },
      "outputs": [],
      "source": [
        "%debug"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FM9Z3uXvtgnt"
      },
      "outputs": [],
      "source": [
        "energy_samples_det = samplers.ode_sample(diff,energy_conditioned_scores(diff,score_fn),key,mb.shape)\n",
        "print(f'ODE performance {pmetric(energy_samples_det)[0]}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OdqVaK4bOvfo"
      },
      "outputs": [],
      "source": [
        "from scipy.ndimage import correlate1d\n",
        "energy_samples_traj = samplers.sde_sample(diff,energy_conditioned_scores(diff,score_fn),key,mb.shape,nsteps=1000,traj=True)\n",
        "n=energy_samples_traj.shape[0]+1\n",
        "ts = (.5+jnp.arange(n)[::-1])[:-1]/n\n",
        "scores = vmap(score_fn)(energy_samples_traj.reshape(len(ts),-1),ts).reshape(energy_samples_traj.shape)\n",
        "best_reconstructions = (energy_samples_traj+diff.sigma(ts)[:,None,None,None]**2*scores)/diff.scale(ts)[:,None,None,None]\n",
        "vs = -correlate1d(best_reconstructions,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=2)\n",
        "z = ode_datasets.pack(best_reconstructions,(vmap(vmap(vmap(ds.mass)))(best_reconstructions)@vs[...,None]).squeeze(-1))\n",
        "Hs = vmap(vmap(vmap(ds.hamiltonian)))(z)\n",
        "kstart=10\n",
        "z0 = z[-1,:,kstart]\n",
        "#print(vmap(ds.hamiltonian)(z0)[i])\n",
        "z_gts = vmap(ds.integrate,(0,None),0)(z0,ds.T_long[kstart:])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9DTppttzKL7k"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "import matplotlib as mpl\n",
        "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
        "cmap='inferno'\n",
        "i=5 #@param {type:\"slider\", min:0, max:30, step:1}\n",
        "\n",
        "fig1 = plt.figure()\n",
        "ax1 = fig1.add_subplot(111)\n",
        "data = Hs[100::25,i,:].T\n",
        "#print(data[:,-1])\n",
        "ax1.plot(ds.T_long,data[::-1],alpha=.6,lw=2)\n",
        "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
        "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
        "for k,j in enumerate(ax1.lines):\n",
        "    j.set_color(colors[k])\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel(r'Energy')\n",
        "minn,maxx = data[:,-1].min(),data[:,-1].max()\n",
        "plt.ylim(minn-2*(maxx-minn),maxx+2*(maxx-minn))\n",
        "divider = make_axes_locatable(plt.gca())\n",
        "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
        "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
        "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
        "#cb1.ax.invert_yaxis()\n",
        "cb1.set_label('diffusion time (0,1)')\n",
        "plt.gcf().add_axes(ax_cb)\n",
        "\n",
        "ax1.plot(ds.T_long[kstart:],vmap(vmap(ds.hamiltonian))(z_gts)[i],color='g',lw=3)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mc0IHMpaDrMR"
      },
      "outputs": [],
      "source": [
        "@jit\n",
        "def rel_err(x,y):\n",
        "  return  jnp.abs(x-y).sum(-1)/(jnp.abs(x).sum(-1)+jnp.abs(y).sum(-1))\n",
        "\n",
        "\n",
        "kstart=10\n",
        "@jit\n",
        "def log_prediction_metric(qs):\n",
        "  k=kstart\n",
        "  q = qs[k:]\n",
        "  v = -(q[:-2]-q[2:])/(2*(ds.T[1]-ds.T[0]))\n",
        "  #print(vmap(ds.mass)(q[1:-1]).shape,v.shape)\n",
        "  z = ode_datasets.pack(q[1:-1],(vmap(ds.mass)(q[1:-1])@v[...,None]).squeeze(-1))\n",
        "  T = ds.T_long[k+1:-1]\n",
        "  z_gt = ds.integrate(z[0],T)\n",
        "  return jnp.log(rel_err(z,z_gt)[1:len(T)//3]).mean()\n",
        "\n",
        "@jit\n",
        "def pmetric(qs):\n",
        "  log_metric = vmap(log_prediction_metric)(qs)\n",
        "  return jnp.exp(log_metric.mean()),jnp.exp(log_metric.std()/jnp.sqrt(log_metric.shape[0]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fAyaiDRrMDd5"
      },
      "source": [
        "Unconditional Prediction quality"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Ec8x-pzpwIUg"
      },
      "outputs": [],
      "source": [
        "stoch_samples = samplers.sde_sample(diff,score_fn,key,x[:30].shape,nsteps=2000,traj=False)\n",
        "det_samples = samplers.ode_sample(diff,score_fn,key,x[:30].shape)\n",
        "print(f'ODE performance {pmetric(det_samples)[0]}')\n",
        "print(f'SDE performance {pmetric(stoch_samples)[0]}')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fg2JsQacJrBV"
      },
      "outputs": [],
      "source": [
        "key = random.PRNGKey(45)\n",
        "#s=s2#,history = samplers.sde_sampler(denoiser,params,key,(32,)+data.shape[1:],nsteps=500,smin=sigma_min,smax=sigma_max)\n",
        "s = stoch_samples\n",
        "\n",
        "k = 5\n",
        "q = s[:,k:]\n",
        "v = -(q[:,:-2]-q[:,2:])/(2*(ds.T[1]-ds.T[0]))\n",
        "z = ode_datasets.pack(q[:,1:-1],(vmap(vmap(ds.mass))(q[:,1:-1])@v[...,None]).squeeze(-1))\n",
        "T = ds.T_long[k+1:-1]\n",
        "z0 = z[:,0]\n",
        "z_gts = vmap(ds.integrate,(0,None),0)(z0,T)\n",
        "z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)\n",
        "z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b-DH4XmZMwWX"
      },
      "outputs": [],
      "source": [
        "for pred in [z,z_pert,z_random]:\n",
        "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T,rel_errs)\n",
        "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Prediction Error')\n",
        "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gxOy42H5MyDV"
      },
      "outputs": [],
      "source": [
        "H_gts = vmap(vmap(ds.hamiltonian))(z_gts)\n",
        "for pred in [z,z_pert,z_random]:\n",
        "  Hs = vmap(vmap(ds.hamiltonian))(pred)\n",
        "  clamped_errs = jax.lax.clamp(1e-3,jnp.abs(Hs-H_gts)/jnp.abs(Hs*H_gts),np.inf)\n",
        "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
        "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
        "  plt.plot(T,rel_errs)\n",
        "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
        "\n",
        "plt.plot()\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time')\n",
        "plt.ylabel('Energy Error')\n",
        "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5tQ0FJA2M94F"
      },
      "source": [
        "Compared trajectories"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SHKOgFXNJxY6"
      },
      "outputs": [],
      "source": [
        "for i in range(10):\n",
        "  fig = plt.figure()\n",
        "  ax = fig.add_subplot(1, 1, 1)\n",
        "  line1, = ax.plot(T,z_gts[i,:,0])\n",
        "  line2, = ax.plot(T,z[i,:,0])\n",
        "  line3, = ax.plot(T,z_pert[i,:,0])\n",
        "  plt.xlabel('Time t')\n",
        "  plt.ylabel(r'State')\n",
        "  plt.legend(['gt','model','pert'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KEMilLlTJyzM"
      },
      "outputs": [],
      "source": [
        "for i in range(10):\n",
        "  fig = plt.figure()\n",
        "  ax = fig.add_subplot(1, 1, 1)\n",
        "  line1, = ax.plot(T,z_gts[i,:,0])\n",
        "  line2, = ax.plot(T,z[i,:,0])\n",
        "  line3, = ax.plot(T,z_gts[i,:,-1])\n",
        "  line5, = ax.plot(T,z[i,:,-1])\n",
        "  plt.xlabel('Time t')\n",
        "  plt.ylabel(r'State')\n",
        "  plt.legend([r'$\\theta_0$ gt',r'$\\theta_0$ model',r'v gt', r'v model'])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1tEXOhfcMmho"
      },
      "outputs": [],
      "source": [
        "metric_vals =[]\n",
        "metric_stds = []\n",
        "Ns = [25,50,100,200,500,1000,2000]\n",
        "for N in Ns:\n",
        "  s = samplers.sde_sample(diff,score_fn,key,x[:30].shape,nsteps=N)\n",
        "  mean,std = pmetric(s)\n",
        "  metric_vals.append(mean)\n",
        "  metric_stds.append(std)\n",
        "metric_vals = np.array(metric_vals)\n",
        "metric_stds = np.array(metric_stds)\n",
        "\n",
        "plt.plot(Ns,metric_vals)\n",
        "plt.fill_between(Ns, metric_vals/metric_stds, metric_vals*metric_stds,alpha=.3)\n",
        "plt.xlabel('Sampler steps')\n",
        "plt.ylabel('Pmetric value')\n",
        "plt.xscale('log')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Qt2vnvKFoBuo"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D467_GYqB0Vo"
      },
      "outputs": [],
      "source": [
        "ds2 = ode_datasets.NPendulum(N=50,n=2,dt=.002)\n",
        "thetas2,vs2 = ode_datasets.unpack(ds2.Zs)\n",
        "test_x2 = thetas2\n",
        "#test_x2 = -correlate1d(test_x2,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=1)\n",
        "freq = np.fft.rfftfreq(test_x2.shape[1],d=(ds2.T[1]-ds2.T[0]))\n",
        "ffts  =jnp.abs(np.fft.rfft(test_x2,axis=1))\n",
        "avg_freq  =jnp.exp(jnp.log(ffts)[:,:,-1]).mean(0)\n",
        "plt.plot(freq,avg_freq.T,color='blue',label=r'2 Pendulum $\\theta$');\n",
        "plt.plot(freq,100/freq,label=r'1/f')\n",
        "plt.plot(freq,30/freq**2,label=r'1/$f^2$')\n",
        "plt.plot(freq,10/freq**3,label=r'1/$f^3$')\n",
        "#plt.plot(freq,np.exp(-10*freq**2),label=r'$e^{-f^2}$')\n",
        "plt.plot(freq,ffts[:,:,-1].T,color='y',alpha=.2);\n",
        "plt.yscale('log')\n",
        "plt.xscale('log')\n",
        "plt.ylim(1e-3,1e5)\n",
        "plt.xlabel('Frequency f')\n",
        "plt.ylabel(r'Fourier spectrum')\n",
        "plt.legend()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "train_diagnostics.ipynb",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
