{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 20001/20001 [00:08<00:00, 2423.05it/s]\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "from tqdm import trange\n",
    "\n",
    "# Parameters & Grid Setup\n",
    "L = 10.0            # half-length of the domain\n",
    "N = 256             # number of spatial grid points\n",
    "dx = 2*L / N\n",
    "x  = np.linspace(-L, L - dx, N)\n",
    "\n",
    "# Wave numbers for the Fourier domain\n",
    "k = (2*np.pi) * np.fft.fftfreq(N, d=dx)\n",
    "k2 = k**2\n",
    "k4 = k**4\n",
    "\n",
    "def initial_u(x):\n",
    "    return 0.5 * np.exp(-x**2)\n",
    "\n",
    "def initial_v(x):\n",
    "    return np.zeros_like(x)\n",
    "\n",
    "u = initial_u(x)\n",
    "v = initial_v(x)\n",
    "\n",
    "dt = 0.001\n",
    "tmax = 20.0\n",
    "num_steps = int(tmax/dt)\n",
    "record_interval = 1\n",
    "\n",
    "def dx(f):\n",
    "    f_hat = np.fft.fft(f)\n",
    "    return (np.fft.ifft(1j * k * f_hat)).real\n",
    "\n",
    "def dxx(f):\n",
    "    f_hat = np.fft.fft(f)\n",
    "    return (np.fft.ifft(-k2 * f_hat)).real\n",
    "\n",
    "def dxxx(f):\n",
    "    f_hat = np.fft.fft(f)\n",
    "    return (np.fft.ifft((1j*k)**3 * f_hat)).real\n",
    "\n",
    "def dxxxx(f):\n",
    "    f_hat = np.fft.fft(f)\n",
    "    return (np.fft.ifft(k4 * f_hat)).real\n",
    "\n",
    "def rhs(u, v):\n",
    "    # u_t = v\n",
    "    # v_t = -u_xxxx - 0.5 * (u^2)_xx\n",
    "    u_xxxx = dxxxx(u)\n",
    "    u_sq   = u**2\n",
    "    u_sq_xx = dxx(u_sq)\n",
    "    v_t = -u_xxxx - 0.5 * u_sq_xx\n",
    "    return v, v_t\n",
    "\n",
    "def rk4_step(u, v, dt):\n",
    "    # k1\n",
    "    du1, dv1 = rhs(u, v)\n",
    "    k1u = dt * du1\n",
    "    k1v = dt * dv1\n",
    "    # k2\n",
    "    du2, dv2 = rhs(u + 0.5*k1u, v + 0.5*k1v)\n",
    "    k2u = dt * du2\n",
    "    k2v = dt * dv2\n",
    "    # k3\n",
    "    du3, dv3 = rhs(u + 0.5*k2u, v + 0.5*k2v)\n",
    "    k3u = dt * du3\n",
    "    k3v = dt * dv3\n",
    "    # k4\n",
    "    du4, dv4 = rhs(u + k3u, v + k3v)\n",
    "    k4u = dt * du4\n",
    "    k4v = dt * dv4\n",
    "\n",
    "    # combine\n",
    "    u_new = u + (k1u + 2*k2u + 2*k3u + k4u)/6.0\n",
    "    v_new = v + (k1v + 2*k2v + 2*k3v + k4v)/6.0\n",
    "    return u_new, v_new\n",
    "\n",
    "solution_data = []\n",
    "\n",
    "for step in trange(num_steps+1):\n",
    "    t = step * dt\n",
    "    \n",
    "    if step % record_interval == 0:\n",
    "        # compute and store spatial derivatives\n",
    "        u_x = dx(u)\n",
    "        u_xx = dxx(u)\n",
    "        u_xxx = dxxx(u)\n",
    "        u_xxxx = dxxxx(u)\n",
    "        # compute time derivatives analytically\n",
    "        u_t = v\n",
    "        u_tt = -u_xxxx - 0.5 * dxx(u**2)\n",
    "        u_ttt = -dxxxx(v) - dxx(u*v)  # chain rule\n",
    "        u_tttt = -dxxxx(u_tt) - dxx(v * v + u * u_tt)  # chain rule\n",
    "\n",
    "        # compute spatial derivatives of time derivatives\n",
    "        u_xt = dx(v)\n",
    "        u_xxt = dxx(v)  # spatial derivative of pure time derivative\n",
    "        u_xtt = dx(u_tt)\n",
    "        u_xxxt = dxxx(v) # spatial derivative of pure time derivative\n",
    "        u_xxtt = dxx(u_tt)\n",
    "        u_xttt = dx(u_ttt)\n",
    "\n",
    "        snapshot = {\n",
    "            \"t\": t,\n",
    "            \"x\": x.copy(),\n",
    "            \"u\": u.copy(),\n",
    "            \"u_x\": u_x,\n",
    "            \"u_xx\": u_xx,\n",
    "            \"u_xxx\": u_xxx,\n",
    "            \"u_xxxx\": u_xxxx,\n",
    "            \"u_t\": u_t.copy(),\n",
    "            \"u_tt\": u_tt,\n",
    "            \"u_xt\": u_xt,\n",
    "            \"u_ttt\": u_ttt,\n",
    "            \"u_xxt\": u_xxt,\n",
    "            \"u_xtt\": u_xtt,\n",
    "            \"u_tttt\": u_tttt,\n",
    "            \"u_xxxt\": u_xxxt,\n",
    "            \"u_xxtt\": u_xxtt,\n",
    "            \"u_xttt\": u_xttt\n",
    "        }\n",
    "        solution_data.append(snapshot)\n",
    "\n",
    "    u, v = rk4_step(u, v, dt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "dict_keys(['t', 'x', 'u', 'u_x', 'u_xx', 'u_xxx', 'u_xxxx', 'u_t', 'u_tt', 'u_xt', 'u_ttt', 'u_xxt', 'u_xtt', 'u_tttt', 'u_xxxt', 'u_xxtt', 'u_xttt'])\n"
     ]
    }
   ],
   "source": [
    "import h5py\n",
    "\n",
    "data_store = {}\n",
    "for key in solution_data[0].keys():\n",
    "    data_store[key] = np.array([data[key] for data in solution_data])  # (Nt, Nx)\n",
    "\n",
    "with h5py.File(\"../data/boussinesq/boussinesq_1_dt_1e-3.h5\", \"w\") as f:\n",
    "    for key, value in data_store.items():\n",
    "        f.create_dataset(key, data=value)\n",
    "print(data_store.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
