{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6fzlwhtU7Wlh"
      },
      "outputs": [],
      "source": [
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6ThikF_GpKd6"
      },
      "outputs": [],
      "source": [
        "# from colabtools import adhoc_import\n",
        "import xarray\n",
        "import functools\n",
        "import jax\n",
        "#jax.config.update('jax_platform_name', 'cpu')\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import jax_cfd.base as cfd\n",
        "import jax_cfd.spectral.utils as spectral_utils\n",
        "import jax_cfd.spectral.equations as spectral_equations\n",
        "import jax_cfd.spectral.time_stepping as spectral_stepping\n",
        "from jax_cfd.base import grids\n",
        "\n",
        "#   import jax_cfd.ml as ml\n",
        "\n",
        "#   equations = adhoc_import.Reload(equations, reset_flags=True)\n",
        "#   utils = adhoc_import.Reload(utils, reset_flags=True)\n",
        "\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "from jax.experimental.ode import odeint\n",
        "from jax import grad, jit, vmap, jacfwd, jvp, vjp\n",
        "from jax import random\n",
        "from tqdm.auto import tqdm\n",
        "from scipy.ndimage import correlate1d\n",
        "from jax.scipy.signal import correlate\n",
        "import numpy as np"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T1QQs6hpsr5U"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xnrfyyTdjS1t"
      },
      "outputs": [],
      "source": [
        " def truncated_rfft(u):\n",
        "  \"\"\"Applies the 2/3 rule by truncating higher Fourier modes.\n",
        "\n",
        "  Args:\n",
        "    u: the real-space representation of the input signal\n",
        "\n",
        "  Returns:\n",
        "    Downsampled version of `u` in rfft-space.\n",
        "  \"\"\"\n",
        "  uhat = jnp.fft.rfft(u)\n",
        "  k, = uhat.shape\n",
        "  final_size = int(np.ceil(2 / 3 * k))# + 1\n",
        "  return 2 / 3 * uhat[:final_size]\n",
        "\n",
        "\n",
        "def padded_irfft(uhat):\n",
        "  n, = uhat.shape\n",
        "  final_shape = int(np.floor(3 / 2 * n))\n",
        "  smoothed = jnp.pad(uhat, (0, final_shape - n))\n",
        "  assert smoothed.shape == (final_shape,), \"incorrect padded shape\"\n",
        "  return (3/2) * jnp.fft.irfft(smoothed)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_soiG1in85W3"
      },
      "outputs": [],
      "source": [
        "import dataclasses\n",
        "#from jax_cfd.base import boundaries\n",
        "@dataclasses.dataclass\n",
        "class NonlinearSchrodinger(spectral_stepping.ImplicitExplicitODE):\n",
        "  \"\"\"Nonlinear schrodinger equation split in implicit and explicit parts.\n",
        "\n",
        "  The NLS equation is\n",
        "    psi_t = -i psi_xx/8 - i|psi|^2 psi/2 - psi_x/2 \n",
        "\n",
        "  Attributes:\n",
        "    grid: underlying grid of the process\n",
        "    smooth: smooth the non-linear term using the 3/2-rule\n",
        "  \"\"\"\n",
        "  grid: grids.Grid\n",
        "  smooth: bool = True\n",
        "\n",
        "  def __post_init__(self):\n",
        "    self.kx, = self.grid.rfft_axes()\n",
        "    self.two_pi_i_k = 2j * jnp.pi * self.kx\n",
        "    diffusive_term = -self.two_pi_i_k**2/8# +self.two_pi_i_k**2/2#-self.two_pi_i_k**2/8 \n",
        "    self.diffusive_term = jnp.concatenate([diffusive_term,diffusive_term])\n",
        "    self.advection_term = -jnp.concatenate([self.two_pi_i_k,self.two_pi_i_k])/2\n",
        "    self.rfft = truncated_rfft if self.smooth else jnp.fft.rfft\n",
        "    self.irfft = padded_irfft if self.smooth else jnp.fft.irfft\n",
        "\n",
        "  def mul_i(self, psi):\n",
        "    \"\"\" multiply the state by i\"\"\"\n",
        "    N = len(psi)//2\n",
        "    real, imag = psi[:N], psi[N:]\n",
        "    return jnp.concatenate([-imag, real])\n",
        "\n",
        "  def explicit_terms(self, psihat):\n",
        "    \"\"\"Non-linear parts of the equation.\"\"\"\n",
        "    N = len(psihat)//2\n",
        "    uhat,vhat = psihat[:N],psihat[N:]\n",
        "    u = self.irfft(uhat)\n",
        "    v = self.irfft(vhat)\n",
        "    psi_squared = (u**2+v**2)\n",
        "    cubic_real = self.rfft(psi_squared*u)\n",
        "    cubic_imag = self.rfft(psi_squared*v)\n",
        "    ipsi_cubed_hat = self.mul_i(jnp.concatenate([cubic_real,cubic_imag]))\n",
        "    return -ipsi_cubed_hat/2#ipsi_cubed_hat#-ipsi_cubed_hat/2#+self.advection_term*psihat\n",
        "\n",
        "  def implicit_terms(self, psihat):\n",
        "    \"\"\"Linear parts of the equation, namely `i psi_xx/2`.\"\"\"\n",
        "    return self.diffusive_term*self.mul_i(psihat)\n",
        "\n",
        "  def implicit_solve(self, psihat, time_step):\n",
        "    \"\"\"Solves for `implicit_terms`, implicitly. \n",
        "        Implements (1-idtA)^-1 = (1+idtA)/(1+dt^2A^2) where A is the invertible\n",
        "        implicit terms. Must be done via conjugate because i is a matrix.\"\"\"\n",
        "    ipsihat = self.mul_i(psihat)\n",
        "    numerator = psihat + time_step*self.diffusive_term*self.mul_i(psihat)\n",
        "    denominator = 1+(time_step*self.diffusive_term)**2\n",
        "    return numerator/denominator\n",
        "\n",
        "@dataclasses.dataclass\n",
        "class ModifiedNonlinearSchrodinger(spectral_stepping.ImplicitExplicitODE):\n",
        "  \"\"\"Nonlinear schrodinger equation split in implicit and explicit parts.\n",
        "\n",
        "  The MNLS equation is\n",
        "    psi_t = -i psi_xx/8 - i|psi|^2 psi/2 - psi_x/2 + HOT\n",
        "\n",
        "  Attributes:\n",
        "    grid: underlying grid of the process\n",
        "    smooth: smooth the non-linear term using the 3/2-rule\n",
        "  \"\"\"\n",
        "  grid: grids.Grid\n",
        "  smooth: bool = True\n",
        "\n",
        "  def __post_init__(self):\n",
        "    self.kx, = self.grid.rfft_axes()\n",
        "    self.two_pi_i_k = 2j * jnp.pi * self.kx\n",
        "    self.doubled = jnp.concatenate([self.two_pi_i_k,self.two_pi_i_k])\n",
        "    implicit_term = -self.two_pi_i_k**2/8\n",
        "    self.implicit_term = jnp.concatenate([implicit_term,implicit_term])\n",
        "    self.rfft = truncated_rfft if self.smooth else jnp.fft.rfft\n",
        "    self.irfft = padded_irfft if self.smooth else jnp.fft.irfft\n",
        "\n",
        "  def mul_i(self, psi):\n",
        "    \"\"\" multiply the state by i\"\"\"\n",
        "    N = len(psi)//2\n",
        "    real, imag = psi[:N], psi[N:]\n",
        "    return jnp.concatenate([-imag, real])\n",
        "\n",
        "  def explicit_terms(self, psihat):\n",
        "    \"\"\"Non-linear parts of the equation,.\"\"\"\n",
        "    N = len(psihat)//2\n",
        "    uhat,vhat = psihat[:N],psihat[N:]\n",
        "    u = self.irfft(uhat)\n",
        "    v = self.irfft(vhat)\n",
        "    psi_squared = (u**2+v**2)\n",
        "    uterm = self.rfft(psi_squared*v)\n",
        "    vterm = self.rfft(-psi_squared*u)\n",
        "    cubic = jnp.concatenate([uterm,vterm])/2\n",
        "    dispersion = (psihat*self.doubled**3)/16\n",
        "    \n",
        "    dx_u = self.irfft(uhat*self.two_pi_i_k)\n",
        "    dx_v = self.irfft(vhat*self.two_pi_i_k)\n",
        "    transport_a_real = self.rfft(-(3/2)*psi_squared*dx_u)\n",
        "    transport_a_imag = self.rfft(-(3/2)*psi_squared*dx_v)\n",
        "    transport_a = jnp.concatenate([transport_a_real,transport_a_imag])\n",
        "    transport_b_real = self.rfft(-psi_squared*dx_u/4)\n",
        "    transport_b_imag = self.rfft(psi_squared*dx_v/4)\n",
        "    transport_b = jnp.concatenate([transport_b_real,transport_b_imag])\n",
        "    dx_potential = -self.rfft(psi_squared)*jnp.abs(self.kx)/2\n",
        "    potential_term_real = self.rfft(self.irfft(dx_potential)*v)\n",
        "    potential_term_imag = self.rfft(-self.irfft(dx_potential)*u)\n",
        "    potential_term = jnp.concatenate([potential_term_real,potential_term_imag])\n",
        "    return cubic+dispersion+transport_a+transport_b+potential_term\n",
        "\n",
        "  def implicit_terms(self, psihat):\n",
        "    \"\"\"Linear parts of the equation, namely `i psi_xx/2`.\"\"\"\n",
        "    return self.implicit_term*self.mul_i(psihat)\n",
        "\n",
        "  def implicit_solve(self, psihat, time_step):\n",
        "    \"\"\"Solves for `implicit_terms`, implicitly. \n",
        "        Implements (1-idtA)^-1 = (1+idtA)/(1+dt^2A^2) where A is the invertible\n",
        "        implicit terms. Must be done via conjugate because i is a matrix.\"\"\"\n",
        "    ipsihat = self.mul_i(psihat)\n",
        "    numerator = psihat + time_step*self.implicit_term*self.mul_i(psihat)\n",
        "    denominator = 1+(time_step*self.implicit_term)**2\n",
        "    return numerator/denominator"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "l9IpO2CAWa5w"
      },
      "outputs": [],
      "source": [
        "jnp.fft.fftfreq(99)[int(np.ceil(99/2)):]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "G_O-XkIgZZEY"
      },
      "outputs": [],
      "source": [
        "jnp.fft.fftfreq(99)#.sum()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8-6jHliUZYIY"
      },
      "outputs": [],
      "source": [
        "int(np.ceil(99/2))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VOtNXfcJ5x5_"
      },
      "outputs": [],
      "source": [
        "fft_truncated_2x()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "udNMj0cJ56av"
      },
      "outputs": [],
      "source": [
        "z = jnp.fft.fft(u0)\n",
        "z1 = ifft_padded_2x(z)\n",
        "z2 = fft_truncated_2x(z1)\n",
        "jnp.abs(z2-z).mean()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gJ5NS4npWNMA"
      },
      "outputs": [],
      "source": [
        " def fft_truncated_2x(u):\n",
        "  \"\"\"Applies the 1/2 rule by truncating higher Fourier modes.\n",
        "\n",
        "  Args:\n",
        "    u: the (complex) input signal\n",
        "\n",
        "  Returns:\n",
        "    Downsampled version of `u` in fft-space.\n",
        "  \"\"\"\n",
        "  uhat = jnp.fft.fftshift(jnp.fft.fft(u))\n",
        "  k, = uhat.shape\n",
        "  final_size = (k+1)//2#int(np.ceil(k/2))# + 1\n",
        "  # shifted_freq = jnp.fft.fftshift(jnp.fft.fftfreq(k))\n",
        "  # #print('sfreq',shifted_freq)\n",
        "  # clipped_sfreq = shifted_freq[final_size//2:(-final_size+1)//2]\n",
        "  # print('clipped_sfreq',clipped_sfreq)\n",
        "  # #print('out_sfreq',jnp.fft.ifftshift(clipped_sfreq))\n",
        "  return jnp.fft.ifftshift(uhat[final_size//2:(-final_size+1)//2])/2\n",
        "\n",
        "\n",
        "def ifft_padded_2x(uhat):\n",
        "  n, = uhat.shape\n",
        "  final_size = n+2*(n//2)\n",
        "  added = n//2\n",
        "  smoothed = jnp.pad(jnp.fft.fftshift(uhat), (added, added))\n",
        "  assert smoothed.shape == (final_size,), \"incorrect padded shape\"\n",
        "  return 2 * jnp.fft.ifft(jnp.fft.ifftshift(smoothed))\n",
        "\n",
        "@dataclasses.dataclass\n",
        "class NLS(spectral_stepping.ImplicitExplicitODE):\n",
        "  \"\"\"Nonlinear schrodinger equation split in implicit and explicit parts.\n",
        "\n",
        "  The NLS equation is\n",
        "    psi_t = -i psi_xx/8 - i|psi|^2 psi/2 - psi_x/2 \n",
        "\n",
        "  Attributes:\n",
        "    grid: underlying grid of the process\n",
        "    smooth: smooth the non-linear term using the 3/2-rule\n",
        "  \"\"\"\n",
        "  grid: grids.Grid\n",
        "  smooth: bool = True\n",
        "\n",
        "  def __post_init__(self):\n",
        "    self.kx, = self.grid.fft_axes()\n",
        "    self.two_pi_i_k = 2j * jnp.pi * self.kx\n",
        "    self.fft = fft_truncated_2x if self.smooth else jnp.fft.fft\n",
        "    self.ifft = ifft_padded_2x if self.smooth else jnp.fft.ifft\n",
        "\n",
        "  def explicit_terms(self, psihat):\n",
        "    \"\"\"Non-linear parts of the equation.\"\"\"\n",
        "    psi = self.ifft(psihat)\n",
        "    ipsi_cubed = 1j*psi*jnp.abs(psi)**2\n",
        "    ipsi_cubed_hat = self.fft(ipsi_cubed)\n",
        "    return -ipsi_cubed_hat/2\n",
        "\n",
        "  def implicit_terms(self, psihat):\n",
        "    \"\"\"Linear parts of the equation, namely `-i psi_xx/2`.\"\"\"\n",
        "    return -1j*psihat*self.two_pi_i_k**2/8\n",
        "\n",
        "  def implicit_solve(self, psihat, time_step):\n",
        "    \"\"\"Solves for `implicit_terms`, implicitly. \n",
        "        Implements (1-idtA)^-1 = (1+idtA)/(1+dt^2A^2) where A is the invertible\n",
        "        implicit terms. Must be done via conjugate because i is a matrix.\"\"\"\n",
        "    return psihat/(1-time_step*(-1j*self.two_pi_i_k**2/8))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8-XLd72Sc7SK"
      },
      "outputs": [],
      "source": [
        "def rollout(stepfn,steps,u0,max_samples=1024):\n",
        "   multistepfn = jit(cfd.funcutils.repeated(stepfn,max(steps//max_samples,1)))\n",
        "   return cfd.funcutils.trajectory(multistepfn,max_samples,start_with_input=True)(u0)\n",
        "\n",
        "\n",
        "def solve(u0, t_final=1., max_samples=1024,dt=1e-2,L=500):\n",
        "  N = len(u0)\n",
        "  grid = grids.Grid((N,),domain=((-L/2,L/2),))\n",
        "  dx, = grid.step\n",
        "  xs, = grid.axes(offset=(0,))\n",
        "  eq = NLS(grid=grid)\n",
        "  stepfn = spectral_stepping.crank_nicolson_rk4(eq,dt)\n",
        "  #stepfn = spectral_stepping.imex_runge_kutta(eq,dt)\n",
        "  uhat0 = jnp.fft.fft(u0)\n",
        "  #print(stepfn(uhat0))\n",
        "  numsteps = int(t_final/dt)\n",
        "  steps,uhat_traj = rollout(stepfn,numsteps,uhat0,max_samples)\n",
        "  #print(uhat_traj.shape,steps[0].shape,steps[1].shape)\n",
        "  u_traj = jax.vmap(jnp.fft.ifft)(uhat_traj)\n",
        "  #timesteps = steps*dt#\n",
        "  timesteps = (jnp.arange(min(max_samples,numsteps)))*dt*max(numsteps//max_samples,1)\n",
        "  return u_traj,xs,timesteps\n",
        "\n",
        "L=500#256*np.pi\n",
        "N=2**4+1\n",
        "#x = (np.arange(N)/N)*L\n",
        "eps=.05\n",
        "sig=.1\n",
        "\n",
        "dx=L/N\n",
        "k = np.fft.fftfreq(N, d=dx)\n",
        "eps=.05\n",
        "sigma=.01 # original .01\n",
        "dk = k[1]-k[0]\n",
        "u0_spectrum = (eps**2)*np.exp(-(k/sigma)**2/2)/np.sqrt(2*np.pi*sigma**2)\n",
        "u0_phase = np.exp(2*np.pi*np.random.rand(N)*1j)\n",
        "u0 = np.fft.ifft(u0_phase*np.sqrt(2*dk*u0_spectrum),norm='forward')\n",
        "dt=1e-2\n",
        "N = len(u0)\n",
        "grid = grids.Grid((N,),domain=((-L/2,L/2),))\n",
        "dx, = grid.step\n",
        "xs, = grid.axes(offset=(0,))\n",
        "eq = NLS(grid=grid)\n",
        "stepfn = spectral_stepping.crank_nicolson_rk2(eq,dt)\n",
        "#stepfn = spectral_stepping.imex_runge_kutta(eq,dt)\n",
        "uhat0 = jnp.fft.fft(u0)\n",
        "stepfn(uhat0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wMWDydbIibs2"
      },
      "outputs": [],
      "source": [
        "\n",
        "\n",
        "# def rollout(stepfn,steps,u0,max_samples=1024):\n",
        "#   stepfn = jit(stepfn)\n",
        "#   u=u0\n",
        "#   out = []\n",
        "#   step_out = []\n",
        "#   for step in range(steps):\n",
        "#     u = stepfn(u)\n",
        "#     if steps\u003c=max_samples or not (step%(steps//max_samples)):\n",
        "#       out.append(u)\n",
        "#       step_out.append(step)\n",
        "#   return jnp.array(step_out),jnp.stack(out,axis=0)\n",
        "\n",
        "def rollout(stepfn,steps,u0,max_samples=1024):\n",
        "   multistepfn = jit(cfd.funcutils.repeated(stepfn,max(steps//max_samples,1)))\n",
        "   return cfd.funcutils.trajectory(multistepfn,max_samples)(u0)\n",
        "\n",
        "\n",
        "def solve(u0, t_final=1., max_samples=1024,dt=1e-2,L=500):\n",
        "  N = len(u0)\n",
        "  grid = grids.Grid((N,),domain=((-L/2,L/2),))\n",
        "  dx, = grid.step\n",
        "  xs, = grid.axes(offset=(0,))\n",
        "  eq = NonlinearSchrodinger(grid=grid)\n",
        "  stepfn = spectral_stepping.crank_nicolson_rk4(eq,dt)\n",
        "  #stepfn = spectral_stepping.imex_runge_kutta(eq,dt)\n",
        "  uhat0_real = jnp.fft.rfft(jnp.real(u0))\n",
        "  n = len(uhat0_real)\n",
        "  uhat0 = jnp.concatenate([uhat0_real,jnp.fft.rfft(jnp.imag(u0))])\n",
        "  numsteps = int(t_final/dt)\n",
        "  steps,uhat_traj = rollout(stepfn,numsteps,uhat0,max_samples)\n",
        "  #print(uhat_traj.shape,steps[0].shape,steps[1].shape)\n",
        "  u_traj_real = jax.vmap(jnp.fft.irfft)(uhat_traj[:,:n])\n",
        "  u_traj_imag = jax.vmap(jnp.fft.irfft)(uhat_traj[:,n:])\n",
        "  #timesteps = steps*dt#\n",
        "  timesteps = (1+jnp.arange(numsteps//max(numsteps//max_samples,1)))*dt*max(numsteps//max_samples,1)\n",
        "  return u_traj_real+1j*u_traj_imag,xs,timesteps"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BG1-Cbl_gmPC"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HnaDDx59goWx"
      },
      "source": [
        "## Random phase initial condition distribution"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K0vCdQf0H8EB"
      },
      "outputs": [],
      "source": [
        "L=500#256*np.pi\n",
        "N=2**11\n",
        "#x = (np.arange(N)/N)*L\n",
        "eps=.05\n",
        "sig=.1\n",
        "\n",
        "dx=L/N\n",
        "k = np.fft.fftfreq(N, d=dx)\n",
        "eps=.05\n",
        "sigma=.01 # original .01\n",
        "dk = k[1]-k[0]\n",
        "u0_spectrum = (eps**2)*np.exp(-(k/sigma)**2/2)/np.sqrt(2*np.pi*sigma**2)\n",
        "u0_phase = np.exp(2*np.pi*np.random.rand(N)*1j)\n",
        "u0 = np.fft.ifft(u0_phase*np.sqrt(2*dk*u0_spectrum),norm='forward')\n",
        "#u0 = .5*(np.mean(np.abs(u0))+u0)\n",
        "plt.plot(jnp.abs(u0))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "06FxzvMBJZ2q"
      },
      "outputs": [],
      "source": [
        "T = 1024#4096\n",
        "dt=1e-2\n",
        "soln,x_ds,t_ds = solve(u0,T,dt=dt,max_samples=T)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PGeiOgiS8sjE"
      },
      "outputs": [],
      "source": [
        "soln.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "KzheB3dRwu27"
      },
      "outputs": [],
      "source": [
        "t_ds.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fj95CimZMoYD"
      },
      "outputs": [],
      "source": [
        "u_ds = jnp.abs(soln)\n",
        "import xarray\n",
        "plt.figure(figsize=(9, 6))\n",
        "xarray.DataArray(\n",
        "    u_ds, dims=[\"time\", \"space\"], coords={\"time\": t_ds[:u_ds.shape[0]], \"space\": x_ds,\n",
        "}).plot.imshow(\n",
        "    cmap=\"RdBu\", robust=True)\n",
        "plt.grid(False)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YFqJ7fRk1pRP"
      },
      "outputs": [],
      "source": [
        "import scipy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "S6CjzopAarEw"
      },
      "outputs": [],
      "source": [
        "a = scipy.ndimage.zoom(soln,(1/32,1/16))\n",
        "donwsampled_u = a"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kX-cBHO5ZQoN"
      },
      "outputs": [],
      "source": [
        "a.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "H--vgzCcZUyR"
      },
      "outputs": [],
      "source": [
        "t_ds.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iVwm8c7FDhLz"
      },
      "outputs": [],
      "source": [
        "\n",
        "import xarray\n",
        "plt.figure(figsize=(9, 6))\n",
        "xarray.DataArray(\n",
        "    jnp.imag(donwsampled_u), dims=[\"time\", \"space\"], coords={\"time\": t_ds[::32], \"space\": x_ds[::16],\n",
        "}).plot.imshow(\n",
        "    cmap=\"RdBu\", robust=True)\n",
        "plt.grid(False)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GGdUybEn1PWn"
      },
      "outputs": [],
      "source": [
        "u_ds[::8,::16].shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4a91fYm9KHfA"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uDnTI6X71L8w"
      },
      "outputs": [],
      "source": [
        "plt.plot(t_ds,jnp.max(jnp.abs(u_ds),axis=-1))\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel('Max wave height')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "16eEUUWVdFAw"
      },
      "outputs": [],
      "source": [
        "plt.plot(t_ds[::32],jnp.max(jnp.abs(a),axis=-1))\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel('Max wave height')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "b6EQgeyJ-U2A"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "#tf.config.experimental.set_visible_devices([], \"GPU\")\n",
        "\n",
        "# from colabtools import adhoc_import\n",
        "import importlib\n",
        "from userdiffusion.config import get_config\n",
        "from userdiffusion.train import train_and_evaluate\n",
        "from userdiffusion import ode_datasets, unet, samplers, diffusion as train\n",
        "importlib.reload(ode_datasets)\n",
        "importlib.reload(unet)\n",
        "importlib.reload(samplers)\n",
        "importlib.reload(train)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8YBs_M_UAXps"
      },
      "outputs": [],
      "source": [
        "jnp.abs(soln).shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2-uiwACyAg_b"
      },
      "outputs": [],
      "source": [
        "test_x = train_x =  jnp.abs(u_ds[::8,::16])#.shape#jnp.abs(u_ds[::64,::32])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "p8nuMdLD03S_"
      },
      "outputs": [],
      "source": [
        "test_x[:64,].shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZZ09ZZxGAR0w"
      },
      "outputs": [],
      "source": [
        "x = test_x[None,:,:,None]#next(dataiter())\n",
        "t = np.random.rand(x.shape[0])\n",
        "model = unet.UNet(unet.unet_64_config(out_dim=x.shape[-1],base_channels=24))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HeIdKxsfKlAO"
      },
      "outputs": [],
      "source": [
        "x.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "s9wXbLQTREa8"
      },
      "outputs": [],
      "source": [
        "from absl import logging\n",
        "#logging.getLogger().setLevel(logging.INFO)\n",
        "logging.get_absl_handler().python_handler.stream = sys.stdout"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WZT6tUz3AvxY"
      },
      "outputs": [],
      "source": [
        "t = np.random.rand(x.shape[0])\n",
        "params = model.init(random.PRNGKey(42), x=x,t=t,train=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EEt6dSwcAN7Z"
      },
      "outputs": [],
      "source": [
        "dataset = tf.data.Dataset.from_tensor_slices(train_x)\n",
        "dataiter = dataset.shuffle(len(dataset)).batch(bs).as_numpy_iterator"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QpAEsgJ_k7w0"
      },
      "outputs": [],
      "source": [
        "import scipy\n",
        "downsample = lambda u: u[...,::2]#scipy.signal.decimate(u,2)\n",
        "soln2,x_ds2,_ = solve(jnp.array(downsample(u0)),T,dt=dt)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8MsaZ573u1UP"
      },
      "outputs": [],
      "source": [
        "errs = jnp.sqrt(jnp.square(downsample(soln)-soln2).mean(-1)/jnp.square(soln).mean(-1))\n",
        "plt.plot(t_ds,errs)\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel('Relative error vs doubled resolution')\n",
        "plt.yscale('log')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1Dl86HhRZIhZ"
      },
      "outputs": [],
      "source": [
        "# RW solution\n",
        "N=2**11\n",
        "L = 40*jnp.pi#*np.sqrt(2)\n",
        "grid = grids.Grid((N,),domain=((-L/2,L/2),))\n",
        "dx, = grid.step\n",
        "xs, = grid.axes(offset=(0,))\n",
        "zs = xs*np.sqrt(2)\n",
        "u0 = (4*zs**2-3)/(1+4*zs**2)\n",
        "\n",
        "tau = 8\n",
        "T = tau*2\n",
        "dt=3e-4\n",
        "soln,x_ds,t_ds = solve(u0,T,dt=dt,L=L)\n",
        "z_ds = x_ds*np.sqrt(2)\n",
        "tau_ds = t_ds/2\n",
        "u_ds = jnp.abs(soln)\n",
        "import xarray\n",
        "plt.figure(figsize=(9, 6))\n",
        "xarray.DataArray(\n",
        "    u_ds, dims=[\"time\", \"space\"], coords={\"time\": t_ds[:u_ds.shape[0]], \"space\": x_ds,\n",
        "}).plot.imshow(\n",
        "    cmap=\"RdBu\", robust=False)\n",
        "plt.grid(False)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6x21x4ThjeS0"
      },
      "outputs": [],
      "source": [
        "print(u_ds.shape,t_ds.shape,x_ds.shape)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_AZzwSduaeKl"
      },
      "outputs": [],
      "source": [
        "\n",
        "gt_soln = jnp.conj((1-4*(1+2j*tau_ds[:,None])/(1+4*(z_ds**2+tau_ds[:,None]**2)))*jnp.exp(1j*tau_ds[:,None]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "N_-yBvtqq5Lm"
      },
      "outputs": [],
      "source": [
        "plt.figure(figsize=(9, 6))\n",
        "xarray.DataArray(\n",
        "    jnp.abs(gt_soln), dims=[\"time\", \"space\"], coords={\"time\": t_ds[:u_ds.shape[0]], \"space\": x_ds,\n",
        "}).plot.imshow(\n",
        "    cmap=\"RdBu\", robust=False)\n",
        "plt.grid(False)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YVowOMJW5HK2"
      },
      "outputs": [],
      "source": [
        "gt_soln.shape"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0y-jzoTWhmuo"
      },
      "outputs": [],
      "source": [
        "plt.plot(t_ds,jnp.abs(soln-gt_soln).mean(-1))\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel('Psi error')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "merTjEGorbsm"
      },
      "outputs": [],
      "source": [
        "jnp.abs(soln-gt_soln).mean()\u003c1e-3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iikOK4LahoHX"
      },
      "outputs": [],
      "source": [
        "plt.plot(t_ds,jnp.abs(jnp.abs(u_ds)-jnp.abs(gt_soln)).mean(-1))\n",
        "plt.yscale('log')\n",
        "plt.xlabel('Time t')\n",
        "plt.ylabel('|Psi| error')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DFVUAaZJa88V"
      },
      "outputs": [],
      "source": [
        "plt.plot(x_ds,jnp.abs(u0))\n",
        "plt.plot(x_ds,jnp.abs(gt_soln[0]))\n",
        "plt.plot(x_ds,jnp.abs(soln[0]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "26rgRMpDaH-z"
      },
      "outputs": [],
      "source": [
        "import xarray\n",
        "plt.figure(figsize=(9, 6))\n",
        "xarray.DataArray(\n",
        "    jnp.abs(gt_soln), dims=[\"time\", \"space\"], coords={\"time\": t_ds[:u_ds.shape[0]], \"space\": x_ds,\n",
        "}).plot.imshow(\n",
        "    cmap=\"RdBu\", robust=False)\n",
        "plt.grid(False)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "TMtdFSyBX2PB"
      },
      "outputs": [],
      "source": [
        "wave = jnp.real(jnp.exp(1j*(x_ds-t_ds[:,None])/30))\n",
        "plt.figure(figsize=(18, 12))\n",
        "xarray.DataArray(\n",
        "    wave, dims=[\"time\", \"space\"], coords={\"time\": t_ds[:u_ds.shape[0]], \"space\": x_ds,\n",
        "}).plot.imshow(\n",
        "    cmap=\"RdBu\", robust=True)\n",
        "plt.grid(False)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T6zXswjdXWJX"
      },
      "outputs": [],
      "source": [
        "u_ds = jnp.abs(soln)\n",
        "import xarray\n",
        "plt.figure(figsize=(9, 6))\n",
        "xarray.DataArray(\n",
        "    u_ds, dims=[\"time\", \"space\"], coords={\"time\": t_ds[:u_ds.shape[0]], \"space\": x_ds,\n",
        "}).plot.imshow(\n",
        "    cmap=\"RdBu\", robust=True)\n",
        "plt.grid(False)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZrqG5zAgXusK"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Pgngsj3VKd7D"
      },
      "outputs": [],
      "source": [
        "plt.imshow(jnp.abs(soln))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2khclllpL_0q"
      },
      "outputs": [],
      "source": [
        "plt.plot(jnp.absolute(u0))\n",
        "plt.plot(jnp.real(u0))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AbJ8kc9IUhqJ"
      },
      "outputs": [],
      "source": [
        "\n",
        "import matplotlib.pyplot as plt\n",
        "from scipy.integrate import solve_ivp\n",
        "N = 1024\n",
        "dx = 500/N\n",
        "x = jnp.arange(N)*dx\n",
        "\n",
        "eps =.5\n",
        "a=.001\n",
        "\n",
        "\n",
        "  #du3dx = correlate1d(u3,ddx,mode='wrap')\n",
        "  #rhs = d2udx2*eps**2/2+u3+1j*a*eps*du3dx\n",
        "  #return rhs/(-1j*eps) \n",
        "\n",
        "# k = np.fft.fftfreq(N, d=dx)\n",
        "# eps=2e1\n",
        "# sigma=.05\n",
        "# u0_spectrum = (eps**2)*np.exp(-(k/sigma)**2/2)/np.sqrt(2*np.pi*sigma**2)\n",
        "# u0_phase = np.exp(2*np.pi*np.random.rand(N)*1j)\n",
        "# u0 = np.fft.ifft(u0_phase*np.sqrt(2*u0_spectrum))\n",
        "\n",
        "k = np.fft.fftfreq(N, d=dx)\n",
        "eps=2.5\n",
        "sigma=.01\n",
        "u0_spectrum = (eps**2)*np.exp(-(k/sigma)**2/2)/np.sqrt(2*np.pi*sigma**2)\n",
        "u0_phase = np.exp(2*np.pi*np.random.rand(N)*1j)\n",
        "u0 = np.fft.ifft(u0_phase*np.sqrt(2*u0_spectrum))\n",
        "\n",
        "ddx = np.array([-1,0,1])/(2*dx)\n",
        "ddx2 = np.array([1,-2,1])/(dx**2)\n",
        "def nls(t,u):\n",
        "  dudx = correlate1d(u,ddx,mode='wrap')\n",
        "  d2udx2 = correlate1d(u,ddx2,mode='wrap')\n",
        "  u3 = (jnp.abs(u)**2)*u\n",
        "  return dudx/2+1j*d2udx2/8+1j*u3/2\n",
        "sol = solve_ivp(nls,(0,200),u0,rtol=1e-6,method='DOP853')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kY5_M9vCHZWg"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IlhE0BOFBmIt"
      },
      "outputs": [],
      "source": [
        "t = sol.t\n",
        "y = sol.y[:,::(len(t)//100)+1]\n",
        "t= t[::len(t)//100+1]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "63BKfZSOiQUT"
      },
      "outputs": [],
      "source": [
        "k=15\n",
        "cs = np.random.rand(k)*300+80\n",
        "ws = np.random.rand(k)*15+10\n",
        "vs = np.random.randn(k)*10\n",
        "rs = np.random.randn(k)*3\n",
        "u0 = .1*sum(jnp.exp(-((x-c)/w)**2/2-v*1j*x/w)*r for c,w,v,r in zip(cs,ws,vs,rs))\n",
        "#u0= jnp.exp(-((x-300)/30)**2/2 - 20*1j*x)*.2\n",
        "plt.plot(np.real(u0))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RLaDqXVOiizG"
      },
      "outputs": [],
      "source": [
        "\n",
        "sol = solve_ivp(mnls,(0,200),u0,rtol=1e-6,method='BDF')#,method='RK23',rtol=1e-3)#,t_eval=jnp.linspace(0,20,10))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZZbT4dQcjWQt"
      },
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZpMFFrRajX21"
      },
      "outputs": [],
      "source": [
        "len(t)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CzeBfYwVlibT"
      },
      "outputs": [],
      "source": [
        "plt.plot(jnp.abs(sol.y[:,76]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pscD2qiujZny"
      },
      "outputs": [],
      "source": [
        "from matplotlib import rc\n",
        "rc('animation', html='jshtml')\n",
        "\n",
        "fig = plt.figure()\n",
        "ax1 = fig.add_subplot(111)\n",
        "line, = ax1.plot(x, u0, c='r', label=r'$|\\psi(x)|$')\n",
        "plt.ylim(-.3,.3)\n",
        "def init():\n",
        "    line.set_data(x, u0)\n",
        "    return [line]\n",
        "\n",
        "def animate(i):\n",
        "    line.set_data(x,jnp.real(y[:,i]))\n",
        "    return [line]\n",
        "\n",
        "from matplotlib import animation\n",
        "anim = animation.FuncAnimation(\n",
        "        fig,\n",
        "        animate,\n",
        "        frames=len(t),\n",
        "        interval=33,\n",
        "        init_func=init,\n",
        "        blit=False)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-GXDFNzwlZEa"
      },
      "outputs": [],
      "source": [
        "anim"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EWMJ0fKUUfEG"
      },
      "outputs": [],
      "source": [
        "plt.plot(k/(k[1]-k[0]))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1Mcpv3AL9tng"
      },
      "outputs": [],
      "source": [
        "L=256*np.pi\n",
        "eps=.05\n",
        "sig=.1\n",
        "N=2**12\n",
        "dx=L/N\n",
        "k = np.fft.fftfreq(N, d=dx)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "asxjdG0z2YCi"
      },
      "outputs": [],
      "source": [
        "# u0RandPhase from https://www.dropbox.com/sh/ov0952luetkpgwr/AAAJyAo94peygtxQyb9_FmAua/%2BIC?dl=0\u0026subfolder_nav_tracking=1\n",
        "eps=.05\n",
        "sigma=.1\n",
        "u0 = 1*np.ones(N)+0j\n",
        "x = (np.arange(N)/N)*L\n",
        "ii = np.arange(1000)+1\n",
        "u0 += np.exp(-(2*np.pi*ii/(L*sigma))**2+1j*(2*np.pi/L)*ii*x[:,None]+2*np.pi*np.random.rand(len(ii))*1j).sum(1)\n",
        "u0 += np.exp(-(2*np.pi*ii/(L*sigma))**2-1j*(2*np.pi/L)*ii*x[:,None]+2*np.pi*np.random.rand(len(ii))*1j).sum(1)\n",
        "u0 = u0*eps/np.sqrt(2*np.pi*sigma**2)\n",
        "#k = np.fft.fftfreq(N, d=dx)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "k3NRuRBn7Clp"
      },
      "outputs": [],
      "source": [
        "# u0GaussSpec from https://www.dropbox.com/sh/ov0952luetkpgwr/AAAJyAo94peygtxQyb9_FmAua/%2BIC?dl=0\u0026preview=u0GaussSpec.m\u0026subfolder_nav_tracking=1\n",
        "eps=.05\n",
        "sigma=.1\n",
        "x = (np.arange(N)/N)*L\n",
        "dkx = 2*np.pi/L\n",
        "k = 6*np.fft.fftfreq(N, d=dx)\n",
        "S = (1+eps*np.random.randn(N))*(eps**2*np.exp(-(k/sigma)**2/2)/np.sqrt(2*np.pi*sigma**2))\n",
        "uhat = np.sqrt(2*dkx*S)*np.exp(2j*np.pi*np.random.rand(N))\n",
        "u0 = np.fft.ifft(uhat)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x3MyBu_L-Bo0"
      },
      "outputs": [],
      "source": [
        "# y0JONSWAP_Hs from https://www.dropbox.com/sh/ov0952luetkpgwr/AAAJyAo94peygtxQyb9_FmAua/%2BIC?dl=0\u0026preview=u0JONSWAP_Hs.m\u0026subfolder_nav_tracking=1\n",
        "\n",
        "k0 = 2*np.pi/200\n",
        "L = 256*np.pi/k0\n",
        "N = 2**int(np.ceil(np.log2(8*k0*L/np.pi)))\n",
        "x = (np.arange(N)/N)*L\n",
        "\n",
        "dx = L/N\n",
        "ksl = k0\n",
        "gamma=5\n",
        "Hs=9.5\n",
        "k = np.fft.fftfreq(N, d=dx)\n",
        "sig0 = .07*(k\u003c=0)+.09*(k\u003e0)\n",
        "S = (k+k0)**(-3) * np.exp(-1.5*(k0/(k+0j))**2)*gamma**np.exp(-((k+0j)-k0)**2/2/(sig0*k0)**2)\n",
        "#print(S.shape)\n",
        "S[k+k0\u003c=0]=0\n",
        "S[np.abs(k)\u003eksl]=0\n",
        "uhat  = np.sqrt(S)*np.exp(2j*np.pi*np.random.rand(len(S)))\n",
        "u = np.fft.ifft(uhat)\n",
        "Hs_0 = 4*np.std(np.real(u))\n",
        "u0 = (Hs/Hs_0)*u"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pmJCroCUcmEa"
      },
      "outputs": [],
      "source": [
        "L=500#256*np.pi\n",
        "N=2**12\n",
        "x = (np.arange(N)/N)*L\n",
        "eps=.05\n",
        "sig=.1\n",
        "\n",
        "dx=L/N\n",
        "k = np.fft.fftfreq(N, d=dx)\n",
        "eps=.05\n",
        "sigma=.01\n",
        "dk = k[1]-k[0]\n",
        "u0_spectrum = (eps**2)*np.exp(-(k/sigma)**2/2)/np.sqrt(2*np.pi*sigma**2)\n",
        "u0_phase = np.exp(2*np.pi*np.random.rand(N)*1j)\n",
        "u0 = np.fft.ifft(u0_phase*np.sqrt(2*dk*u0_spectrum),norm='forward')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B4e9oCqS5BSh"
      },
      "outputs": [],
      "source": [
        "import matplotlib.pyplot as plt\n",
        "\n",
        "plt.figure(figsize=(10,3))\n",
        "plt.plot(x,np.real(u0*np.exp(1j*x)))\n",
        "plt.xlabel('Space')\n",
        "plt.ylim(-.3,.3)\n",
        "plt.xlim(0,L)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aNbve1WCKc2Y"
      },
      "outputs": [],
      "source": [
        "plt.plot(x,np.abs(u0))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1oz7Yl2VE0f3"
      },
      "outputs": [],
      "source": [
        "np.real(u0)[:20]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aKey0F9eFS7E"
      },
      "outputs": [],
      "source": [
        "\n",
        "plt.plot(k,np.abs(uhat))\n",
        "plt.xlabel('angular frequency k')\n",
        "plt.yscale('log')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h1G70D9LfFtw"
      },
      "outputs": [],
      "source": [
        "\n",
        "plt.plot(k,np.abs(np.fft.fft(u0)))\n",
        "plt.xlabel('angular frequency k')\n",
        "plt.yscale('log')\n",
        "#plt.xlim(-2,2)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jPELuzgg1mBV"
      },
      "outputs": [],
      "source": [
        "np.fft.fft(u0)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pr5i1OhDfYq-"
      },
      "outputs": [],
      "source": [
        "plt.plot(np.real(out))\n",
        "plt.plot(np.imag(out))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UO6ysiGLgW8E"
      },
      "outputs": [],
      "source": [
        "# from colabtools import adhoc_import\n",
        "import importlib\n",
        "import jax_cfd.base as cfd\n",
        "import jax_cfd.spectral.utils as spectral_utils\n",
        "import jax_cfd.spectral.equations as spectral_equations\n",
        "import jax_cfd.spectral.time_stepping as spectral_stepping\n",
        "from jax_cfd.base import grids\n",
        "from jax_cfd.spectral.equations_test import EquationsTest1D"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YqlkviEv5ms-"
      },
      "outputs": [],
      "source": [
        "t = EquationsTest1D()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yHYNnizn5sBA"
      },
      "outputs": [],
      "source": [
        "t.test_nls_equation()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rD5TqePG5ugM"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "name": "nonlinear_schrodinger.ipynb",
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
