{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6be5e9ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import random\n",
    "\n",
    "import pysindy as ps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "440a70a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fourth_order_diff(x, dt):\n",
    "    dx = np.zeros([x.shape[0], x.shape[1]])\n",
    "    dx[0] = (-11.0 / 6) * x[0] + 3 * x[1] - 1.5 * x[2] + x[3] / 3\n",
    "    dx[1] = (-11.0 / 6) * x[1] + 3 * x[2] - 1.5 * x[3] + x[4] / 3\n",
    "    dx[2:-2] = (-1.0 / 12) * x[4:] + (2.0 / 3) * x[3:-1] - (2.0 / 3) * x[1:-3] + (1.0 / 12) * x[:-4]\n",
    "    dx[-2] = (11.0 / 6) * x[-2] - 3.0 * x[-3] + 1.5 * x[-4] - x[-5] / 3.0\n",
    "    dx[-1] = (11.0 / 6) * x[-1] - 3.0 * x[-2] + 1.5 * x[-3] - x[-4] / 3.0\n",
    "    return dx / dt\n",
    "\n",
    "def sample_trajectory(x0, coefs, library, timesteps, dt, batch_size):\n",
    "    coefs = np.transpose(coefs, (0, 2, 1))\n",
    "    xs = []\n",
    "    curr = np.array([x0 for i in range(batch_size)])\n",
    "    for i in range(timesteps):\n",
    "        curr_lib = library.transform(curr).reshape(10, 1, 9)\n",
    "        coef_idx = np.random.randint(0, len(coefs), batch_size)\n",
    "        curr_coefs = coefs[coef_idx]\n",
    "        dx = np.matmul(curr_lib, curr_coefs).squeeze(1)\n",
    "        curr = curr + dx * dt\n",
    "        xs.append(curr)\n",
    "    xs = np.array(xs)\n",
    "    return np.transpose(xs, (1, 0, 2))\n",
    "\n",
    "def sample_trajectory2(x0, coefs, library, timesteps, dt, batch_size):\n",
    "    coefs = np.transpose(coefs, (0, 2, 1))\n",
    "    coefs_mean, coefs_std = coefs.mean(0), coefs.std(0)\n",
    "    coefs_mean = np.array([coefs_mean for _ in range(batch_size)])\n",
    "    coefs_std = np.array([coefs_std for _ in range(batch_size)])\n",
    "    xs = []\n",
    "    curr = np.array([x0 for _ in range(batch_size)])\n",
    "    for i in range(timesteps):\n",
    "        curr_lib = library.transform(curr).reshape(10, 1, 9)\n",
    "        noise = np.random.normal(0, 1, (batch_size, coefs.shape[1], coefs.shape[2]))\n",
    "        curr_coefs = coefs_mean + coefs_std * noise\n",
    "        dx = np.matmul(curr_lib, curr_coefs).squeeze(1)\n",
    "        curr = curr + dx * dt\n",
    "        xs.append(curr)\n",
    "    xs = np.array(xs)\n",
    "    return np.transpose(xs, (1, 0, 2))\n",
    "\n",
    "def plot_samples(xs, samples, num_samples=4, dpi=300, figsize=None, filename=None):\n",
    "    sns.set()\n",
    "\n",
    "    # https://dawes.wordpress.com/2014/06/27/publication-ready-3d-figures-from-matplotlib/\n",
    "    # fig = plt.figure(figsize=(batch_size + 1, 3.5), dpi=300)\n",
    "    if figsize is not None:\n",
    "        fig = plt.figure(figsize=figsize, dpi=dpi)\n",
    "    else:\n",
    "        fig = plt.figure(dpi=dpi)\n",
    "    fig.tight_layout()\n",
    "    ct = 0\n",
    "    for i in range(num_samples):\n",
    "        ax = fig.add_subplot(1, num_samples, ct + 1, projection='3d')\n",
    "        if i == 0:\n",
    "            ax.plot(xs[:, 0], xs[:, 1], xs[:,2], color='red')\n",
    "        else:\n",
    "            ax.plot(samples[i][:, 0], samples[i][:, 1], samples[i][:,2], color='blue')\n",
    "        ct += 1\n",
    "\n",
    "        ax.grid(False)\n",
    "        color_tuple = (1.0, 1.0, 1.0, 0.0)\n",
    "\n",
    "        ax.xaxis.set_pane_color(color_tuple)\n",
    "        ax.yaxis.set_pane_color(color_tuple)\n",
    "        ax.zaxis.set_pane_color(color_tuple)\n",
    "        ax.xaxis.line.set_color(color_tuple)\n",
    "        ax.yaxis.line.set_color(color_tuple)\n",
    "        ax.zaxis.line.set_color(color_tuple)\n",
    "\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_zticks([])\n",
    "\n",
    "    plt.subplots_adjust(wspace=0)\n",
    "    \n",
    "    if filename is not None:\n",
    "        plt.savefig(filename)\n",
    "        plt.close()\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "45bd0554",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "runs = \"../runs/rossler_rmse/\"\n",
    "if not os.path.isdir(runs):\n",
    "    os.makedirs(runs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7dd72e9",
   "metadata": {},
   "source": [
    "# 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "746d970d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "(x)' = -1.001 y + -0.998 z\n",
      "(y)' = 1.002 x + 0.213 y\n",
      "(z)' = -5.561 z + 0.991 x z + 0.007 y z\n",
      "\n",
      "1\n",
      "(x)' = -1.001 y + -0.999 z\n",
      "(y)' = -0.014 1 + 0.982 x + 0.227 y + 0.044 z + -0.017 y z\n",
      "(z)' = 0.127 1 + -5.278 z + 0.941 x z + -0.019 y z\n",
      "\n",
      "2\n",
      "(x)' = -1.001 y + -0.999 z\n",
      "(y)' = 0.966 x\n",
      "(z)' = 0.144 1 + -5.742 z + 1.026 x z + 0.046 y z\n",
      "\n",
      "3\n",
      "(x)' = -1.001 y + -0.999 z\n",
      "(y)' = 1.011 x + 0.210 y\n",
      "(z)' = 0.126 1 + -5.510 z + 0.997 x z\n",
      "\n",
      "4\n",
      "(x)' = -1.001 y + -0.997 z\n",
      "(y)' = 1.022 x + 0.221 y\n",
      "(z)' = 0.188 1 + -5.685 z + 1.000 x z\n",
      "\n",
      "5\n",
      "(x)' = -1.001 y + -0.997 z\n",
      "(y)' = 1.007 x + 0.198 y\n",
      "(z)' = -5.728 z + 1.030 x z\n",
      "\n",
      "6\n",
      "(x)' = -1.001 y + -0.998 z\n",
      "(y)' = 0.967 x\n",
      "(z)' = 0.189 1 + -5.590 z + 0.991 x z\n",
      "\n",
      "7\n",
      "(x)' = -1.001 y + -0.999 z\n",
      "(y)' = 0.956 x\n",
      "(z)' = 0.153 1 + -5.604 z + 0.984 x z\n",
      "\n",
      "8\n",
      "(x)' = -1.001 y + -0.998 z\n",
      "(y)' = -0.011 1 + 1.008 x + 0.224 y\n",
      "(z)' = 0.131 1 + -5.403 z + 0.961 x z\n",
      "\n",
      "9\n",
      "(x)' = -1.001 y + -0.999 z\n",
      "(y)' = 0.998 x + 0.215 y\n",
      "(z)' = -5.625 z + 0.969 x z\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    print(i)\n",
    "    seed = 650 + i\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    \n",
    "    dt = 0.01\n",
    "    x_train = np.load('../data/rossler_rmse/scale-1.0/x_train' + str(i) + '.npy')\n",
    "    x_dot = fourth_order_diff(x_train, dt)\n",
    "\n",
    "    feature_names = ['x', 'y', 'z']\n",
    "    # Instantiate and fit the SINDy model \n",
    "    library = ps.PolynomialLibrary(degree=3, include_bias=True)\n",
    "    optimizer = ps.SR3(\n",
    "        threshold=0.19, thresholder=\"l0\", max_iter=1000, normalize_columns=False, tol=1e-1\n",
    "    )\n",
    "    model = ps.SINDy(feature_names=feature_names, feature_library=library, optimizer=optimizer)\n",
    "    model.fit(x_train, x_dot=x_dot, t=dt, ensemble=True, quiet=True, n_models=500)\n",
    "    model.print()\n",
    "    print()\n",
    "    coefs = np.array(model.coef_list)\n",
    "    np.save(runs + \"esindy_1_\" + str(i) + \".npy\", coefs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16530a7e",
   "metadata": {},
   "source": [
    "# 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4eb7f5be",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "(x)' = -1.001 y + -0.997 z\n",
      "(y)' = 0.955 x + 0.053 z\n",
      "(z)' = 0.238 1 + -8.299 z + 1.254 x z + 0.656 y z + -0.072 x y z\n",
      "\n",
      "1\n",
      "(x)' = -1.001 y + -0.993 z\n",
      "(y)' = -0.115 1 + 0.965 x + -0.007 z\n",
      "(z)' = -0.394 1 + -0.064 x + -5.109 z + 0.991 x z\n",
      "\n",
      "2\n",
      "(x)' = -1.001 y + -0.996 z\n",
      "(y)' = 0.998 x + 0.204 y\n",
      "(z)' = -6.161 z + 1.109 x z + -0.009 y z\n",
      "\n",
      "3\n",
      "(x)' = -1.001 y + -0.999 z\n",
      "(y)' = 0.962 x\n",
      "(z)' = -7.204 z + 1.236 x z + 0.167 y z\n",
      "\n",
      "4\n",
      "(x)' = -1.001 y + -0.995 z\n",
      "(y)' = 0.911 x\n",
      "(z)' = -0.453 1 + -0.172 x + -4.464 z + 0.940 x z\n",
      "\n",
      "5\n",
      "(x)' = -1.001 y + -0.999 z\n",
      "(y)' = 0.937 x\n",
      "(z)' = 0.055 1 + -6.665 z + 1.185 x z + 0.429 y z + -0.041 z^2\n",
      "\n",
      "6\n",
      "(x)' = -1.001 y + -0.998 z\n",
      "(y)' = 1.002 x + 0.226 y\n",
      "(z)' = 0.089 1 + -5.675 z + 1.144 x z + -0.050 y z + -0.069 z^2\n",
      "\n",
      "7\n",
      "(x)' = -1.001 y + -0.996 z\n",
      "(y)' = 0.991 x\n",
      "(z)' = 0.165 1 + -5.871 z + 0.923 x z + 0.098 y z + 0.041 z^2\n",
      "\n",
      "8\n",
      "(x)' = -1.002 y + -0.995 z\n",
      "(y)' = 0.969 x\n",
      "(z)' = 0.046 1 + -5.916 z + 1.088 x z + 0.040 y z\n",
      "\n",
      "9\n",
      "(x)' = -1.001 y + -0.998 z\n",
      "(y)' = 0.997 x + 0.223 y + -0.001 z\n",
      "(z)' = 0.404 1 + -7.384 z + 1.100 x z + 0.250 y z\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    print(i)\n",
    "    seed = 750 + i\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    \n",
    "    dt = 0.01\n",
    "    x_train = np.load('../data/rossler_rmse/scale-5.0/x_train' + str(i) + '.npy')\n",
    "    x_dot = fourth_order_diff(x_train, dt)\n",
    "\n",
    "    feature_names = ['x', 'y', 'z']\n",
    "    # Instantiate and fit the SINDy model \n",
    "    library = ps.PolynomialLibrary(degree=3, include_bias=True)\n",
    "    optimizer = ps.SR3(\n",
    "        threshold=0.19, thresholder=\"l0\", max_iter=1000, normalize_columns=False, tol=1e-1\n",
    "    )\n",
    "    model = ps.SINDy(feature_names=feature_names, feature_library=library, optimizer=optimizer)\n",
    "    model.fit(x_train, x_dot=x_dot, t=dt, ensemble=True, quiet=True, n_models=500)\n",
    "    model.print()\n",
    "    print()\n",
    "    coefs = np.array(model.coef_list)\n",
    "    #print(coefs.mean(0))\n",
    "    #print('')\n",
    "    np.save(runs + \"esindy_5_\" + str(i) + \".npy\", coefs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c41020db",
   "metadata": {},
   "source": [
    "# 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b8a775ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "(x)' = -1.001 y + -0.996 z\n",
      "(y)' = 0.970 x\n",
      "(z)' = 0.346 1 + -0.045 x + 0.068 y + -6.525 z + 1.219 x z + -0.677 y z + -0.020 z^2 + 0.097 x y z\n",
      "\n",
      "1\n",
      "(x)' = -1.001 y + -1.001 z\n",
      "(y)' = 0.997 x + 0.224 y\n",
      "(z)' = -0.692 1 + -0.131 y + -4.976 z + 0.930 x z + 0.203 y z\n",
      "\n",
      "2\n",
      "(x)' = -1.000 y + -1.001 z\n",
      "(y)' = 0.951 x\n",
      "(z)' = -0.409 1 + 0.091 x + -0.007 y + -4.152 z + 0.889 x z + -0.041 y z\n",
      "\n",
      "3\n",
      "(x)' = -1.001 y + -0.999 z\n",
      "(y)' = 1.007 x + 0.207 y\n",
      "(z)' = -0.227 x + -5.558 z + 1.094 x z + -0.195 y z\n",
      "\n",
      "4\n",
      "(x)' = -1.001 y + -1.001 z\n",
      "(y)' = 0.992 x + 0.195 y\n",
      "(z)' = -1.262 1 + 1.506 z + -4.761 y z + -0.075 z^2 + 0.454 x y z + 0.610 y^2 z\n",
      "\n",
      "5\n",
      "(x)' = -1.001 y + -0.993 z\n",
      "(y)' = -0.036 1 + 0.982 x\n",
      "(z)' = -0.021 x + -3.149 z + 0.833 x z + -0.283 y z\n",
      "\n",
      "6\n",
      "(x)' = -1.002 y + -0.990 z\n",
      "(y)' = 0.968 x\n",
      "(z)' = 0.409 1 + 0.071 x + -7.805 z + 1.087 x z + 0.273 y z + 0.063 z^2\n",
      "\n",
      "7\n",
      "(x)' = -1.001 y + -1.007 z\n",
      "(y)' = 1.006 x + 0.185 y\n",
      "(z)' = 1.683 1 + 0.378 x + -17.371 z + 1.705 x z + 4.890 y z + 0.162 z^2 + -0.408 x y z + -0.317 y^2 z\n",
      "\n",
      "8\n",
      "(x)' = -1.001 y + -0.994 z\n",
      "(y)' = 1.010 x + 0.227 y\n",
      "(z)' = -0.023 1 + -5.452 z + 0.942 x z + -0.227 y z + 0.015 z^2\n",
      "\n",
      "9\n",
      "(x)' = -1.001 y + -1.000 z\n",
      "(y)' = 0.970 x + 0.201 y\n",
      "(z)' = -0.096 1 + 0.013 x + 0.001 y + -6.374 z + 0.948 x z + -0.057 y z + 0.053 z^2\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    print(i)\n",
    "    seed = 850 + i\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    \n",
    "    dt = 0.01\n",
    "    x_train = np.load('../data/rossler_rmse/scale-10.0/x_train' + str(i) + '.npy')\n",
    "    x_dot = fourth_order_diff(x_train, dt)\n",
    "    feature_names = ['x', 'y', 'z']\n",
    "    # Instantiate and fit the SINDy model \n",
    "    library = ps.PolynomialLibrary(degree=3, include_bias=True)\n",
    "    optimizer = ps.SR3(\n",
    "        threshold=0.19, thresholder=\"l0\", max_iter=1000, normalize_columns=False, tol=1e-1\n",
    "    )\n",
    "    model = ps.SINDy(feature_names=feature_names, feature_library=library, optimizer=optimizer)\n",
    "    model.fit(x_train, x_dot=x_dot, t=dt, ensemble=True, quiet=True, n_models=500)\n",
    "    model.print()\n",
    "    print()\n",
    "    coefs = np.array(model.coef_list)\n",
    "    #print(coefs.mean(0))\n",
    "    #print('')\n",
    "    np.save(runs + \"esindy_10_\" + str(i) + \".npy\", coefs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c40e18da",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8999a18e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
