{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f617ecc5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<frozen importlib._bootstrap>:219: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n",
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "import joblib as jl\n",
    "import cebra.datasets\n",
    "from cebra import CEBRA\n",
    "import scipy.io as sio\n",
    "import sklearn.metrics\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.linear_model import LinearRegression\n",
    "import time\n",
    "sys.path.insert(0, './third_party/pivae')\n",
    "import pivae_code.datasets, pivae_code.conv_pi_vae, pivae_code.pi_vae\n",
    "\n",
    "train_percent = 0.60\n",
    "valid_percent = 0.20\n",
    "test_percent = 0.20\n",
    "embed_dimension = 3\n",
    "batch_size = 200\n",
    "np.random.seed(2024)\n",
    "iterations = 30\n",
    "output_dimension = 3\n",
    "learning_rate = 0.001\n",
    "dur = 40"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8ff486f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dataset_2D_to_3D(dataset_2D):\n",
    "    # data = train_set.neural.numpy()  # [time_bins, neurons]\n",
    "    time_bins, neurons = dataset_2D.shape\n",
    "    receptive_field_size = 10  # Total bins in receptive field\n",
    "    half_window = receptive_field_size // 2\n",
    "    dataset_3D = np.zeros((time_bins, neurons, receptive_field_size))\n",
    "    for t in range(time_bins):\n",
    "        for n in range(neurons):\n",
    "            # Calculate the indices for the receptive field window\n",
    "            start_idx = max(0, t - half_window)\n",
    "            end_idx = min(time_bins, t + half_window + 1)\n",
    "\n",
    "            # Slice the window for neuron n\n",
    "            window = dataset_2D[start_idx:end_idx, n]\n",
    "\n",
    "            # Calculate where to place the window in the receptive field dimension\n",
    "            # Adjusting indices to fit exactly within the receptive field slots\n",
    "            window_start = half_window - (t - start_idx)\n",
    "            window_end = window_start + (end_idx - start_idx)\n",
    "\n",
    "            # Ensure the window fits exactly into the new_data array\n",
    "            window_start = max(0, window_start)\n",
    "            window_end = min(receptive_field_size, window_end)\n",
    "            \n",
    "            dataset_3D[t, n, window_start:window_end] = window[:window_end - window_start]\n",
    "    return dataset_3D\n",
    "\n",
    "def to_batch_list(x, y, batch_size):\n",
    "    x = x.squeeze()\n",
    "    ### print(x.shape) ### (6885/1390/1903, 120, 10)\n",
    "    if len(x.shape) == 3:\n",
    "        x = x.transpose(0,2,1) \n",
    "        print(x.shape) ### (6885/1390/1903, 10, 120)\n",
    "    x_batch_list = np.array_split(x, int(len(x) / batch_size))\n",
    "    print(int(len(x) / batch_size)) ### 6885/1390/1903 divided by batch-size===34/6/9\n",
    "    y_batch_list = np.array_split(y, int(len(y) / batch_size))\n",
    "    return x_batch_list, y_batch_list\n",
    "\n",
    "def custom_data_generator(x_all, u_one_hot):\n",
    "    while True:\n",
    "        for ii in range(len(x_all)):\n",
    "            #print(x_all[ii].shape)\n",
    "            #print(u_one_hot[ii].shape)\n",
    "            yield ([x_all[ii], u_one_hot[ii]], None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "122b7983",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Mihili_20140303_neural_con_dis_index.mat\n",
      "(4960, 10, 52)\n",
      "24\n",
      "(1640, 10, 52)\n",
      "8\n",
      "(1720, 10, 52)\n",
      "8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2024-09-28 23:21:37.059922: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1\n",
      "2024-09-28 23:21:37.060225: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1561] Found device 0 with properties: \n",
      "pciBusID: 0000:17:00.0 name: NVIDIA RTX A5000 computeCapability: 8.6\n",
      "coreClock: 1.695GHz coreCount: 64 deviceMemorySize: 23.68GiB deviceMemoryBandwidth: 715.34GiB/s\n",
      "2024-09-28 23:21:37.060378: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1561] Found device 1 with properties: \n",
      "pciBusID: 0000:65:00.0 name: NVIDIA RTX A5000 computeCapability: 8.6\n",
      "coreClock: 1.695GHz coreCount: 64 deviceMemorySize: 23.67GiB deviceMemoryBandwidth: 715.34GiB/s\n",
      "2024-09-28 23:21:37.060586: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcudart.so.10.1'; dlerror: libcudart.so.10.1: cannot open shared object file: No such file or directory\n",
      "2024-09-28 23:21:37.060633: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcublas.so.10'; dlerror: libcublas.so.10: cannot open shared object file: No such file or directory\n",
      "2024-09-28 23:21:37.060660: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcufft.so.10\n",
      "2024-09-28 23:21:37.060678: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcurand.so.10\n",
      "2024-09-28 23:21:37.060711: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcusolver.so.10'; dlerror: libcusolver.so.10: cannot open shared object file: No such file or directory\n",
      "2024-09-28 23:21:37.060744: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcusparse.so.10'; dlerror: libcusparse.so.10: cannot open shared object file: No such file or directory\n",
      "2024-09-28 23:21:37.060777: W tensorflow/stream_executor/platform/default/dso_loader.cc:55] Could not load dynamic library 'libcudnn.so.7'; dlerror: libcudnn.so.7: cannot open shared object file: No such file or directory\n",
      "2024-09-28 23:21:37.060781: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1598] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.\n",
      "Skipping registering GPU devices...\n",
      "2024-09-28 23:21:37.060924: I tensorflow/core/platform/cpu_feature_guard.cc:143] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 AVX512F FMA\n",
      "2024-09-28 23:21:37.073169: I tensorflow/core/platform/profile_utils/cpu_utils.cc:102] CPU Frequency: 4099895000 Hz\n",
      "2024-09-28 23:21:37.073707: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7c44e0000b70 initialized for platform Host (this does not guarantee that XLA will be used). Devices:\n",
      "2024-09-28 23:21:37.073738: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version\n",
      "2024-09-28 23:21:37.075808: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1102] Device interconnect StreamExecutor with strength 1 edge matrix:\n",
      "2024-09-28 23:21:37.075831: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1108]      \n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            (None, 10, 52)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_3 (InputLayer)            (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 26552       input_1[0][0]                    \n",
      "                                                                 input_3[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 52)       561073      encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 587,625\n",
      "Trainable params: 587,625\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "24/24 [==============================] - 2s 73ms/step - loss: 340.8502 - val_loss: 280.5778\n",
      "Epoch 2/30\n",
      "24/24 [==============================] - 1s 24ms/step - loss: 294.1462 - val_loss: 275.7045\n",
      "Epoch 3/30\n",
      "24/24 [==============================] - 1s 24ms/step - loss: 289.2836 - val_loss: 272.2307\n",
      "Epoch 4/30\n",
      "24/24 [==============================] - 1s 22ms/step - loss: 285.1106 - val_loss: 269.4312\n",
      "Epoch 5/30\n",
      "24/24 [==============================] - 1s 22ms/step - loss: 282.2206 - val_loss: 268.2137\n",
      "Epoch 6/30\n",
      "24/24 [==============================] - 1s 23ms/step - loss: 280.7293 - val_loss: 267.2141\n",
      "Epoch 7/30\n",
      "24/24 [==============================] - 1s 24ms/step - loss: 279.6524 - val_loss: 266.9448\n",
      "Epoch 8/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 278.8973 - val_loss: 266.3832\n",
      "Epoch 9/30\n",
      "24/24 [==============================] - 1s 24ms/step - loss: 278.5037 - val_loss: 265.9495\n",
      "Epoch 10/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 277.8974 - val_loss: 265.2920\n",
      "Epoch 11/30\n",
      "24/24 [==============================] - 1s 23ms/step - loss: 277.4114 - val_loss: 265.1194\n",
      "Epoch 12/30\n",
      "24/24 [==============================] - 0s 21ms/step - loss: 277.1659 - val_loss: 264.6302\n",
      "Epoch 13/30\n",
      "24/24 [==============================] - 0s 20ms/step - loss: 276.4993 - val_loss: 264.1677\n",
      "Epoch 14/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 276.0829 - val_loss: 263.6472\n",
      "Epoch 15/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 275.5631 - val_loss: 262.9081\n",
      "Epoch 16/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 275.0049 - val_loss: 262.1812\n",
      "Epoch 17/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 274.5492 - val_loss: 261.6141\n",
      "Epoch 18/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 273.9175 - val_loss: 261.4175\n",
      "Epoch 19/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 273.3565 - val_loss: 261.0345\n",
      "Epoch 20/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 273.0234 - val_loss: 260.5263\n",
      "Epoch 21/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 272.8430 - val_loss: 260.1158\n",
      "Epoch 22/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 272.4196 - val_loss: 260.2961\n",
      "Epoch 23/30\n",
      "24/24 [==============================] - 0s 20ms/step - loss: 271.9117 - val_loss: 259.6956\n",
      "Epoch 24/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 271.6609 - val_loss: 259.3927\n",
      "Epoch 25/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 271.2548 - val_loss: 259.7564\n",
      "Epoch 26/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 271.1545 - val_loss: 258.8534\n",
      "Epoch 27/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 270.7125 - val_loss: 258.8961\n",
      "Epoch 28/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 270.5169 - val_loss: 258.8687\n",
      "Epoch 29/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 270.2190 - val_loss: 258.5956\n",
      "Epoch 30/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 270.0452 - val_loss: 258.6232\n",
      "80% Train LR= 0.001  r2-3D= 0.735  r2-2D= 0.6643\n",
      "20% Test  LR= 0.001  r2-3D= 0.727  r2-2D= 0.6643\n",
      "Chewie_20161007_neural_con_dis_index.mat\n",
      "(4000, 10, 70)\n",
      "20\n",
      "(1320, 10, 70)\n",
      "6\n",
      "(1400, 10, 70)\n",
      "7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_4 (InputLayer)            (None, 10, 70)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_6 (InputLayer)            (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 28856       input_4[0][0]                    \n",
      "                                                                 input_6[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 70)       1015168     encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 1,044,024\n",
      "Trainable params: 1,044,024\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "20/20 [==============================] - 2s 79ms/step - loss: 509.1984 - val_loss: 456.5847\n",
      "Epoch 2/30\n",
      "20/20 [==============================] - 0s 24ms/step - loss: 459.5318 - val_loss: 447.8467\n",
      "Epoch 3/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 455.4457 - val_loss: 445.7658\n",
      "Epoch 4/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 453.9036 - val_loss: 443.0821\n",
      "Epoch 5/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 452.1447 - val_loss: 440.4364\n",
      "Epoch 6/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 448.2412 - val_loss: 436.1971\n",
      "Epoch 7/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 445.9436 - val_loss: 434.3904\n",
      "Epoch 8/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 444.3732 - val_loss: 435.3487\n",
      "Epoch 9/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 443.1647 - val_loss: 433.5012\n",
      "Epoch 10/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 442.3306 - val_loss: 433.0740\n",
      "Epoch 11/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 441.2107 - val_loss: 432.0532\n",
      "Epoch 12/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 440.1737 - val_loss: 431.0114\n",
      "Epoch 13/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 439.4467 - val_loss: 429.2687\n",
      "Epoch 14/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 438.2310 - val_loss: 428.3387\n",
      "Epoch 15/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 437.5628 - val_loss: 428.4729\n",
      "Epoch 16/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 437.1804 - val_loss: 427.7363\n",
      "Epoch 17/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 436.8860 - val_loss: 427.1593\n",
      "Epoch 18/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 436.5662 - val_loss: 426.7782\n",
      "Epoch 19/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 435.0294 - val_loss: 425.1551\n",
      "Epoch 20/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 434.3635 - val_loss: 424.3416\n",
      "Epoch 21/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 433.9106 - val_loss: 424.2174\n",
      "Epoch 22/30\n",
      "20/20 [==============================] - 0s 22ms/step - loss: 433.7300 - val_loss: 424.0701\n",
      "Epoch 23/30\n",
      "20/20 [==============================] - 0s 22ms/step - loss: 433.1719 - val_loss: 423.8669\n",
      "Epoch 24/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 432.9557 - val_loss: 423.9435\n",
      "Epoch 25/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 432.8063 - val_loss: 423.2924\n",
      "Epoch 26/30\n",
      "20/20 [==============================] - 0s 22ms/step - loss: 432.6429 - val_loss: 423.3448\n",
      "Epoch 27/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 433.1160 - val_loss: 423.9859\n",
      "Epoch 28/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 432.3389 - val_loss: 423.1404\n",
      "Epoch 29/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 431.8551 - val_loss: 422.8136\n",
      "Epoch 30/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 431.8900 - val_loss: 423.3041\n",
      "80% Train LR= 0.001  r2-3D= 0.607  r2-2D= 0.581\n",
      "20% Test  LR= 0.001  r2-3D= 0.598  r2-2D= 0.5615\n",
      "Mihili_20140306_neural_con_dis_index.mat\n",
      "(5200, 10, 43)\n",
      "26\n",
      "(1720, 10, 43)\n",
      "8\n",
      "(1760, 10, 43)\n",
      "8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_7 (InputLayer)            (None, 10, 43)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_9 (InputLayer)            (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 25400       input_7[0][0]                    \n",
      "                                                                 input_9[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 43)       382100      encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 407,500\n",
      "Trainable params: 407,500\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "26/26 [==============================] - 2s 68ms/step - loss: 211.9564 - val_loss: 143.7982\n",
      "Epoch 2/30\n",
      "26/26 [==============================] - 1s 25ms/step - loss: 144.9195 - val_loss: 137.2298\n",
      "Epoch 3/30\n",
      "26/26 [==============================] - 1s 27ms/step - loss: 141.9686 - val_loss: 136.4602\n",
      "Epoch 4/30\n",
      "26/26 [==============================] - 1s 27ms/step - loss: 141.5531 - val_loss: 136.2949\n",
      "Epoch 5/30\n",
      "26/26 [==============================] - 1s 28ms/step - loss: 141.1989 - val_loss: 136.1590\n",
      "Epoch 6/30\n",
      "26/26 [==============================] - 1s 28ms/step - loss: 140.7164 - val_loss: 135.8214\n",
      "Epoch 7/30\n",
      "26/26 [==============================] - 1s 29ms/step - loss: 140.3506 - val_loss: 135.3839\n",
      "Epoch 8/30\n",
      "26/26 [==============================] - 1s 29ms/step - loss: 139.7324 - val_loss: 134.7940\n",
      "Epoch 9/30\n",
      "26/26 [==============================] - 1s 26ms/step - loss: 139.2351 - val_loss: 134.2696\n",
      "Epoch 10/30\n",
      "26/26 [==============================] - 1s 20ms/step - loss: 138.9943 - val_loss: 133.9754\n",
      "Epoch 11/30\n",
      "26/26 [==============================] - 1s 22ms/step - loss: 138.8046 - val_loss: 133.7803\n",
      "Epoch 12/30\n",
      "26/26 [==============================] - 1s 21ms/step - loss: 138.6028 - val_loss: 133.7852\n",
      "Epoch 13/30\n",
      "26/26 [==============================] - 1s 21ms/step - loss: 138.4839 - val_loss: 133.6521\n",
      "Epoch 14/30\n",
      "26/26 [==============================] - 1s 27ms/step - loss: 138.3508 - val_loss: 133.6066\n",
      "Epoch 15/30\n",
      "26/26 [==============================] - 1s 20ms/step - loss: 138.2143 - val_loss: 133.4815\n",
      "Epoch 16/30\n",
      "26/26 [==============================] - 0s 19ms/step - loss: 138.0705 - val_loss: 133.4369\n",
      "Epoch 17/30\n",
      "26/26 [==============================] - 1s 27ms/step - loss: 137.9941 - val_loss: 133.1860\n",
      "Epoch 18/30\n",
      "26/26 [==============================] - 1s 27ms/step - loss: 137.8440 - val_loss: 133.3446\n",
      "Epoch 19/30\n",
      "26/26 [==============================] - 1s 28ms/step - loss: 137.6792 - val_loss: 133.2408\n",
      "Epoch 20/30\n",
      "26/26 [==============================] - 1s 27ms/step - loss: 137.5517 - val_loss: 133.1978\n",
      "Epoch 21/30\n",
      "26/26 [==============================] - 1s 29ms/step - loss: 137.5024 - val_loss: 133.2937\n",
      "Epoch 22/30\n",
      "26/26 [==============================] - 1s 29ms/step - loss: 137.3709 - val_loss: 133.2252\n",
      "Epoch 23/30\n",
      "26/26 [==============================] - 1s 28ms/step - loss: 137.3156 - val_loss: 133.1683\n",
      "Epoch 24/30\n",
      "26/26 [==============================] - 1s 28ms/step - loss: 137.2124 - val_loss: 133.1729\n",
      "Epoch 25/30\n",
      "26/26 [==============================] - 1s 28ms/step - loss: 137.1635 - val_loss: 133.3441\n",
      "Epoch 26/30\n",
      "26/26 [==============================] - 1s 29ms/step - loss: 137.1431 - val_loss: 133.2551\n",
      "Epoch 27/30\n",
      "26/26 [==============================] - 1s 26ms/step - loss: 137.0474 - val_loss: 133.2190\n",
      "Epoch 28/30\n",
      "26/26 [==============================] - 1s 28ms/step - loss: 137.0179 - val_loss: 133.3282\n",
      "Epoch 29/30\n",
      "26/26 [==============================] - 1s 26ms/step - loss: 136.9765 - val_loss: 133.4245\n",
      "Epoch 30/30\n",
      "26/26 [==============================] - 1s 29ms/step - loss: 136.9456 - val_loss: 133.5136\n",
      "80% Train LR= 0.001  r2-3D= 0.652  r2-2D= 0.6359\n",
      "20% Test  LR= 0.001  r2-3D= 0.666  r2-2D= 0.6526\n",
      "Mihili_20140218_neural_con_dis_index.mat\n",
      "(5400, 10, 38)\n",
      "27\n",
      "(1800, 10, 38)\n",
      "9\n",
      "(1800, 10, 38)\n",
      "9\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_10 (InputLayer)           (None, 10, 38)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_12 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 24760       input_10[0][0]                   \n",
      "                                                                 input_12[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 38)       300288      encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 325,048\n",
      "Trainable params: 325,048\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "27/27 [==============================] - 2s 68ms/step - loss: 242.0309 - val_loss: 193.7683\n",
      "Epoch 2/30\n",
      "27/27 [==============================] - 1s 25ms/step - loss: 196.1057 - val_loss: 188.3077\n",
      "Epoch 3/30\n",
      "27/27 [==============================] - 1s 23ms/step - loss: 193.3108 - val_loss: 187.4796\n",
      "Epoch 4/30\n",
      "27/27 [==============================] - 1s 22ms/step - loss: 192.3970 - val_loss: 186.3179\n",
      "Epoch 5/30\n",
      "27/27 [==============================] - 1s 27ms/step - loss: 190.9481 - val_loss: 183.6638\n",
      "Epoch 6/30\n",
      "27/27 [==============================] - 1s 19ms/step - loss: 188.0635 - val_loss: 181.2261\n",
      "Epoch 7/30\n",
      "27/27 [==============================] - 1s 20ms/step - loss: 185.9508 - val_loss: 180.0222\n",
      "Epoch 8/30\n",
      "27/27 [==============================] - 1s 22ms/step - loss: 184.7092 - val_loss: 178.9177\n",
      "Epoch 9/30\n",
      "27/27 [==============================] - 1s 22ms/step - loss: 183.9869 - val_loss: 178.7024\n",
      "Epoch 10/30\n",
      "27/27 [==============================] - 1s 19ms/step - loss: 183.5352 - val_loss: 178.3321\n",
      "Epoch 11/30\n",
      "27/27 [==============================] - 1s 20ms/step - loss: 182.9658 - val_loss: 178.4333\n",
      "Epoch 12/30\n",
      "27/27 [==============================] - 1s 21ms/step - loss: 182.7088 - val_loss: 178.2079\n",
      "Epoch 13/30\n",
      "27/27 [==============================] - 1s 19ms/step - loss: 182.3324 - val_loss: 178.1794\n",
      "Epoch 14/30\n",
      "27/27 [==============================] - 1s 23ms/step - loss: 182.0693 - val_loss: 178.0921\n",
      "Epoch 15/30\n",
      "27/27 [==============================] - 1s 25ms/step - loss: 181.7977 - val_loss: 177.4324\n",
      "Epoch 16/30\n",
      "27/27 [==============================] - 1s 25ms/step - loss: 181.4664 - val_loss: 177.1941\n",
      "Epoch 17/30\n",
      "27/27 [==============================] - 1s 20ms/step - loss: 181.0890 - val_loss: 176.8227\n",
      "Epoch 18/30\n",
      "27/27 [==============================] - 1s 22ms/step - loss: 180.8486 - val_loss: 177.1811\n",
      "Epoch 19/30\n",
      "27/27 [==============================] - 1s 21ms/step - loss: 180.8372 - val_loss: 176.9361\n",
      "Epoch 20/30\n",
      "27/27 [==============================] - 1s 22ms/step - loss: 180.4909 - val_loss: 176.7003\n",
      "Epoch 21/30\n",
      "27/27 [==============================] - 1s 28ms/step - loss: 180.4079 - val_loss: 176.5674\n",
      "Epoch 22/30\n",
      "27/27 [==============================] - 0s 18ms/step - loss: 180.2229 - val_loss: 176.5403\n",
      "Epoch 23/30\n",
      "27/27 [==============================] - 0s 18ms/step - loss: 180.0841 - val_loss: 176.1704\n",
      "Epoch 24/30\n",
      "27/27 [==============================] - 1s 19ms/step - loss: 179.9004 - val_loss: 176.0097\n",
      "Epoch 25/30\n",
      "27/27 [==============================] - 1s 19ms/step - loss: 179.7289 - val_loss: 176.2950\n",
      "Epoch 26/30\n",
      "27/27 [==============================] - 0s 18ms/step - loss: 179.7940 - val_loss: 176.4930\n",
      "Epoch 27/30\n",
      "27/27 [==============================] - 1s 22ms/step - loss: 179.7159 - val_loss: 176.1816\n",
      "Epoch 28/30\n",
      "27/27 [==============================] - 1s 24ms/step - loss: 179.4767 - val_loss: 175.8645\n",
      "Epoch 29/30\n",
      "27/27 [==============================] - 1s 35ms/step - loss: 179.3428 - val_loss: 175.6662\n",
      "Epoch 30/30\n",
      "27/27 [==============================] - 1s 37ms/step - loss: 179.1690 - val_loss: 175.8036\n",
      "80% Train LR= 0.001  r2-3D= 0.555  r2-2D= 0.5496\n",
      "20% Test  LR= 0.001  r2-3D= 0.547  r2-2D= 0.5444\n",
      "Chewie_20150319_neural_con_dis_index.mat\n",
      "(24600, 10, 72)\n",
      "123\n",
      "(8200, 10, 72)\n",
      "41\n",
      "(8240, 10, 72)\n",
      "41\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_13 (InputLayer)           (None, 10, 72)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_15 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 29112       input_13[0][0]                   \n",
      "                                                                 input_15[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 72)       1073873     encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 1,102,985\n",
      "Trainable params: 1,102,985\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "123/123 [==============================] - 4s 34ms/step - loss: 485.4338 - val_loss: 492.6667\n",
      "Epoch 2/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 472.9664 - val_loss: 483.9709\n",
      "Epoch 3/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 466.7204 - val_loss: 482.2994\n",
      "Epoch 4/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 464.4972 - val_loss: 480.0728\n",
      "Epoch 5/30\n",
      "123/123 [==============================] - 3s 22ms/step - loss: 463.1376 - val_loss: 479.3752\n",
      "Epoch 6/30\n",
      "123/123 [==============================] - 3s 22ms/step - loss: 462.1404 - val_loss: 478.3223\n",
      "Epoch 7/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 461.7710 - val_loss: 477.7016\n",
      "Epoch 8/30\n",
      "123/123 [==============================] - 3s 22ms/step - loss: 460.9703 - val_loss: 477.1130\n",
      "Epoch 9/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 460.4211 - val_loss: 476.7625\n",
      "Epoch 10/30\n",
      "123/123 [==============================] - 3s 22ms/step - loss: 460.0913 - val_loss: 476.8478\n",
      "Epoch 11/30\n",
      "123/123 [==============================] - 3s 22ms/step - loss: 459.7634 - val_loss: 476.2652\n",
      "Epoch 12/30\n",
      "123/123 [==============================] - 3s 22ms/step - loss: 459.4853 - val_loss: 475.8038\n",
      "Epoch 13/30\n",
      "123/123 [==============================] - 3s 22ms/step - loss: 458.8670 - val_loss: 475.6623\n",
      "Epoch 14/30\n",
      "123/123 [==============================] - 3s 22ms/step - loss: 458.7979 - val_loss: 475.9157\n",
      "Epoch 15/30\n",
      "123/123 [==============================] - 3s 22ms/step - loss: 458.4905 - val_loss: 475.7699\n",
      "Epoch 16/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 458.7056 - val_loss: 475.6452\n",
      "Epoch 17/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 458.8887 - val_loss: 476.2317\n",
      "Epoch 18/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 458.6960 - val_loss: 475.7155\n",
      "Epoch 19/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 458.5380 - val_loss: 475.5696\n",
      "Epoch 20/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 458.4196 - val_loss: 475.7440\n",
      "Epoch 21/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 458.0706 - val_loss: 475.8653\n",
      "Epoch 22/30\n",
      "123/123 [==============================] - 3s 26ms/step - loss: 458.1081 - val_loss: 476.0979\n",
      "Epoch 23/30\n",
      "123/123 [==============================] - 3s 24ms/step - loss: 458.0140 - val_loss: 475.5320\n",
      "Epoch 24/30\n",
      "123/123 [==============================] - 3s 24ms/step - loss: 458.1081 - val_loss: 476.5894\n",
      "Epoch 25/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 457.8933 - val_loss: 475.8920\n",
      "Epoch 26/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 457.8658 - val_loss: 476.2831\n",
      "Epoch 27/30\n",
      "123/123 [==============================] - 3s 25ms/step - loss: 457.5440 - val_loss: 476.2704\n",
      "Epoch 28/30\n",
      "123/123 [==============================] - 3s 25ms/step - loss: 457.5248 - val_loss: 475.2183\n",
      "Epoch 29/30\n",
      "123/123 [==============================] - 3s 22ms/step - loss: 457.4936 - val_loss: 475.4597\n",
      "Epoch 30/30\n",
      "123/123 [==============================] - 3s 23ms/step - loss: 457.6657 - val_loss: 475.6938\n",
      "80% Train LR= 0.001  r2-3D= 0.399  r2-2D= 0.1481\n",
      "20% Test  LR= 0.001  r2-3D= 0.395  r2-2D= 0.128\n",
      "Chewie_20150629_neural_con_dis_index.mat\n",
      "(4120, 10, 49)\n",
      "20\n",
      "(1360, 10, 49)\n",
      "6\n",
      "(1400, 10, 49)\n",
      "7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_16 (InputLayer)           (None, 10, 49)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_18 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 26168       input_16[0][0]                   \n",
      "                                                                 input_18[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 49)       496055      encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 522,223\n",
      "Trainable params: 522,223\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "20/20 [==============================] - 2s 78ms/step - loss: 351.8375 - val_loss: 303.7609\n",
      "Epoch 2/30\n",
      "20/20 [==============================] - 1s 26ms/step - loss: 314.4512 - val_loss: 296.4455\n",
      "Epoch 3/30\n",
      "20/20 [==============================] - 1s 28ms/step - loss: 310.8125 - val_loss: 294.2620\n",
      "Epoch 4/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 309.3679 - val_loss: 292.9599\n",
      "Epoch 5/30\n",
      "20/20 [==============================] - 0s 21ms/step - loss: 308.3982 - val_loss: 292.3051\n",
      "Epoch 6/30\n",
      "20/20 [==============================] - 1s 26ms/step - loss: 307.6872 - val_loss: 291.3272\n",
      "Epoch 7/30\n",
      "20/20 [==============================] - 1s 28ms/step - loss: 307.1649 - val_loss: 291.3142\n",
      "Epoch 8/30\n",
      "20/20 [==============================] - 1s 29ms/step - loss: 306.6677 - val_loss: 290.8365\n",
      "Epoch 9/30\n",
      "20/20 [==============================] - 1s 31ms/step - loss: 306.2660 - val_loss: 290.4280\n",
      "Epoch 10/30\n",
      "20/20 [==============================] - 1s 30ms/step - loss: 305.8014 - val_loss: 289.9120\n",
      "Epoch 11/30\n",
      "20/20 [==============================] - 1s 29ms/step - loss: 305.3109 - val_loss: 289.3806\n",
      "Epoch 12/30\n",
      "20/20 [==============================] - 1s 31ms/step - loss: 304.6145 - val_loss: 288.4020\n",
      "Epoch 13/30\n",
      "20/20 [==============================] - 1s 29ms/step - loss: 304.0901 - val_loss: 287.7396\n",
      "Epoch 14/30\n",
      "20/20 [==============================] - 1s 31ms/step - loss: 303.7292 - val_loss: 287.0956\n",
      "Epoch 15/30\n",
      "20/20 [==============================] - 1s 27ms/step - loss: 303.1305 - val_loss: 286.8669\n",
      "Epoch 16/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 302.8523 - val_loss: 286.9388\n",
      "Epoch 17/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 302.4058 - val_loss: 286.7431\n",
      "Epoch 18/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 302.1807 - val_loss: 286.5066\n",
      "Epoch 19/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 302.0832 - val_loss: 286.1382\n",
      "Epoch 20/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 301.7479 - val_loss: 286.0337\n",
      "Epoch 21/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 301.4078 - val_loss: 285.6189\n",
      "Epoch 22/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 301.2691 - val_loss: 285.1928\n",
      "Epoch 23/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 301.1058 - val_loss: 284.5581\n",
      "Epoch 24/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 300.9583 - val_loss: 284.4868\n",
      "Epoch 25/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 301.0379 - val_loss: 284.3906\n",
      "Epoch 26/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 300.9930 - val_loss: 284.7173\n",
      "Epoch 27/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 301.0574 - val_loss: 284.4496\n",
      "Epoch 28/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 300.6570 - val_loss: 283.9000\n",
      "Epoch 29/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 299.9146 - val_loss: 283.7434\n",
      "Epoch 30/30\n",
      "20/20 [==============================] - 0s 19ms/step - loss: 299.5076 - val_loss: 283.8085\n",
      "80% Train LR= 0.001  r2-3D= 0.645  r2-2D= 0.6162\n",
      "20% Test  LR= 0.001  r2-3D= 0.602  r2-2D= 0.574\n",
      "Chewie_20161014_neural_con_dis_index.mat\n",
      "(17760, 10, 88)\n",
      "88\n",
      "(5920, 10, 88)\n",
      "29\n",
      "(5920, 10, 88)\n",
      "29\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_19 (InputLayer)           (None, 10, 88)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_21 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 31160       input_19[0][0]                   \n",
      "                                                                 input_21[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 88)       1602913     encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 1,634,073\n",
      "Trainable params: 1,634,073\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "88/88 [==============================] - 4s 43ms/step - loss: 556.7380 - val_loss: 526.2131\n",
      "Epoch 2/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 531.8386 - val_loss: 517.2676\n",
      "Epoch 3/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 525.4936 - val_loss: 512.1588\n",
      "Epoch 4/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 522.0873 - val_loss: 510.1337\n",
      "Epoch 5/30\n",
      "88/88 [==============================] - 3s 29ms/step - loss: 520.1547 - val_loss: 508.7078\n",
      "Epoch 6/30\n",
      "88/88 [==============================] - 3s 29ms/step - loss: 518.8114 - val_loss: 508.4574\n",
      "Epoch 7/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 517.5212 - val_loss: 507.6764\n",
      "Epoch 8/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 516.3010 - val_loss: 507.0687\n",
      "Epoch 9/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 515.6039 - val_loss: 506.8064\n",
      "Epoch 10/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 514.9491 - val_loss: 506.7090\n",
      "Epoch 11/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 514.4758 - val_loss: 506.1352\n",
      "Epoch 12/30\n",
      "88/88 [==============================] - 2s 27ms/step - loss: 513.7352 - val_loss: 504.7100\n",
      "Epoch 13/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 513.1320 - val_loss: 505.4231\n",
      "Epoch 14/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 512.7445 - val_loss: 504.6395\n",
      "Epoch 15/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 512.5827 - val_loss: 504.8209\n",
      "Epoch 16/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 512.4214 - val_loss: 504.0807\n",
      "Epoch 17/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 511.9268 - val_loss: 503.7593\n",
      "Epoch 18/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 511.5275 - val_loss: 503.7663\n",
      "Epoch 19/30\n",
      "88/88 [==============================] - 3s 29ms/step - loss: 511.3146 - val_loss: 503.3054\n",
      "Epoch 20/30\n",
      "88/88 [==============================] - 3s 29ms/step - loss: 511.1490 - val_loss: 503.6688\n",
      "Epoch 21/30\n",
      "88/88 [==============================] - 3s 29ms/step - loss: 510.8458 - val_loss: 503.6984\n",
      "Epoch 22/30\n",
      "88/88 [==============================] - 3s 29ms/step - loss: 510.8995 - val_loss: 502.7043\n",
      "Epoch 23/30\n",
      "88/88 [==============================] - 3s 32ms/step - loss: 510.5918 - val_loss: 503.3678\n",
      "Epoch 24/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 510.3552 - val_loss: 503.4792\n",
      "Epoch 25/30\n",
      "88/88 [==============================] - 3s 29ms/step - loss: 510.3380 - val_loss: 503.5833\n",
      "Epoch 26/30\n",
      "88/88 [==============================] - 3s 29ms/step - loss: 510.0071 - val_loss: 502.8969\n",
      "Epoch 27/30\n",
      "88/88 [==============================] - 3s 29ms/step - loss: 509.8469 - val_loss: 503.4214\n",
      "Epoch 28/30\n",
      "88/88 [==============================] - 3s 32ms/step - loss: 509.7911 - val_loss: 502.7094\n",
      "Epoch 29/30\n",
      "88/88 [==============================] - 2s 28ms/step - loss: 509.4954 - val_loss: 501.8621\n",
      "Epoch 30/30\n",
      "88/88 [==============================] - 2s 27ms/step - loss: 509.5396 - val_loss: 502.1087\n",
      "80% Train LR= 0.001  r2-3D= 0.396  r2-2D= 0.2743\n",
      "20% Test  LR= 0.001  r2-3D= 0.372  r2-2D= 0.2504\n",
      "Chewie_20150313_neural_con_dis_index.mat\n",
      "(24880, 10, 86)\n",
      "124\n",
      "(8280, 10, 86)\n",
      "41\n",
      "(8360, 10, 86)\n",
      "41\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_22 (InputLayer)           (None, 10, 86)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_24 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 30904       input_22[0][0]                   \n",
      "                                                                 input_24[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 86)       1531008     encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 1,561,912\n",
      "Trainable params: 1,561,912\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "124/124 [==============================] - 4s 36ms/step - loss: 537.4827 - val_loss: 522.8871\n",
      "Epoch 2/30\n",
      "124/124 [==============================] - 3s 28ms/step - loss: 516.7570 - val_loss: 514.6785\n",
      "Epoch 3/30\n",
      "124/124 [==============================] - 4s 28ms/step - loss: 512.6414 - val_loss: 512.1890\n",
      "Epoch 4/30\n",
      "124/124 [==============================] - 3s 28ms/step - loss: 510.8287 - val_loss: 510.7997\n",
      "Epoch 5/30\n",
      "124/124 [==============================] - 3s 28ms/step - loss: 509.5867 - val_loss: 509.7428\n",
      "Epoch 6/30\n",
      "124/124 [==============================] - 3s 28ms/step - loss: 508.5614 - val_loss: 509.1609\n",
      "Epoch 7/30\n",
      "124/124 [==============================] - 4s 29ms/step - loss: 507.7749 - val_loss: 508.2520\n",
      "Epoch 8/30\n",
      "124/124 [==============================] - 4s 30ms/step - loss: 507.2714 - val_loss: 507.7680\n",
      "Epoch 9/30\n",
      "124/124 [==============================] - 4s 28ms/step - loss: 506.7096 - val_loss: 507.1288\n",
      "Epoch 10/30\n",
      "124/124 [==============================] - 3s 27ms/step - loss: 506.2629 - val_loss: 507.5726\n",
      "Epoch 11/30\n",
      "124/124 [==============================] - 3s 26ms/step - loss: 505.9142 - val_loss: 507.1282\n",
      "Epoch 12/30\n",
      "124/124 [==============================] - 3s 27ms/step - loss: 505.6310 - val_loss: 506.8398\n",
      "Epoch 13/30\n",
      "124/124 [==============================] - 3s 26ms/step - loss: 505.3570 - val_loss: 506.4266\n",
      "Epoch 14/30\n",
      "124/124 [==============================] - 3s 28ms/step - loss: 505.0716 - val_loss: 506.7587\n",
      "Epoch 15/30\n",
      "124/124 [==============================] - 3s 27ms/step - loss: 504.8442 - val_loss: 506.1258\n",
      "Epoch 16/30\n",
      "124/124 [==============================] - 3s 27ms/step - loss: 504.7211 - val_loss: 506.1875\n",
      "Epoch 17/30\n",
      "124/124 [==============================] - 3s 27ms/step - loss: 504.4987 - val_loss: 505.5794\n",
      "Epoch 18/30\n",
      "124/124 [==============================] - 3s 27ms/step - loss: 504.3258 - val_loss: 506.0383\n",
      "Epoch 19/30\n",
      "124/124 [==============================] - 3s 27ms/step - loss: 504.2479 - val_loss: 505.5466\n",
      "Epoch 20/30\n",
      "124/124 [==============================] - 3s 28ms/step - loss: 504.1031 - val_loss: 505.2625\n",
      "Epoch 21/30\n",
      "124/124 [==============================] - 3s 27ms/step - loss: 503.9779 - val_loss: 505.5456\n",
      "Epoch 22/30\n",
      "124/124 [==============================] - 4s 30ms/step - loss: 503.7850 - val_loss: 505.8421\n",
      "Epoch 23/30\n",
      "124/124 [==============================] - 4s 29ms/step - loss: 503.6926 - val_loss: 504.4779\n",
      "Epoch 24/30\n",
      "124/124 [==============================] - 3s 27ms/step - loss: 503.5515 - val_loss: 505.5856\n",
      "Epoch 25/30\n",
      "124/124 [==============================] - 3s 28ms/step - loss: 503.4285 - val_loss: 504.7471\n",
      "Epoch 26/30\n",
      "124/124 [==============================] - 4s 29ms/step - loss: 503.3088 - val_loss: 504.6695\n",
      "Epoch 27/30\n",
      "124/124 [==============================] - 3s 27ms/step - loss: 503.2376 - val_loss: 505.3358\n",
      "Epoch 28/30\n",
      "124/124 [==============================] - 3s 28ms/step - loss: 503.2374 - val_loss: 505.1407\n",
      "Epoch 29/30\n",
      "124/124 [==============================] - 3s 27ms/step - loss: 503.1660 - val_loss: 504.4849\n",
      "Epoch 30/30\n",
      "124/124 [==============================] - 3s 27ms/step - loss: 503.0853 - val_loss: 504.8730\n",
      "80% Train LR= 0.001  r2-3D= 0.338  r2-2D= 0.134\n",
      "20% Test  LR= 0.001  r2-3D= 0.348  r2-2D= 0.1308\n",
      "Chewie_20161005_neural_con_dis_index.mat\n",
      "(4840, 10, 82)\n",
      "24\n",
      "(1600, 10, 82)\n",
      "8\n",
      "(1640, 10, 82)\n",
      "8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_25 (InputLayer)           (None, 10, 82)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_27 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 30392       input_25[0][0]                   \n",
      "                                                                 input_27[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 82)       1392148     encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 1,422,540\n",
      "Trainable params: 1,422,540\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "24/24 [==============================] - 2s 76ms/step - loss: 590.2281 - val_loss: 500.2172\n",
      "Epoch 2/30\n",
      "24/24 [==============================] - 1s 28ms/step - loss: 519.6820 - val_loss: 486.5146\n",
      "Epoch 3/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 512.5834 - val_loss: 483.9651\n",
      "Epoch 4/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 510.4149 - val_loss: 482.7147\n",
      "Epoch 5/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 508.1341 - val_loss: 481.0869\n",
      "Epoch 6/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 505.2724 - val_loss: 479.4126\n",
      "Epoch 7/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 504.2005 - val_loss: 479.4449\n",
      "Epoch 8/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 503.0457 - val_loss: 478.4508\n",
      "Epoch 9/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 502.0969 - val_loss: 477.9217\n",
      "Epoch 10/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 501.1949 - val_loss: 477.3165\n",
      "Epoch 11/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 500.6868 - val_loss: 477.5689\n",
      "Epoch 12/30\n",
      "24/24 [==============================] - 1s 29ms/step - loss: 499.8111 - val_loss: 476.9132\n",
      "Epoch 13/30\n",
      "24/24 [==============================] - 1s 33ms/step - loss: 499.3260 - val_loss: 476.2262\n",
      "Epoch 14/30\n",
      "24/24 [==============================] - 1s 31ms/step - loss: 498.9091 - val_loss: 476.3788\n",
      "Epoch 15/30\n",
      "24/24 [==============================] - 1s 28ms/step - loss: 497.8704 - val_loss: 474.7920\n",
      "Epoch 16/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 497.0328 - val_loss: 474.1286\n",
      "Epoch 17/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 496.0286 - val_loss: 473.5971\n",
      "Epoch 18/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 494.6025 - val_loss: 472.0879\n",
      "Epoch 19/30\n",
      "24/24 [==============================] - 1s 29ms/step - loss: 493.1187 - val_loss: 471.0836\n",
      "Epoch 20/30\n",
      "24/24 [==============================] - 1s 29ms/step - loss: 492.3573 - val_loss: 470.6121\n",
      "Epoch 21/30\n",
      "24/24 [==============================] - 1s 29ms/step - loss: 491.5685 - val_loss: 470.4684\n",
      "Epoch 22/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 490.9565 - val_loss: 470.1326\n",
      "Epoch 23/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 489.9256 - val_loss: 469.3407\n",
      "Epoch 24/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 488.3753 - val_loss: 467.0282\n",
      "Epoch 25/30\n",
      "24/24 [==============================] - 1s 28ms/step - loss: 487.4357 - val_loss: 466.6822\n",
      "Epoch 26/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 486.4643 - val_loss: 465.5466\n",
      "Epoch 27/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 485.7255 - val_loss: 464.8790\n",
      "Epoch 28/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 485.2511 - val_loss: 463.9666\n",
      "Epoch 29/30\n",
      "24/24 [==============================] - 1s 28ms/step - loss: 484.5987 - val_loss: 463.6575\n",
      "Epoch 30/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 483.9426 - val_loss: 462.7241\n",
      "80% Train LR= 0.001  r2-3D= 0.385  r2-2D= 0.3069\n",
      "20% Test  LR= 0.001  r2-3D= 0.352  r2-2D= 0.2735\n",
      "Chewie_20161021_neural_con_dis_index.mat\n",
      "(6840, 10, 84)\n",
      "34\n",
      "(2280, 10, 84)\n",
      "11\n",
      "(2320, 10, 84)\n",
      "11\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_28 (InputLayer)           (None, 10, 84)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_30 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 30648       input_28[0][0]                   \n",
      "                                                                 input_30[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 84)       1460753     encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 1,491,401\n",
      "Trainable params: 1,491,401\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "34/34 [==============================] - 2s 59ms/step - loss: 579.8974 - val_loss: 545.9285\n",
      "Epoch 2/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 551.9757 - val_loss: 541.6871\n",
      "Epoch 3/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 547.7850 - val_loss: 538.9152\n",
      "Epoch 4/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 544.4451 - val_loss: 536.5101\n",
      "Epoch 5/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 540.3178 - val_loss: 532.3030\n",
      "Epoch 6/30\n",
      "34/34 [==============================] - 1s 26ms/step - loss: 535.4183 - val_loss: 530.3420\n",
      "Epoch 7/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 532.1387 - val_loss: 528.4607\n",
      "Epoch 8/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 529.8469 - val_loss: 527.9081\n",
      "Epoch 9/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 528.1393 - val_loss: 526.2427\n",
      "Epoch 10/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 526.5449 - val_loss: 526.1694\n",
      "Epoch 11/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 525.2602 - val_loss: 525.4552\n",
      "Epoch 12/30\n",
      "34/34 [==============================] - 1s 28ms/step - loss: 524.1948 - val_loss: 524.1669\n",
      "Epoch 13/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 523.1128 - val_loss: 522.7721\n",
      "Epoch 14/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 521.7453 - val_loss: 522.3577\n",
      "Epoch 15/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 520.8712 - val_loss: 522.2493\n",
      "Epoch 16/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 520.2436 - val_loss: 521.5012\n",
      "Epoch 17/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 519.6076 - val_loss: 521.3041\n",
      "Epoch 18/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 519.2846 - val_loss: 521.2408\n",
      "Epoch 19/30\n",
      "34/34 [==============================] - 1s 28ms/step - loss: 519.2664 - val_loss: 520.8884\n",
      "Epoch 20/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 518.7918 - val_loss: 521.1227\n",
      "Epoch 21/30\n",
      "34/34 [==============================] - 1s 34ms/step - loss: 518.5191 - val_loss: 519.6395\n",
      "Epoch 22/30\n",
      "34/34 [==============================] - 1s 34ms/step - loss: 518.4129 - val_loss: 519.4513\n",
      "Epoch 23/30\n",
      "34/34 [==============================] - 1s 33ms/step - loss: 517.5547 - val_loss: 519.3767\n",
      "Epoch 24/30\n",
      "34/34 [==============================] - 1s 34ms/step - loss: 517.1452 - val_loss: 519.7625\n",
      "Epoch 25/30\n",
      "34/34 [==============================] - 1s 34ms/step - loss: 516.8610 - val_loss: 519.4854\n",
      "Epoch 26/30\n",
      "34/34 [==============================] - 1s 33ms/step - loss: 516.6931 - val_loss: 518.6984\n",
      "Epoch 27/30\n",
      "34/34 [==============================] - 1s 33ms/step - loss: 516.4112 - val_loss: 519.3644\n",
      "Epoch 28/30\n",
      "34/34 [==============================] - 1s 33ms/step - loss: 516.4921 - val_loss: 519.4769\n",
      "Epoch 29/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 516.3753 - val_loss: 519.6166\n",
      "Epoch 30/30\n",
      "34/34 [==============================] - 1s 27ms/step - loss: 516.4316 - val_loss: 517.9043\n",
      "80% Train LR= 0.001  r2-3D= 0.436  r2-2D= 0.1929\n",
      "20% Test  LR= 0.001  r2-3D= 0.412  r2-2D= 0.1435\n",
      "Chewie_20161006_neural_con_dis_index.mat\n",
      "(5000, 10, 63)\n",
      "25\n",
      "(1640, 10, 63)\n",
      "8\n",
      "(1720, 10, 63)\n",
      "8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_31 (InputLayer)           (None, 10, 63)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_33 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 27960       input_31[0][0]                   \n",
      "                                                                 input_33[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 63)       819700      encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 847,660\n",
      "Trainable params: 847,660\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "25/25 [==============================] - 2s 69ms/step - loss: 448.5253 - val_loss: 389.1006\n",
      "Epoch 2/30\n",
      "25/25 [==============================] - 1s 22ms/step - loss: 437.1294 - val_loss: 388.1016\n",
      "Epoch 3/30\n",
      "25/25 [==============================] - 1s 21ms/step - loss: 436.1916 - val_loss: 388.8814\n",
      "Epoch 4/30\n",
      "25/25 [==============================] - 1s 22ms/step - loss: 427.2653 - val_loss: 388.5619\n",
      "Epoch 5/30\n",
      "25/25 [==============================] - 1s 22ms/step - loss: 423.8463 - val_loss: 387.0998\n",
      "Epoch 6/30\n",
      "25/25 [==============================] - 1s 22ms/step - loss: 421.8775 - val_loss: 384.5617\n",
      "Epoch 7/30\n",
      "25/25 [==============================] - 1s 23ms/step - loss: 424.5729 - val_loss: 385.8512\n",
      "Epoch 8/30\n",
      "25/25 [==============================] - 1s 22ms/step - loss: 421.4153 - val_loss: 384.5440\n",
      "Epoch 9/30\n",
      "25/25 [==============================] - 1s 23ms/step - loss: 421.5339 - val_loss: 384.2408\n",
      "Epoch 10/30\n",
      "25/25 [==============================] - 1s 23ms/step - loss: 421.5277 - val_loss: 382.8440\n",
      "Epoch 11/30\n",
      "25/25 [==============================] - 1s 22ms/step - loss: 420.8726 - val_loss: 382.8991\n",
      "Epoch 12/30\n",
      "25/25 [==============================] - 1s 22ms/step - loss: 421.0507 - val_loss: 382.6378\n",
      "Epoch 13/30\n",
      "25/25 [==============================] - 1s 22ms/step - loss: 422.9441 - val_loss: 382.5994\n",
      "Epoch 14/30\n",
      "25/25 [==============================] - 1s 21ms/step - loss: 421.8952 - val_loss: 382.9541\n",
      "Epoch 15/30\n",
      "25/25 [==============================] - 1s 22ms/step - loss: 422.0354 - val_loss: 381.0152\n",
      "Epoch 16/30\n",
      "25/25 [==============================] - 1s 22ms/step - loss: 420.3804 - val_loss: 382.6252\n",
      "Epoch 17/30\n",
      "25/25 [==============================] - 1s 22ms/step - loss: 420.7468 - val_loss: 382.7586\n",
      "Epoch 18/30\n",
      "25/25 [==============================] - 1s 20ms/step - loss: 418.9021 - val_loss: 382.5311\n",
      "Epoch 19/30\n",
      "25/25 [==============================] - 1s 20ms/step - loss: 419.8858 - val_loss: 383.5986\n",
      "Epoch 20/30\n",
      "25/25 [==============================] - 1s 21ms/step - loss: 416.1656 - val_loss: 381.9489\n",
      "Epoch 21/30\n",
      "25/25 [==============================] - 1s 20ms/step - loss: 416.3785 - val_loss: 381.9070\n",
      "Epoch 22/30\n",
      "25/25 [==============================] - 1s 20ms/step - loss: 414.9818 - val_loss: 381.7342\n",
      "Epoch 23/30\n",
      "25/25 [==============================] - 1s 21ms/step - loss: 414.3063 - val_loss: 380.7409\n",
      "Epoch 24/30\n",
      "25/25 [==============================] - 1s 21ms/step - loss: 414.2998 - val_loss: 380.0343\n",
      "Epoch 25/30\n",
      "25/25 [==============================] - 1s 20ms/step - loss: 415.1585 - val_loss: 380.7875\n",
      "Epoch 26/30\n",
      "25/25 [==============================] - 1s 20ms/step - loss: 412.7251 - val_loss: 379.8399\n",
      "Epoch 27/30\n",
      "25/25 [==============================] - 1s 20ms/step - loss: 411.9490 - val_loss: 379.7953\n",
      "Epoch 28/30\n",
      "25/25 [==============================] - 1s 20ms/step - loss: 410.9298 - val_loss: 379.4107\n",
      "Epoch 29/30\n",
      "25/25 [==============================] - 1s 21ms/step - loss: 409.3702 - val_loss: 378.0297\n",
      "Epoch 30/30\n",
      "25/25 [==============================] - 1s 21ms/step - loss: 410.6149 - val_loss: 378.5804\n",
      "80% Train LR= 0.001  r2-3D= 0.424  r2-2D= 0.415\n",
      "20% Test  LR= 0.001  r2-3D= 0.419  r2-2D= 0.4093\n",
      "Chewie_20160929_neural_con_dis_index.mat\n",
      "(4960, 10, 74)\n",
      "24\n",
      "(1640, 10, 74)\n",
      "8\n",
      "(1720, 10, 74)\n",
      "8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_34 (InputLayer)           (None, 10, 74)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_36 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 29368       input_34[0][0]                   \n",
      "                                                                 input_36[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 74)       1134228     encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 1,163,596\n",
      "Trainable params: 1,163,596\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "24/24 [==============================] - 2s 74ms/step - loss: 492.5696 - val_loss: 444.0892\n",
      "Epoch 2/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 448.8489 - val_loss: 439.1057\n",
      "Epoch 3/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 445.5302 - val_loss: 436.9328\n",
      "Epoch 4/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 442.5430 - val_loss: 434.7142\n",
      "Epoch 5/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 440.2401 - val_loss: 434.8022\n",
      "Epoch 6/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 438.8594 - val_loss: 432.1127\n",
      "Epoch 7/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 437.2716 - val_loss: 431.6194\n",
      "Epoch 8/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 436.6185 - val_loss: 431.1040\n",
      "Epoch 9/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 435.2640 - val_loss: 430.0988\n",
      "Epoch 10/30\n",
      "24/24 [==============================] - 1s 24ms/step - loss: 433.7394 - val_loss: 428.9130\n",
      "Epoch 11/30\n",
      "24/24 [==============================] - 1s 24ms/step - loss: 432.2148 - val_loss: 427.3442\n",
      "Epoch 12/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 430.7244 - val_loss: 427.0627\n",
      "Epoch 13/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 429.7197 - val_loss: 426.2592\n",
      "Epoch 14/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 428.7883 - val_loss: 426.0439\n",
      "Epoch 15/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 427.9029 - val_loss: 426.1921\n",
      "Epoch 16/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 427.7752 - val_loss: 427.2111\n",
      "Epoch 17/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 427.7466 - val_loss: 423.9206\n",
      "Epoch 18/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 426.0385 - val_loss: 423.6142\n",
      "Epoch 19/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 425.3694 - val_loss: 422.2229\n",
      "Epoch 20/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 424.4623 - val_loss: 422.3896\n",
      "Epoch 21/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 423.7269 - val_loss: 422.2558\n",
      "Epoch 22/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 423.2851 - val_loss: 422.1916\n",
      "Epoch 23/30\n",
      "24/24 [==============================] - 1s 26ms/step - loss: 422.8055 - val_loss: 422.0358\n",
      "Epoch 24/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 422.2896 - val_loss: 421.3748\n",
      "Epoch 25/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 422.0616 - val_loss: 420.6380\n",
      "Epoch 26/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 421.3281 - val_loss: 420.5336\n",
      "Epoch 27/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 420.7381 - val_loss: 420.1704\n",
      "Epoch 28/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 420.3839 - val_loss: 419.7432\n",
      "Epoch 29/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 419.9710 - val_loss: 420.3567\n",
      "Epoch 30/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 420.0133 - val_loss: 420.1438\n",
      "80% Train LR= 0.001  r2-3D= 0.495  r2-2D= 0.3537\n",
      "20% Test  LR= 0.001  r2-3D= 0.453  r2-2D= 0.3016\n",
      "Chewie_20150630_neural_con_dis_index.mat\n",
      "(4120, 10, 44)\n",
      "20\n",
      "(1360, 10, 44)\n",
      "6\n",
      "(1400, 10, 44)\n",
      "7\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_37 (InputLayer)           (None, 10, 44)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_39 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 25528       input_37[0][0]                   \n",
      "                                                                 input_39[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 44)       402153      encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 427,681\n",
      "Trainable params: 427,681\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "20/20 [==============================] - 2s 92ms/step - loss: 310.5220 - val_loss: 269.3843\n",
      "Epoch 2/30\n",
      "20/20 [==============================] - 0s 21ms/step - loss: 266.1642 - val_loss: 260.9313\n",
      "Epoch 3/30\n",
      "20/20 [==============================] - 0s 23ms/step - loss: 261.2746 - val_loss: 259.8181\n",
      "Epoch 4/30\n",
      "20/20 [==============================] - 0s 22ms/step - loss: 260.5398 - val_loss: 259.2640\n",
      "Epoch 5/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 259.8620 - val_loss: 257.9477\n",
      "Epoch 6/30\n",
      "20/20 [==============================] - 0s 19ms/step - loss: 258.7646 - val_loss: 257.2289\n",
      "Epoch 7/30\n",
      "20/20 [==============================] - 0s 20ms/step - loss: 257.3597 - val_loss: 255.3651\n",
      "Epoch 8/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 256.1404 - val_loss: 254.3174\n",
      "Epoch 9/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 255.1367 - val_loss: 253.1804\n",
      "Epoch 10/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 254.4164 - val_loss: 252.8547\n",
      "Epoch 11/30\n",
      "20/20 [==============================] - 0s 16ms/step - loss: 253.7458 - val_loss: 251.9612\n",
      "Epoch 12/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 253.2028 - val_loss: 251.7955\n",
      "Epoch 13/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 252.4706 - val_loss: 251.1923\n",
      "Epoch 14/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 251.8516 - val_loss: 250.6353\n",
      "Epoch 15/30\n",
      "20/20 [==============================] - 0s 16ms/step - loss: 251.4999 - val_loss: 250.3101\n",
      "Epoch 16/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 251.2583 - val_loss: 250.0399\n",
      "Epoch 17/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 250.8806 - val_loss: 249.8756\n",
      "Epoch 18/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 250.6845 - val_loss: 249.5065\n",
      "Epoch 19/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 250.2680 - val_loss: 249.0565\n",
      "Epoch 20/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 250.0615 - val_loss: 249.1259\n",
      "Epoch 21/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 249.9172 - val_loss: 248.7712\n",
      "Epoch 22/30\n",
      "20/20 [==============================] - 0s 20ms/step - loss: 249.7824 - val_loss: 248.1266\n",
      "Epoch 23/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 249.5136 - val_loss: 247.8418\n",
      "Epoch 24/30\n",
      "20/20 [==============================] - 0s 18ms/step - loss: 249.4636 - val_loss: 247.8079\n",
      "Epoch 25/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 249.1365 - val_loss: 247.7066\n",
      "Epoch 26/30\n",
      "20/20 [==============================] - 0s 17ms/step - loss: 248.8872 - val_loss: 247.1418\n",
      "Epoch 27/30\n",
      "20/20 [==============================] - 0s 16ms/step - loss: 248.8126 - val_loss: 247.4650\n",
      "Epoch 28/30\n",
      "20/20 [==============================] - 0s 16ms/step - loss: 248.7504 - val_loss: 247.3764\n",
      "Epoch 29/30\n",
      "20/20 [==============================] - 0s 16ms/step - loss: 248.6360 - val_loss: 246.8428\n",
      "Epoch 30/30\n",
      "20/20 [==============================] - 0s 16ms/step - loss: 248.4276 - val_loss: 246.9264\n",
      "80% Train LR= 0.001  r2-3D= 0.604  r2-2D= 0.5358\n",
      "20% Test  LR= 0.001  r2-3D= 0.613  r2-2D= 0.5789\n",
      "Mihili_20140217_neural_con_dis_index.mat\n",
      "(4920, 10, 44)\n",
      "24\n",
      "(1640, 10, 44)\n",
      "8\n",
      "(1680, 10, 44)\n",
      "8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_40 (InputLayer)           (None, 10, 44)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_42 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 25528       input_40[0][0]                   \n",
      "                                                                 input_42[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 44)       402153      encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 427,681\n",
      "Trainable params: 427,681\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "24/24 [==============================] - 2s 68ms/step - loss: 272.7402 - val_loss: 215.1210\n",
      "Epoch 2/30\n",
      "24/24 [==============================] - 1s 23ms/step - loss: 222.8957 - val_loss: 209.5319\n",
      "Epoch 3/30\n",
      "24/24 [==============================] - 1s 22ms/step - loss: 220.0951 - val_loss: 207.8011\n",
      "Epoch 4/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 218.2039 - val_loss: 204.5340\n",
      "Epoch 5/30\n",
      "24/24 [==============================] - 1s 33ms/step - loss: 215.2683 - val_loss: 201.4487\n",
      "Epoch 6/30\n",
      "24/24 [==============================] - 1s 28ms/step - loss: 212.9391 - val_loss: 200.0948\n",
      "Epoch 7/30\n",
      "24/24 [==============================] - 1s 22ms/step - loss: 211.3824 - val_loss: 198.7865\n",
      "Epoch 8/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 210.0541 - val_loss: 198.1734\n",
      "Epoch 9/30\n",
      "24/24 [==============================] - 1s 22ms/step - loss: 209.2828 - val_loss: 197.4572\n",
      "Epoch 10/30\n",
      "24/24 [==============================] - 0s 17ms/step - loss: 208.5183 - val_loss: 197.1709\n",
      "Epoch 11/30\n",
      "24/24 [==============================] - 0s 17ms/step - loss: 208.1105 - val_loss: 196.8407\n",
      "Epoch 12/30\n",
      "24/24 [==============================] - 0s 16ms/step - loss: 207.6892 - val_loss: 196.6110\n",
      "Epoch 13/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 207.2231 - val_loss: 195.9481\n",
      "Epoch 14/30\n",
      "24/24 [==============================] - 1s 23ms/step - loss: 206.9325 - val_loss: 195.7384\n",
      "Epoch 15/30\n",
      "24/24 [==============================] - 1s 30ms/step - loss: 206.6739 - val_loss: 195.4615\n",
      "Epoch 16/30\n",
      "24/24 [==============================] - 1s 34ms/step - loss: 206.4923 - val_loss: 195.4052\n",
      "Epoch 17/30\n",
      "24/24 [==============================] - 1s 36ms/step - loss: 206.3034 - val_loss: 195.0299\n",
      "Epoch 18/30\n",
      "24/24 [==============================] - 1s 37ms/step - loss: 206.0487 - val_loss: 195.1214\n",
      "Epoch 19/30\n",
      "24/24 [==============================] - 1s 34ms/step - loss: 205.9431 - val_loss: 194.8073\n",
      "Epoch 20/30\n",
      "24/24 [==============================] - 1s 27ms/step - loss: 205.6799 - val_loss: 194.7020\n",
      "Epoch 21/30\n",
      "24/24 [==============================] - 1s 24ms/step - loss: 205.4514 - val_loss: 194.5626\n",
      "Epoch 22/30\n",
      "24/24 [==============================] - 1s 31ms/step - loss: 205.3422 - val_loss: 194.6303\n",
      "Epoch 23/30\n",
      "24/24 [==============================] - 1s 28ms/step - loss: 205.1348 - val_loss: 194.4320\n",
      "Epoch 24/30\n",
      "24/24 [==============================] - 1s 31ms/step - loss: 204.9138 - val_loss: 194.2096\n",
      "Epoch 25/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 204.7411 - val_loss: 194.1354\n",
      "Epoch 26/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 204.6653 - val_loss: 193.8747\n",
      "Epoch 27/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 204.4242 - val_loss: 193.1642\n",
      "Epoch 28/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 204.1743 - val_loss: 193.3156\n",
      "Epoch 29/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 204.0164 - val_loss: 193.0730\n",
      "Epoch 30/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 203.9268 - val_loss: 192.9541\n",
      "80% Train LR= 0.001  r2-3D= 0.602  r2-2D= 0.5915\n",
      "20% Test  LR= 0.001  r2-3D= 0.582  r2-2D= 0.5673\n",
      "Mihili_20140304_neural_con_dis_index.mat\n",
      "(4840, 10, 39)\n",
      "24\n",
      "(1600, 10, 39)\n",
      "8\n",
      "(1680, 10, 39)\n",
      "8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_43 (InputLayer)           (None, 10, 39)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_45 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 24888       input_43[0][0]                   \n",
      "                                                                 input_45[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 39)       314380      encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 339,268\n",
      "Trainable params: 339,268\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "24/24 [==============================] - 2s 70ms/step - loss: 295.0005 - val_loss: 260.3261\n",
      "Epoch 2/30\n",
      "24/24 [==============================] - 1s 24ms/step - loss: 262.9257 - val_loss: 254.4745\n",
      "Epoch 3/30\n",
      "24/24 [==============================] - 1s 23ms/step - loss: 257.1887 - val_loss: 248.8632\n",
      "Epoch 4/30\n",
      "24/24 [==============================] - 0s 17ms/step - loss: 253.5442 - val_loss: 247.3663\n",
      "Epoch 5/30\n",
      "24/24 [==============================] - 0s 17ms/step - loss: 252.3719 - val_loss: 246.6754\n",
      "Epoch 6/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 251.7336 - val_loss: 246.1905\n",
      "Epoch 7/30\n",
      "24/24 [==============================] - 0s 17ms/step - loss: 251.2301 - val_loss: 245.6424\n",
      "Epoch 8/30\n",
      "24/24 [==============================] - 0s 17ms/step - loss: 250.7854 - val_loss: 245.1499\n",
      "Epoch 9/30\n",
      "24/24 [==============================] - 0s 17ms/step - loss: 250.4150 - val_loss: 244.8848\n",
      "Epoch 10/30\n",
      "24/24 [==============================] - 0s 20ms/step - loss: 250.0546 - val_loss: 244.6226\n",
      "Epoch 11/30\n",
      "24/24 [==============================] - 0s 17ms/step - loss: 249.6263 - val_loss: 243.9500\n",
      "Epoch 12/30\n",
      "24/24 [==============================] - 0s 17ms/step - loss: 249.1790 - val_loss: 243.4462\n",
      "Epoch 13/30\n",
      "24/24 [==============================] - 0s 19ms/step - loss: 248.8107 - val_loss: 243.0948\n",
      "Epoch 14/30\n",
      "24/24 [==============================] - 0s 17ms/step - loss: 248.3218 - val_loss: 242.8459\n",
      "Epoch 15/30\n",
      "24/24 [==============================] - 0s 17ms/step - loss: 247.8514 - val_loss: 242.4012\n",
      "Epoch 16/30\n",
      "24/24 [==============================] - 0s 16ms/step - loss: 247.5162 - val_loss: 242.3254\n",
      "Epoch 17/30\n",
      "24/24 [==============================] - 0s 16ms/step - loss: 247.1803 - val_loss: 241.9458\n",
      "Epoch 18/30\n",
      "24/24 [==============================] - 0s 16ms/step - loss: 246.7098 - val_loss: 242.3461\n",
      "Epoch 19/30\n",
      "24/24 [==============================] - 0s 16ms/step - loss: 246.5414 - val_loss: 242.4270\n",
      "Epoch 20/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 246.3116 - val_loss: 242.2014\n",
      "Epoch 21/30\n",
      "24/24 [==============================] - 1s 22ms/step - loss: 246.1534 - val_loss: 241.9883\n",
      "Epoch 22/30\n",
      "24/24 [==============================] - 0s 20ms/step - loss: 246.0417 - val_loss: 240.8995\n",
      "Epoch 23/30\n",
      "24/24 [==============================] - 1s 25ms/step - loss: 245.7173 - val_loss: 240.7796\n",
      "Epoch 24/30\n",
      "24/24 [==============================] - 1s 30ms/step - loss: 245.6306 - val_loss: 240.6922\n",
      "Epoch 25/30\n",
      "24/24 [==============================] - 1s 28ms/step - loss: 245.3887 - val_loss: 241.0239\n",
      "Epoch 26/30\n",
      "24/24 [==============================] - 1s 35ms/step - loss: 245.3570 - val_loss: 240.8828\n",
      "Epoch 27/30\n",
      "24/24 [==============================] - 0s 21ms/step - loss: 245.2439 - val_loss: 241.3835\n",
      "Epoch 28/30\n",
      "24/24 [==============================] - 1s 23ms/step - loss: 245.1238 - val_loss: 240.7741\n",
      "Epoch 29/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 244.9328 - val_loss: 240.9253\n",
      "Epoch 30/30\n",
      "24/24 [==============================] - 0s 18ms/step - loss: 244.7732 - val_loss: 240.7397\n",
      "80% Train LR= 0.001  r2-3D= 0.627  r2-2D= 0.5316\n",
      "20% Test  LR= 0.001  r2-3D= 0.606  r2-2D= 0.5242\n",
      "Mihili_20140307_neural_con_dis_index.mat\n",
      "(5160, 10, 26)\n",
      "25\n",
      "(1720, 10, 26)\n",
      "8\n",
      "(1760, 10, 26)\n",
      "8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_46 (InputLayer)           (None, 10, 26)       0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_48 (InputLayer)           (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 23224       input_46[0][0]                   \n",
      "                                                                 input_48[0][0]                   \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 26)       141108      encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 164,332\n",
      "Trainable params: 164,332\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/30\n",
      "25/25 [==============================] - 2s 73ms/step - loss: 186.7978 - val_loss: 162.4359\n",
      "Epoch 2/30\n",
      "25/25 [==============================] - 0s 19ms/step - loss: 158.5133 - val_loss: 157.9717\n",
      "Epoch 3/30\n",
      "25/25 [==============================] - 0s 15ms/step - loss: 156.2616 - val_loss: 156.4338\n",
      "Epoch 4/30\n",
      "25/25 [==============================] - 0s 16ms/step - loss: 155.1115 - val_loss: 155.7022\n",
      "Epoch 5/30\n",
      "25/25 [==============================] - 0s 16ms/step - loss: 153.5815 - val_loss: 153.5266\n",
      "Epoch 6/30\n",
      "25/25 [==============================] - 0s 17ms/step - loss: 151.4808 - val_loss: 150.7382\n",
      "Epoch 7/30\n",
      "25/25 [==============================] - 0s 16ms/step - loss: 149.8157 - val_loss: 149.7992\n",
      "Epoch 8/30\n",
      "25/25 [==============================] - 0s 17ms/step - loss: 148.5634 - val_loss: 148.2168\n",
      "Epoch 9/30\n",
      "25/25 [==============================] - 0s 18ms/step - loss: 147.6160 - val_loss: 147.5276\n",
      "Epoch 10/30\n",
      "25/25 [==============================] - 0s 18ms/step - loss: 147.2000 - val_loss: 147.1548\n",
      "Epoch 11/30\n",
      "25/25 [==============================] - 0s 18ms/step - loss: 146.8866 - val_loss: 146.7032\n",
      "Epoch 12/30\n",
      "25/25 [==============================] - 1s 21ms/step - loss: 146.6785 - val_loss: 146.5787\n",
      "Epoch 13/30\n",
      "25/25 [==============================] - 1s 25ms/step - loss: 146.4858 - val_loss: 146.4477\n",
      "Epoch 14/30\n",
      "25/25 [==============================] - 1s 23ms/step - loss: 146.3946 - val_loss: 146.2839\n",
      "Epoch 15/30\n",
      "25/25 [==============================] - 1s 26ms/step - loss: 146.2534 - val_loss: 146.3238\n",
      "Epoch 16/30\n",
      "25/25 [==============================] - 1s 33ms/step - loss: 146.1685 - val_loss: 145.9673\n",
      "Epoch 17/30\n",
      "25/25 [==============================] - 1s 32ms/step - loss: 146.0675 - val_loss: 145.9799\n",
      "Epoch 18/30\n",
      "25/25 [==============================] - 1s 23ms/step - loss: 145.9334 - val_loss: 145.9897\n",
      "Epoch 19/30\n",
      "25/25 [==============================] - 1s 23ms/step - loss: 145.7138 - val_loss: 145.7462\n",
      "Epoch 20/30\n",
      "25/25 [==============================] - 1s 26ms/step - loss: 145.6317 - val_loss: 145.6696\n",
      "Epoch 21/30\n",
      "25/25 [==============================] - 1s 25ms/step - loss: 145.4354 - val_loss: 145.5676\n",
      "Epoch 22/30\n",
      "25/25 [==============================] - 1s 27ms/step - loss: 145.3202 - val_loss: 145.7143\n",
      "Epoch 23/30\n",
      "25/25 [==============================] - 0s 17ms/step - loss: 145.1352 - val_loss: 145.5419\n",
      "Epoch 24/30\n",
      "25/25 [==============================] - 0s 17ms/step - loss: 145.0611 - val_loss: 145.5199\n",
      "Epoch 25/30\n",
      "25/25 [==============================] - 0s 17ms/step - loss: 144.9376 - val_loss: 145.6143\n",
      "Epoch 26/30\n",
      "25/25 [==============================] - 0s 16ms/step - loss: 144.7953 - val_loss: 145.5899\n",
      "Epoch 27/30\n",
      "25/25 [==============================] - 0s 16ms/step - loss: 144.7014 - val_loss: 145.3417\n",
      "Epoch 28/30\n",
      "25/25 [==============================] - 0s 20ms/step - loss: 144.6394 - val_loss: 145.1922\n",
      "Epoch 29/30\n",
      "25/25 [==============================] - 0s 17ms/step - loss: 144.5911 - val_loss: 145.2689\n",
      "Epoch 30/30\n",
      "25/25 [==============================] - 1s 21ms/step - loss: 144.5931 - val_loss: 145.2108\n",
      "80% Train LR= 0.001  r2-3D= 0.575  r2-2D= 0.3214\n",
      "20% Test  LR= 0.001  r2-3D= 0.581  r2-2D= 0.3268\n"
     ]
    }
   ],
   "source": [
    "directory = \"./data/SU_16M1\"\n",
    "files = os.listdir(directory)\n",
    "angle_to_new_value = {-180: 4,-135: 5,-90: 6,-45: 7,0: 0,45: 1,90: 2,135: 3,180: 4}\n",
    "for file in files:\n",
    "    print(file)\n",
    "    mat_contents = sio.loadmat(os.path.join(directory, file))\n",
    "    neural = mat_contents['neural_M1']\n",
    "    continuous_index_XY = mat_contents['continuous_index']\n",
    "    discrete_index = mat_contents['discrete_index']\n",
    "    vectorized_map = np.vectorize(lambda x: angle_to_new_value[x])\n",
    "    discrete_index = 45*vectorized_map(discrete_index)\n",
    "    L = neural.shape[0]\n",
    "    N_values_hist = round(L/5)\n",
    "    random_indices = np.random.choice(L, size=N_values_hist, replace=False)\n",
    "    indices_X = continuous_index_XY[random_indices, 0]\n",
    "    indices_Y = continuous_index_XY[random_indices, 1]\n",
    "    index_diffs_X = np.abs(indices_X[:, None] - indices_X[None, :]) \n",
    "    index_diffs_Y = np.abs(indices_Y[:, None] - indices_Y[None, :])\n",
    "    l_dist_XY = index_diffs_X + index_diffs_Y\n",
    "    l_dist_XY_1d = l_dist_XY[~np.eye(N_values_hist, dtype=bool)].flatten()\n",
    "    XY_scale = 10\n",
    "    continuous_index_XY = continuous_index_XY*XY_scale\n",
    "    continuous_index = np.column_stack((continuous_index_XY, discrete_index))\n",
    "    # print('continuous_index=', continuous_index[:10])\n",
    "    N_bins, N_neurons = neural.shape\n",
    "    train_end = int(N_bins * train_percent)// dur * dur\n",
    "    valid_end = train_end + int(N_bins * valid_percent)\n",
    "    valid_end = valid_end// dur * dur\n",
    "    train_neural = neural[:train_end, :]\n",
    "    Y_train = continuous_index[:train_end, :]\n",
    "    valid_neural = neural[train_end:valid_end, :]\n",
    "    Y_valid = continuous_index[train_end:valid_end, :]\n",
    "    test_neural = neural[valid_end:, :]\n",
    "    Y_test = continuous_index[valid_end:, :]\n",
    "    # print(np.unique(Y_train[:, 2]))\n",
    "    X_train = dataset_2D_to_3D(train_neural)\n",
    "    X_valid = dataset_2D_to_3D(valid_neural)\n",
    "    X_test = dataset_2D_to_3D(test_neural)\n",
    "\n",
    "    train_x, train_u = to_batch_list(X_train, Y_train, batch_size)\n",
    "    train_loader = custom_data_generator(train_x, train_u)\n",
    "\n",
    "    valid_x, valid_u = to_batch_list(X_valid, Y_valid, batch_size)\n",
    "    valid_loader = custom_data_generator(valid_x, valid_u)\n",
    "\n",
    "    test_x, test_u = to_batch_list(X_test, Y_test, batch_size)\n",
    "    test_loader  = custom_data_generator(test_x, test_u)\n",
    "    try:    \n",
    "        conv_pivae = pivae_code.conv_pi_vae.conv_vae_mdl(\n",
    "            dim_x = N_neurons,\n",
    "            dim_z = embed_dimension,\n",
    "            dim_u = 3,\n",
    "            time_window=10,\n",
    "            gen_nodes=60,\n",
    "            n_blk=2,\n",
    "            mdl=\"poisson\",\n",
    "            disc=False,\n",
    "            learning_rate=learning_rate)      \n",
    "        s_n = conv_pivae.fit_generator(\n",
    "            train_loader, ### will call \"def custom_data_generator\" \n",
    "            steps_per_epoch=len(train_x), ### 34\n",
    "            epochs=iterations, ### iterations\n",
    "            verbose=1,\n",
    "            validation_data = valid_loader,\n",
    "            validation_steps = len(valid_x))\n",
    "\n",
    "        start_time = time.time()\n",
    "        X = np.concatenate(train_x) ### (Xbins, 10=5ms-offset+5ms-offset, Xneurons)\n",
    "        labels = np.concatenate(train_u) ### (Xbins, position+direction)\n",
    "        outputs_train = conv_pivae.predict([X, labels])\n",
    "        X = np.concatenate(test_x) \n",
    "        labels = np.concatenate(test_u) \n",
    "        outputs_test = conv_pivae.predict([X, labels])\n",
    "        end_time = time.time()\n",
    "        execution_time = np.round((end_time - start_time), 2)\n",
    "        ### Outputs: post_mean, post_log_var, z_sample,fire_rate, lam_mean, lam_log_var, z_mean, z_log_var\n",
    "        cebra_veldir_train = outputs_train[0]\n",
    "        cebra_veldir_test = outputs_test[0]\n",
    "\n",
    "        X = cebra_veldir_train\n",
    "        y = Y_train[:,0:2]\n",
    "        reg_3d = LinearRegression().fit(X, y)       #### 1st fit ####\n",
    "        pred_vel = reg_3d.predict(X)\n",
    "        vel_train_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "        pca = PCA(n_components=2)\n",
    "        pca_2d = pca.fit(X)                         #### 2nd fit ####\n",
    "        X_2d = pca_2d.transform(X)\n",
    "        reg_2d = LinearRegression().fit(X_2d, y)    #### 3rd fit ####\n",
    "        pred_vel = reg_2d.predict(X_2d)\n",
    "        vel_train_r2_pca = sklearn.metrics.r2_score(y, pred_vel)\n",
    "        vel_train_r2_pca = np.round(vel_train_r2_pca, 4)\n",
    "\n",
    "        print('80% Train LR=', str(learning_rate), \\\n",
    "              ' r2-3D=', str(np.round(vel_train_r2, 3)), ' r2-2D=', str(vel_train_r2_pca))\n",
    "        ###************* use previous trained \"reg_3d & pca_2d & reg_2d\" ###***************\n",
    "        ###************* use previous trained \"reg_3d & pca_2d & reg_2d\" ###***************\n",
    "        X = cebra_veldir_test\n",
    "        y = Y_test[:,0:2]\n",
    "        pred_vel = reg_3d.predict(X)\n",
    "        vel_test_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "        X_2d = pca_2d.transform(X)\n",
    "        pred_vel = reg_2d.predict(X_2d)\n",
    "        vel_test_r2_pca = sklearn.metrics.r2_score(y, pred_vel)\n",
    "        vel_test_r2_pca = np.round(vel_test_r2_pca, 4)\n",
    "\n",
    "        print('20% Test  LR=', str(learning_rate), \\\n",
    "              ' r2-3D=', str(np.round(vel_test_r2, 3)), ' r2-2D=', str(vel_test_r2_pca))\n",
    "\n",
    "        new_filename = file[:16] + \"_LR_\"+str(learning_rate)+ \\\n",
    "            \"_iterations_\"+str(iterations)+ \\\n",
    "            \"_80%train_\"+str(vel_train_r2_pca)+ \\\n",
    "            \"_20%test_\"+str(vel_test_r2_pca)+\".npz\"\n",
    "        file_save = os.path.join('./data/Fig2_SU/piv/emb_M1',new_filename)\n",
    "        np.savez(file_save,\n",
    "                 execution_time = execution_time,\n",
    "                 learning_rate = learning_rate,\n",
    "                 iterations = iterations,\n",
    "                 cebra_veldir_train=cebra_veldir_train,\n",
    "                 cebra_veldir_test=cebra_veldir_test,\n",
    "                 continuous_index_train=Y_train,\n",
    "                 continuous_index_test=Y_test,\n",
    "                 vel_train_r2 = vel_train_r2,\n",
    "                 vel_test_r2 = vel_test_r2,\n",
    "                 vel_train_r2_pca = vel_train_r2_pca,\n",
    "                 vel_test_r2_pca = vel_test_r2_pca)\n",
    "        \n",
    "        fig = plt.figure(figsize=(6,5))\n",
    "        ax = plt.subplot(111)\n",
    "        val_loss = s_n.history['val_loss'][:]\n",
    "        loss = s_n.history['loss'][:]\n",
    "        loss = np.array(s_n.history['loss'])\n",
    "        loss_stable = loss[-10:]\n",
    "        # print('loss_stable=', loss_stable)\n",
    "        plt.plot(val_loss, c='deepskyblue', label='val-loss')\n",
    "        plt.plot(loss, c='blue', label='loss')\n",
    "        ax.spines['top'].set_visible(False)\n",
    "        ax.spines['right'].set_visible(False)\n",
    "        ax.set_xlabel('Iterations')\n",
    "        ax.set_ylabel('piVAE Loss')\n",
    "        plt.legend(bbox_to_anchor=(0.5,0.3), frameon = False )\n",
    "        plt.title('itr='+str(iterations)+' loss='+str(int(np.mean(loss_stable))))\n",
    "        filename = file[:16]+\"_loss.pdf\"\n",
    "        file_save = os.path.join('./data/Fig2_SU/piv/emb_M1/loss', filename)\n",
    "        plt.savefig(file_save)\n",
    "        plt.close(fig)\n",
    "    except Exception as e:\n",
    "        print(' LR=', str(learning_rate), ' fail')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb1afb43",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  },
  "vscode": {
   "interpreter": {
    "hash": "dc327929684d2c13e929b2699e1b37518dbb61b921da51c352c926069002ee0e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
