{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6be5e9ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "import numpy as np\n",
    "import seaborn as sns\n",
    "import random\n",
    "\n",
    "import pysindy as ps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "440a70a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def fourth_order_diff(x, dt):\n",
    "    dx = np.zeros([x.shape[0], x.shape[1]])\n",
    "    dx[0] = (-11.0 / 6) * x[0] + 3 * x[1] - 1.5 * x[2] + x[3] / 3\n",
    "    dx[1] = (-11.0 / 6) * x[1] + 3 * x[2] - 1.5 * x[3] + x[4] / 3\n",
    "    dx[2:-2] = (-1.0 / 12) * x[4:] + (2.0 / 3) * x[3:-1] - (2.0 / 3) * x[1:-3] + (1.0 / 12) * x[:-4]\n",
    "    dx[-2] = (11.0 / 6) * x[-2] - 3.0 * x[-3] + 1.5 * x[-4] - x[-5] / 3.0\n",
    "    dx[-1] = (11.0 / 6) * x[-1] - 3.0 * x[-2] + 1.5 * x[-3] - x[-4] / 3.0\n",
    "    return dx / dt\n",
    "\n",
    "def sample_trajectory(x0, coefs, library, timesteps, dt, batch_size):\n",
    "    coefs = np.transpose(coefs, (0, 2, 1))\n",
    "    xs = []\n",
    "    curr = np.array([x0 for i in range(batch_size)])\n",
    "    for i in range(timesteps):\n",
    "        curr_lib = library.transform(curr).reshape(10, 1, 9)\n",
    "        coef_idx = np.random.randint(0, len(coefs), batch_size)\n",
    "        curr_coefs = coefs[coef_idx]\n",
    "        dx = np.matmul(curr_lib, curr_coefs).squeeze(1)\n",
    "        curr = curr + dx * dt\n",
    "        xs.append(curr)\n",
    "    xs = np.array(xs)\n",
    "    return np.transpose(xs, (1, 0, 2))\n",
    "\n",
    "def sample_trajectory2(x0, coefs, library, timesteps, dt, batch_size):\n",
    "    coefs = np.transpose(coefs, (0, 2, 1))\n",
    "    coefs_mean, coefs_std = coefs.mean(0), coefs.std(0)\n",
    "    coefs_mean = np.array([coefs_mean for _ in range(batch_size)])\n",
    "    coefs_std = np.array([coefs_std for _ in range(batch_size)])\n",
    "    xs = []\n",
    "    curr = np.array([x0 for _ in range(batch_size)])\n",
    "    for i in range(timesteps):\n",
    "        curr_lib = library.transform(curr).reshape(10, 1, 9)\n",
    "        noise = np.random.normal(0, 1, (batch_size, coefs.shape[1], coefs.shape[2]))\n",
    "        curr_coefs = coefs_mean + coefs_std * noise\n",
    "        dx = np.matmul(curr_lib, curr_coefs).squeeze(1)\n",
    "        curr = curr + dx * dt\n",
    "        xs.append(curr)\n",
    "    xs = np.array(xs)\n",
    "    return np.transpose(xs, (1, 0, 2))\n",
    "\n",
    "def plot_samples(xs, samples, num_samples=4, dpi=300, figsize=None, filename=None):\n",
    "    sns.set()\n",
    "\n",
    "    # https://dawes.wordpress.com/2014/06/27/publication-ready-3d-figures-from-matplotlib/\n",
    "    # fig = plt.figure(figsize=(batch_size + 1, 3.5), dpi=300)\n",
    "    if figsize is not None:\n",
    "        fig = plt.figure(figsize=figsize, dpi=dpi)\n",
    "    else:\n",
    "        fig = plt.figure(dpi=dpi)\n",
    "    fig.tight_layout()\n",
    "    ct = 0\n",
    "    for i in range(num_samples):\n",
    "        ax = fig.add_subplot(1, num_samples, ct + 1, projection='3d')\n",
    "        if i == 0:\n",
    "            ax.plot(xs[:, 0], xs[:, 1], xs[:,2], color='red')\n",
    "        else:\n",
    "            ax.plot(samples[i][:, 0], samples[i][:, 1], samples[i][:,2], color='blue')\n",
    "        ct += 1\n",
    "\n",
    "        ax.grid(False)\n",
    "        color_tuple = (1.0, 1.0, 1.0, 0.0)\n",
    "\n",
    "        ax.xaxis.set_pane_color(color_tuple)\n",
    "        ax.yaxis.set_pane_color(color_tuple)\n",
    "        ax.zaxis.set_pane_color(color_tuple)\n",
    "        ax.xaxis.line.set_color(color_tuple)\n",
    "        ax.yaxis.line.set_color(color_tuple)\n",
    "        ax.zaxis.line.set_color(color_tuple)\n",
    "\n",
    "        ax.set_xticks([])\n",
    "        ax.set_yticks([])\n",
    "        ax.set_zticks([])\n",
    "\n",
    "    plt.subplots_adjust(wspace=0)\n",
    "    \n",
    "    if filename is not None:\n",
    "        plt.savefig(filename)\n",
    "        plt.close()\n",
    "    plt.show()\n",
    "    plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "50ea76ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import random\n",
    "import numpy as np\n",
    "\n",
    "from HyperSINDy import Net\n",
    "from baseline import Trainer\n",
    "from library_utils import Library\n",
    "from Datasets import SyntheticDataset\n",
    "from other import init_weights, set_random_seed\n",
    "\n",
    "\"\"\"\n",
    "Train HyperSINDy on lorenz\n",
    "\"\"\"\n",
    "\n",
    "def pipeline(library, trainset, epochs, lr,\n",
    "             lmda_init, lmda_max, lmda_max_epoch, lmda_spike, lmda_spike_epoch,\n",
    "             beta_init, beta_max, beta_max_epoch, beta_spike, beta_spike_epoch,\n",
    "             adam_reg, gamma_factor, batch_size,\n",
    "             thresh_interval, eval_interval, hard_thresh,\n",
    "             run_name, runs, noise_dim, hidden_dim, stat_size, device,\n",
    "             num_hidden, batch_norm):\n",
    "    print(run_name)\n",
    "\n",
    "    torch.cuda.set_device(device=device)\n",
    "    device = torch.cuda.current_device()\n",
    "    net = Net(library, noise_dim=noise_dim, hidden_dim=hidden_dim,\n",
    "              statistic_batch_size=stat_size, num_hidden=num_hidden,\n",
    "              batch_norm=batch_norm).to(device)\n",
    "    net.apply(init_weights)\n",
    "\n",
    "    trainer = Trainer(net, library, runs + run_name, runs + \"cp_\" + run_name + \".pt\",\n",
    "                      beta_init=beta_init, beta_max=beta_max, beta_max_epoch=beta_max_epoch, \n",
    "                      beta_spike=beta_spike, beta_spike_epoch=beta_spike_epoch,\n",
    "                      lmda_init=lmda_init, lmda_max=lmda_max, lmda_max_epoch=lmda_max_epoch,\n",
    "                      lmda_spike=lmda_spike, lmda_spike_epoch=lmda_spike_epoch,\n",
    "                      learning_rate=lr, adam_reg=adam_reg, gamma_factor=gamma_factor,\n",
    "                      epochs=epochs, batch_size=batch_size, device=device,\n",
    "                      hard_threshold=hard_thresh, threshold_interval=thresh_interval,\n",
    "                      eval_interval=eval_interval)\n",
    "    trainer.train(trainset)\n",
    "\n",
    "\n",
    "def load_data(library, data_folder, dataset, t, dt, model, end):\n",
    "    x = np.load(data_folder + dataset + \"/x_train\" + str(end) + \".npy\")\n",
    "    if t is not None:\n",
    "        t = np.load(data_folder + dataset + \"/x_ts.npy\")\n",
    "    return SyntheticDataset(x=x, t=t, library=library, dataset=dataset, dt=dt, model=model)\n",
    "\n",
    "# Globals\n",
    "data_folder = \"../data/\"\n",
    "model = \"HyperSINDy\"\n",
    "dt = 0.01\n",
    "hidden_dim = 64\n",
    "stat_size = 250\n",
    "num_hidden = 5\n",
    "x_dim = 3\n",
    "adam_reg = 1e-2\n",
    "gamma_factor = 0.999\n",
    "poly_order = 3\n",
    "include_constant = True\n",
    "runs = \"../runs/lorenz/\"\n",
    "library = Library(n=x_dim, poly_order=poly_order, include_constant=include_constant)\n",
    "t = None\n",
    "seed = 0\n",
    "\n",
    "def load_model(device, z_dim, poly_order, include_constant,\n",
    "               noise_dim, hidden_dim, stat_size, batch_size,\n",
    "               cp_path):\n",
    "\n",
    "    torch.cuda.set_device(device=device)\n",
    "    device = torch.cuda.current_device()\n",
    "\n",
    "    library = Library(n=z_dim, poly_order=poly_order, include_constant=include_constant)\n",
    "\n",
    "    net = Net(library, noise_dim=noise_dim, hidden_dim=hidden_dim,\n",
    "              statistic_batch_size=stat_size, num_hidden=num_hidden).to(device)\n",
    "    net.get_masked_coefficients(batch_size=batch_size, device=device)\n",
    "\n",
    "    cp = torch.load(cp_path, map_location=\"cuda:\" + str(device)) \n",
    "    net.load_state_dict(cp['model'])\n",
    "    net.to(device)\n",
    "    net = net.eval()\n",
    "    \n",
    "    return net, library, device\n",
    "\n",
    "device = 1\n",
    "noise_dim = 6\n",
    "batch_size = 500"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "45bd0554",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "runs = \"../runs/rossler_rmse/\"\n",
    "if not os.path.isdir(runs):\n",
    "    os.makedirs(runs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "bb6d87a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "e1 = [np.transpose(np.load(runs + \"esindy_1_\" + str(i) + \".npy\"), (0, 2, 1)) for i in range(10)]\n",
    "e2 = [np.transpose(np.load(runs + \"esindy_5_\" + str(i) + \".npy\"), (0, 2, 1)) for i in range(10)]\n",
    "e3 = [np.transpose(np.load(runs + \"esindy_10_\" + str(i) + \".npy\"), (0, 2, 1)) for i in range(10)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "61ae3171",
   "metadata": {},
   "outputs": [],
   "source": [
    "h1 = [load_model(device, x_dim, poly_order, include_constant, noise_dim,\n",
    "                     hidden_dim, stat_size, batch_size, runs + \"cp_1_\" + str(i) + \".pt\")[0] for i in range(10)]\n",
    "h2 = [load_model(device, x_dim, poly_order, include_constant, noise_dim,\n",
    "                     hidden_dim, stat_size, batch_size, runs + \"cp_2_\" + str(i) + \".pt\")[0] for i in range(10)]\n",
    "h3 = [load_model(device, x_dim, poly_order, include_constant, noise_dim,\n",
    "                     hidden_dim, stat_size, batch_size, runs + \"cp_3_\" + str(i) + \".pt\")[0] for i in range(10)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6aaa57c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = 28345790\n",
    "set_random_seed(seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a508ee1d",
   "metadata": {},
   "outputs": [],
   "source": [
    "h1 = [net.get_masked_coefficients(batch_size=500, device=device).detach().cpu().numpy() for net in h1]\n",
    "h2 = [net.get_masked_coefficients(batch_size=500, device=device).detach().cpu().numpy() for net in h2]\n",
    "h3 = [net.get_masked_coefficients(batch_size=500, device=device).detach().cpu().numpy() for net in h3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "8b18acc2",
   "metadata": {},
   "outputs": [],
   "source": [
    "def derivative(x, y, z, a, b, c, t, scale):\n",
    "    # sample coefficients\n",
    "    a = np.random.normal(a, min(1.0, scale))\n",
    "    b = np.random.normal(b, min(1.0, scale))\n",
    "    c = np.random.normal(c, scale)\n",
    "\n",
    "    # derivative\n",
    "    x_dot = -y - z\n",
    "    y_dot = x + a * y\n",
    "    z_dot = b + z * (x - c)\n",
    "    \n",
    "    return np.array([x_dot, y_dot, z_dot]), np.array([a, b, c])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "a6bcdfa3",
   "metadata": {},
   "outputs": [],
   "source": [
    "gtm = np.zeros([20, 3])\n",
    "gtm[2][0] = -1\n",
    "gtm[3][0] = -1\n",
    "gtm[1][1] = 1\n",
    "gtm[2][1] = 0.2\n",
    "gtm[0][2] = 0.2\n",
    "gtm[3][2] = -5.7\n",
    "gtm[6][2] = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "08403218",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[ 0. ,  0. ,  0.2],\n",
       "       [ 0. ,  1. ,  0. ],\n",
       "       [-1. ,  0.2,  0. ],\n",
       "       [-1. ,  0. , -5.7],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  1. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ],\n",
       "       [ 0. ,  0. ,  0. ]])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gtm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "76fc4484",
   "metadata": {},
   "outputs": [],
   "source": [
    "gts1 = np.zeros([20, 3])\n",
    "gts1[2][1] = 1.0\n",
    "gts1[0][2] = 1.0\n",
    "gts1[3][2] = 1.0\n",
    "\n",
    "gts2 = np.zeros([20, 3])\n",
    "gts2[2][1] = 5.0\n",
    "gts2[0][2] = 5.0\n",
    "gts2[3][2] = 5.0\n",
    "\n",
    "gts3 = np.zeros([20, 3])\n",
    "gts3[2][1] = 10.0\n",
    "gts3[0][2] = 10.0\n",
    "gts3[3][2] = 10.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "1ac9f508",
   "metadata": {},
   "outputs": [],
   "source": [
    "es = np.array([e1, e2, e3])\n",
    "hs = np.array([h1, h2, h3])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "5ea1f7cf",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 10, 500, 20, 3)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "es.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b6868bf4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 10, 500, 20, 3)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hs.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "21afabab",
   "metadata": {},
   "outputs": [],
   "source": [
    "def rmse(pred, truth):\n",
    "    return np.sqrt(np.sum(np.square(pred - truth), (1, 2))) / np.sqrt(np.sum(np.square(truth)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "bcbaf91c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "SIGMA : 1\n",
      "RMSE of MEAN Coefficient, HYPERSINDY : 0.029 +- 0.035\n",
      "RMSE of MEAN Coefficient, ESINDY     : 0.034 +- 0.012\n",
      "RMSE of STD  Coefficient, HYPERSINDY : 0.828 +- 0.059\n",
      "RMSE of STD  Coefficient, ESINDY     : 0.919 +- 0.015\n",
      "SIGMA : 5\n",
      "RMSE of MEAN Coefficient, HYPERSINDY : 0.086 +- 0.047\n",
      "RMSE of MEAN Coefficient, ESINDY     : 0.131 +- 0.057\n",
      "RMSE of STD  Coefficient, HYPERSINDY : 0.807 +- 0.012\n",
      "RMSE of STD  Coefficient, ESINDY     : 0.883 +- 0.025\n",
      "SIGMA : 10\n",
      "RMSE of MEAN Coefficient, HYPERSINDY : 0.228 +- 0.138\n",
      "RMSE of MEAN Coefficient, ESINDY     : 0.221 +- 0.147\n",
      "RMSE of STD  Coefficient, HYPERSINDY : 0.812 +- 0.014\n",
      "RMSE of STD  Coefficient, ESINDY     : 0.903 +- 0.022\n"
     ]
    }
   ],
   "source": [
    "sigma = [1, 5, 10]\n",
    "gt_stds = [gts1, gts2, gts3]\n",
    "for i in range(3):\n",
    "    print(\"SIGMA : \" + str(sigma[i]))\n",
    "    e, h = es[i], hs[i] # 10 x 500 x 19 x 3\n",
    "    curr_gts = gt_stds[i]\n",
    "\n",
    "    mean_e, std_e = e.mean(1), e.std(1) # 10 x 19 x 3\n",
    "    mean_h, std_h = h.mean(1), h.std(1)\n",
    "\n",
    "    mean_e, std_e = rmse(mean_e, gtm), rmse(std_e, curr_gts) # 10\n",
    "    mean_h, std_h = rmse(mean_h, gtm), rmse(std_h, curr_gts)\n",
    "\n",
    "    mean_mean_e, mean_std_e = mean_e.mean(), std_e.mean() # ()\n",
    "    mean_mean_h, mean_std_h = mean_h.mean(), std_h.mean()\n",
    "    \n",
    "    std_mean_e, std_std_e = mean_e.std(), std_e.std() # ()\n",
    "    std_mean_h, std_std_h = mean_h.std(), std_h.std()\n",
    "    \n",
    "    print(\"RMSE of MEAN Coefficient, HYPERSINDY : \" + str(np.round(mean_mean_h, 3)) + \" +- \" + str(np.round(std_mean_h, 3)))\n",
    "    print(\"RMSE of MEAN Coefficient, ESINDY     : \" + str(np.round(mean_mean_e, 3)) + \" +- \" + str(np.round(std_mean_e, 3)))\n",
    "    print(\"RMSE of STD  Coefficient, HYPERSINDY : \" + str(np.round(mean_std_h, 3)) + \" +- \" + str(np.round(std_std_h, 3)))\n",
    "    print(\"RMSE of STD  Coefficient, ESINDY     : \" + str(np.round(mean_std_e, 3)) + \" +- \" + str(np.round(std_std_e, 3)))\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dfd29e5",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
