{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6c0d40bb-cdf6-4a59-a233-55fff970045c",
   "metadata": {},
   "source": [
    "# Compare performance of continuous time models\n",
    "\n",
    "Please, install `easy_tpp` and put [the original ODE-RNN implementation](https://github.com/YuliaRubanova/latent_ode/tree/master) to the `lib` folder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9e45cd6e-8456-4bf0-bceb-98ab23d986b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "import torch\n",
    "from easy_tpp.model.torch_model.torch_baselayer import DNN\n",
    "from easy_tpp.utils import rk4_step_method\n",
    "from easy_tpp.model.torch_model.torch_nhp import ContTimeLSTMCell as EasyContTimeLSTMCell\n",
    "from easy_tpp.model.torch_model.torch_ode_tpp import NeuralODE as EasyODEGRUCell\n",
    "from hotpp.data import PaddedBatch\n",
    "from hotpp.nn.encoder.rnn.ctlstm import ContTimeLSTM as HotContTimeLSTM\n",
    "from hotpp.nn.encoder.rnn.ode import ODEGRU as HotODEGRU\n",
    "from lib.ode_func import ODEFunc\n",
    "from lib.ode_rnn import ODE_RNN as LatODEGRU\n",
    "from lib.diffeq_solver import DiffeqSolver\n",
    "from lib.utils import create_net\n",
    "\n",
    "def measure(model, n_trials=100):\n",
    "    model.cuda()\n",
    "    model.eval()\n",
    "    x = torch.randn(64, 100, 64).cuda()\n",
    "    dt = torch.rand(64, 100).cuda()\n",
    "    torch.cuda.synchronize()\n",
    "    start = time.time()\n",
    "    with torch.no_grad():\n",
    "        for _ in range(n_trials):\n",
    "            out = model(x, dt)\n",
    "    torch.cuda.synchronize()\n",
    "    print(\"Forward\", (time.time() - start) / n_trials)\n",
    "    start = time.time()\n",
    "    for _ in range(n_trials):\n",
    "        model(x, dt)[0].mean().backward()\n",
    "    torch.cuda.synchronize()\n",
    "    print(\"FW / BW\", (time.time() - start) / n_trials)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "37b73a11-f594-48a9-8475-608486c5d382",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "OUR CT-LSTM\n",
      "Forward 0.0024916505813598632\n",
      "FW / BW 0.027592058181762694\n",
      "EesyTPP CT-LSTM\n",
      "Forward 0.015844099521636963\n",
      "FW / BW 0.05194777488708496\n"
     ]
    }
   ],
   "source": [
    "class EasyContTimeLSTM(torch.nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super().__init__()\n",
    "        self.hidden_size = hidden_size\n",
    "        self.cell = EasyContTimeLSTMCell(hidden_size)\n",
    "\n",
    "    def forward(self, x, dt):\n",
    "        all_hiddens = []\n",
    "        all_outputs = []\n",
    "        all_cells = []\n",
    "        all_cell_bars = []\n",
    "        all_decays = []\n",
    "\n",
    "        h_t, c_t, c_bar_i = torch.zeros(len(x),\n",
    "                                        3 * self.hidden_size,\n",
    "                                        device=x.device).chunk(3, dim=1)\n",
    "        b, l, d = x.shape\n",
    "        for i in range(l):\n",
    "            cell_i, c_bar_i, decay_i, output_i = self.cell(x[:, i], h_t, c_t, c_bar_i)\n",
    "            c_t, h_t = self.cell.decay(cell_i, c_bar_i, decay_i, output_i, dt[:, i:i + 1])\n",
    "            all_outputs.append(output_i)\n",
    "            all_decays.append(decay_i)\n",
    "            all_cells.append(cell_i)\n",
    "            all_cell_bars.append(c_bar_i)\n",
    "            all_hiddens.append(h_t)\n",
    "        cell_stack = torch.stack(all_cells, dim=1)\n",
    "        cell_bar_stack = torch.stack(all_cell_bars, dim=1)\n",
    "        decay_stack = torch.stack(all_decays, dim=1)\n",
    "        output_stack = torch.stack(all_outputs, dim=1)\n",
    "\n",
    "        # [batch_size, max_seq_length, hidden_dim]\n",
    "        hiddens_stack = torch.stack(all_hiddens, dim=1)\n",
    "\n",
    "        # [batch_size, max_seq_length, 4, hidden_dim]\n",
    "        decay_states_stack = torch.stack((cell_stack,\n",
    "                                          cell_bar_stack,\n",
    "                                          decay_stack,\n",
    "                                          output_stack),\n",
    "                                         dim=2)\n",
    "\n",
    "        return hiddens_stack, decay_states_stack\n",
    "\n",
    "class ContTimeLSTM(torch.nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super().__init__()\n",
    "        self.model = HotContTimeLSTM(hidden_size, hidden_size)\n",
    "\n",
    "    def forward(self, x, dt):\n",
    "        state = self.model.init_state.repeat(len(x), 1)\n",
    "        return self.model._forward_loop(self.model.cell.preprocess(x), dt, state)\n",
    "\n",
    "print(\"OUR CT-LSTM\")\n",
    "measure(ContTimeLSTM(64))\n",
    "\n",
    "print(\"EesyTPP CT-LSTM\")\n",
    "measure(EasyContTimeLSTM(64))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "b449d5c8-abe1-4c36-bd64-15fdaf2a8b63",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Our ODE-RNN\n",
      "Forward 0.009947264194488525\n",
      "FW / BW 0.04754711389541626\n",
      "Orig ODE-RNN\n",
      "Forward 0.08002629280090331\n",
      "FW / BW 0.16193607568740845\n"
     ]
    }
   ],
   "source": [
    "class OrigODEGRU(LatODEGRU):\n",
    "    def __init__(self, hidden_size):\n",
    "        rec_ode_func = ODEFunc(\n",
    "\t\t\tinput_dim = hidden_size, \n",
    "\t\t\tlatent_dim = hidden_size,\n",
    "\t\t\tode_func_net = create_net(hidden_size, hidden_size, n_layers=0, n_units=hidden_size),\n",
    "\t\t\tdevice = \"cuda\").to(\"cuda\")\n",
    "\n",
    "        solver = DiffeqSolver(hidden_size, rec_ode_func, \"rk4\", hidden_size, \n",
    "\t\t\todeint_rtol = 1e-3, odeint_atol = 1e-4, device = \"cuda\")\n",
    "\n",
    "        super().__init__(hidden_size, hidden_size, n_gru_units=hidden_size,\n",
    "                         z0_diffeq_solver=solver)\n",
    "        \n",
    "    def forward(self, x, dt):\n",
    "        mask = torch.full_like(x, True, dtype=torch.bool, device=x.device)\n",
    "        pred_x, info = self.get_reconstruction(dt[0], x, dt[0], mask=mask, mode=None)\n",
    "        return pred_x\n",
    "\n",
    "\n",
    "class ODEGRU(torch.nn.Module):\n",
    "    def __init__(self, hidden_size):\n",
    "        super().__init__()\n",
    "        self.model = HotODEGRU(hidden_size, hidden_size, diff_hidden_size=hidden_size)\n",
    "\n",
    "    def forward(self, x, dt):\n",
    "        state = self.model.init_state.repeat(len(x), 1)\n",
    "        return self.model._forward_loop(x, dt, state)\n",
    "\n",
    "print(\"Our ODE-RNN\")\n",
    "measure(ODEGRU(64))\n",
    "\n",
    "print(\"Orig ODE-RNN\")\n",
    "measure(OrigODEGRU(64))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
