{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fe2e5863",
   "metadata": {},
   "source": [
    "## Analysis of the neural ODE on the CCT and CED benchmarks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b47b62d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import enum\n",
    "import os\n",
    "import argparse\n",
    "import logging\n",
    "import time\n",
    "import numpy as np\n",
    "import numpy.random as npr\n",
    "import matplotlib\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from torchdiffeq import odeint #includes exogenous inputs\n",
    "\n",
    "class LatentODEfunc(nn.Module):\n",
    "\n",
    "    def __init__(self, latent_dim=4, nhidden=20, udim=None):\n",
    "        super(LatentODEfunc, self).__init__()\n",
    "        self.elu = nn.ELU(inplace=True)\n",
    "        udim_val = 0 if udim is None else udim\n",
    "        self.fc1 = nn.Linear(latent_dim + udim_val, nhidden)\n",
    "        self.fc2 = nn.Linear(nhidden, nhidden)\n",
    "        self.fc3 = nn.Linear(nhidden, latent_dim)\n",
    "        self.nfe = 0\n",
    "\n",
    "    def forward(self, t, x, u=None):\n",
    "        self.nfe += 1\n",
    "        if u is not None: #append the input to the state\n",
    "            x = torch.cat([x,u[:,None] if u.ndim==1 else u],dim=1)\n",
    "        out = self.fc1(x)\n",
    "        out = self.elu(out)\n",
    "        out = self.fc2(out)\n",
    "        out = self.elu(out)\n",
    "        out = self.fc3(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class RecognitionRNN(nn.Module):\n",
    "\n",
    "    def __init__(self, latent_dim=4, obs_dim=2, nhidden=25, nbatch=1):\n",
    "        super(RecognitionRNN, self).__init__()\n",
    "        self.nhidden = nhidden\n",
    "        self.nbatch = nbatch\n",
    "        self.i2h = nn.Linear(obs_dim + nhidden, nhidden)\n",
    "        self.h2o = nn.Linear(nhidden, latent_dim * 2)\n",
    "\n",
    "    def forward(self, x, h):\n",
    "        combined = torch.cat((x, h), dim=1)\n",
    "        h = torch.tanh(self.i2h(combined))\n",
    "        out = self.h2o(h)\n",
    "        return out, h\n",
    "\n",
    "    def initHidden(self):\n",
    "        return torch.zeros(self.nbatch, self.nhidden)\n",
    "\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "\n",
    "    def __init__(self, latent_dim=4, obs_dim=2, nhidden=20):\n",
    "        super(Decoder, self).__init__()\n",
    "        self.relu = nn.ReLU(inplace=True)\n",
    "        self.fc1 = nn.Linear(latent_dim, nhidden)\n",
    "        self.fc2 = nn.Linear(nhidden, obs_dim)\n",
    "\n",
    "    def forward(self, z):\n",
    "        out = self.fc1(z)\n",
    "        out = self.relu(out)\n",
    "        out = self.fc2(out)\n",
    "        return out\n",
    "\n",
    "\n",
    "class RunningAverageMeter(object):\n",
    "    \"\"\"Computes and stores the average and current value\"\"\"\n",
    "\n",
    "    def __init__(self, momentum=0.99):\n",
    "        self.momentum = momentum\n",
    "        self.reset()\n",
    "\n",
    "    def reset(self):\n",
    "        self.val = None\n",
    "        self.avg = 0\n",
    "\n",
    "    def update(self, val):\n",
    "        if self.val is None:\n",
    "            self.avg = val\n",
    "        else:\n",
    "            self.avg = self.avg * self.momentum + val * (1 - self.momentum)\n",
    "        self.val = val\n",
    "\n",
    "\n",
    "def log_normal_pdf(x, mean, logvar):\n",
    "    const = torch.from_numpy(np.array([2. * np.pi])).float().to(x.device)\n",
    "    const = torch.log(const)\n",
    "    return -.5 * (const + logvar + (x - mean) ** 2. / torch.exp(logvar))\n",
    "\n",
    "\n",
    "def normal_kl(mu1, lv1, mu2, lv2):\n",
    "    v1 = torch.exp(lv1)\n",
    "    v2 = torch.exp(lv2)\n",
    "    lstd1 = lv1 / 2.\n",
    "    lstd2 = lv2 / 2.\n",
    "\n",
    "    kl = lstd2 - lstd1 + ((v1 + (mu1 - mu2) ** 2.) / (2. * v2)) - .5\n",
    "    return kl\n",
    "\n",
    "def get_train_val_test(dataset):\n",
    "    import deepSI\n",
    "    from deepSI import System_data_list\n",
    "    if dataset=='CED':\n",
    "        data_full = deepSI.datasets.CED()\n",
    "        train = System_data_list([data_i[:300] for data_i in data_full])\n",
    "        test = System_data_list([data_i[300:] for data_i in data_full])\n",
    "        val = System_data_list([t[:100] for t in test])\n",
    "    elif dataset=='CCT':\n",
    "        train, test = deepSI.datasets.Cascaded_Tanks()\n",
    "        val, test = test[:len(test)//2], test\n",
    "    return train, val, test\n",
    "\n",
    "def get_torch_data(dataset, device=torch.device('cpu')):\n",
    "    train, val, test = get_train_val_test(dataset)\n",
    "    from deepSI.system_data import System_data_norm\n",
    "    norm = System_data_norm() #normalization\n",
    "    norm.fit(train)\n",
    "    train, val, test = [norm.transform(t) for t in [train, val, test]]\n",
    "\n",
    "    if dataset=='CCT':\n",
    "        samp_trajs = train.y[None,:,None]\n",
    "        u_samp_trajs = train.u[None,:,None]\n",
    "        orig_trajs = test.y[None,:,None]\n",
    "        u_orig_trajs = test.u[None,:,None]\n",
    "        val_trajs = val.y[None,:,None]\n",
    "        u_val_trajs = val.u[None,:,None]\n",
    "\n",
    "        sample_time = 4. #seconds\n",
    "        samp_ts = np.arange(len(train.y))*sample_time\n",
    "\n",
    "    elif dataset=='CED':\n",
    "        samp_trajs = np.array([ti.y[:,None] for ti in train])\n",
    "        u_samp_trajs =  np.array([ti.u for ti in train])\n",
    "        orig_trajs = np.array([ti.y[:,None] for ti in test])\n",
    "        u_orig_trajs = np.array([ti.u for ti in test])\n",
    "        val_trajs = np.array([ti.y[:,None] for ti in val])\n",
    "        u_val_trajs = np.array([ti.u[:,None] for ti in val])\n",
    "\n",
    "        sample_time = 1/50 #seconds\n",
    "        samp_ts = np.arange(len(train[0]))*sample_time\n",
    "\n",
    "    orig_trajs = torch.from_numpy(orig_trajs).float().to(device) #samples, time\n",
    "    u_orig_trajs = torch.from_numpy(u_orig_trajs).float().to(device)\n",
    "    samp_trajs = torch.from_numpy(samp_trajs).float().to(device)\n",
    "    u_samp_trajs = torch.from_numpy(u_samp_trajs).float().to(device)\n",
    "    samp_ts = torch.from_numpy(samp_ts).float().to(device)\n",
    "    u_val_trajs = torch.from_numpy(u_val_trajs).float().to(device)\n",
    "    val_trajs = torch.from_numpy(val_trajs).float().to(device)\n",
    "    \n",
    "    return orig_trajs, u_orig_trajs, samp_trajs, u_samp_trajs, samp_ts, \\\n",
    "            u_val_trajs, val_trajs, sample_time, norm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cc46c7dc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ckpt-best-CCT-1.pth\n",
      "RMS= 0.28936187906794736\n",
      "ckpt-best-CCT-2.pth\n",
      "RMS= 0.31364721722852457\n",
      "ckpt-best-CCT-3.pth\n",
      "RMS= 0.2067198689884109\n",
      "ckpt-best-CCT-4.pth\n",
      "RMS= 0.6541493821647515\n",
      "ckpt-best-CCT-5.pth\n",
      "RMS= 0.3882140760448425\n",
      "ckpt-best-CCT-6.pth\n",
      "RMS= 0.23907775720886332\n",
      "ckpt-best-CCT-7.pth\n",
      "RMS= 0.2273063909022871\n",
      "ckpt-best-CCT-8.pth\n",
      "RMS= 0.2549915661111742\n",
      "ckpt-best-CCT-9.pth\n",
      "RMS= 0.3344552356823532\n",
      "ckpt-best-CCT-10.pth\n",
      "RMS= 0.30742240312827346\n",
      "ckpt-best-CCT-11.pth\n",
      "RMS= 0.2454621261155722\n",
      "ckpt-best-CCT-12.pth\n",
      "RMS= 0.31688610899904324\n",
      "ckpt-best-CCT-13.pth\n",
      "RMS= 0.3166067180697689\n",
      "ckpt-best-CCT-14.pth\n",
      "RMS= 0.4938528896709381\n",
      "ckpt-best-CCT-15.pth\n",
      "RMS= 0.22484235246757336\n",
      "ckpt-best-CCT-101.pth\n",
      "RMS= 0.23857751407904534\n",
      "ckpt-best-CCT-102.pth\n",
      "RMS= 0.9365915276210175\n",
      "ckpt-best-CCT-103.pth\n",
      "RMS= 0.30130051087175175\n",
      "ckpt-best-CCT-104.pth\n",
      "RMS= 0.3585485406642151\n",
      "ckpt-best-CCT-106.pth\n",
      "RMS= 0.17933939666348714\n",
      "ckpt-best-CCT-107.pth\n",
      "RMS= 0.2810473376241484\n",
      "ckpt-best-CCT-108.pth\n",
      "RMS= 0.2886307006668135\n",
      "ckpt-best-CCT-109.pth\n",
      "RMS= 0.2933428677752006\n",
      "ckpt-best-CCT-110.pth\n",
      "RMS= 0.2913026908201021\n",
      "\n",
      "###########################\n",
      "###########################\n",
      "RMS results CCT\n",
      "min= 0.17933939666348714 mean= 0.33256987744317107 len= 24\n",
      "\n",
      "###########################\n",
      "###########################\n",
      "ckpt-best-CED-1.pth\n",
      "RMS set1= 0.15298568506817745\n",
      "RMS set2= 0.08794612486194658\n",
      "ckpt-best-CED-2.pth\n",
      "RMS set1= 0.25002118208531054\n",
      "RMS set2= 0.21785693372146184\n",
      "ckpt-best-CED-3.pth\n",
      "RMS set1= 0.21134203469257207\n",
      "RMS set2= 0.18343509203666722\n",
      "ckpt-best-CED-4.pth\n",
      "RMS set1= 0.14682859175067475\n",
      "RMS set2= 0.10180256589445989\n",
      "ckpt-best-CED-5.pth\n",
      "RMS set1= 0.24021221845016258\n",
      "RMS set2= 0.1692937601242141\n",
      "ckpt-best-CED-6.pth\n",
      "RMS set1= 0.3156902815451979\n",
      "RMS set2= 0.3206390284840523\n",
      "ckpt-best-CED-8.pth\n",
      "RMS set1= 0.1359566389764936\n",
      "RMS set2= 0.09948638919948824\n",
      "ckpt-best-CED-15.pth\n",
      "RMS set1= 0.131404213089415\n",
      "RMS set2= 0.08634576532151655\n",
      "\n",
      "###########################\n",
      "###########################\n",
      "RMS results CED\n",
      "set 1 min= 0.131404213089415 mean= 0.1980551057072505 len= 8\n",
      "set 2 min= 0.08634576532151655 mean= 0.15835070745547583 len= 8\n",
      "\n",
      "###########################\n",
      "###########################\n",
      "ckpt-best-CED-1.pth\n",
      "RMS set1= 0.2968658029813024\n",
      "RMS set2= 0.27132587597537666\n",
      "ckpt-best-CED-2.pth\n",
      "RMS set1= 0.8890212593817004\n",
      "RMS set2= 0.8825380789181865\n",
      "ckpt-best-CED-3.pth\n",
      "RMS set1= 0.4198213828299597\n",
      "RMS set2= 0.4060465309193541\n",
      "ckpt-best-CED-4.pth\n",
      "RMS set1= 0.6817628707848025\n",
      "RMS set2= 0.7525875382994578\n",
      "ckpt-best-CED-5.pth\n",
      "RMS set1= 0.14128327315766964\n",
      "RMS set2= 0.09799061515470316\n",
      "ckpt-best-CED-6.pth\n",
      "RMS set1= 0.26084004497664576\n",
      "RMS set2= 0.4273647612412531\n",
      "ckpt-best-CED-7.pth\n",
      "RMS set1= 0.3308889260495026\n",
      "RMS set2= 0.3286468319652884\n",
      "ckpt-best-CED-8.pth\n",
      "RMS set1= 0.23475945410489285\n",
      "RMS set2= 0.23774554745040244\n",
      "ckpt-best-CED-9.pth\n",
      "RMS set1= 0.29729132955969134\n",
      "RMS set2= 0.3129484663126848\n",
      "ckpt-best-CED-100.pth\n",
      "RMS set1= 0.2847180437518755\n",
      "RMS set2= 0.299385827569138\n",
      "ckpt-best-CED-101.pth\n",
      "RMS set1= 0.29840437055083235\n",
      "RMS set2= 0.3284228898848342\n",
      "ckpt-best-CED-102.pth\n",
      "RMS set1= 0.3338712402262253\n",
      "RMS set2= 0.3329976151371837\n",
      "ckpt-best-CED-103.pth\n",
      "RMS set1= 0.3119814012559624\n",
      "RMS set2= 0.3353964844520112\n",
      "ckpt-best-CED-104.pth\n",
      "RMS set1= 0.21626522822614572\n",
      "RMS set2= 0.2619615440219105\n",
      "ckpt-best-CED-105.pth\n",
      "RMS set1= 0.4562884250876721\n",
      "RMS set2= 0.424638021072581\n",
      "ckpt-best-CED-106.pth\n",
      "RMS set1= 0.29856529910601926\n",
      "RMS set2= 0.2454821499238392\n",
      "ckpt-best-CED-107.pth\n",
      "RMS set1= 0.39564068287271026\n",
      "RMS set2= 0.3471658936223572\n",
      "\n",
      "###########################\n",
      "###########################\n",
      "RMS results CED tau_1\n",
      "set 1 min= 0.14128327315766964 mean= 0.36166288440609473 len= 17\n",
      "set 2 min= 0.09799061515470316 mean= 0.37015556893650364 len= 17\n",
      "\n",
      "###########################\n",
      "###########################\n"
     ]
    }
   ],
   "source": [
    "# dataset='CED'\n",
    "# # tau_1 = True\n",
    "for dataset, tau_1 in [('CCT',False),('CED',False),('CED',True)]:\n",
    "    method = 'rk4' #exogenous inputs only implemented for Euler, RK4 and midpoint.\n",
    "    device = torch.device('cpu')\n",
    "    train_dir = './models-neural-ode/'\n",
    "    niters = 20000\n",
    "\n",
    "    #given by latent_ODE:\n",
    "    nhidden = 20\n",
    "    rnn_nhidden = 25\n",
    "    obs_dim = 1\n",
    "    noise_std = 0.1\n",
    "\n",
    "\n",
    "\n",
    "    orig_trajs, u_orig_trajs, samp_trajs, u_samp_trajs, samp_ts, u_val_trajs, val_trajs, sample_time, norm = \\\n",
    "            get_torch_data(dataset, device=device)\n",
    "\n",
    "    if dataset=='CCT':\n",
    "        dttau = 0.032\n",
    "        latent_dim = 2\n",
    "\n",
    "    elif dataset=='CED':\n",
    "        dttau = 0.12\n",
    "        latent_dim = 3\n",
    "\n",
    "    tau = 1 if tau_1 else sample_time/dttau\n",
    "\n",
    "    samp_ts /= tau #dt /dt * dttau = dttau\n",
    "\n",
    "\n",
    "\n",
    "    rmslist = []\n",
    "    for I in range(1,400):\n",
    "        try:\n",
    "            ckpt_path = os.path.join(train_dir, f'ckpt-best-{dataset}-{I}.pth' if tau_1==False else f'ckpt-best-{dataset}-{I}-tau_1.pth')\n",
    "            out = torch.load(ckpt_path)\n",
    "        except FileNotFoundError:\n",
    "            continue\n",
    "        print(f'ckpt-best-{dataset}-{I}.pth')\n",
    "        func = out['func']\n",
    "        rec = out['rec']\n",
    "        dec = out['dec']\n",
    "\n",
    "        h = rec.initHidden().to(device)\n",
    "        for t in reversed(range(orig_trajs.size(1))):\n",
    "            obs = orig_trajs[:, t, :]\n",
    "            out, h = rec.forward(obs, h)\n",
    "        qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]\n",
    "        epsilon = torch.randn(qz0_mean.size()).to(device)\n",
    "        z0 = epsilon * torch.exp(.5 * qz0_logvar)*0 + qz0_mean #only mean\n",
    "\n",
    "        # forward in time and solve ode for reconstructions\n",
    "        pred_z = odeint(func, z0, samp_ts[:orig_trajs.shape[1]], u=u_orig_trajs, method=method).permute(1, 0, 2)\n",
    "        pred_x = dec(pred_z)\n",
    "        orig_trajs_p = pred_x.detach()\n",
    "\n",
    "        # matplotlib.use()\n",
    "        if dataset=='CCT':\n",
    "    #         plt.figure(figsize=(12,3))\n",
    "    #         plt.plot(orig_trajs[0,:,0].numpy())\n",
    "    #         plt.plot(orig_trajs_p[0,:,0].numpy())\n",
    "    #         plt.show()\n",
    "            rms = torch.mean((orig_trajs - orig_trajs_p)**2).item()**0.5*norm.ystd\n",
    "            rmslist.append(rms)\n",
    "            print('RMS=',rms)\n",
    "        if dataset=='CED':\n",
    "    #         plt.figure(figsize=(12,3))\n",
    "    #         plt.plot(orig_trajs[0,:,0].numpy())\n",
    "    #         plt.plot(orig_trajs_p[0,:,0].numpy())\n",
    "    #         plt.show()\n",
    "    #         plt.figure(figsize=(12,3))\n",
    "    #         plt.plot(orig_trajs[1,:,0].numpy())\n",
    "    #         plt.plot(orig_trajs_p[1,:,0].numpy())\n",
    "    #         plt.show()\n",
    "            rms1 = torch.mean((orig_trajs[0] - orig_trajs_p[0])**2).item()**0.5*norm.ystd\n",
    "            rms2 = torch.mean((orig_trajs[1] - orig_trajs_p[1])**2).item()**0.5*norm.ystd\n",
    "            print('RMS set1=',rms1)\n",
    "            print('RMS set2=',rms2)\n",
    "            rmslist.append((rms1,rms2))\n",
    "\n",
    "    rmslist = np.array(rmslist)\n",
    "\n",
    "    print()\n",
    "    print('###########################')\n",
    "    print('###########################')\n",
    "    if dataset=='CED':\n",
    "        print(\"RMS results CED tau_1\" if tau_1 else \"RMS results CED\")\n",
    "        print('set 1 min=',np.min(rmslist[:,0],axis=0), 'mean=',np.mean(rmslist[:,0],axis=0), 'len=',len(rmslist))\n",
    "        print('set 2 min=',np.min(rmslist[:,1],axis=0), 'mean=',np.mean(rmslist[:,1],axis=0), 'len=',len(rmslist))\n",
    "    elif dataset=='CCT':\n",
    "        print(\"RMS results CCT\")\n",
    "        print('min=',np.min(rmslist,axis=0), 'mean=',np.mean(rmslist,axis=0), 'len=',len(rmslist))\n",
    "    print()\n",
    "    print('###########################')\n",
    "    print('###########################')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
