{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "import torch\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "############################################\n",
    "# Added for GluNet package\n",
    "############################################\n",
    "from trainer_glunet import LatentODEWrapper\n",
    "from eval_glunet import test\n",
    "sys.path.insert(1, '../..')\n",
    "os.chdir('../..')\n",
    "# utils for darts\n",
    "from utils.darts_dataset import *\n",
    "from utils.darts_processing import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--------------------------------\n",
      "Loading column definition...\n",
      "Checking column definition...\n",
      "Loading data...\n",
      "Dropping columns / rows...\n",
      "Checking for NA values...\n",
      "Setting data types...\n",
      "Dropping columns / rows...\n",
      "Encoding data...\n",
      "\tUpdated column definition:\n",
      "\t\tid: REAL_VALUED (ID)\n",
      "\t\ttime: DATE (TIME)\n",
      "\t\tgl: REAL_VALUED (TARGET)\n",
      "\t\tgender: REAL_VALUED (STATIC_INPUT)\n",
      "\t\tage: REAL_VALUED (STATIC_INPUT)\n",
      "\t\tBMI: REAL_VALUED (STATIC_INPUT)\n",
      "\t\tglycaemia: REAL_VALUED (STATIC_INPUT)\n",
      "\t\tHbA1c: REAL_VALUED (STATIC_INPUT)\n",
      "\t\tfollow.up: REAL_VALUED (STATIC_INPUT)\n",
      "\t\tT2DM: REAL_VALUED (STATIC_INPUT)\n",
      "\t\ttime_year: REAL_VALUED (KNOWN_INPUT)\n",
      "\t\ttime_month: REAL_VALUED (KNOWN_INPUT)\n",
      "\t\ttime_day: REAL_VALUED (KNOWN_INPUT)\n",
      "\t\ttime_hour: REAL_VALUED (KNOWN_INPUT)\n",
      "\t\ttime_minute: REAL_VALUED (KNOWN_INPUT)\n",
      "Interpolating data...\n",
      "\tDropped segments: 63\n",
      "\tExtracted segments: 205\n",
      "\tInterpolated values: 241\n",
      "\tPercent of values interpolated: 0.22%\n",
      "Splitting data...\n",
      "\tTrain: 72275 (45.89%)\n",
      "\tVal: 35713 (22.68%)\n",
      "\tTest: 38253 (24.29%)\n",
      "\tTest OOD: 11242 (7.14%)\n",
      "Scaling data...\n",
      "\tNo scaling applied\n",
      "Data formatting complete.\n",
      "--------------------------------\n"
     ]
    }
   ],
   "source": [
    "# define data\n",
    "formatter, series, scalers = load_data(seed=0, \n",
    "                                       study_file=None, \n",
    "                                       dataset='colas', \n",
    "                                       use_covs=True, \n",
    "                                       cov_type='dual',\n",
    "                                       use_static_covs=True)\n",
    "\n",
    "dataset_train = SamplingDatasetDual(series['train']['target'],\n",
    "                                    series['train']['future'],\n",
    "                                    output_chunk_length=12,\n",
    "                                    input_chunk_length=48,\n",
    "                                    use_static_covariates=True,\n",
    "                                    max_samples_per_ts=100,\n",
    "                                    )\n",
    "dataset_val = SamplingDatasetDual(series['val']['target'],\n",
    "                                    series['val']['future'],   \n",
    "                                    output_chunk_length=12,\n",
    "                                    input_chunk_length=48,\n",
    "                                    use_static_covariates=True,)\n",
    "dataset_test = SamplingDatasetInferenceDual(target_series=series['test']['target'],\n",
    "                                            covariates=series['test']['future'],\n",
    "                                            input_chunk_length=48,\n",
    "                                            output_chunk_length=12,\n",
    "                                            use_static_covariates=True,\n",
    "                                            array_output_only=True)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_wrapper = LatentODEWrapper(device = 'cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/grads/m/mrsergazinov/.conda/envs/glunet/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py:296: UserWarning: t is not on the same device as y0. Coercing to y0.device.\n",
      "  warnings.warn(\"t is not on the same device as y0. Coercing to y0.device.\")\n"
     ]
    }
   ],
   "source": [
    "# get torch summary writer and initialize at the directory: ./output/tensorboard_latentode_colas/\n",
    "logger = SummaryWriter(log_dir='./output/tensorboard_latentode_colas/')\n",
    "\n",
    "model_wrapper.fit(train_dataset=dataset_train,\n",
    "                    val_dataset=dataset_val,\n",
    "                    learning_rate= 1e-3,\n",
    "                    batch_size= 128,\n",
    "                    epochs= 2,\n",
    "                    num_samples= 10,\n",
    "                    device= 'cuda',\n",
    "                    model_path='./output/tensorboard_latentode_colas/model.ckpt',\n",
    "                    trial= None,\n",
    "                    logger= logger)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_wrapper.load(model_path='./output/tensorboard_latentode_colas/model.ckpt',\n",
    "                    device= 'cuda')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/grads/m/mrsergazinov/.conda/envs/glunet/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py:296: UserWarning: t is not on the same device as y0. Coercing to y0.device.\n",
      "  warnings.warn(\"t is not on the same device as y0. Coercing to y0.device.\")\n"
     ]
    }
   ],
   "source": [
    "predictions = model_wrapper.predict(dataset_test,\n",
    "                                    num_samples=10,\n",
    "                                    batch_size=128,\n",
    "                                    device='cuda')\n",
    "trues = np.array([dataset_test.evalsample(i).values() for i in range(len(dataset_test))])\n",
    "trues = (trues - scalers['target'].min_) / scalers['target'].scale_\n",
    "obsrv_std = 0.01 / scalers['target'].scale_\n",
    "predictions = (predictions - scalers['target'].min_) / scalers['target'].scale_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "errors, _, _ = test(trues, predictions, obsrv_std)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([65.999405,  7.235167], dtype=float32)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.median(errors, axis=0)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Legacy (below)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/grads/m/mrsergazinov/.conda/envs/glunet/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py:296: UserWarning: t is not on the same device as y0. Coercing to y0.device.\n",
      "  warnings.warn(\"t is not on the same device as y0. Coercing to y0.device.\")\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size=128, shuffle=False)\n",
    "batch = iter(dataloader_val).next()\n",
    "device = 'cuda'\n",
    "inp_len, out_len = 48, 12\n",
    "batch_dict = {\n",
    "            'observed_data': batch[0].to(device),\n",
    "            'observed_tp': torch.arange(0, inp_len).to(device) / 12,\n",
    "            'data_to_predict': batch[-1].to(device),\n",
    "            'tp_to_predict': torch.arange(inp_len, inp_len+out_len).to(device) / 12,\n",
    "            'observed_mask': torch.ones(batch[0].shape).to(device),\n",
    "            'mask_predicted_data': None,\n",
    "            'labels': None,\n",
    "            'mode': 'extrap',\n",
    "\t    }\n",
    "pred_y, info = model_wrapper.model.get_reconstruction(batch_dict[\"tp_to_predict\"], \n",
    "                                                        batch_dict[\"observed_data\"], \n",
    "                                                        batch_dict[\"observed_tp\"], \n",
    "                                                        mask = batch_dict[\"observed_mask\"], \n",
    "                                                        n_traj_samples = 100,\n",
    "                                                        mode = batch_dict[\"mode\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred_y = pred_y.detach().cpu().numpy()\n",
    "y = batch_dict['data_to_predict'].detach().cpu().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "fig, axs = plt.subplots(4, 4, figsize=(40, 20))\n",
    "for i in range(16):\n",
    "    ax = axs[i // 4, i % 4]\n",
    "    df = pd.DataFrame(pred_y[:, i, :, 0].transpose(1, 0))\n",
    "    df = pd.melt(df.reset_index(), id_vars=['index'], value_vars=df.columns)\n",
    "    df['type'] = 'pred'\n",
    "    df2 = pd.DataFrame(y[i, :, 0])\n",
    "    df2 = pd.melt(df2.reset_index(), id_vars=['index'], value_vars=df2.columns)\n",
    "    df2['type'] = 'true'\n",
    "    df = pd.concat([df, df2])\n",
    "    sns.lineplot(x=\"index\", y=\"value\", hue='type', data=df, ax=ax)\n",
    "fig.savefig(\"pred_vs_true.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define params\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "input_dim = 1\n",
    "classif_per_tp = False\n",
    "n_labels = 1\n",
    "class args_parser:\n",
    "    z0_encoder = 'odernn'\n",
    "    latents = 10\n",
    "    rec_dims = 20\n",
    "    rec_layers = 1\n",
    "    gen_layers = 1\n",
    "    units = 100\n",
    "    gru_units = 100\n",
    "    \n",
    "    extrap = True    \n",
    "    poisson = False\n",
    "    classif = False\n",
    "    linear_classif = False\n",
    "    dataset = 'glunet'\n",
    "args = args_parser()\n",
    "\n",
    "# define model\n",
    "obsrv_std = 0.01\n",
    "obsrv_std = torch.Tensor([obsrv_std]).to(device)\n",
    "z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))\n",
    "model = create_LatentODE_model(args, input_dim, z0_prior, obsrv_std, device, \n",
    "                                classif_per_tp = classif_per_tp,\n",
    "                                n_labels = n_labels,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "./\n",
      "./\n",
      "./\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[1;32m/home/grads/m/mrsergazinov/GluNet/lib/latent_ode/test.ipynb Cell 5\u001b[0m in \u001b[0;36m<cell line: 12>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Becewkgsw223b01.engr.tamu.edu/home/grads/m/mrsergazinov/GluNet/lib/latent_ode/test.ipynb#W4sdnNjb2RlLXJlbW90ZQ%3D%3D?line=23'>24</a>\u001b[0m \tbatch_dict \u001b[39m=\u001b[39m {\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Becewkgsw223b01.engr.tamu.edu/home/grads/m/mrsergazinov/GluNet/lib/latent_ode/test.ipynb#W4sdnNjb2RlLXJlbW90ZQ%3D%3D?line=24'>25</a>\u001b[0m \t\t\u001b[39m'\u001b[39m\u001b[39mobserved_data\u001b[39m\u001b[39m'\u001b[39m: batch[\u001b[39m0\u001b[39m]\u001b[39m.\u001b[39mto(device),\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Becewkgsw223b01.engr.tamu.edu/home/grads/m/mrsergazinov/GluNet/lib/latent_ode/test.ipynb#W4sdnNjb2RlLXJlbW90ZQ%3D%3D?line=25'>26</a>\u001b[0m         \u001b[39m'\u001b[39m\u001b[39mobserved_tp\u001b[39m\u001b[39m'\u001b[39m: torch\u001b[39m.\u001b[39marange(\u001b[39m0\u001b[39m, inp_len)\u001b[39m.\u001b[39mto(device) \u001b[39m/\u001b[39m \u001b[39m12\u001b[39m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Becewkgsw223b01.engr.tamu.edu/home/grads/m/mrsergazinov/GluNet/lib/latent_ode/test.ipynb#W4sdnNjb2RlLXJlbW90ZQ%3D%3D?line=31'>32</a>\u001b[0m         \u001b[39m'\u001b[39m\u001b[39mmode\u001b[39m\u001b[39m'\u001b[39m: \u001b[39m'\u001b[39m\u001b[39mextrap\u001b[39m\u001b[39m'\u001b[39m\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Becewkgsw223b01.engr.tamu.edu/home/grads/m/mrsergazinov/GluNet/lib/latent_ode/test.ipynb#W4sdnNjb2RlLXJlbW90ZQ%3D%3D?line=32'>33</a>\u001b[0m \t}\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Becewkgsw223b01.engr.tamu.edu/home/grads/m/mrsergazinov/GluNet/lib/latent_ode/test.ipynb#W4sdnNjb2RlLXJlbW90ZQ%3D%3D?line=33'>34</a>\u001b[0m \ttrain_res \u001b[39m=\u001b[39m model\u001b[39m.\u001b[39mcompute_all_losses(batch_dict, n_traj_samples \u001b[39m=\u001b[39m \u001b[39m3\u001b[39m, kl_coef \u001b[39m=\u001b[39m kl_coef)\n\u001b[0;32m---> <a href='vscode-notebook-cell://ssh-remote%2Becewkgsw223b01.engr.tamu.edu/home/grads/m/mrsergazinov/GluNet/lib/latent_ode/test.ipynb#W4sdnNjb2RlLXJlbW90ZQ%3D%3D?line=34'>35</a>\u001b[0m \ttrain_res[\u001b[39m\"\u001b[39;49m\u001b[39mloss\u001b[39;49m\u001b[39m\"\u001b[39;49m]\u001b[39m.\u001b[39;49mbackward()\n\u001b[1;32m     <a href='vscode-notebook-cell://ssh-remote%2Becewkgsw223b01.engr.tamu.edu/home/grads/m/mrsergazinov/GluNet/lib/latent_ode/test.ipynb#W4sdnNjb2RlLXJlbW90ZQ%3D%3D?line=35'>36</a>\u001b[0m \toptimizer\u001b[39m.\u001b[39mstep()\n",
      "File \u001b[0;32m~/.conda/envs/glunet/lib/python3.10/site-packages/torch/_tensor.py:396\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m    387\u001b[0m \u001b[39mif\u001b[39;00m has_torch_function_unary(\u001b[39mself\u001b[39m):\n\u001b[1;32m    388\u001b[0m     \u001b[39mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m    389\u001b[0m         Tensor\u001b[39m.\u001b[39mbackward,\n\u001b[1;32m    390\u001b[0m         (\u001b[39mself\u001b[39m,),\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    394\u001b[0m         create_graph\u001b[39m=\u001b[39mcreate_graph,\n\u001b[1;32m    395\u001b[0m         inputs\u001b[39m=\u001b[39minputs)\n\u001b[0;32m--> 396\u001b[0m torch\u001b[39m.\u001b[39;49mautograd\u001b[39m.\u001b[39;49mbackward(\u001b[39mself\u001b[39;49m, gradient, retain_graph, create_graph, inputs\u001b[39m=\u001b[39;49minputs)\n",
      "File \u001b[0;32m~/.conda/envs/glunet/lib/python3.10/site-packages/torch/autograd/__init__.py:173\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m    168\u001b[0m     retain_graph \u001b[39m=\u001b[39m create_graph\n\u001b[1;32m    170\u001b[0m \u001b[39m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m    171\u001b[0m \u001b[39m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m    172\u001b[0m \u001b[39m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 173\u001b[0m Variable\u001b[39m.\u001b[39;49m_execution_engine\u001b[39m.\u001b[39;49mrun_backward(  \u001b[39m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m    174\u001b[0m     tensors, grad_tensors_, retain_graph, create_graph, inputs,\n\u001b[1;32m    175\u001b[0m     allow_unreachable\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m, accumulate_grad\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m)\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# log_path = \"logs/\" + file_name + \"_\" + str(experimentID) + \".log\"\n",
    "# if not os.path.exists(\"logs/\"):\n",
    "# \tutils.makedirs(\"logs/\")\n",
    "log_path = './output/latentode_colas.log'\n",
    "logger = latentode_utils.get_logger(logpath=log_path, filepath='./')\n",
    "\n",
    "optimizer = optim.Adamax(model.parameters(), lr=1e-3)\n",
    "\n",
    "num_batches = 100\n",
    "niters = 100\n",
    "\n",
    "for itr in range(1, num_batches * (niters + 1)):\n",
    "\toptimizer.zero_grad()\n",
    "\tlatentode_utils.update_learning_rate(optimizer, decay_rate = 0.999, lowest = 1e-3 / 10)\n",
    "\n",
    "\twait_until_kl_inc = 10\n",
    "\tif itr // num_batches < wait_until_kl_inc:\n",
    "\t\tkl_coef = 0.\n",
    "\telse:\n",
    "\t\tkl_coef = (1-0.99** (itr // num_batches - wait_until_kl_inc))\n",
    "\n",
    "\tbatch = train_loader.__next__()\n",
    "\tinp_len, out_len = batch[0].shape[0], batch[0].shape[1], batch[-1].shape[1]\n",
    "\tbatch_dict = {\n",
    "\t\t'observed_data': batch[0].to(device),\n",
    "        'observed_tp': torch.arange(0, inp_len).to(device) / 12,\n",
    "        'data_to_predict': batch[-1].to(device),\n",
    "        'tp_to_predict': torch.arange(inp_len, inp_len+out_len).to(device) / 12,\n",
    "        'observed_mask': torch.ones(batch[0].shape).to(device),\n",
    "        'mask_predicted_data': None,\n",
    "        'labels': None,\n",
    "        'mode': 'extrap'\n",
    "\t}\n",
    "\ttrain_res = model.compute_all_losses(batch_dict, n_traj_samples = 3, kl_coef = kl_coef)\n",
    "\ttrain_res[\"loss\"].backward()\n",
    "\toptimizer.step()\n",
    "\n",
    "\tif itr % num_batches == 0:\n",
    "\t\tfor itr_val in range(1, 100):\n",
    "\t\t\tbatch_val = val_loader.__next__()\n",
    "\t\t\tinp_len, out_len = batch_val[0].shape[0], batch_val[0].shape[1], batch_val[-1].shape[1]\n",
    "\t\t\tbatch_dict_val = {\n",
    "\t\t\t\t'observed_data': batch_val[0].to(device),\n",
    "\t\t\t\t'observed_tp': torch.arange(0, inp_len).to(device) / 12,\n",
    "\t\t\t\t'data_to_predict': batch_val[-1].to(device),\n",
    "\t\t\t\t'tp_to_predict': torch.arange(inp_len, inp_len+out_len).to(device) / 12,\n",
    "\t\t\t\t'observed_mask': torch.ones(batch_val[0].shape).to(device),\n",
    "\t\t\t\t'mask_predicted_data': None,\n",
    "\t\t\t\t'labels': None,\n",
    "\t\t\t\t'mode': 'extrap'\n",
    "\t\t\t}\n",
    "\t\t\tval_res = model.compute_all_losses(batch_dict_val, n_traj_samples = 3, kl_coef = kl_coef)\n",
    "\t\t\tif itr_val == 1:\n",
    "\t\t\t\tloss_val = val_res[\"loss\"]\n",
    "\t\t\telse:\n",
    "\t\t\t\tloss_val += val_res[\"loss\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'loss': tensor(30.1813, device='cuda:0', grad_fn=<MeanBackward0>),\n",
       " 'likelihood': tensor(-31.3523, device='cuda:0'),\n",
       " 'mse': tensor(0.0070, device='cuda:0'),\n",
       " 'pois_likelihood': tensor(0., device='cuda:0'),\n",
       " 'ce_loss': tensor(0., device='cuda:0'),\n",
       " 'kl_first_p': tensor(3.2290, device='cuda:0'),\n",
       " 'std_first_p': tensor(0.0311, device='cuda:0')}"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "batch = val_loader.__next__()\n",
    "inp_len, out_len = batch[0].shape[1], batch[-1].shape[1]\n",
    "batch_dict = {\n",
    "    'observed_data': batch[0].to(device),\n",
    "    'observed_tp': torch.arange(0, inp_len).to(device) / 12,\n",
    "    'data_to_predict': batch[-1].to(device),\n",
    "    'tp_to_predict': torch.arange(inp_len, inp_len+out_len).to(device) / 12,\n",
    "    'observed_mask': torch.ones(batch[0].shape).to(device),\n",
    "    'mask_predicted_data': None,\n",
    "    'labels': None,\n",
    "    'mode': 'extrap'\n",
    "}\n",
    "\n",
    "pred_y, info = model.get_reconstruction(batch_dict[\"tp_to_predict\"], \n",
    "                                        batch_dict[\"observed_data\"], \n",
    "                                        batch_dict[\"observed_tp\"], \n",
    "                                        mask = batch_dict[\"observed_mask\"], \n",
    "                                        n_traj_samples = 100,\n",
    "                                        mode = batch_dict[\"mode\"])\n",
    "pred_y = pred_y.detach().cpu().numpy()\n",
    "y = batch_dict['data_to_predict'].detach().cpu().numpy()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_y.shape[1]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_loader = torch.utils.data.DataLoader(dataset_train, \n",
    "                                            batch_size=32,\n",
    "                                            shuffle=True,\n",
    "                                            drop_last=True)\n",
    "val_loader = torch.utils.data.DataLoader(dataset_val,\n",
    "                                         batch_size=32,\n",
    "                                         shuffle=True,\n",
    "                                         drop_last=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 'observed_data', 'observed_tp', 'data_to_predict', 'tp_to_predict', 'observed_mask', 'mask_predicted_data', 'labels', 'mode'\n",
    "device = 'cuda'\n",
    "for batch in train_loader:\n",
    "    batch_dict = {\n",
    "        'observed_data': batch[0].to(device),\n",
    "        'observed_tp': torch.arange(0, batch[0].shape[1]).unsqueeze(0).repeat(batch[0].shape[0], 1).to(device) / 12,\n",
    "        'data_to_predict': batch[-1].to(device),\n",
    "        'tp_to_predict': torch.arange(batch[0].shape[1], batch[0].shape[1] + batch[-1].shape[1]).unsqueeze(0).repeat(batch[-1].shape[0], 1).to(device) / 12,\n",
    "        'observed_mask': torch.ones(batch[0].shape).to(device),\n",
    "        'mask_predicted_data': None,\n",
    "        'labels': None,\n",
    "        'mode': 'extrap'\n",
    "    }\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class args_parser:\n",
    "    dataset = 'periodic'\n",
    "    n = 100\n",
    "    niters = 300\n",
    "    lr = 1e-2\n",
    "    batch_size = 50\n",
    "    viz = True\n",
    "    save = 'experiments/'\n",
    "    load = None\n",
    "    random_seed = 1991\n",
    "    sample_tp = None\n",
    "    cut_tp = None\n",
    "    quantization = 0.1\n",
    "    latent_ode = True\n",
    "    z0_encoder = 'odernn'\n",
    "    latents = 10\n",
    "    rec_dims = 20\n",
    "    rec_layers = 1\n",
    "    gen_layers = 1\n",
    "    units = 100\n",
    "    gru_units = 100\n",
    "    \n",
    "    extrap = True    \n",
    "    poisson = False\n",
    "    classif = False\n",
    "    linear_classif = False\n",
    "    \n",
    "    timepoints = 100\n",
    "    max_t = 5.\n",
    "    noise_weight = 0.01\n",
    "\n",
    "args = args_parser()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create data\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "data_obj = parse_datasets(args, device)\n",
    "\n",
    "input_dim = data_obj[\"input_dim\"]\n",
    "classif_per_tp = False\n",
    "n_labels = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_dict = latentode_utils.get_next_batch(data_obj[\"train_dataloader\"])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([50])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "batch_dict['observed_tp'].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the model\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "obsrv_std = 0.01\n",
    "obsrv_std = torch.Tensor([obsrv_std]).to(device)\n",
    "z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))\n",
    "\n",
    "model = create_LatentODE_model(args, 1, z0_prior, obsrv_std, device, \n",
    "                                classif_per_tp = False,\n",
    "                                n_labels = 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt_path = '/home/grads/m/mrsergazinov/latent_ode/experiments/experiment_30327.ckpt'\n",
    "utils.get_ckpt_model(ckpt_path, model, device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_dict = utils.get_next_batch(data_obj[\"train_dataloader\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_dict.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_dict['observed_data'].shape "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_style(\"whitegrid\")\n",
    "\n",
    "batch_dict = utils.get_next_batch(data_obj[\"train_dataloader\"])\n",
    "\n",
    "pred_y, info = model.get_reconstruction(batch_dict[\"tp_to_predict\"], \n",
    "                                        batch_dict[\"observed_data\"], \n",
    "                                        batch_dict[\"observed_tp\"], \n",
    "                                        mask = batch_dict[\"observed_mask\"], \n",
    "                                        n_traj_samples = 100,\n",
    "                                        mode = batch_dict[\"mode\"])\n",
    "pred_y = pred_y.detach().cpu().numpy()\n",
    "y = batch_dict['data_to_predict'].detach().cpu().numpy()\n",
    "\n",
    "fig, axs = plt.subplots(5, pred_y.shape[1] // 5, figsize=(20, 10))\n",
    "for i in range(pred_y.shape[1]):\n",
    "    ax = axs[i % 5, i // 5]\n",
    "    df = pd.DataFrame(pred_y[:, i, :, 0].transpose(1, 0))\n",
    "    df = pd.melt(df.reset_index(), id_vars=['index'], value_vars=df.columns)\n",
    "    df['type'] = 'pred'\n",
    "    df2 = pd.DataFrame(y[i, :, 0])\n",
    "    df2 = pd.melt(df2.reset_index(), id_vars=['index'], value_vars=df2.columns)\n",
    "    df2['type'] = 'true'\n",
    "    df = pd.concat([df, df2])\n",
    "    sns.lineplot(x=\"index\", y=\"value\", hue='type', data=df, ax=ax)\n",
    "fig.savefig(\"pred_vs_true.png\")\n",
    "\n",
    "# df_pred_y_batch0 = pd.DataFrame(pred_y[:, 0, :, 0].transpose(1, 0))\n",
    "# df_pred_y_batch0 = pd.melt(df_pred_y_batch0.reset_index(), id_vars=['index'], value_vars=df_pred_y_batch0.columns)\n",
    "\n",
    "# # plot the data\n",
    "# fig, ax = plt.subplots(figsize=(10, 6))\n",
    "# sns.lineplot(x=\"index\", y=\"value\", data=df_pred_y_batch0, ax=ax)\n",
    "# # save the figure\n",
    "# fig.savefig(\"pred_y_batch0.png\")\n",
    "\n",
    "# # plot batch_dict['data_to_predict'][0, :, 0].detach().cpu().numpy() and save\n",
    "# df_y_batch0 = pd.DataFrame()\n",
    "# df_y_batch0[\"value\"] = batch_dict['data_to_predict'][0, :, 0].detach().cpu().numpy()\n",
    "# df_y_batch0[\"index\"] = np.arange(df_y_batch0.shape[0])\n",
    "# fig, ax = plt.subplots(figsize=(10, 6))\n",
    "# sns.lineplot(x=\"index\", y=\"value\", data=df_y_batch0, ax=ax)\n",
    "# fig.savefig(\"y_batch0.png\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_path = \"logs/\" + file_name + \"_\" + str(experimentID) + \".log\"\n",
    "if not os.path.exists(\"logs/\"):\n",
    "\tutils.makedirs(\"logs/\")\n",
    "logger = utils.get_logger(logpath=log_path, filepath='./')\n",
    "\n",
    "optimizer = optim.Adamax(model.parameters(), lr=args.lr)\n",
    "\n",
    "num_batches = data_obj[\"n_train_batches\"]\n",
    "\n",
    "for itr in range(1, num_batches * (args.niters + 1)):\n",
    "\toptimizer.zero_grad()\n",
    "\tutils.update_learning_rate(optimizer, decay_rate = 0.999, lowest = args.lr / 10)\n",
    "\n",
    "\twait_until_kl_inc = 10\n",
    "\tif itr // num_batches < wait_until_kl_inc:\n",
    "\t\tkl_coef = 0.\n",
    "\telse:\n",
    "\t\tkl_coef = (1-0.99** (itr // num_batches - wait_until_kl_inc))\n",
    "\n",
    "\tbatch_dict = utils.get_next_batch(data_obj[\"train_dataloader\"])\n",
    "\ttrain_res = model.compute_all_losses(batch_dict, n_traj_samples = 3, kl_coef = kl_coef)\n",
    "\ttrain_res[\"loss\"].backward()\n",
    "\toptimizer.step()\n",
    "\n",
    "\tn_iters_to_viz = 1\n",
    "\tif itr % (n_iters_to_viz * num_batches) == 0:\n",
    "\t\twith torch.no_grad():\n",
    "\n",
    "\t\t\ttest_res = compute_loss_all_batches(model, \n",
    "\t\t\t\tdata_obj[\"test_dataloader\"], args,\n",
    "\t\t\t\tn_batches = data_obj[\"n_test_batches\"],\n",
    "\t\t\t\texperimentID = experimentID,\n",
    "\t\t\t\tdevice = device,\n",
    "\t\t\t\tn_traj_samples = 3, kl_coef = kl_coef)\n",
    "\n",
    "\t\t\tmessage = 'Epoch {:04d} [Test seq (cond on sampled tp)] | Loss {:.6f} | Likelihood {:.6f} | KL fp {:.4f} | FP STD {:.4f}|'.format(\n",
    "\t\t\t\titr//num_batches, \n",
    "\t\t\t\ttest_res[\"loss\"].detach(), test_res[\"likelihood\"].detach(), \n",
    "\t\t\t\ttest_res[\"kl_first_p\"], test_res[\"std_first_p\"])\n",
    "\t\t\n",
    "\t\t\tlogger.info(\"Experiment \" + str(experimentID))\n",
    "\t\t\tlogger.info(message)\n",
    "\t\t\tlogger.info(\"KL coef: {}\".format(kl_coef))\n",
    "\t\t\tlogger.info(\"Train loss (one batch): {}\".format(train_res[\"loss\"].detach()))\n",
    "\t\t\tlogger.info(\"Train CE loss (one batch): {}\".format(train_res[\"ce_loss\"].detach()))\n",
    "\t\t\t\n",
    "\t\t\tif \"auc\" in test_res:\n",
    "\t\t\t\tlogger.info(\"Classification AUC (TEST): {:.4f}\".format(test_res[\"auc\"]))\n",
    "\n",
    "\t\t\tif \"mse\" in test_res:\n",
    "\t\t\t\tlogger.info(\"Test MSE: {:.4f}\".format(test_res[\"mse\"]))\n",
    "\n",
    "\t\t\tif \"accuracy\" in train_res:\n",
    "\t\t\t\tlogger.info(\"Classification accuracy (TRAIN): {:.4f}\".format(train_res[\"accuracy\"]))\n",
    "\n",
    "\t\t\tif \"accuracy\" in test_res:\n",
    "\t\t\t\tlogger.info(\"Classification accuracy (TEST): {:.4f}\".format(test_res[\"accuracy\"]))\n",
    "\n",
    "\t\t\tif \"pois_likelihood\" in test_res:\n",
    "\t\t\t\tlogger.info(\"Poisson likelihood: {}\".format(test_res[\"pois_likelihood\"]))\n",
    "\n",
    "\t\t\tif \"ce_loss\" in test_res:\n",
    "\t\t\t\tlogger.info(\"CE loss: {}\".format(test_res[\"ce_loss\"]))\n",
    "\n",
    "\t\ttorch.save({\n",
    "\t\t\t'args': args,\n",
    "\t\t\t'state_dict': model.state_dict(),\n",
    "\t\t}, ckpt_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.6"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
