{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "bdf11d0b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append(\"../\")\n",
    "import os\n",
    "os.environ[\"CUDA_DEVICE_ORDER\"] = \"PCI_BUS_ID\"\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import numpy as np\n",
    "\n",
    "from HyperSINDy import Net\n",
    "from baseline import Trainer\n",
    "from library_utils import Library\n",
    "from Datasets import SyntheticDataset\n",
    "from other import init_weights, set_random_seed\n",
    "\n",
    "from exp_utils import get_equations, log_equations\n",
    "\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from tabulate import tabulate\n",
    "\n",
    "sns.set()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "a6c03784",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model(device, z_dim, poly_order, include_constant,\n",
    "               noise_dim, hidden_dim, stat_size, batch_size,\n",
    "               num_hidden, batch_norm, cp_path):\n",
    "\n",
    "    torch.cuda.set_device(device=device)\n",
    "    device = torch.cuda.current_device()\n",
    "\n",
    "    library = Library(n=z_dim, poly_order=poly_order, include_constant=include_constant)\n",
    "\n",
    "    net = Net(library, noise_dim=noise_dim, hidden_dim=hidden_dim,\n",
    "              statistic_batch_size=stat_size,\n",
    "              num_hidden=num_hidden, batch_norm=batch_norm).to(device)\n",
    "\n",
    "    cp = torch.load(cp_path, map_location=\"cuda:\" + str(device)) \n",
    "    net.load_state_dict(cp['model'])\n",
    "    net.to(device)\n",
    "    \n",
    "    return net, library, device\n",
    "\n",
    "def gather_data(coefs_mean, coefs_std, feature_names, nonzero, z_dim):\n",
    "    coefs_mean = np.round(coefs_mean, 2)\n",
    "    coefs_std = np.round(coefs_std, 2)\n",
    "    res = []\n",
    "    for i in range(z_dim):\n",
    "        cm, cs = coefs_mean[i], coefs_std[i]\n",
    "        curr_list = []\n",
    "        for j in range(len(feature_names)):\n",
    "            if nonzero[i][j]:\n",
    "                curr_list.append((feature_names[j], cm[j], cs[j])) # could also append cd[i][j]\n",
    "        res.append(curr_list)\n",
    "    return res\n",
    "\n",
    "def get_coef_stats(net, batch_size=1000, device=2):\n",
    "    coefs = net.get_masked_coefficients(batch_size=1000, device=device).detach().cpu().numpy()\n",
    "    coefs_t = np.transpose(coefs, (2, 1, 0))\n",
    "    nonzero = coefs_t.mean(2) != 0\n",
    "    coefs_mean, coefs_std = np.mean(coefs_t, 2), np.std(coefs_t, 2)\n",
    "    return coefs_mean, coefs_std, nonzero\n",
    "\n",
    "def build_table(net, batch_size=1000, device=2, print_file=False, print_fancy=True, print_latex=False):\n",
    "    set_random_seed(SEED)\n",
    "    coefs_mean, coefs_std, nonzero = get_coef_stats(net, batch_size, device)\n",
    "\n",
    "    feature_names = net.library.get_feature_names()\n",
    "\n",
    "    data = gather_data(coefs_mean, coefs_std, feature_names, nonzero, net.z_dim)\n",
    "    \n",
    "    eq_starts = [\"dx\" + str(i + 1) for i in range(z_dim)]\n",
    "    terms = np.array(['x' + str(i + 1) for i in range(z_dim)])\n",
    "    gts = {}\n",
    "    for i in range(len(eq_starts)):\n",
    "        curr_start = eq_starts[i]\n",
    "        eq_terms = [(8, \"\")]\n",
    "        eq_terms.append((-1, terms[i]))\n",
    "        eq_terms.append((1, terms[(i + 1) % z_dim] + terms[i - 1]))\n",
    "        eq_terms.append((-1, terms[i - 2] + terms[i - 1]))\n",
    "        gts[curr_start] = eq_terms\n",
    "            \n",
    "\n",
    "    table = []\n",
    "    ct1 = 0\n",
    "    eq_ct = 0\n",
    "    for eq in data:\n",
    "        ct = 0\n",
    "        curr_eq_start = eq_starts[eq_ct]\n",
    "        curr_true_eq = gts[curr_eq_start]\n",
    "        for term in eq:\n",
    "            curr_term = term[0]\n",
    "            true_coef = 0\n",
    "            for true_term in curr_true_eq:\n",
    "                if curr_term == true_term[1]:\n",
    "                    true_coef = true_term[0]\n",
    "                else:\n",
    "                    t_idx = np.char.find(curr_term, 'x', start=1)\n",
    "                    if (curr_term[t_idx:] + curr_term[0:t_idx]) == true_term[1]:\n",
    "                        true_coef = true_term[0]\n",
    "                    \n",
    "            if ct == 0:\n",
    "                row = [eq_starts[eq_ct], term[0], str(true_coef), term[1], term[2]]\n",
    "            else:\n",
    "                row = [\"\", term[0],  str(true_coef), term[1], term[2]]\n",
    "            table.append(row)\n",
    "            ct += 1\n",
    "\n",
    "        # loop through true terms to see if learned model missed any terms\n",
    "        for true_term in curr_true_eq:\n",
    "            found = False\n",
    "            for term in eq:\n",
    "                if term[0] == true_term[1]:\n",
    "                    found = True\n",
    "                else:\n",
    "                   # flip it to avoid case of x1x3 != x3x1\n",
    "                    t_idx = np.char.find(term[0], 'x', start=1)\n",
    "                    if (term[0][t_idx:] + term[0][0:t_idx]) == true_term[1]:\n",
    "                        found = True\n",
    "            if not found:\n",
    "                row = [\"\", true_term[1], str(true_term[0]), 0, 0]\n",
    "                table.append(row)\n",
    "\n",
    "        eq_ct += 1\n",
    "        ct1 += 1\n",
    "\n",
    "    headers = [\"EQUATION\", \"TERM\", \"TRUE\", \"MEAN\", \"STD\"]\n",
    "    fancy_table = tabulate(table, headers, tablefmt=\"fancy_outline\")\n",
    "    latex_table = tabulate(table, headers, tablefmt=\"latex\")\n",
    "    if print_fancy:\n",
    "        print(fancy_table)\n",
    "    if print_latex:\n",
    "        print(latex_table)\n",
    "    if print_file:\n",
    "        with open(\"table.txt\", \"w\") as f:\n",
    "            print(fancy_table, file=f)\n",
    "        with open(\"table_latex.txt\", \"w\") as f:\n",
    "            print(latex_table, file=f)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4da775ae",
   "metadata": {},
   "source": [
    "# Load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "976bca4a",
   "metadata": {},
   "outputs": [],
   "source": [
    "SEED = 5281998"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "da5db36a",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_folder = \"../data/\"\n",
    "model = \"HyperSINDy\"\n",
    "dt = 0.01\n",
    "hidden_dim = 128\n",
    "stat_size = 250\n",
    "num_hidden = 5\n",
    "z_dim = 10\n",
    "adam_reg = 1e-2\n",
    "gamma_factor = 0.999\n",
    "poly_order = 3\n",
    "include_constant = True\n",
    "device = 2\n",
    "batch_norm = False\n",
    "noise_dim = 20\n",
    "runs = \"../runs/lorenz96\"\n",
    "library = Library(n=z_dim, poly_order=poly_order, include_constant=include_constant)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4ab8f0a4",
   "metadata": {},
   "outputs": [],
   "source": [
    "net1, library, device = load_model(device, z_dim, poly_order, include_constant,\n",
    "                                  noise_dim, hidden_dim, stat_size, stat_size,\n",
    "                                  num_hidden, batch_norm, runs + \"/cp_1.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eff94f4b",
   "metadata": {},
   "source": [
    "# Table"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7e08799e",
   "metadata": {},
   "outputs": [],
   "source": [
    "eq1 = get_equations(net1, library, model, device, seed=SEED)\n",
    "all_eqs = [eq1]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9c021e53",
   "metadata": {},
   "source": [
    "# print equations in a latext friendly format"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "cadd22a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "eq_starts = [\"dx\" + str(i + 1) for i in range(z_dim)]\n",
    "terms = np.array(['x' + str(i + 1) for i in range(z_dim)])\n",
    "gts = {}\n",
    "for i in range(len(eq_starts)):\n",
    "    curr_start = eq_starts[i]\n",
    "    eq_terms = [(8, \"\")]\n",
    "    eq_terms.append((-1, terms[i]))\n",
    "    eq_terms.append((1, terms[(i + 1) % z_dim] + terms[i - 1]))\n",
    "    eq_terms.append((-1, terms[i - 2] + terms[i - 1]))\n",
    "    gts[curr_start] = eq_terms"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "99a10603",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'dx1': [(8, ''), (-1, 'x1'), (1, 'x2x10'), (-1, 'x9x10')],\n",
       " 'dx2': [(8, ''), (-1, 'x2'), (1, 'x3x1'), (-1, 'x10x1')],\n",
       " 'dx3': [(8, ''), (-1, 'x3'), (1, 'x4x2'), (-1, 'x1x2')],\n",
       " 'dx4': [(8, ''), (-1, 'x4'), (1, 'x5x3'), (-1, 'x2x3')],\n",
       " 'dx5': [(8, ''), (-1, 'x5'), (1, 'x6x4'), (-1, 'x3x4')],\n",
       " 'dx6': [(8, ''), (-1, 'x6'), (1, 'x7x5'), (-1, 'x4x5')],\n",
       " 'dx7': [(8, ''), (-1, 'x7'), (1, 'x8x6'), (-1, 'x5x6')],\n",
       " 'dx8': [(8, ''), (-1, 'x8'), (1, 'x9x7'), (-1, 'x6x7')],\n",
       " 'dx9': [(8, ''), (-1, 'x9'), (1, 'x10x8'), (-1, 'x7x8')],\n",
       " 'dx10': [(8, ''), (-1, 'x10'), (1, 'x1x9'), (-1, 'x8x9')]}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5c0c273e",
   "metadata": {},
   "outputs": [],
   "source": [
    "def reformat(eqs, gts, eq_starts, filename=None):\n",
    "    eq_ct = 0\n",
    "    for eq in eqs:\n",
    "        if eq == \"MEAN\":\n",
    "            continue\n",
    "        if eq == \"STD\":\n",
    "            eq_ct = 0\n",
    "            continue\n",
    "            \n",
    "        curr_eq_start = eq_starts[eq_ct]\n",
    "        curr_true_eq = gts[curr_eq_start]\n",
    "        \n",
    "        eq = eq.split(\" \")\n",
    "        \n",
    "        result = \"\"\n",
    "        for i in range(len(eq)):\n",
    "            curr_term = eq[i]\n",
    "            if curr_term[0:2] == \"dx\":\n",
    "                result += \"\\dot{x}_{\" + curr_term[2:] + \"}\"\n",
    "            elif curr_term == \"=\":\n",
    "                result += \" = \"\n",
    "            elif \"x\" not in curr_term and curr_term != \"+\":\n",
    "                result += curr_term + \" \"\n",
    "            elif curr_term == \"+\":\n",
    "                if eq[i + 1][0] == \"-\":\n",
    "                    result += \"- \"\n",
    "                    next_term = eq[i + 1][1:]\n",
    "                else:\n",
    "                    result += \"+ \"\n",
    "                    next_term = eq[i + 1]\n",
    "                next_term = next_term.split(\"x\")\n",
    "                coef = next_term[0]\n",
    "                result += coef\n",
    "                for j in range(1, len(next_term)):\n",
    "                    result += \"x_{\" + next_term[j] + \"}\"\n",
    "                result += \" \"\n",
    "                    \n",
    "        print(result)\n",
    "        if filename is not None:\n",
    "            print(result, file=filename)\n",
    "    print()\n",
    "    if filename is not None:\n",
    "        print(file=filename)  \n",
    "        \n",
    "        \n",
    "        eq_ct += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3915bc1f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\\dot{x}_{1} = 6.75 - 0.75x_{1} + 0.99x_{2}x_{10} - 0.99x_{9}x_{10} \n",
      "\\dot{x}_{2} = 8.46 - 0.74x_{2} + 0.98x_{1}x_{3} - 0.99x_{1}x_{10} \n",
      "\\dot{x}_{3} = 7.59 - 0.71x_{3} - 0.99x_{1}x_{2} + 0.98x_{2}x_{4} \n",
      "\\dot{x}_{4} = 7.89 - 0.77x_{4} - 0.97x_{2}x_{3} + 0.97x_{3}x_{5} \n",
      "\\dot{x}_{5} = 6.7 - 0.81x_{5} - 0.97x_{3}x_{4} + 0.97x_{4}x_{6} \n",
      "\\dot{x}_{6} = 7.86 - 0.76x_{6} - 0.99x_{4}x_{5} + 0.99x_{5}x_{7} \n",
      "\\dot{x}_{7} = 7.31 - 0.75x_{7} - 0.99x_{5}x_{6} + 0.98x_{6}x_{8} \n",
      "\\dot{x}_{8} = 7.56 - 0.76x_{8} - 0.98x_{6}x_{7} + 0.98x_{7}x_{9} \n",
      "\\dot{x}_{9} = 7.45 - 0.74x_{9} - 0.98x_{7}x_{8} + 0.99x_{8}x_{10} \n",
      "\\dot{x}_{10} = 6.49 - 0.75x_{10} + 0.99x_{1}x_{9} - 0.99x_{8}x_{9} \n",
      "\\dot{x}_{1} = 8.44 + 0.04x_{1} + 0.02x_{2}x_{10} + 0.01x_{9}x_{10} \n",
      "\\dot{x}_{2} = 7.98 + 0.05x_{2} + 0.02x_{1}x_{3} + 0.02x_{1}x_{10} \n",
      "\\dot{x}_{3} = 8.1 + 0.05x_{3} + 0.01x_{1}x_{2} + 0.02x_{2}x_{4} \n",
      "\\dot{x}_{4} = 7.72 + 0.04x_{4} + 0.01x_{2}x_{3} + 0.01x_{3}x_{5} \n",
      "\\dot{x}_{5} = 7.65 + 0.04x_{5} + 0.02x_{3}x_{4} + 0.01x_{4}x_{6} \n",
      "\\dot{x}_{6} = 8.47 + 0.05x_{6} + 0.03x_{4}x_{5} + 0.01x_{5}x_{7} \n",
      "\\dot{x}_{7} = 8.09 + 0.06x_{7} + 0.01x_{5}x_{6} + 0.02x_{6}x_{8} \n",
      "\\dot{x}_{8} = 7.89 + 0.03x_{8} + 0.02x_{6}x_{7} + 0.01x_{7}x_{9} \n",
      "\\dot{x}_{9} = 7.97 + 0.04x_{9} + 0.01x_{7}x_{8} + 0.02x_{8}x_{10} \n",
      "\\dot{x}_{10} = 7.83 + 0.03x_{10} + 0.01x_{1}x_{9} + 0.01x_{8}x_{9} \n",
      "\n"
     ]
    }
   ],
   "source": [
    "with open(\"../results/lorenz96.txt\", \"w\") as f:\n",
    "    for curr_eqs in all_eqs:\n",
    "        reformat(curr_eqs, gts, eq_starts, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e775c741",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b442a079",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
