{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6be5e9ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import random\n",
    "\n",
    "import pysindy as ps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "440a70a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fourth_order_diff(x, dt):\n",
    "    dx = np.zeros([x.shape[0], x.shape[1]])\n",
    "    dx[0] = (-11.0 / 6) * x[0] + 3 * x[1] - 1.5 * x[2] + x[3] / 3\n",
    "    dx[1] = (-11.0 / 6) * x[1] + 3 * x[2] - 1.5 * x[3] + x[4] / 3\n",
    "    dx[2:-2] = (-1.0 / 12) * x[4:] + (2.0 / 3) * x[3:-1] - (2.0 / 3) * x[1:-3] + (1.0 / 12) * x[:-4]\n",
    "    dx[-2] = (11.0 / 6) * x[-2] - 3.0 * x[-3] + 1.5 * x[-4] - x[-5] / 3.0\n",
    "    dx[-1] = (11.0 / 6) * x[-1] - 3.0 * x[-2] + 1.5 * x[-3] - x[-4] / 3.0\n",
    "    return dx / dt\n",
    "\n",
    "def sample_trajectory(x0, coefs, library, timesteps, dt, batch_size):\n",
    "    coefs = np.transpose(coefs, (0, 2, 1))\n",
    "    xs = []\n",
    "    curr = np.array([x0 for i in range(batch_size)])\n",
    "    for i in range(timesteps):\n",
    "        curr_lib = library.transform(curr).reshape(10, 1, 9)\n",
    "        coef_idx = np.random.randint(0, len(coefs), batch_size)\n",
    "        curr_coefs = coefs[coef_idx]\n",
    "        dx = np.matmul(curr_lib, curr_coefs).squeeze(1)\n",
    "        curr = curr + dx * dt\n",
    "        xs.append(curr)\n",
    "    xs = np.array(xs)\n",
    "    return np.transpose(xs, (1, 0, 2))\n",
    "\n",
    "def sample_trajectory2(x0, coefs, library, timesteps, dt, batch_size):\n",
    "    coefs = np.transpose(coefs, (0, 2, 1))\n",
    "    coefs_mean, coefs_std = coefs.mean(0), coefs.std(0)\n",
    "    coefs_mean = np.array([coefs_mean for _ in range(batch_size)])\n",
    "    coefs_std = np.array([coefs_std for _ in range(batch_size)])\n",
    "    xs = []\n",
    "    curr = np.array([x0 for _ in range(batch_size)])\n",
    "    for i in range(timesteps):\n",
    "        curr_lib = library.transform(curr).reshape(10, 1, 9)\n",
    "        noise = np.random.normal(0, 1, (batch_size, coefs.shape[1], coefs.shape[2]))\n",
    "        curr_coefs = coefs_mean + coefs_std * noise\n",
    "        dx = np.matmul(curr_lib, curr_coefs).squeeze(1)\n",
    "        curr = curr + dx * dt\n",
    "        xs.append(curr)\n",
    "    xs = np.array(xs)\n",
    "    return np.transpose(xs, (1, 0, 2))\n",
    "\n",
    "def plot_samples(xs, samples, num_samples=4, dpi=300, figsize=None, filename=None):\n",
    "    sns.set()\n",
    "\n",
    "    # https://dawes.wordpress.com/2014/06/27/publication-ready-3d-figures-from-matplotlib/\n",
    "    # fig = plt.figure(figsize=(batch_size + 1, 3.5), dpi=300)\n",
    "    if figsize is not None:\n",
    "        fig = plt.figure(figsize=figsize, dpi=dpi)\n",
    "    else:\n",
    "        fig = plt.figure(dpi=dpi)\n",
    "    fig.tight_layout()\n",
    "    ct = 0\n",
    "    for i in range(num_samples):\n",
    "        ax = fig.add_subplot(1, num_samples, ct + 1, projection='3d')\n",
    "        if i == 0:\n",
    "            ax.plot(xs[:, 0], xs[:, 1], xs[:,2], color='red')\n",
    "        else:\n",
    "            ax.plot(samples[i][:, 0], samples[i][:, 1], samples[i][:,2], color='blue')\n",
    "        ct += 1\n",
    "\n",
    "        ax.grid(False)\n",
    "        color_tuple = (1.0, 1.0, 1.0, 0.0)\n",
    "\n",
    "        ax.xaxis.set_pane_color(color_tuple)\n",
    "        ax.yaxis.set_pane_color(color_tuple)\n",
    "        ax.zaxis.set_pane_color(color_tuple)\n",
    "        ax.xaxis.line.set_color(color_tuple)\n",
    "        ax.yaxis.line.set_color(color_tuple)\n",
    "        ax.zaxis.line.set_color(color_tuple)\n",
    "\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_zticks([])\n",
    "\n",
    "    plt.subplots_adjust(wspace=0)\n",
    "    \n",
    "    if filename is not None:\n",
    "        plt.savefig(filename)\n",
    "        plt.close()\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7dd72e9",
   "metadata": {},
   "source": [
    "# 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "746d970d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(x)' = -9.860 x + 9.994 y\n",
      "(y)' = 25.551 x + 0.205 y + -0.954 x z\n",
      "(z)' = -0.240 x + 0.225 y + -2.602 z + 0.997 x y\n",
      "[[0.022411   0.01930534 0.         0.         0.         0.\n",
      "  0.         0.         0.        ]\n",
      " [0.11809155 0.0379148  0.         0.         0.         0.00312712\n",
      "  0.         0.         0.        ]\n",
      " [0.08093548 0.06984578 0.01448469 0.         0.00340899 0.\n",
      "  0.         0.         0.        ]]\n",
      "0.11809155107037621\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(862023)\n",
    "random.seed(862023)\n",
    "\n",
    "dt = 0.01\n",
    "x_train = np.load('../data/lorenz/scale-1.0/x_train.npy')\n",
    "x_dot = fourth_order_diff(x_train, dt)\n",
    "x_test = np.load('../data/lorenz/scale-1.0/x_test_0.npy')\n",
    "x0 = x_test[0]\n",
    "\n",
    "feature_names = ['x', 'y', 'z']\n",
    "# Instantiate and fit the SINDy model \n",
    "library = ps.PolynomialLibrary(degree=2, include_bias=False)\n",
    "optimizer = ps.SR3(\n",
    "    threshold=0.5, thresholder=\"l0\", max_iter=1000, normalize_columns=False, tol=1e-1\n",
    ")\n",
    "model = ps.SINDy(feature_names=feature_names, feature_library=library, optimizer=optimizer)\n",
    "model.fit(x_train, x_dot=x_dot, t=dt, ensemble=True, quiet=True, n_models=500)\n",
    "ensemble_coefs = np.array(model.coef_list)\n",
    "\n",
    "model.print()\n",
    "\n",
    "model.coef_list = np.mean(ensemble_coefs, 0)\n",
    "\n",
    "the_std = ensemble_coefs.std(0)\n",
    "print(the_std)\n",
    "print(np.max(the_std))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "282ffd07",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(862023)\n",
    "random.seed(862023)\n",
    "\n",
    "samples = sample_trajectory(x0, ensemble_coefs, library, 10000, dt, 10)\n",
    "\n",
    "#plot_samples(x_test, samples, 5, 300, (20, 20), None)\n",
    "plot_samples(x_test, samples, 5, 300, (20, 20), \"../results/esindy_lorenz_gen_1\")\n",
    "\n",
    "#samples = sample_trajectory2(x0, ensemble_coefs, library, 10000, dt, 10)\n",
    "\n",
    "#plot_samples(x_test, samples, 5, 300, (20, 20), None)\n",
    "#plot_samples(x_test, samples, 5, 300, (20, 20), \"esindy_lorenznc_gen2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e91e722",
   "metadata": {},
   "source": [
    "# 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "eca56a9c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(x)' = -9.772 x + 9.975 y\n",
      "(y)' = 23.908 x + 0.892 y + -0.922 x z\n",
      "(z)' = -0.357 x + 0.228 y + -2.325 z + 0.959 x y\n",
      "[[0.10477731 0.09234954 0.         0.         0.         0.\n",
      "  0.         0.         0.        ]\n",
      " [1.06669903 0.47814381 0.         0.         0.         0.0197099\n",
      "  0.         0.         0.        ]\n",
      " [0.27333435 0.20884593 0.08527987 0.04235572 0.02981011 0.\n",
      "  0.         0.         0.        ]]\n",
      "1.0666990288899492\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(862023)\n",
    "random.seed(862023)\n",
    "\n",
    "dt = 0.01\n",
    "x_train = np.load('../data/lorenz/scale-5.0/x_train.npy')\n",
    "x_dot = fourth_order_diff(x_train, dt)\n",
    "x_test = np.load('../data/lorenz/scale-5.0/x_test_0.npy')\n",
    "x0 = x_test[0]\n",
    "\n",
    "feature_names = ['x', 'y', 'z']\n",
    "# Instantiate and fit the SINDy model \n",
    "library = ps.PolynomialLibrary(degree=2, include_bias=False)\n",
    "optimizer = ps.SR3(\n",
    "    threshold=0.5, thresholder=\"l0\", max_iter=1000, normalize_columns=False, tol=1e-1\n",
    ")\n",
    "model = ps.SINDy(feature_names=feature_names, feature_library=library, optimizer=optimizer)\n",
    "model.fit(x_train, x_dot=x_dot, t=dt, ensemble=True, quiet=True, n_models=500)\n",
    "ensemble_coefs = np.array(model.coef_list)\n",
    "\n",
    "model.print()\n",
    "\n",
    "model.coef_list = np.mean(ensemble_coefs, 0)\n",
    "\n",
    "the_std = ensemble_coefs.std(0)\n",
    "print(the_std)\n",
    "print(np.max(the_std))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "0d95ff2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(862023)\n",
    "random.seed(862023)\n",
    "\n",
    "samples = sample_trajectory(x0, ensemble_coefs, library, 10000, dt, 10)\n",
    "\n",
    "#plot_samples(x_test, samples, 5, 300, (20, 20), None)\n",
    "plot_samples(x_test, samples, 5, 300, (20, 20), \"../results/esindy_lorenz_gen_5\")\n",
    "\n",
    "#samples = sample_trajectory2(x0, ensemble_coefs, library, 10000, dt, 10)\n",
    "\n",
    "#plot_samples(x_test, samples, 5, 300, (20, 20), None)\n",
    "#plot_samples(x_test, samples, 5, 300, (20, 20), \"esindy_lorenznc_gen2\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "077edfda",
   "metadata": {},
   "source": [
    "# 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "236fc1f1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(x)' = -8.468 x + 9.481 y\n",
      "(y)' = 23.714 x + 1.557 y + -0.940 x z\n",
      "(z)' = 0.219 x + 0.231 y + -1.785 z + 0.949 x y\n",
      "[[0.2632121  0.22849929 0.         0.         0.         0.\n",
      "  0.         0.         0.        ]\n",
      " [0.98279896 0.47134142 0.         0.         0.         0.01683359\n",
      "  0.         0.         0.        ]\n",
      " [0.57243191 0.46640596 0.15581896 0.         0.02711568 0.\n",
      "  0.         0.         0.        ]]\n",
      "0.9827989615553067\n"
     ]
    }
   ],
   "source": [
    "np.random.seed(862023)\n",
    "random.seed(862023)\n",
    "\n",
    "dt = 0.01\n",
    "x_train = np.load('../data/lorenz/scale-10.0/x_train.npy')\n",
    "x_dot = fourth_order_diff(x_train, dt)\n",
    "x_test = np.load('../data/lorenz/scale-10.0/x_test_0.npy')\n",
    "x0 = x_test[0]\n",
    "\n",
    "feature_names = ['x', 'y', 'z']\n",
    "# Instantiate and fit the SINDy model \n",
    "library = ps.PolynomialLibrary(degree=2, include_bias=False)\n",
    "optimizer = ps.SR3(\n",
    "    threshold=0.9, thresholder=\"l0\", max_iter=1000, normalize_columns=False, tol=1e-1\n",
    ")\n",
    "model = ps.SINDy(feature_names=feature_names, feature_library=library, optimizer=optimizer)\n",
    "model.fit(x_train, x_dot=x_dot, t=dt, ensemble=True, quiet=True, n_models=500)\n",
    "ensemble_coefs = np.array(model.coef_list)\n",
    "\n",
    "model.print()\n",
    "\n",
    "model.coef_list = np.mean(ensemble_coefs, 0)\n",
    "\n",
    "the_std = ensemble_coefs.std(0)\n",
    "print(the_std)\n",
    "print(np.max(the_std))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "6765fcbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.random.seed(862023)\n",
    "random.seed(862023)\n",
    "\n",
    "samples = sample_trajectory(x0, ensemble_coefs, library, 10000, dt, 10)\n",
    "\n",
    "#plot_samples(x_test, samples, 5, 300, (20, 20), None)\n",
    "plot_samples(x_test, samples, 5, 300, (20, 20), \"../results/esindy_lorenz_gen_10\")\n",
    "\n",
    "#samples = sample_trajectory2(x0, ensemble_coefs, library, 10000, dt, 10)\n",
    "\n",
    "#plot_samples(x_test, samples, 5, 300, (20, 20), None)\n",
    "#plot_samples(x_test, samples, 5, 300, (20, 20), \"esindy_lorenznc_gen2\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "498c75f9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6bcdfa3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c40e18da",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8999a18e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
