{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6be5e9ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import random\n",
    "\n",
    "import pysindy as ps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "440a70a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fourth_order_diff(x, dt):\n",
    "    dx = np.zeros([x.shape[0], x.shape[1]])\n",
    "    dx[0] = (-11.0 / 6) * x[0] + 3 * x[1] - 1.5 * x[2] + x[3] / 3\n",
    "    dx[1] = (-11.0 / 6) * x[1] + 3 * x[2] - 1.5 * x[3] + x[4] / 3\n",
    "    dx[2:-2] = (-1.0 / 12) * x[4:] + (2.0 / 3) * x[3:-1] - (2.0 / 3) * x[1:-3] + (1.0 / 12) * x[:-4]\n",
    "    dx[-2] = (11.0 / 6) * x[-2] - 3.0 * x[-3] + 1.5 * x[-4] - x[-5] / 3.0\n",
    "    dx[-1] = (11.0 / 6) * x[-1] - 3.0 * x[-2] + 1.5 * x[-3] - x[-4] / 3.0\n",
    "    return dx / dt\n",
    "\n",
    "def sample_trajectory(x0, coefs, library, timesteps, dt, batch_size):\n",
    "    coefs = np.transpose(coefs, (0, 2, 1))\n",
    "    xs = []\n",
    "    curr = np.array([x0 for i in range(batch_size)])\n",
    "    for i in range(timesteps):\n",
    "        curr_lib = library.transform(curr).reshape(10, 1, 9)\n",
    "        coef_idx = np.random.randint(0, len(coefs), batch_size)\n",
    "        curr_coefs = coefs[coef_idx]\n",
    "        dx = np.matmul(curr_lib, curr_coefs).squeeze(1)\n",
    "        curr = curr + dx * dt\n",
    "        xs.append(curr)\n",
    "    xs = np.array(xs)\n",
    "    return np.transpose(xs, (1, 0, 2))\n",
    "\n",
    "def sample_trajectory2(x0, coefs, library, timesteps, dt, batch_size):\n",
    "    coefs = np.transpose(coefs, (0, 2, 1))\n",
    "    coefs_mean, coefs_std = coefs.mean(0), coefs.std(0)\n",
    "    coefs_mean = np.array([coefs_mean for _ in range(batch_size)])\n",
    "    coefs_std = np.array([coefs_std for _ in range(batch_size)])\n",
    "    xs = []\n",
    "    curr = np.array([x0 for _ in range(batch_size)])\n",
    "    for i in range(timesteps):\n",
    "        curr_lib = library.transform(curr).reshape(10, 1, 9)\n",
    "        noise = np.random.normal(0, 1, (batch_size, coefs.shape[1], coefs.shape[2]))\n",
    "        curr_coefs = coefs_mean + coefs_std * noise\n",
    "        dx = np.matmul(curr_lib, curr_coefs).squeeze(1)\n",
    "        curr = curr + dx * dt\n",
    "        xs.append(curr)\n",
    "    xs = np.array(xs)\n",
    "    return np.transpose(xs, (1, 0, 2))\n",
    "\n",
    "def plot_samples(xs, samples, num_samples=4, dpi=300, figsize=None, filename=None):\n",
    "    sns.set()\n",
    "\n",
    "    # https://dawes.wordpress.com/2014/06/27/publication-ready-3d-figures-from-matplotlib/\n",
    "    # fig = plt.figure(figsize=(batch_size + 1, 3.5), dpi=300)\n",
    "    if figsize is not None:\n",
    "        fig = plt.figure(figsize=figsize, dpi=dpi)\n",
    "    else:\n",
    "        fig = plt.figure(dpi=dpi)\n",
    "    fig.tight_layout()\n",
    "    ct = 0\n",
    "    for i in range(num_samples):\n",
    "        ax = fig.add_subplot(1, num_samples, ct + 1, projection='3d')\n",
    "        if i == 0:\n",
    "            ax.plot(xs[:, 0], xs[:, 1], xs[:,2], color='red')\n",
    "        else:\n",
    "            ax.plot(samples[i][:, 0], samples[i][:, 1], samples[i][:,2], color='blue')\n",
    "        ct += 1\n",
    "\n",
    "        ax.grid(False)\n",
    "        color_tuple = (1.0, 1.0, 1.0, 0.0)\n",
    "\n",
    "        ax.xaxis.set_pane_color(color_tuple)\n",
    "        ax.yaxis.set_pane_color(color_tuple)\n",
    "        ax.zaxis.set_pane_color(color_tuple)\n",
    "        ax.xaxis.line.set_color(color_tuple)\n",
    "        ax.yaxis.line.set_color(color_tuple)\n",
    "        ax.zaxis.line.set_color(color_tuple)\n",
    "\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_zticks([])\n",
    "\n",
    "    plt.subplots_adjust(wspace=0)\n",
    "    \n",
    "    if filename is not None:\n",
    "        plt.savefig(filename)\n",
    "        plt.close()\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "45bd0554",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "runs = \"../runs/lorenz_rmse/\"\n",
    "if not os.path.isdir(runs):\n",
    "    os.makedirs(runs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e7dd72e9",
   "metadata": {},
   "source": [
    "# 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "746d970d",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "5\n",
      "6\n",
      "7\n",
      "8\n",
      "9\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    print(i)\n",
    "    seed = 700 + i\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    \n",
    "    dt = 0.01\n",
    "    x_train = np.load('../data/lorenz_rmse/scale-1.0/x_train' + str(i) + '.npy')\n",
    "    x_dot = fourth_order_diff(x_train, dt)\n",
    "\n",
    "    feature_names = ['x', 'y', 'z']\n",
    "    # Instantiate and fit the SINDy model \n",
    "    library = ps.PolynomialLibrary(degree=3, include_bias=False)\n",
    "    optimizer = ps.SR3(\n",
    "        threshold=0.5, thresholder=\"l0\", max_iter=1000, normalize_columns=False, tol=1e-1\n",
    "    )\n",
    "    model = ps.SINDy(feature_names=feature_names, feature_library=library, optimizer=optimizer)\n",
    "    model.fit(x_train, x_dot=x_dot, t=dt, ensemble=True, quiet=True, n_models=500)\n",
    "    coefs = np.array(model.coef_list)\n",
    "    np.save(runs + \"esindy_1_\" + str(i) + \".npy\", coefs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "16530a7e",
   "metadata": {},
   "source": [
    "# 5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4eb7f5be",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "[[-9.42556909e+00  9.72660211e+00 -9.62927251e-03  0.00000000e+00\n",
      "  -1.02142481e-05  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.49714400e+01  4.67978820e-01 -1.44116763e-04  0.00000000e+00\n",
      "   0.00000000e+00 -9.45492120e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 9.47261676e-01 -4.51678253e-01 -2.47126112e+00  2.75176562e-02\n",
      "   9.70792349e-01 -1.27182559e-02 -1.96591511e-02  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "1\n",
      "[[-9.63574326e+00  9.86427030e+00  1.43509883e-06  0.00000000e+00\n",
      "   3.99563451e-06  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.45526636e+01  5.69044487e-01 -9.71047211e-05  0.00000000e+00\n",
      "  -6.45404004e-05 -9.33066315e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 5.27742142e-01 -1.44004711e-01 -2.47210801e+00  5.25521395e-03\n",
      "   9.75983263e-01 -2.91019375e-03 -5.08568520e-03  8.93554940e-05\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "2\n",
      "[[-9.63415578e+00  9.92705560e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.51771643e+01  4.54117632e-01 -4.34339319e-06  0.00000000e+00\n",
      "   1.00303821e-04 -9.53158426e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 4.12319293e-01 -2.19245528e-01 -2.56239247e+00  1.12875952e-01\n",
      "   9.22108832e-01  1.97637981e-03 -1.46223840e-02  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "3\n",
      "[[-9.45752340e+00  9.74934809e+00  1.17428941e-03  0.00000000e+00\n",
      "  -2.70656561e-05  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.39241541e+01  7.95243862e-01  7.02299822e-05  0.00000000e+00\n",
      "   0.00000000e+00 -9.15936544e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.47177122e-01  4.48584150e-02 -2.40305503e+00  2.94655126e-02\n",
      "   9.58325785e-01 -4.90673383e-03 -1.31538958e-02 -2.67707090e-04\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "4\n",
      "[[-9.63718381e+00  9.95033159e+00  1.04188057e-04  1.90731855e-04\n",
      "  -8.95368335e-05 -1.64941974e-03  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.48528506e+01  4.93339729e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00 -9.44791012e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 5.77257045e-01 -3.23948229e-01 -2.41873878e+00  8.84198476e-02\n",
      "   9.17792535e-01  1.66010245e-03 -4.37230901e-03  8.84073021e-05\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "5\n",
      "[[-9.69636173e+00  9.93870169e+00 -2.24555059e-03  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.49300217e+01  5.24881846e-01 -4.90740345e-04  0.00000000e+00\n",
      "   0.00000000e+00 -9.45403658e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [-3.01730780e-01 -7.70639309e-03 -2.61359485e+00  7.21109798e-02\n",
      "   9.59705809e-01 -2.02251725e-03 -8.54366065e-04 -3.13868220e-03\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "6\n",
      "[[-9.24689912e+00  9.52177855e+00  5.45193618e-04  8.43788674e-05\n",
      "  -2.19104449e-04  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.48682178e+01  4.04830854e-01  2.64493978e-03  4.52864317e-05\n",
      "  -4.19866604e-05 -9.39085532e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.20149607e-02 -9.35831452e-02 -2.48802574e+00  1.51222503e-01\n",
      "   8.79207606e-01  1.25208211e-02  0.00000000e+00 -1.18496754e-04\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "7\n",
      "[[-9.68858118e+00  9.93142817e+00  0.00000000e+00  0.00000000e+00\n",
      "  -1.57300402e-05  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.41723957e+01  7.56853929e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00 -9.25125263e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 5.22821683e-01 -1.76400461e-01 -2.46304475e+00  1.05995774e-02\n",
      "   1.05380951e+00 -2.24833923e-03 -5.13129592e-02  2.78269331e-04\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "8\n",
      "[[-9.74049754e+00  9.98834021e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.45236553e+01  5.37332217e-01  2.74194289e-05  0.00000000e+00\n",
      "   0.00000000e+00 -9.35095184e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 9.14453895e-01 -5.37529104e-01 -2.52826741e+00  9.43313226e-02\n",
      "   9.39692517e-01  2.65085171e-03 -1.50841501e-02  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "9\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[-9.51510235e+00  9.77699408e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  1.31710947e-04  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.52458505e+01  4.83410192e-01  7.87877360e-04  0.00000000e+00\n",
      "   0.00000000e+00 -9.53045631e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 1.01308484e-02 -6.51347813e-02 -2.50354047e+00  5.30234979e-02\n",
      "   9.79182151e-01 -3.78346436e-03 -3.27287958e-02  1.95193791e-04\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    print(i)\n",
    "    seed = 700 + i\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    \n",
    "    dt = 0.01\n",
    "    x_train = np.load('../data/lorenz_rmse/scale-5.0/x_train' + str(i) + '.npy')\n",
    "    x_dot = fourth_order_diff(x_train, dt)\n",
    "\n",
    "    feature_names = ['x', 'y', 'z']\n",
    "    # Instantiate and fit the SINDy model \n",
    "    library = ps.PolynomialLibrary(degree=3, include_bias=False)\n",
    "    optimizer = ps.SR3(\n",
    "        threshold=0.5, thresholder=\"l0\", max_iter=1000, normalize_columns=False, tol=1e-1\n",
    "    )\n",
    "    model = ps.SINDy(feature_names=feature_names, feature_library=library, optimizer=optimizer)\n",
    "    model.fit(x_train, x_dot=x_dot, t=dt, ensemble=True, quiet=True, n_models=500)\n",
    "    coefs = np.array(model.coef_list)\n",
    "    print(coefs.mean(0))\n",
    "    print('')\n",
    "    np.save(runs + \"esindy_5_\" + str(i) + \".npy\", coefs)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c41020db",
   "metadata": {},
   "source": [
    "# 10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b8a775ed",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "[[-8.31580903  9.10945166  0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.        ]\n",
      " [22.88980517  0.86458642  0.          0.          0.         -0.89905904\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.        ]\n",
      " [-1.18087005  0.7464302  -1.93146807 -0.06048827  0.99595913  0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.        ]]\n",
      "\n",
      "1\n",
      "[[-8.78244543e+00  9.50962610e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 1.24877440e+00  1.36207588e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00 -2.93729355e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [-1.39024206e-01  5.99145901e-01 -1.62076734e+00 -7.89969830e-02\n",
      "   9.51801380e-01  3.14798482e-03  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "2\n",
      "[[-9.24428121e+00  9.87298474e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [-1.25109823e+01  7.07411215e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00 -4.71695867e-02  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 5.30748660e-01 -3.15728015e-01 -2.15902882e+00 -7.40126370e-02\n",
      "   1.01397128e+00 -3.87595237e-03  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "3\n",
      "[[-8.53724116e+00  9.33779059e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 2.42875332e+01  7.16165793e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00 -9.29428838e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [-1.81469168e-02 -2.40539420e-01 -1.60042532e+00 -7.64402569e-02\n",
      "   9.34619136e-01 -5.96044743e-04  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "4\n",
      "[[-8.57482639e+00  9.37755688e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 5.14521387e+00  1.62786063e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00 -4.01535544e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 4.65471585e-01 -1.11895741e-01 -2.02354651e+00 -1.00864237e-02\n",
      "   9.42395855e-01 -1.56401640e-04  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "5\n",
      "[[-8.96923578  9.67131067  0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.        ]\n",
      " [20.46632784  1.11569     0.          0.          0.         -0.83356315\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.        ]\n",
      " [ 0.57504064 -0.43226402 -1.76709733 -0.08894874  0.97019143  0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.          0.          0.          0.          0.          0.\n",
      "   0.        ]]\n",
      "\n",
      "6\n",
      "[[-9.00724989e+00  9.70825095e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 5.29773162e+00  7.08874798e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00 -3.90284451e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [-5.39228130e-01 -2.34916229e-02 -1.90278784e+00 -1.03745619e-01\n",
      "   9.96764370e-01  2.73774552e-04  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "7\n",
      "[[-8.63823644e+00  9.49922363e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [-7.15676341e+00  3.07596228e+00 -1.35794994e-04  0.00000000e+00\n",
      "   0.00000000e+00 -1.00979874e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [-7.87037749e-02 -2.89402009e-01 -2.12893269e+00 -2.61563326e-02\n",
      "   9.67242851e-01  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "8\n",
      "[[-8.74029707e+00  9.48130895e+00 -1.25797663e-04  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [-7.03811970e+00  6.67517736e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00 -1.90819107e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 6.56010354e-01 -1.66699369e-01 -2.07344796e+00  3.75409293e-02\n",
      "   8.75598882e-01  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n",
      "9\n",
      "[[-8.46090741e+00  9.25700303e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 1.40249482e+01  4.12451054e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00 -6.29131036e-01  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]\n",
      " [ 9.49873001e-01 -2.38520137e-01 -1.84857114e+00 -6.32223380e-02\n",
      "   1.01457206e+00  3.08111599e-04  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00\n",
      "   0.00000000e+00  0.00000000e+00  0.00000000e+00]]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for i in range(10):\n",
    "    print(i)\n",
    "    seed = 700 + i\n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    \n",
    "    dt = 0.01\n",
    "    x_train = np.load('../data/lorenz_rmse/scale-10.0/x_train' + str(i) + '.npy')\n",
    "    x_dot = fourth_order_diff(x_train, dt)\n",
    "\n",
    "    feature_names = ['x', 'y', 'z']\n",
    "    # Instantiate and fit the SINDy model \n",
    "    library = ps.PolynomialLibrary(degree=3, include_bias=False)\n",
    "    optimizer = ps.SR3(\n",
    "        threshold=0.9, thresholder=\"l0\", max_iter=1000, normalize_columns=False, tol=1e-1\n",
    "    )\n",
    "    model = ps.SINDy(feature_names=feature_names, feature_library=library, optimizer=optimizer)\n",
    "    model.fit(x_train, x_dot=x_dot, t=dt, ensemble=True, quiet=True, n_models=500)\n",
    "    coefs = np.array(model.coef_list)\n",
    "    print(coefs.mean(0))\n",
    "    print('')\n",
    "    np.save(runs + \"esindy_10_\" + str(i) + \".npy\", coefs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "498c75f9",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6bcdfa3",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c40e18da",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8999a18e",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
