{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6Aryrhn76oA7"
   },
   "outputs": [],
   "source": [
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "# https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "gLeUAXP5j8RJ"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "#tf.config.experimental.set_visible_devices([], \"GPU\")\n",
    "\n",
    "import importlib\n",
    "from userdiffusion import ode_datasets, unet, samplers, diffusion as train\n",
    "importlib.reload(ode_datasets)\n",
    "importlib.reload(unet)\n",
    "importlib.reload(samplers)\n",
    "importlib.reload(train)\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import rc\n",
    "rc('animation', html='jshtml')\n",
    "import jax.numpy as jnp\n",
    "import numpy as np\n",
    "import jax"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RQwCeAGHkB3i"
   },
   "outputs": [],
   "source": [
    "dt = 6.\n",
    "bs = 400\n",
    "ds = ode_datasets.FitzHughDataset(N=4000+bs,dt=dt,integration_time=3000)\n",
    "\n",
    "train_x = ds.Zs[bs:,:60]\n",
    "test_x = ds.Zs[:bs,:60]\n",
    "T_long =ds.T_long[:60]\n",
    "dataset = tf.data.Dataset.from_tensor_slices(train_x)\n",
    "\n",
    "dataiter = dataset.shuffle(len(dataset)).batch(bs).as_numpy_iterator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bg2BkwRS3thR"
   },
   "outputs": [],
   "source": [
    "plt.plot(T_long,train_x[:300,:,:2].sum(-1).T/2)\n",
    "plt.xlabel('Time (s)')\n",
    "plt.ylabel(r'$\\bar{x}$')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vdo3l-2L-ILC"
   },
   "outputs": [],
   "source": [
    "jnp.abs(train_x).max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JbFiGz6W8Qgc"
   },
   "outputs": [],
   "source": [
    "x = test_x#next(dataiter())\n",
    "t = np.random.rand(x.shape[0])\n",
    "model = unet.UNet(unet.unet_64_config(out_dim=x.shape[-1],base_channels=24))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "uzMJBTiHDeuH"
   },
   "outputs": [],
   "source": [
    "from jax import jit,vmap\n",
    "@jit\n",
    "def rel_err(x,y):\n",
    "  return  jnp.abs(x-y).sum(-1)/(jnp.abs(x).sum(-1)+jnp.abs(y).sum(-1))\n",
    "\n",
    "\n",
    "kstart=10\n",
    "@jit\n",
    "def log_prediction_metric(qs):\n",
    "  k=kstart\n",
    "  z = q = qs[k:]\n",
    "  T = T_long[k:]\n",
    "  z_gt = ds.integrate(z[0],T)\n",
    "  return jnp.log(rel_err(z,z_gt)[1:len(T)//3]).mean()\n",
    "\n",
    "@jit\n",
    "def pmetric(qs):\n",
    "  log_metric = vmap(log_prediction_metric)(qs)\n",
    "  return jnp.exp(log_metric.mean()),jnp.exp(log_metric.std()/jnp.sqrt(log_metric.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nUUCduoEkaYh"
   },
   "outputs": [],
   "source": [
    "noisetype='White'#@param ['White','Pink','Brown']\n",
    "noise = {'White':train.Identity,'Pink':train.PinkCovariance,'Brown':train.BrownianCovariance}[noisetype]\n",
    "difftype='VE'#@param ['VP','VE','SubVP','Test']\n",
    "diff = {'VP':train.VariancePreserving,'VE':train.VarianceExploding,\n",
    "        'SubVP':train.SubVariancePreserving,'Test':train.Test}[difftype](noise)\n",
    "epochs = 2000#@param {'type':'integer'}\n",
    "score_fn = train.train_diffusion(model,dataiter,epochs,diffusion=diff,lr=3e-4)\n",
    "key= jax.random.PRNGKey(38)\n",
    "nll = samplers.compute_nll(diff,score_fn,key,x).mean()\n",
    "stoch_samples = samplers.sde_sample(diff,score_fn,key,x[:30].shape,nsteps=1000,traj=False)\n",
    "err = pmetric(stoch_samples)[0]\n",
    "print(f\"{noise.__name__} gets NLL {nll:.3f} and err {err:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "VZQ8YlieCeEr"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nkI0FpH16jeA"
   },
   "outputs": [],
   "source": [
    "z = jnp.linspace(-3,3,100)\n",
    "plt.plot(z,jax.scipy.stats.norm.cdf(z),label='probit')\n",
    "plt.plot(z,jax.nn.sigmoid(1.6*z),label='logit')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "FhPhx-4NrYDi"
   },
   "outputs": [],
   "source": [
    "import jax \n",
    "jax.config.update('jax_default_matmul_precision', 'float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cH8P02jcfKRH"
   },
   "outputs": [],
   "source": [
    "from jax import grad,jit\n",
    "condition_amount = 13\n",
    "mb = x[:30,:]\n",
    "\n",
    "def event_constraint(x):\n",
    "    C = jnp.max(x[...,:2].mean(-1),-1)-2\n",
    "    return C\n",
    "\n",
    "def statistic(x):\n",
    "    return jnp.max(x[...,:2].mean(-1),-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aWMGSwF950vR"
   },
   "outputs": [],
   "source": [
    "sample_traj = samplers.sde_sample(diff,score_fn,key,x[:400].shape,nsteps=1000,traj=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "neMcE0Gv7rDf"
   },
   "outputs": [],
   "source": [
    "#from train import unsqueeze_like\n",
    "diffusion=diff\n",
    "scorefn =score_fn#event_scores(diff,score_fn,event_constraint)\n",
    "constraint=event_constraint\n",
    "def xhat(xt,t):\n",
    "    #print(xt.shape,t.shape)\n",
    "    tt = train.unsqueeze_like(xt,t)\n",
    "    dt = .00\n",
    "    score_xhat1 = (xt+diffusion.sigma(tt+dt)**2*scorefn(xt,t+dt))/diffusion.scale(tt+dt)\n",
    "    #score_xhat2 = (xt+diffusion.sigma(tt-dt)**2*scorefn(xt,t-dt))/diffusion.scale(tt-dt)\n",
    "    #limiting_xhat = (xt/(1+diffusion.sigma(tt)**2/data_std**2))/diffusion.scale(tt)\n",
    "    #m1 = (t+dt<=1)+0.\n",
    "    #m2 = (t-dt>=0)+0.\n",
    "    #m1,m2 = train.unsqueeze_like(xt,m1,m2)\n",
    "    return (score_xhat1)#*m1)#+score_xhat2*m2)/(m1+m2)\n",
    "def cstd(xt,t):\n",
    "    xh = xhat(xt,t)\n",
    "    C,DC = vmap(jax.value_and_grad(constraint))(xh)\n",
    "    SigmaDC = vmap(jax.grad(lambda x,t: constraint(xhat(x[None],t)[0])))(xt,t)\n",
    "    std2 = ((DC*SigmaDC).sum((-1,-2))*diffusion.scale(t))# NOTE: will not work with img inputs\n",
    "    std3 = (DC*DC).sum((-1,-2))*diffusion.scale(t)\n",
    "    std2 = jnp.sqrt(jnp.abs(std2**2))\n",
    "    std = jnp.sqrt(jnp.abs(std2)+1e-4)*(diff.sigma(t)/diff.scale(t))\n",
    "    return C,std\n",
    "\n",
    "def log_p(xt,t):\n",
    "    Cs,stds = cstd(xt,t)\n",
    "    return jax.nn.log_sigmoid(1.6*Cs/stds).sum()\n",
    "    #return jax.scipy.stats.norm.logcdf(Cs/stds).sum()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2V4BqFCKAQwE"
   },
   "outputs": [],
   "source": [
    "#jnp.where(event_constraint(sample_traj[-1])>0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kCHK6XQo8ioE"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "N=2000\n",
    "ts = (.5+np.arange(N)[::-1])[:-1:4]/N\n",
    "#xt = sample_traj[:,77]\n",
    "i = 195 #@param {type:\"slider\", min:0, max:200, step:1}\n",
    "#xt = sample_traj[:,i] #\n",
    "xt = event_samples_traj[::4,6]\n",
    "Cs,stds = cstd(xt,ts)\n",
    "#Cs,stds = log_p(event_samples_traj[:,1],ts)\n",
    "grads = grad(log_p)(xt,ts)\n",
    "xh = xhat(xt,ts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dtWPed7Mpzyr"
   },
   "outputs": [],
   "source": [
    "tt = train.unsqueeze_like(xt,ts)\n",
    "normed_scores = diffusion.sigma(tt)*scorefn(xt,ts)\n",
    "normed_scores2 = diffusion.sigma(tt)*event_scores(diff,score_fn,event_constraint)(xt,ts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1cSvNtMeqDk6"
   },
   "outputs": [],
   "source": [
    "\n",
    "plt.plot(ts,jnp.sqrt((normed_scores2**2).mean((-1,-2))),label=r'Normed Score w/ event constraint')\n",
    "plt.plot(ts,jnp.sqrt((normed_scores**2).mean((-1,-2))),label=r'Normed Score')\n",
    "plt.xlabel('t')\n",
    "plt.legend()\n",
    "plt.yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6Yd1YS_yNnnG"
   },
   "outputs": [],
   "source": [
    "xh = jnp.where(jnp.isnan(xh),jnp.zeros_like(xh),xh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cqj0cLQ9MuJ4"
   },
   "outputs": [],
   "source": [
    "xnorm = jnp.sqrt((xh*xh).mean((-1,-2)))\n",
    "plt.plot(ts,xnorm)\n",
    "plt.plot(ts,diff.scale(ts)*data_std**2/jnp.sqrt((diff.scale(ts)*data_std)**2+diff.sigma(ts)**2))\n",
    "plt.plot(ts,diff.sigma(ts)/100)\n",
    "plt.legend([r'$||\\hat{x}_t||$',r'$sa^2/\\sqrt{s^2a^2+\\sigma^2}$',r'$\\sigma_t/100$'])\n",
    "plt.xlabel('t')\n",
    "plt.yscale('log')\n",
    "plt.ylim(.5*jnp.min(xnorm),2*jnp.max(xnorm))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-h5QXFyVnCZH"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import matplotlib as mpl\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "i=4 #@param {type:\"slider\", min:0, max:30, step:1}\n",
    "\n",
    "cmap='inferno'\n",
    "\n",
    "\n",
    "fig1 = plt.figure()\n",
    "ax1 = fig1.add_subplot(111)\n",
    "data = jnp.sqrt((xh[:-50:5,:]**2).mean(-1)).T\n",
    "ax1.plot(T_long,data[:],alpha=.6,lw=2)\n",
    "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
    "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
    "for i,j in enumerate(ax1.lines):\n",
    "    j.set_color(colors[i])\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "plt.ylim(-1,7)\n",
    "divider = make_axes_locatable(plt.gca())\n",
    "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
    "norm = mpl.colors.Normalize(vmin=ts[0], vmax=ts[-50])    \n",
    "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
    "#cb1.ax.invert_yaxis()\n",
    "cb1.set_label('diffusion time (0,1)')\n",
    "plt.gcf().add_axes(ax_cb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Yg6A3IjMCFns"
   },
   "outputs": [],
   "source": [
    "grad_norm = jnp.sqrt((grads**2).sum((-1,-2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "x4T6fRio889v"
   },
   "outputs": [],
   "source": [
    "plt.plot(ts,Cs,label='Cs',alpha=.5)\n",
    "plt.plot(ts,stds,label='sigma',alpha=.5)\n",
    "#plt.plot(ts,Cs/stds,label='ratio',alpha=.5)\n",
    "plt.legend()\n",
    "plt.ylim(-2,8)\n",
    "plt.xlabel('t')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "alnd5tlnIbfk"
   },
   "outputs": [],
   "source": [
    "plt.plot(ts,jnp.abs(Cs/stds))\n",
    "plt.yscale('log')\n",
    "plt.xlabel('t')\n",
    "plt.ylabel('C/std')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7XnSqcOP-ILT"
   },
   "outputs": [],
   "source": [
    "plt.plot(ts,stds)\n",
    "plt.plot(ts,diff.sigma(ts))\n",
    "plt.plot(ts,grad_norm)\n",
    "plt.plot(ts,grad_norm*diff.sigma(ts))\n",
    "plt.ylim(1e-4,1e3)\n",
    "plt.yscale('log')\n",
    "plt.xlabel('t')\n",
    "plt.legend([r'$\\sqrt{\\nabla C^T\\Sigma_t\\nabla C}$',r'$\\sigma_t$',r'$\\nabla \\log \\Phi$',r'$\\sigma_t \\nabla \\log \\Phi$'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Fy9DrAAj-e1O"
   },
   "outputs": [],
   "source": [
    "plt.plot(ts,jax.scipy.stats.norm.cdf(Cs/stds),label='probit')\n",
    "plt.plot(ts,jax.nn.sigmoid(1.6*Cs/stds),label='logit')\n",
    "plt.ylabel('P(E|xt)')\n",
    "plt.xlabel('t')\n",
    "plt.yscale('log')\n",
    "plt.legend()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SICB8PSTWQhy"
   },
   "outputs": [],
   "source": [
    "data_std = test_x.std()\n",
    "def event_scores(diffusion,scorefn,constraint):\n",
    "  \"\"\" Conditions on inequality constraint C(x)>0\"\"\"\n",
    "  def xhat(xt,t):\n",
    "    tt = train.unsqueeze_like(xt,t)\n",
    "    score_xhat = (xt+diffusion.sigma(tt)**2*scorefn(xt,t))/diffusion.scale(tt)\n",
    "    return score_xhat\n",
    "\n",
    "  def conditioned_scores(xt,t):\n",
    "    b,n,c = xt.shape\n",
    "    unflat_xt = xt.reshape(b,-1,c)\n",
    "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
    "    if not hasattr(t,'shape') or not len(t.shape):\n",
    "      tt = t*jnp.ones(b)\n",
    "    else:\n",
    "      tt = t\n",
    "    def log_p(xt):\n",
    "      xh = xhat(xt,tt)\n",
    "      C,DC = vmap(jax.value_and_grad(constraint))(xh)#.reshape(b,-1,n*c)\n",
    "      SigmaDC = vmap(jax.grad(lambda x,t: constraint(xhat(x[None],t)[0])))(xt,tt)\n",
    "      std2 = ((DC*SigmaDC).sum((-1,-2))*diffusion.scale(t))# NOTE: will not work with img inputs\n",
    "      std3 = (DC*DC).sum((-1,-2))*diffusion.scale(t)\n",
    "      std2 = jnp.sqrt(jnp.abs(std2*std2))\n",
    "      std = jnp.sqrt(jnp.abs(std2)+1e-2)*(diff.sigma(t)/diff.scale(t))\n",
    "      #reg = 1e-5*jnp.eye(sig.shape[-1])[None]/(1+diffusion.sigma(t)**2/data_std**2)[:,None,None]\n",
    "      #10*diff.scale(t)**2/diff.sigma(t)**2\n",
    "      return jax.nn.log_sigmoid(1.6*C/std).sum()\n",
    "      #return jax.scipy.stats.norm.logcdf(C/std).sum()\n",
    "    unobserved_score += grad(log_p)(xt)#.reshape(unflat_xt.shape)\n",
    "    return unobserved_score\n",
    "  return jit(conditioned_scores)\n",
    "\n",
    "#event_samples = samplers.sde_sample(diff,event_scores(diff,score_fn,event_constraint),key,mb.shape,nsteps=2000,traj=False)\n",
    "#event_samples_traj = samplers.sde_sample(diff,event_scores(diff,score_fn,event_constraint),key,mb.shape,nsteps=2000,traj=True)\n",
    "#event_samples_det = samplers.ode_sample(diff,event_scores(diff,score_fn,event_constraint),key,mb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0V1ueO89XYW_"
   },
   "outputs": [],
   "source": [
    "event_samples_det = samplers.ode_sample(diff,event_scores(diff,score_fn,event_constraint),key,mb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "u2uMKqL1VhY2"
   },
   "outputs": [],
   "source": [
    "ds.Zs[(event_constraint(ds.Zs[:,:60])>0)].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AjKMC1cqRerm"
   },
   "outputs": [],
   "source": [
    "from scipy.ndimage import correlate1d\n",
    "inp = stoch_samples#train_x\n",
    "ode = vmap(vmap(ds.dynamics))\n",
    "for inp in [stoch_samples,train_x,event_samples,ds.Zs[(event_constraint(ds.Zs[:,:60])>0)][:,:60],stoch_samples[event_constraint(stoch_samples)>0],event_samples_det]:\n",
    "  v = correlate1d(inp,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=1)\n",
    "  F = ode(inp,None)\n",
    "  print(rel_err(F[:,2:-2],v[:,2:-2]).mean())\n",
    "  #print(F.shape)\n",
    "  plt.plot(F[0,:,-1])\n",
    "  plt.plot(v[0,:,-1])\n",
    "  plt.xlabel('timesteps')\n",
    "  plt.legend(['ODE F(z,t)','Finite diff dz/dt'])\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "L_MyAXhtE-9S"
   },
   "outputs": [],
   "source": [
    "(event_constraint(event_samples_traj[-1])>0).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bgnUrOTDw56W"
   },
   "outputs": [],
   "source": [
    "(event_constraint(event_samples_det)>0).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8gMOjVzMZOX6"
   },
   "outputs": [],
   "source": [
    "(event_constraint(event_samples)>0).mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3ZX9QyfiBpi3"
   },
   "outputs": [],
   "source": [
    "from userdiffusion import samplers\n",
    "\n",
    "prior_scale = 112/300\n",
    "\n",
    "logp = samplers.logp(diff,score_fn,key,event_samples[5:10],prior_scale,num_probes=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1jbDys_wCxYq"
   },
   "outputs": [],
   "source": [
    "logp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Bex6ikOqRMtQ"
   },
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "def f(z):\n",
    "  N=1000\n",
    "  timesteps = (.5+np.arange(N)[::])/N\n",
    "  scores = score_fn#event_scores(diff,score_fn,event_constraint)\n",
    "  z0,_ = samplers.heun_integrate2(jit(partial(diffusion.dynamics,scores)),z,timesteps)\n",
    "  return z0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b106DePWTdB4"
   },
   "outputs": [],
   "source": [
    "z0 = f(event_samples[5:6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5HWGMqBLaf8f"
   },
   "outputs": [],
   "source": [
    "z0.std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "paeC7EpuTmbR"
   },
   "outputs": [],
   "source": [
    "J = jax.jacfwd(f)(event_samples[5:6])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-xM4OfbpUBql"
   },
   "outputs": [],
   "source": [
    "std_max = diffusion.sigma(diffusion.tmax)#*prior_scale\n",
    "logpxf = -(z0.reshape(z0.shape[0],-1)**2/std_max**2 + jnp.log(2*np.pi*std_max**2)).sum(-1)/2\n",
    "s,logdet = jnp.linalg.slogdet(J.reshape(240,240))#+1e-3*jnp.eye(240))\n",
    "logpa = logpxf+logdet\n",
    "print(logpa)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "BcqTjxdPefQY"
   },
   "outputs": [],
   "source": [
    "[562.35547]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PeCQYRJ7jh0r"
   },
   "outputs": [],
   "source": [
    "[977.17847]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1dhqL0u5kb_C"
   },
   "outputs": [],
   "source": [
    "jnp.exp(973.3657-977.17847)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "o5SpXjlMUNWW"
   },
   "outputs": [],
   "source": [
    "logpxf,logdet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bAYarUEmCmrp"
   },
   "outputs": [],
   "source": [
    "conditional_logp = samplers.logp(diff,event_scores(diff,score_fn,event_constraint),key,event_samples[5:10],prior_scale,num_probes=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GDXt9kCME3gM"
   },
   "outputs": [],
   "source": [
    "conditional_logp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "LyJYHmKkNpXb"
   },
   "outputs": [],
   "source": [
    "jnp.exp(logp-conditional_logp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cT05TeylN31I"
   },
   "outputs": [],
   "source": [
    "logp2 = samplers.logp(diff,score_fn,key,event_samples[5:10],prior_scale,num_probes=300)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3FBSimymO83_"
   },
   "outputs": [],
   "source": [
    "logp3 = samplers.logp(diff,score_fn,key,event_samples[5:10],prior_scale,num_probes=1000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "_wBjoqgiPSFW"
   },
   "outputs": [],
   "source": [
    "logp3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "RR-evAr8N67_"
   },
   "outputs": [],
   "source": [
    "logp2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Kboc2IIeteh7"
   },
   "outputs": [],
   "source": [
    "stoch_samples = samplers.sde_sample(diff,score_fn,key,x[:400].shape,nsteps=2000,traj=False)\n",
    "sample_traj = samplers.sde_sample(diff,score_fn,key,x[:400].shape,nsteps=2000,traj=True)\n",
    "det_samples = samplers.ode_sample(diff,score_fn,key,x[:400].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "p0ygptpzWU2Z"
   },
   "outputs": [],
   "source": [
    "true_events = (event_constraint(ds.Zs[:,:60])>0)\n",
    "model_events_ode = (event_constraint(det_samples)>0)\n",
    "model_events_sde = (event_constraint(stoch_samples)>0)\n",
    "print(f\"True event rate        {true_events.mean():.3f}+-{true_events.std()/jnp.sqrt(len(true_events)):.3f}\")\n",
    "print(f\"model event rate (ODE) {model_events_ode.mean():.3f}+-{model_events_ode.std()/jnp.sqrt(len(model_events_ode)):.3f}\")\n",
    "print(f\"model event rate (SDE) {model_events_sde.mean():.3f}+-{model_events_sde.std()/jnp.sqrt(len(model_events_sde)):.3f} (2k steps)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7NE2qsfsh1Nr"
   },
   "outputs": [],
   "source": [
    "vals = np.array(statistic(ds.Zs[:,:60]))\n",
    "vals2 = np.array(statistic(event_samples))\n",
    "vals3 = np.array(statistic(ds.Zs[(event_constraint(ds.Zs[:,:60])>0),:60]))\n",
    "vals4 = np.array(statistic(event_samples_det))\n",
    "plt.hist(vals2,bins=30,density=True,alpha=.5)\n",
    "plt.hist(vals4,bins=30,density=True,alpha=.5)\n",
    "plt.hist(vals,bins=30,density=True,alpha=.5)\n",
    "\n",
    "plt.hist(vals3,bins=30,density=True,alpha=.5)\n",
    "plt.legend(['Model x|E (SDE)','Model x|E (ODE)','Data x','Data x|E'])\n",
    "plt.xlim(0,6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "AWNG6T1uXpJB"
   },
   "outputs": [],
   "source": [
    "vals = np.array(statistic(ds.Zs[:,:60]))\n",
    "\n",
    "plt.hist(vals4,bins=30,density=True)\n",
    "plt.hist(vals,bins=30,density=True)\n",
    "vals4 = np.array(statistic(event_samples_det))\n",
    "plt.xlabel('Maximum value over trajectory')\n",
    "plt.ylabel('Frequency')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IUMkKWJuAy2y"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "true_events = ds.Zs[(event_constraint(ds.Zs[:,:60])>0),:60]\n",
    "i=28 # @param {type:\"slider\", min:0, max:30, step:1}\n",
    "#plt.plot(T_long,conditioned_samples[-600::100,i,:,0].T,zorder=0,alpha=.2)\n",
    "plt.plot(T_long,event_samples[i,  :,0].T,label='x|E model sde',zorder=2)\n",
    "plt.plot(T_long,event_samples_det[i,  :,0].T,label='x|E model ode',zorder=2)\n",
    "plt.plot(T_long,true_events[i,:,0],label='gt',alpha=1,zorder=99)\n",
    "#plt.plot(T_long[slc],x[i,slc,0],label='cond',alpha=1,zorder=100,lw=3)\n",
    "\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "#plt.ylim(-3,3)\n",
    "plt.legend()\n",
    "#plt.legend([r'GT',r'Model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fRn9pR8Ewzey"
   },
   "outputs": [],
   "source": [
    "\n",
    "importlib.reload(samplers)\n",
    "importlib.reload(train)\n",
    "#samplers.probability_flow(diff,score_fn,x,1e-4,1.).std()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "y-P8uVpcWP-x"
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "key= jax.random.PRNGKey(38)\n",
    "samplers.compute_nll(diff,score_fn,key,x).mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "8nc8LmPLL_z4"
   },
   "source": [
    "Sample generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5rrpbT2MAWRn"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "i=6 #@param {type:\"slider\", min:0, max:30, step:1}\n",
    "plt.plot(T_long,sample_traj[0::100,i,:,0].T,alpha=1/2)\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "#plt.ylim(-5,5)\n",
    "#plt.legend([r'GT',r'Model'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-LerwGMP9Bla"
   },
   "outputs": [],
   "source": [
    "from jax import vmap\n",
    "n=sample_traj.shape[0]+1\n",
    "ts = (.5+jnp.arange(n)[::-1])[:-1]/n\n",
    "scores = vmap(score_fn)(sample_traj,ts).reshape(sample_traj.shape)\n",
    "best_reconstructions = (sample_traj+diff.sigma(ts)[:,None,None,None]**2*scores)/diff.scale(ts)[:,None,None,None]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xr1eZyO-1IWB"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import matplotlib as mpl\n",
    "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
    "\n",
    "i=4 #@param {type:\"slider\", min:0, max:30, step:1}\n",
    "\n",
    "cmap='inferno'\n",
    "\n",
    "\n",
    "fig1 = plt.figure()\n",
    "ax1 = fig1.add_subplot(111)\n",
    "data = best_reconstructions[100::25,i,:,-1].T\n",
    "ax1.plot(T_long,data[:],alpha=.6,lw=2)\n",
    "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
    "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
    "for i,j in enumerate(ax1.lines):\n",
    "    j.set_color(colors[i])\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "#plt.ylim(-2,2)\n",
    "divider = make_axes_locatable(plt.gca())\n",
    "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
    "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
    "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
    "#cb1.ax.invert_yaxis()\n",
    "cb1.set_label('diffusion time (0,1)')\n",
    "plt.gcf().add_axes(ax_cb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2ohSLTVT9f2M"
   },
   "outputs": [],
   "source": [
    "from scipy.ndimage import correlate1d\n",
    "i=22 #@param {type:\"slider\", min:0, max:30, step:1}\n",
    "vs = -correlate1d(best_reconstructions,np.array([-1,0,1])/2/(ds.T[1]-ds.T[0]),axis=2)\n",
    "print(vs.shape)\n",
    "fig1 = plt.figure()\n",
    "ax1 = fig1.add_subplot(111)\n",
    "data = vs[100::25,i,:,-1].T\n",
    "ax1.plot(T_long,data[:],alpha=.6,lw=2)\n",
    "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
    "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
    "for i,j in enumerate(ax1.lines):\n",
    "    j.set_color(colors[i])\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'$\\dot \\theta$')\n",
    "#plt.ylim(-2,2)\n",
    "divider = make_axes_locatable(plt.gca())\n",
    "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
    "norm = mpl.colors.Normalize(vmin=ts[100], vmax=ts[-25])    \n",
    "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
    "#cb1.ax.invert_yaxis()\n",
    "cb1.set_label('diffusion time (0,1)')\n",
    "plt.gcf().add_axes(ax_cb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fiq7MGQm-1x-"
   },
   "outputs": [],
   "source": [
    "i=15 # @param {type:\"slider\", min:0, max:30, step:1}\n",
    "nn = sample_traj.shape[2]\n",
    "fft = jnp.abs(np.fft.rfft(sample_traj,axis=2))#[:,:,:nn//2]\n",
    "freq = np.fft.rfftfreq(sample_traj.shape[2],d=(ds.T[1]-ds.T[0]))#[:nn//2]\n",
    "\n",
    "fig1 = plt.figure()\n",
    "ax1 = fig1.add_subplot(111)\n",
    "data = fft[0::25,i,:,-1].T\n",
    "ax1.plot(freq,data[:,:],alpha=.6,lw=2)\n",
    "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
    "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
    "for i,j in enumerate(ax1.lines):\n",
    "    j.set_color(colors[i])\n",
    "plt.xlabel('Frequency f')\n",
    "plt.ylabel(r'Fourier spectrum')\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "#plt.ylim(-2,2)\n",
    "divider = make_axes_locatable(plt.gca())\n",
    "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
    "norm = mpl.colors.Normalize(vmin=ts[0], vmax=ts[-25])    \n",
    "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
    "#cb1.ax.invert_yaxis()\n",
    "cb1.set_label('diffusion time (0,1)')\n",
    "plt.gcf().add_axes(ax_cb)\n",
    "ax1.plot(freq,jnp.abs(np.fft.rfft(x,axis=1))[::10,:,-1].T,color='blue',alpha=.1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "GSCNVbsV-zlI"
   },
   "outputs": [],
   "source": [
    "i=8 # @param {type:\"slider\", min:0, max:30, step:1}\n",
    "nn = best_reconstructions.shape[2]\n",
    "fft = jnp.abs(np.fft.rfft(best_reconstructions,axis=2))#[:,:,:nn//2]\n",
    "freq = np.fft.rfftfreq(best_reconstructions.shape[2],d=(ds.T[1]-ds.T[0]))#[:nn//2]\n",
    "\n",
    "fig1 = plt.figure()\n",
    "ax1 = fig1.add_subplot(111)\n",
    "data = fft[100::25,i,:,-1].T\n",
    "ax1.plot(freq,data[:,:],alpha=.6,lw=2)\n",
    "colors=list(mpl.cm.get_cmap(cmap)(np.linspace(0,1,len(ax1.lines))))\n",
    "#colors = [colors(i) for i in np.linspace(0, 1,len(ax1.lines))]\n",
    "for i,j in enumerate(ax1.lines):\n",
    "    j.set_color(colors[i])\n",
    "plt.xlabel('Frequency f')\n",
    "plt.ylabel(r'Fourier spectrum')\n",
    "plt.yscale('log')\n",
    "plt.xscale('log')\n",
    "#plt.ylim(-2,2)\n",
    "divider = make_axes_locatable(plt.gca())\n",
    "ax_cb = divider.new_horizontal(size=\"5%\", pad=0.05)\n",
    "norm = mpl.colors.Normalize(vmax=ts[100], vmin=ts[-25])    \n",
    "cb1 = mpl.colorbar.ColorbarBase(ax_cb, cmap=mpl.cm.get_cmap(f'{cmap}_r'), orientation='vertical',norm=norm)\n",
    "#cb1.ax.invert_yaxis()\n",
    "cb1.set_label('diffusion time (0,1)')\n",
    "plt.gcf().add_axes(ax_cb)\n",
    "ax1.plot(freq,jnp.abs(np.fft.rfft(x,axis=1))[::10,:,-1].T,color='blue',alpha=.1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Fa5_fj-mU_zw"
   },
   "outputs": [],
   "source": [
    "\n",
    "import matplotlib.pyplot as plt\n",
    "i=4 # @param {type:\"slider\", min:0, max:30, step:1}\n",
    "plt.plot(T_long,test_x[i,:,-1])\n",
    "#plt.plot(T_long,det_samples[i,:,-1])\n",
    "plt.plot(T_long,stoch_samples[i,:,-1])\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "plt.legend([r'GT',r'Model (SDE)'])#r'Model (ODE)', r'Model (SDE)'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QFHO0tDBL42D"
   },
   "source": [
    "Test ability to condition model on previous timesteps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "21fgx9UbckBK"
   },
   "outputs": [],
   "source": [
    "conditioned_sample = samplers.ode_sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xKR36dO6FgwA"
   },
   "outputs": [],
   "source": [
    "from jax import jit,vmap,random\n",
    "\n",
    "@jit\n",
    "def rel_err(z1,z2):\n",
    "  return jnp.abs((jnp.abs(z1-z2)).sum(-1)/(jnp.abs(z1).sum(-1)*jnp.abs(z2).sum(-1)))\n",
    "\n",
    "gt = x[:30]\n",
    "for pred in [conditioned_samples[-1],conditioned_sample]:\n",
    "  clamped_errs = jax.lax.clamp(1e-5,rel_err(pred,gt),np.inf)\n",
    "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
    "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
    "  plt.plot(T_long,rel_errs)\n",
    "  plt.fill_between(T_long, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
    "\n",
    "plt.plot()\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Prediction Error')\n",
    "plt.legend(['SDE completion','ODE completion'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3Xfy7NxAfQ8S"
   },
   "outputs": [],
   "source": [
    "i=7 # @param {type:\"slider\", min:0, max:29, step:1}\n",
    "plt.plot(T_long,x[i,:,1])\n",
    "plt.plot(T_long[slc],x[i,slc,1],lw=3)\n",
    "plt.plot(T_long,conditioned_sample[i,:,1])\n",
    "plt.xlabel('Time t')\n",
    "plt.ylabel(r'State')\n",
    "plt.legend([r'GT','Conditioning',r'Model'])\n",
    "#plt.ylim(-3,3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mMvSdxUUG9Op"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mc0IHMpaDrMR"
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fAyaiDRrMDd5"
   },
   "source": [
    "Unconditional Prediction quality"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Ec8x-pzpwIUg"
   },
   "outputs": [],
   "source": [
    "# stoch_samples = samplers.sde_sample(diff,score_fn,key,x[:30].shape,nsteps=1000,traj=False)\n",
    "# det_samples = samplers.ode_sample(diff,score_fn,key,x[:30].shape)\n",
    "print(f'ODE performance {pmetric(det_samples)[0]}')\n",
    "print(f'SDE performance {pmetric(stoch_samples)[0]}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fg2JsQacJrBV"
   },
   "outputs": [],
   "source": [
    "from jax import random\n",
    "key = random.PRNGKey(45)\n",
    "#s=s2#,history = samplers.sde_sampler(denoiser,params,key,(32,)+data.shape[1:],nsteps=500,smin=sigma_min,smax=sigma_max)\n",
    "s = stoch_samples#energy_samples_det#stoch_samples\n",
    "\n",
    "k = 5\n",
    "z = q = s[:,k:]\n",
    "T = T_long[k:]\n",
    "z0 = z[:,0]\n",
    "z_gts = vmap(ds.integrate,(0,None),0)(z0,T)\n",
    "z_pert = vmap(ds.integrate,(0,None),0)(z0+1e-3*np.random.randn(*z0.shape)*jnp.abs(z0).mean(),T)\n",
    "z_random = vmap(ds.integrate,(0,None),0)(ds.sample_initial_conditions(z0.shape[0]),T)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "b-DH4XmZMwWX"
   },
   "outputs": [],
   "source": [
    "for pred in [z,z_pert,z_random]:\n",
    "  clamped_errs = jax.lax.clamp(1e-3,rel_err(pred,z_gts),np.inf)\n",
    "  rel_errs = np.exp(jnp.log(clamped_errs).mean(0))\n",
    "  rel_stds = np.exp(jnp.log(clamped_errs).std(0))\n",
    "  plt.plot(T,rel_errs)\n",
    "  plt.fill_between(T, rel_errs/rel_stds, rel_errs*rel_stds,alpha=.1)\n",
    "\n",
    "plt.plot()\n",
    "plt.yscale('log')\n",
    "plt.xlabel('Time')\n",
    "plt.ylabel('Prediction Error')\n",
    "plt.legend(['Diffusion Model Rollout','1e-3 Perturbed GT','Random Init'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5tQ0FJA2M94F"
   },
   "source": [
    "Compared trajectories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SHKOgFXNJxY6"
   },
   "outputs": [],
   "source": [
    "for i in range(20):\n",
    "  fig = plt.figure()\n",
    "  ax = fig.add_subplot(1, 1, 1)\n",
    "  line1, = ax.plot(T,z_gts[i,:,:2].sum(-1))\n",
    "  line2, = ax.plot(T,z[i,:,:2].sum(-1))\n",
    "  line3, = ax.plot(T,z_pert[i,:,:2].sum(-1))\n",
    "  plt.xlabel('Time t')\n",
    "  plt.ylabel(r'State')\n",
    "  plt.legend(['gt','model','pert'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KEMilLlTJyzM"
   },
   "outputs": [],
   "source": [
    "for i in range(10):\n",
    "  fig = plt.figure()\n",
    "  ax = fig.add_subplot(1, 1, 1)\n",
    "  line1, = ax.plot(T,z_gts[i,:,0])\n",
    "  line2, = ax.plot(T,z[i,:,0])\n",
    "  line3, = ax.plot(T,z_gts[i,:,-1])\n",
    "  line5, = ax.plot(T,z[i,:,-1])\n",
    "  plt.xlabel('Time t')\n",
    "  plt.ylabel(r'State')\n",
    "  plt.legend(['gt0','model0','gt3','model3'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "1tEXOhfcMmho"
   },
   "outputs": [],
   "source": [
    "metric_vals =[]\n",
    "metric_stds = []\n",
    "Ns = [25,50,100,200,500,1000,2000]\n",
    "for N in Ns:\n",
    "  s = samplers.sde_sample(diff,score_fn,key,x[:30].shape,nsteps=N)\n",
    "  mean,std = pmetric(s)\n",
    "  metric_vals.append(mean)\n",
    "  metric_stds.append(std)\n",
    "metric_vals = np.array(metric_vals)\n",
    "metric_stds = np.array(metric_stds)\n",
    "\n",
    "plt.plot(Ns,metric_vals)\n",
    "plt.fill_between(Ns, metric_vals/metric_stds, metric_vals*metric_stds,alpha=.3)\n",
    "plt.xlabel('Sampler steps')\n",
    "plt.ylabel('Pmetric value')\n",
    "plt.xscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Qt2vnvKFoBuo"
   },
   "outputs": [],
   "source": [
    "from jax import grad,jit\n",
    "condition_amount = 10# @param {type:\"slider\", min:0, max:50, step:1}\n",
    "mb = x[:30,:]\n",
    "data_std = x.std()\n",
    "\n",
    "def inpainting_scores(diffusion,scorefn,observed_values,slc):\n",
    "  b,n,c = observed_values.shape\n",
    "  def conditioned_scores(xt,t):\n",
    "    unflat_xt = xt.reshape(b,-1,c)\n",
    "\n",
    "    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)\n",
    "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
    "    combined_score = unobserved_score.at[:,slc].set(observed_score)\n",
    "    return combined_score\n",
    "  return conditioned_scores\n",
    "\n",
    "def inpainting_scores2(diffusion,scorefn,observed_values,slc,scale=300.):\n",
    "  b,n,c = observed_values.shape\n",
    "  def conditioned_scores(xt,t):\n",
    "    unflat_xt = xt.reshape(b,-1,c)\n",
    "\n",
    "    observed_score = diffusion.noise_score(unflat_xt[:,slc],observed_values,t)\n",
    "    unobserved_score = scorefn(xt,t).reshape(b,-1,c)\n",
    "    def constraint(xt):\n",
    "      one_step_xhat = (xt+diffusion.sigma(t)**2*scorefn(xt,t))/diffusion.scale(t)\n",
    "      return jnp.sum((one_step_xhat.reshape(b,-1,c)[:,slc]-observed_values)**2)\n",
    "    #unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*10/(diff.g2(t)/2)\n",
    "    unobserved_score -= grad(constraint)(xt).reshape(unflat_xt.shape)*scale*diff.scale(t)**2/diff.sigma(t)**2\n",
    "    combined_score = unobserved_score.at[:,slc].set(observed_score)\n",
    "    return combined_score#.reshape(-1)\n",
    "  return jit(conditioned_scores)\n",
    "\n",
    "slc = slice(condition_amount)\n",
    "conditioned_samples = samplers.sde_sample(diff,inpainting_scores2(diff,score_fn,mb[:,slc],slc),key,mb.shape,nsteps=1000,traj=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7ZmPVppPCsJ7"
   },
   "outputs": [],
   "source": [
    "k=30\n",
    "expanded = (mb[None]+jnp.zeros((k,1,1,1))).reshape(mb.shape[0]*k,*mb.shape[1:])#[:,slc]\n",
    "predictions = samplers.sde_sample(diff,inpainting_scores2(diff,score_fn,expanded[:,slc],slc,scale=300.),key,expanded.shape,nsteps=2000,traj=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "KwdQuaQMCvdJ"
   },
   "outputs": [],
   "source": [
    "preds = predictions.reshape(k,-1,*predictions.shape[1:])\n",
    "lower = np.percentile(preds.mean(-1),10,axis=0)\n",
    "upper = np.percentile(preds.mean(-1),90,axis=0)\n",
    "for i in range(mb.shape[0]):\n",
    "  if i>30: break\n",
    "  plt.plot(T_long,mb[i].mean(-1))\n",
    "  #plt.plot(T_long,z_pert[i].mean(-1))\n",
    "  plt.fill_between(T_long,lower[i],upper[i],alpha=.3,color='y')\n",
    "  plt.plot()\n",
    "  #plt.yscale('log')\n",
    "  plt.xlabel('Time')\n",
    "  plt.ylabel('State sum')\n",
    "  plt.legend(['Ground Truth','Model 10-90 percentiles'])\n",
    "  plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "tOUZBfNiCwqO"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "last_runtime": {
    "build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu",
    "kind": "private"
   },
   "name": "fitzhugh.ipynb",
   "private_outputs": true,
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
