{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AJJOEf6c7I-_"
   },
   "outputs": [],
   "source": [
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gLeUAXP5j8RJ"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "#tf.config.experimental.set_visible_devices([], \"GPU\")\n",
    "\n",
    "import importlib\n",
    "from userdiffusion import ode_datasets, unet, samplers, diffusion as train\n",
    "importlib.reload(ode_datasets)\n",
    "importlib.reload(unet)\n",
    "importlib.reload(samplers)\n",
    "importlib.reload(train)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rc\n",
    "rc('animation', html='jshtml')\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "import jax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RQwCeAGHkB3i"
   },
   "outputs": [],
   "source": [
    "dt = .1\n",
    "bs = 400\n",
    "#ds = ode_datasets.NPendulum(N=4000+bs,n=2,dt=dt)\n",
    "ds = ode_datasets.LorenzDataset(N=4000+bs,dt=dt,integration_time=7)\n",
    "\n",
    "thetas  = ds.Zs[bs:,:60]\n",
    "test_x = ds.Zs[:bs,:60]\n",
    "T_long =ds.T_long[:60]\n",
    "#thetas /=thetas.std()\n",
    "#thetas = jax.random.normal(jax.random.PRNGKey(38),thetas.shape)\n",
    "dataset = tf.data.Dataset.from_tensor_slices(thetas)\n",
    "\n",
    "dataiter = dataset.shuffle(len(dataset)).batch(bs).as_numpy_iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uzMJBTiHDeuH"
   },
   "outputs": [],
   "source": [
    "from jax import jit,vmap\n",
    "@jit\n",
    "def rel_err(x,y):\n",
    "  return  jnp.abs(x-y).sum(-1)/(jnp.abs(x).sum(-1)+jnp.abs(y).sum(-1))\n",
    "\n",
    "\n",
    "kstart=10\n",
    "@jit\n",
    "def log_prediction_metric(qs):\n",
    "  k=kstart\n",
    "  z = q = qs[k:]\n",
    "  T = T_long[k:]\n",
    "  z_gt = ds.integrate(z[0],T)\n",
    "  return jnp.log(rel_err(z,z_gt)[1:len(T)//3]).mean()\n",
    "\n",
    "@jit\n",
    "def pmetric(qs):\n",
    "  log_metric = vmap(log_prediction_metric)(qs)\n",
    "  return jnp.exp(log_metric.mean()),jnp.exp(log_metric.std()/jnp.sqrt(log_metric.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JbFiGz6W8Qgc"
   },
   "outputs": [],
   "source": [
    "x = test_x#next(dataiter())\n",
    "t = np.random.rand(x.shape[0])\n",
    "model = unet.UNet(unet.unet_64_config(out_dim=x.shape[-1],base_channels=24))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nUUCduoEkaYh"
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "noisetype='White'#@param ['White','Pink','Brown']\n",
    "noise = {'White':train.Identity,'Pink':train.PinkCovariance,'Brown':train.BrownianCovariance}[noisetype]\n",
    "difftype='VE'#@param ['VP','VE','SubVP','Test']\n",
    "diff = {'VP':train.VariancePreserving,'VE':train.VarianceExploding,\n",
    "        'SubVP':train.SubVariancePreserving,'Test':train.Test}[difftype](noise)\n",
    "epochs = 2000#@param {'type':'integer'}\n",
    "ic_conditioning = False#@param {'type':'boolean'}\n",
    "score_fn = train.train_diffusion(model,dataiter,epochs,diffusion=diff,lr=3e-4,\n",
    "                                 ic_conditioning=ic_conditioning)\n",
    "key= jax.random.PRNGKey(38)\n",
    "cond =test_x[:,:3]\n",
    "eval_scorefn = partial(score_fn,cond=cond) if ic_conditioning else score_fn\n",
    "nll = samplers.compute_nll(diff,eval_scorefn,key,test_x).mean()\n",
    "stoch_samples = samplers.sde_sample(diff,eval_scorefn,key,test_x.shape,nsteps=1000,traj=False)\n",
    "err = pmetric(stoch_samples)[0]\n",
    "print(f\"{noise.__name__} gets NLL {nll:.3f} and err {err:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "P1VdapJf4Omo"
   },
   "outputs": [],
   "source": [
    "samples = noise.sample(jax.random.PRNGKey(39),x.shape)\n",
    "samples2 = jax.random.normal(jax.random.PRNGKey(39),x.shape)\n",
    "samples2 = jnp.cumsum(samples2,axis=1)\n",
    "half_x = x[:,:x.shape[1]//2]\n",
    "samples2 = jnp.concatenate([half_x,half_x[:,::-1]],axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "TY5VdTAx4Xlz"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "i=15 #@param {type:\"slider\", min:0, max:30, step:1}\n",
    "plt.plot(T_long,samples[i,:,0].T,alpha=1/2,label='brown1')\n",
    "plt.plot(T_long,samples2[i,:,0].T,alpha=1/2,label='data')\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "zJk34SSU5r64"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "i=5 #@param {type:\"slider\", min:0, max:30, step:1}\n",
    "fourier_mag1 = jnp.abs(jnp.fft.rfft(samples[::25,:,-1],axis=1))\n",
    "fourier_mag2 = jnp.abs(jnp.fft.rfft(samples2[::25,:,-1],axis=1))\n",
    "plt.plot(fourier_mag1.T,alpha=1/5,label='brown1',color='y')\n",
    "plt.plot(fourier_mag2.T,alpha=1/5,label='data',color='brown')\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "# plt.xlabel('Time t')\n",
    "# plt.ylabel(r'State')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fRn9pR8Ewzey"
   },
   "outputs": [],
   "source": [
    "importlib.reload(samplers)\n",
    "importlib.reload(train)\n",
    "#samplers.probability_flow(diff,score_fn,x,1e-4,1.).std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "y-P8uVpcWP-x"
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "key= jax.random.PRNGKey(38)\n",
    "samplers.compute_nll(diff,score_fn,key,x).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8nc8LmPLL_z4"
   },
   "source": [
    "Sample generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Kboc2IIeteh7"
   },
   "outputs": [],
   "source": [
    "stoch_samples = samplers.sde_sample(diff,score_fn,key,x[:30].shape,nsteps=1000,traj=False)\n",
    "sample_traj = samplers.sde_sample(diff,score_fn,key,x[:30].shape,nsteps=1000,traj=True)\n",
    "det_samples = samplers.ode_sample(diff,score_fn,key,x[:30].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5rrpbT2MAWRn"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "i=6 #@param {type:\"slider\", min:0, max:30, step:1}\n",
    "plt.plot(T_long,sample_traj[0::100,i,:,0].T,alpha=1/2)\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "#plt.ylim(-5,5)\n",
    "#plt.legend([r'GT',r'Model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-LerwGMP9Bla"
   },
   "outputs": [],
   "source": [
    "from jax import vmap\n",
    "n=sample_traj.shape[0]+1\n",
    "ts = (.5+jnp.arange(n)[::-1])[:-1]/n\n",
    "scores = vmap(score_fn)(sample_traj,ts).reshape(sample_traj.shape)\n",
    "best_reconstructions = (sample_traj+diff.sigma(ts)[:,None,None,None]**2*scores)/diff.scale(ts)[:,None,None,None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xr1eZyO-1IWB"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import matplotlib as mpl\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "i=4 #@param {type:\"slider\", min:0, max:30, step:1}\n",
    "\n",
    "cmap='inferno'\n",
    "\n",
    "\n",
    "fig1 = plt.figure()\n",
    "ax1 = fig1.add_subplot(111)\n",
    "data = best_reconstructions[100::25,i,:,-1].T\n",
    "ax1.plot(T_long,data[:],alpha=.6,lw=2)\n",
    "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
    "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
    "for i,j in enumerate(ax1.lines):\n",
    "    j.set_color(colors[i])\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "#plt.ylim(-2,2)\n",
    "divider = make_axes_locatable(plt.gca())\n",
    "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
    "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
    "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
    "#cb1.ax.invert_yaxis()\n",
    "cb1.set_label('diffusion time (0,1)')\n",
    "plt.gcf().add_axes(ax_cb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2ohSLTVT9f2M"
   },
   "outputs": [],
   "source": [
    "from scipy.ndimage import correlate1d\n",
    "i=22 #@param {type:\"slider\", min:0, max:30, step:1}\n",
    "vs = -correlate1d(best_reconstructions,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=2)\n",
    "print(vs.shape)\n",
    "fig1 = plt.figure()\n",
    "ax1 = fig1.add_subplot(111)\n",
    "data = vs[100::25,i,:,-1].T\n",
    "ax1.plot(T_long,data[:],alpha=.6,lw=2)\n",
    "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
    "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
    "for i,j in enumerate(ax1.lines):\n",
    "    j.set_color(colors[i])\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'$\\dot \\theta$')\n",
    "#plt.ylim(-2,2)\n",
    "divider = make_axes_locatable(plt.gca())\n",
    "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
    "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
    "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
    "#cb1.ax.invert_yaxis()\n",
    "cb1.set_label('diffusion time (0,1)')\n",
    "plt.gcf().add_axes(ax_cb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fiq7MGQm-1x-"
   },
   "outputs": [],
   "source": [
    "i=15 # @param {type:\"slider\", min:0, max:30, step:1}\n",
    "nn = sample_traj.shape[2]\n",
    "fft = jnp.abs(np.fft.rfft(sample_traj,axis=2))#[:,:,:nn//2]\n",
    "freq = np.fft.rfftfreq(sample_traj.shape[2],d=(ds.T[1]-ds.T[0]))#[:nn//2]\n",
    "\n",
    "fig1 = plt.figure()\n",
    "ax1 = fig1.add_subplot(111)\n",
    "data = fft[0::25,i,:,-1].T\n",
    "ax1.plot(freq,data[:,:],alpha=.6,lw=2)\n",
    "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
    "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
    "for i,j in enumerate(ax1.lines):\n",
    "    j.set_color(colors[i])\n",
    "plt.xlabel('Frequency f')\n",
    "plt.ylabel(r'Fourier spectrum')\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "#plt.ylim(-2,2)\n",
    "divider = make_axes_locatable(plt.gca())\n",
    "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
    "norm = mpl.colors.Normalize(vmin=ts[0], vmax=ts[-25])    \n",
    "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
    "#cb1.ax.invert_yaxis()\n",
    "cb1.set_label('diffusion time (0,1)')\n",
    "plt.gcf().add_axes(ax_cb)\n",
    "ax1.plot(freq,jnp.abs(np.fft.rfft(x,axis=1))[::10,:,-1].T,color='blue',alpha=.1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GSCNVbsV-zlI"
   },
   "outputs": [],
   "source": [
    "i=8 # @param {type:\"slider\", min:0, max:30, step:1}\n",
    "nn = best_reconstructions.shape[2]\n",
    "fft = jnp.abs(np.fft.rfft(best_reconstructions,axis=2))#[:,:,:nn//2]\n",
    "freq = np.fft.rfftfreq(best_reconstructions.shape[2],d=(ds.T[1]-ds.T[0]))#[:nn//2]\n",
    "\n",
    "fig1 = plt.figure()\n",
    "ax1 = fig1.add_subplot(111)\n",
    "data = fft[100::25,i,:,-1].T\n",
    "ax1.plot(freq,data[:,:],alpha=.6,lw=2)\n",
    "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
    "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
    "for i,j in enumerate(ax1.lines):\n",
    "    j.set_color(colors[i])\n",
    "plt.xlabel('Frequency f')\n",
    "plt.ylabel(r'Fourier spectrum')\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "#plt.ylim(-2,2)\n",
    "divider = make_axes_locatable(plt.gca())\n",
    "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
    "norm = mpl.colors.Normalize(vmax=ts[100], vmin=ts[-25])    \n",
    "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
    "#cb1.ax.invert_yaxis()\n",
    "cb1.set_label('diffusion time (0,1)')\n",
    "plt.gcf().add_axes(ax_cb)\n",
    "ax1.plot(freq,jnp.abs(np.fft.rfft(x,axis=1))[::10,:,-1].T,color='blue',alpha=.1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Fa5_fj-mU_zw"
   },
   "outputs": [],
   "source": [
    "\n",
    "import matplotlib.pyplot as plt\n",
    "i=10 # @param {type:\"slider\", min:0, max:30, step:1}\n",
    "plt.plot(T_long,x[i,:,-1])\n",
    "plt.plot(T_long,det_samples[i,:,-1])\n",
    "plt.plot(T_long,stoch_samples[i,:,-1])\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "plt.legend([r'GT',r'Model (ODE)', r'Model (SDE)'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QFHO0tDBL42D"
   },
   "source": [
    "Test ability to condition model on previous timesteps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cH8P02jcfKRH"
   },
   "outputs": [],
   "source": [
    "from jax import grad,jit\n",
    "condition_amount = 10# @param {type:\"slider\", min:0, max:50, step:1}\n",
    "mb = x[:30,:]\n",
    "data_std = x.std()\n",
    "\n",
    "def inpainting_scores(diffusion,scorefn,observed_values,slc):\n",
    "  b,n,c = observed_values.shape\n",
    "  def conditioned_scores(xt,t):\n",
    "    unflat_xt = xt.reshape(b,-1,c)\n",
    "\n",
    "    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)\n",
    "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
    "    combined_score = unobserved_score.at[:,slc].set(observed_score)\n",
    "    return combined_score\n",
    "  return conditioned_scores\n",
    "\n",
    "def inpainting_scores2(diffusion,scorefn,observed_values,slc,scale=300.):\n",
    "  b,n,c = observed_values.shape\n",
    "  def conditioned_scores(xt,t):\n",
    "    unflat_xt = xt.reshape(b,-1,c)\n",
    "\n",
    "    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)\n",
    "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
    "    def constraint(xt):\n",
    "      one_step_xhat = (xt+diffusion.sigma(t)**2*scorefn(xt,t))/diffusion.scale(t)\n",
    "      return jnp.sum((one_step_xhat.reshape(b,-1,c)[:,slc]-observed_values)**2)\n",
    "    #unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*10/(diff.g2(t)/2)\n",
    "    unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*scale*diff.scale(t)**2/diff.sigma(t)**2\n",
    "    combined_score = unobserved_score.at[:,slc].set(observed_score)\n",
    "    return combined_score#.reshape(-1)\n",
    "  return jit(conditioned_scores)\n",
    "\n",
    "slc = slice(condition_amount)\n",
    "conditioned_samples = samplers.sde_sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape,nsteps=1000,traj=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kWWztvtoUljD"
   },
   "outputs": [],
   "source": [
    "expanded = (mb[None]+jnp.zeros((10,1,1,1))).reshape(mb.shape[0]*10,*mb.shape[1:])#[:,slc]\n",
    "predictions = samplers.sde_sample(diff,inpainting_scores2(diff,score_fn,expanded[:,slc],slc,scale=300.),key,expanded.shape,nsteps=2000,traj=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "92hg5RnQW6jn"
   },
   "outputs": [],
   "source": [
    "z_pert = vmap(ds.integrate,(0,None),0)(mb[:,0]+1e-3*np.random.randn(*mb[:,0].shape),T_long)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3qOJfbZFYLM9"
   },
   "outputs": [],
   "source": [
    "preds = predictions.reshape(10,-1,*predictions.shape[1:])\n",
    "lower = np.percentile(preds.mean(-1),10,axis=0)\n",
    "upper = np.percentile(preds.mean(-1),90,axis=0)\n",
    "for i in range(mb.shape[0]):\n",
    "  if i>10: break\n",
    "  plt.plot(T_long,mb[i].mean(-1))\n",
    "  #plt.plot(T_long,z_pert[i].mean(-1))\n",
    "  plt.fill_between(T_long,lower[i],upper[i],alpha=.3,color='y')\n",
    "  plt.plot()\n",
    "  #plt.yscale('log')\n",
    "  plt.xlabel('Time')\n",
    "  plt.ylabel('State sum')\n",
    "  plt.legend(['Ground Truth','Model 10-90 percentiles'])\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nC9BdYcQVcL1"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-Ps8bGfzWhX4"
   },
   "outputs": [],
   "source": [
    "lower.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0h19RNwvR4is"
   },
   "outputs": [],
   "source": [
    "from jax import jit,vmap,random\n",
    "\n",
    "\n",
    "@jit\n",
    "def rel_err(z1,z2):\n",
    "  return jnp.abs((jnp.abs(z1-z2)).sum(-1)/(jnp.abs(z1).sum(-1)+jnp.abs(z2).sum(-1)))\n",
    "\n",
    "gt = x[:30]\n",
    "#for pred in [conditioned_samples[-1]]:#,conditioned_sample]:\n",
    "for scale in [10.,100.,300.,1000.,3000.]:\n",
    "  pred = samplers.sde_sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc,scale=scale),key,mb.shape,nsteps=2000,traj=False)\n",
    "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,gt),np.inf)\n",
    "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
    "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
    "  plt.plot(T_long,rel_errs,label=f\"r={scale}\")\n",
    "  plt.fill_between(T_long, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
    "\n",
    "plt.plot()\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Prediction Error')\n",
    "plt.legend()#//['SDE completion','ODE completion'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IUMkKWJuAy2y"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "i=1 # @param {type:\"slider\", min:0, max:30, step:1}\n",
    "#plt.plot(T_long,conditioned_samples[-600::100,i,:,0].T,zorder=0,alpha=.2)\n",
    "plt.plot(T_long,conditioned_samples[-1,i,  :,0].T,zorder=2,label='model')\n",
    "plt.plot(T_long,x[i,:,0],label='gt',alpha=1,zorder=99)\n",
    "plt.plot(T_long[slc],x[i,slc,0],label='cond',alpha=1,zorder=100,lw=3)\n",
    "\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "#plt.ylim(-3,3)\n",
    "plt.legend()\n",
    "#plt.legend([r'GT',r'Model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "21fgx9UbckBK"
   },
   "outputs": [],
   "source": [
    "conditioned_sample = samplers.ode_sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xKR36dO6FgwA"
   },
   "outputs": [],
   "source": [
    "from jax import jit,vmap,random\n",
    "\n",
    "@jit\n",
    "def rel_err(z1,z2):\n",
    "  return jnp.abs((jnp.abs(z1-z2)).sum(-1)/(jnp.abs(z1).sum(-1)*jnp.abs(z2).sum(-1)))\n",
    "\n",
    "gt = x[:30]\n",
    "for pred in [conditioned_samples[-1]]:#,conditioned_sample]:\n",
    "  clamped_errs = jax.lax.clamp(1e-5,rel_err(pred,gt),np.inf)\n",
    "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
    "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
    "  plt.plot(T_long,rel_errs)\n",
    "  plt.fill_between(T_long, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
    "\n",
    "plt.plot()\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Time')\n",
    "plt.ylim(1e-4,1)\n",
    "plt.ylabel('Prediction Error')\n",
    "plt.legend(['SDE completion'])#,'ODE completion'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3Xfy7NxAfQ8S"
   },
   "outputs": [],
   "source": [
    "i=11 # @param {type:\"slider\", min:0, max:29, step:1}\n",
    "plt.plot(T_long,x[i,:,1])\n",
    "plt.plot(T_long[slc],x[i,slc,1],lw=3)\n",
    "plt.plot(T_long,conditioned_sample[i,:,1])\n",
    "plt.plot(T_long,conditioned_samples[-1,i,:,1])\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "plt.legend([r'GT','Conditioning', 'ODE rollout','SDE rollout'])\n",
    "#plt.ylim(-3,3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mMvSdxUUG9Op"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mc0IHMpaDrMR"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fAyaiDRrMDd5"
   },
   "source": [
    "Unconditional Prediction quality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ec8x-pzpwIUg"
   },
   "outputs": [],
   "source": [
    "# stoch_samples = samplers.sde_sample(diff,score_fn,key,x[:30].shape,nsteps=1000,traj=False)\n",
    "# det_samples = samplers.ode_sample(diff,score_fn,key,x[:30].shape)\n",
    "print(f'ODE performance {pmetric(det_samples)[0]}')\n",
    "print(f'SDE performance {pmetric(stoch_samples)[0]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fg2JsQacJrBV"
   },
   "outputs": [],
   "source": [
    "from jax import random\n",
    "key = random.PRNGKey(45)\n",
    "#s=s2#,history = samplers.sde_sampler(denoiser,params,key,(32,)+data.shape[1:],nsteps=500,smin=sigma_min,smax=sigma_max)\n",
    "s = stoch_samples#energy_samples_det#stoch_samples\n",
    "\n",
    "k = 5\n",
    "z = q = s[:,k:]\n",
    "T = T_long[k:]\n",
    "z0 = z[:,0]\n",
    "z_gts = vmap(ds.integrate,(0,None),0)(z0,T)\n",
    "z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)\n",
    "z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b-DH4XmZMwWX"
   },
   "outputs": [],
   "source": [
    "for pred in [z,z_pert,z_random]:\n",
    "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)\n",
    "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
    "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
    "  plt.plot(T,rel_errs)\n",
    "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
    "\n",
    "plt.plot()\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Prediction Error')\n",
    "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5tQ0FJA2M94F"
   },
   "source": [
    "Compared trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "oxMDdMV1CiOd"
   },
   "outputs": [],
   "source": [
    "ds.animate(z[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FgjBdB5aDkz2"
   },
   "outputs": [],
   "source": [
    "ds.animate(z_gts[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SHKOgFXNJxY6"
   },
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "  fig = plt.figure()\n",
    "  ax = fig.add_subplot(1, 1, 1)\n",
    "  line1, = ax.plot(T,z_gts[i,:,0])\n",
    "  line2, = ax.plot(T,z[i,:,0])\n",
    "  line3, = ax.plot(T,z_pert[i,:,0])\n",
    "  plt.xlabel('Time t')\n",
    "  plt.ylabel(r'State')\n",
    "  plt.legend(['gt','model','pert'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KEMilLlTJyzM"
   },
   "outputs": [],
   "source": [
    "for i in range(20):\n",
    "  fig = plt.figure()\n",
    "  ax = fig.add_subplot(1, 1, 1)\n",
    "  line1, = ax.plot(T,z_gts[i,:,0])\n",
    "  line2, = ax.plot(T,z[i,:,0])\n",
    "  line3, = ax.plot(T,z_gts[i,:,-1])\n",
    "  line5, = ax.plot(T,z[i,:,-1])\n",
    "  plt.xlabel('Time t')\n",
    "  plt.ylabel(r'State')\n",
    "  plt.legend([r'x gt',r'x model',r'z gt', r'z model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1tEXOhfcMmho"
   },
   "outputs": [],
   "source": [
    "metric_vals =[]\n",
    "metric_stds = []\n",
    "Ns = [25,50,100,200,500,1000,2000]\n",
    "for N in Ns:\n",
    "  s = samplers.sde_sample(diff,score_fn,key,x[:30].shape,nsteps=N)\n",
    "  mean,std = pmetric(s)\n",
    "  metric_vals.append(mean)\n",
    "  metric_stds.append(std)\n",
    "metric_vals = np.array(metric_vals)\n",
    "metric_stds = np.array(metric_stds)\n",
    "\n",
    "plt.plot(Ns,metric_vals)\n",
    "plt.fill_between(Ns, metric_vals/metric_stds, metric_vals*metric_stds,alpha=.3)\n",
    "plt.xlabel('Sampler steps')\n",
    "plt.ylabel('Pmetric value')\n",
    "plt.xscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Qt2vnvKFoBuo"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "x-998NV2Jqsq"
   },
   "outputs": [],
   "source": [
    "T = ds.T_long\n",
    "z = stoch_samples\n",
    "z0 = test_x[:,0]\n",
    "z_gts = test_x#vmap(ds.integrate,(0,None),0)(z0,T)\n",
    "z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape),T)\n",
    "z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)\n",
    "for pred in [z,z_pert,z_random]:\n",
    "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)\n",
    "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
    "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
    "  plt.plot(T,rel_errs)\n",
    "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
    "\n",
    "plt.plot()\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Prediction Error')\n",
    "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "last_runtime": {
    "build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu",
    "kind": "private"
   },
   "name": "lorenz.ipynb",
   "private_outputs": true,
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
