{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f617ecc5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "import joblib as jl\n",
    "import cebra.datasets\n",
    "from cebra import CEBRA\n",
    "import scipy.io as sio\n",
    "import sklearn.metrics\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.linear_model import LinearRegression\n",
    "import time\n",
    "sys.path.insert(0, './third_party/pivae')\n",
    "import pivae_code.datasets, pivae_code.conv_pi_vae, pivae_code.pi_vae\n",
    "\n",
    "train_percent = 0.60\n",
    "valid_percent = 0.20\n",
    "test_percent = 0.20\n",
    "embed_dimension = 3\n",
    "batch_size = 200\n",
    "np.random.seed(2024)\n",
    "iterations = 5\n",
    "output_dimension = 3\n",
    "learning_rate = 0.001\n",
    "dur = 45"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8ff486f8",
   "metadata": {},
   "outputs": [],
   "source": [
    "def dataset_2D_to_3D(dataset_2D):\n",
    "    # data = train_set.neural.numpy()  # [time_bins, neurons]\n",
    "    time_bins, neurons = dataset_2D.shape\n",
    "    receptive_field_size = 10  # Total bins in receptive field\n",
    "    half_window = receptive_field_size // 2\n",
    "    dataset_3D = np.zeros((time_bins, neurons, receptive_field_size))\n",
    "    for t in range(time_bins):\n",
    "        for n in range(neurons):\n",
    "            # Calculate the indices for the receptive field window\n",
    "            start_idx = max(0, t - half_window)\n",
    "            end_idx = min(time_bins, t + half_window + 1)\n",
    "\n",
    "            # Slice the window for neuron n\n",
    "            window = dataset_2D[start_idx:end_idx, n]\n",
    "\n",
    "            # Calculate where to place the window in the receptive field dimension\n",
    "            # Adjusting indices to fit exactly within the receptive field slots\n",
    "            window_start = half_window - (t - start_idx)\n",
    "            window_end = window_start + (end_idx - start_idx)\n",
    "\n",
    "            # Ensure the window fits exactly into the new_data array\n",
    "            window_start = max(0, window_start)\n",
    "            window_end = min(receptive_field_size, window_end)\n",
    "            \n",
    "            dataset_3D[t, n, window_start:window_end] = window[:window_end - window_start]\n",
    "    return dataset_3D\n",
    "\n",
    "def to_batch_list(x, y, batch_size):\n",
    "    x = x.squeeze()\n",
    "    ### print(x.shape) ### (6885/1390/1903, 120, 10)\n",
    "    if len(x.shape) == 3:\n",
    "        x = x.transpose(0,2,1) \n",
    "        print(x.shape) ### (6885/1390/1903, 10, 120)\n",
    "    x_batch_list = np.array_split(x, int(len(x) / batch_size))\n",
    "    print(int(len(x) / batch_size)) ### 6885/1390/1903 divided by batch-size===34/6/9\n",
    "    y_batch_list = np.array_split(y, int(len(y) / batch_size))\n",
    "    return x_batch_list, y_batch_list\n",
    "\n",
    "def custom_data_generator(x_all, u_one_hot):\n",
    "    while True:\n",
    "        for ii in range(len(x_all)):\n",
    "            #print(x_all[ii].shape)\n",
    "            #print(u_one_hot[ii].shape)\n",
    "            yield ([x_all[ii], u_one_hot[ii]], None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "122b7983",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "indy_20161220_02_ready.mat\n",
      "continuous_index= [[11.66075943  2.54052289 71.56505118]\n",
      " [14.04047975  2.13837654 71.56505118]\n",
      " [16.41364575  2.20493911 71.56505118]\n",
      " [19.20918971  3.3448685  71.56505118]\n",
      " [23.33711103  6.17369316 71.56505118]\n",
      " [29.22037201 11.2746592  71.56505118]\n",
      " [36.83051776 19.44177959 71.56505118]\n",
      " [45.80602526 31.9581189  71.56505118]\n",
      " [55.18370808 49.81638144 71.56505118]\n",
      " [63.9002107  72.94612531 71.56505118]]\n",
      "(8910, 10, 119)\n",
      "44\n",
      "(2970, 10, 119)\n",
      "14\n",
      "(3003, 10, 119)\n",
      "15\n",
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_4 (InputLayer)            (None, 10, 119)      0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_6 (InputLayer)            (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 35128       input_4[0][0]                    \n",
      "                                                                 input_6[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 119)      2922780     encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 2,957,908\n",
      "Trainable params: 2,957,908\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/5\n",
      "44/44 [==============================] - 3s 73ms/step - loss: 343.4109 - val_loss: 236.4885\n",
      "Epoch 2/5\n",
      "44/44 [==============================] - 2s 48ms/step - loss: 255.5410 - val_loss: 229.7783\n",
      "Epoch 3/5\n",
      "44/44 [==============================] - 2s 48ms/step - loss: 254.2380 - val_loss: 230.1168\n",
      "Epoch 4/5\n",
      "44/44 [==============================] - 2s 48ms/step - loss: 253.8972 - val_loss: 229.1396\n",
      "Epoch 5/5\n",
      "44/44 [==============================] - 2s 48ms/step - loss: 253.9861 - val_loss: 228.4721\n",
      "loss_stable= [343.60386855 255.54777331 254.24458079 253.9032991  253.99187343]\n",
      "80% Train LR= 0.001  r2-3D= 0.104  r2-2D= 0.0974\n",
      "20% Test  LR= 0.001  r2-3D= 0.095  r2-2D= 0.0907\n",
      "indy_20170131_02_ready.mat\n",
      "continuous_index= [[-9.46054407e-01  3.71722887e-01  2.84036243e+02]\n",
      " [-8.61403257e-01  3.46002116e-01  2.84036243e+02]\n",
      " [-9.42632443e-01  7.73821675e-02  2.84036243e+02]\n",
      " [-1.62416779e+00 -2.56842477e-01  2.84036243e+02]\n",
      " [-2.98887250e+00 -6.91058258e-01  2.84036243e+02]\n",
      " [-4.63720760e+00 -1.62085769e+00  2.84036243e+02]\n",
      " [-5.86930479e+00 -3.88055498e+00  2.84036243e+02]\n",
      " [-5.99669946e+00 -8.53438675e+00  2.84036243e+02]\n",
      " [-4.07709830e+00 -1.68086381e+01  2.84036243e+02]\n",
      " [ 1.34339092e+00 -3.02153888e+01  2.84036243e+02]]\n",
      "(13500, 10, 142)\n",
      "67\n",
      "(4500, 10, 142)\n",
      "22\n",
      "(4544, 10, 142)\n",
      "22\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output encoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to encoder.\n",
      "  warnings.warn(\n",
      "/home/marmoset/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_utils.py:816: UserWarning: Output decoder missing from loss dictionary. We assume this was done on purpose. The fit and evaluate APIs will not be expecting any data to be passed to decoder.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"vae\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_7 (InputLayer)            (None, 10, 142)      0                                            \n",
      "__________________________________________________________________________________________________\n",
      "input_9 (InputLayer)            (None, 3)            0                                            \n",
      "__________________________________________________________________________________________________\n",
      "encoder (Model)                 [(None, 3), (None, 3 38072       input_7[0][0]                    \n",
      "                                                                 input_9[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "decoder (Model)                 (None, 10, 142)      4168048     encoder[1][2]                    \n",
      "==================================================================================================\n",
      "Total params: 4,206,120\n",
      "Trainable params: 4,206,120\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n",
      "None\n",
      "Epoch 1/5\n",
      "67/67 [==============================] - 5s 81ms/step - loss: 436.7980 - val_loss: 317.0359\n",
      "Epoch 2/5\n",
      "67/67 [==============================] - 4s 65ms/step - loss: 354.7268 - val_loss: 315.9783\n",
      "Epoch 3/5\n",
      "67/67 [==============================] - 5s 68ms/step - loss: 354.5506 - val_loss: 315.9163\n",
      "Epoch 4/5\n",
      "67/67 [==============================] - 4s 66ms/step - loss: 353.5831 - val_loss: 315.7714\n",
      "Epoch 5/5\n",
      " 2/67 [..............................] - ETA: 4s - loss: 387.0146"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[6], line 58\u001b[0m\n\u001b[1;32m     47\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:    \n\u001b[1;32m     48\u001b[0m     conv_pivae \u001b[38;5;241m=\u001b[39m pivae_code\u001b[38;5;241m.\u001b[39mconv_pi_vae\u001b[38;5;241m.\u001b[39mconv_vae_mdl(\n\u001b[1;32m     49\u001b[0m         dim_x \u001b[38;5;241m=\u001b[39m N_neurons,\n\u001b[1;32m     50\u001b[0m         dim_z \u001b[38;5;241m=\u001b[39m embed_dimension,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     56\u001b[0m         disc\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m,\n\u001b[1;32m     57\u001b[0m         learning_rate\u001b[38;5;241m=\u001b[39mlearning_rate)      \n\u001b[0;32m---> 58\u001b[0m     s_n \u001b[38;5;241m=\u001b[39m \u001b[43mconv_pivae\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_generator\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m     59\u001b[0m \u001b[43m        \u001b[49m\u001b[43mtrain_loader\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m### will call \"def custom_data_generator\" \u001b[39;49;00m\n\u001b[1;32m     60\u001b[0m \u001b[43m        \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mtrain_x\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m### 34\u001b[39;49;00m\n\u001b[1;32m     61\u001b[0m \u001b[43m        \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43miterations\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m### iterations\u001b[39;49;00m\n\u001b[1;32m     62\u001b[0m \u001b[43m        \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[1;32m     63\u001b[0m \u001b[43m        \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[43mvalid_loader\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     64\u001b[0m \u001b[43m        \u001b[49m\u001b[43mvalidation_steps\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mvalid_x\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     66\u001b[0m     start_time \u001b[38;5;241m=\u001b[39m time\u001b[38;5;241m.\u001b[39mtime()\n\u001b[1;32m     67\u001b[0m     X \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mconcatenate(train_x) \u001b[38;5;66;03m### (Xbins, 10=5ms-offset+5ms-offset, Xneurons)\u001b[39;00m\n",
      "File \u001b[0;32m~/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/legacy/interfaces.py:91\u001b[0m, in \u001b[0;36mgenerate_legacy_interface.<locals>.legacy_support.<locals>.wrapper\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m     88\u001b[0m     signature \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m)`\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m     89\u001b[0m     warnings\u001b[38;5;241m.\u001b[39mwarn(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mUpdate your `\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m object_name \u001b[38;5;241m+\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m` call to the \u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m\n\u001b[1;32m     90\u001b[0m                   \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mKeras 2 API: \u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;241m+\u001b[39m signature, stacklevel\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[0;32m---> 91\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training.py:1718\u001b[0m, in \u001b[0;36mModel.fit_generator\u001b[0;34m(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)\u001b[0m\n\u001b[1;32m   1583\u001b[0m \u001b[38;5;129m@interfaces\u001b[39m\u001b[38;5;241m.\u001b[39mlegacy_generator_methods_support\n\u001b[1;32m   1584\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mfit_generator\u001b[39m(\u001b[38;5;28mself\u001b[39m, generator,\n\u001b[1;32m   1585\u001b[0m                   steps_per_epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1596\u001b[0m                   shuffle\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m,\n\u001b[1;32m   1597\u001b[0m                   initial_epoch\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m):\n\u001b[1;32m   1598\u001b[0m \u001b[38;5;250m    \u001b[39m\u001b[38;5;124;03m\"\"\"Trains the model on data generated batch-by-batch by a Python generator\u001b[39;00m\n\u001b[1;32m   1599\u001b[0m \u001b[38;5;124;03m    (or an instance of `Sequence`).\u001b[39;00m\n\u001b[1;32m   1600\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1716\u001b[0m \u001b[38;5;124;03m    ```\u001b[39;00m\n\u001b[1;32m   1717\u001b[0m \u001b[38;5;124;03m    \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1718\u001b[0m     \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtraining_generator\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfit_generator\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1719\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgenerator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1720\u001b[0m \u001b[43m        \u001b[49m\u001b[43msteps_per_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msteps_per_epoch\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1721\u001b[0m \u001b[43m        \u001b[49m\u001b[43mepochs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mepochs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1722\u001b[0m \u001b[43m        \u001b[49m\u001b[43mverbose\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbose\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1723\u001b[0m \u001b[43m        \u001b[49m\u001b[43mcallbacks\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcallbacks\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1724\u001b[0m \u001b[43m        \u001b[49m\u001b[43mvalidation_data\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_data\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1725\u001b[0m \u001b[43m        \u001b[49m\u001b[43mvalidation_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_steps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1726\u001b[0m \u001b[43m        \u001b[49m\u001b[43mvalidation_freq\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mvalidation_freq\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1727\u001b[0m \u001b[43m        \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1728\u001b[0m \u001b[43m        \u001b[49m\u001b[43mmax_queue_size\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmax_queue_size\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1729\u001b[0m \u001b[43m        \u001b[49m\u001b[43mworkers\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mworkers\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1730\u001b[0m \u001b[43m        \u001b[49m\u001b[43muse_multiprocessing\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_multiprocessing\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1731\u001b[0m \u001b[43m        \u001b[49m\u001b[43mshuffle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mshuffle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m   1732\u001b[0m \u001b[43m        \u001b[49m\u001b[43minitial_epoch\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minitial_epoch\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training_generator.py:217\u001b[0m, in \u001b[0;36mfit_generator\u001b[0;34m(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, validation_freq, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)\u001b[0m\n\u001b[1;32m    214\u001b[0m batch_logs \u001b[38;5;241m=\u001b[39m {\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mbatch\u001b[39m\u001b[38;5;124m'\u001b[39m: batch_index, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124msize\u001b[39m\u001b[38;5;124m'\u001b[39m: batch_size}\n\u001b[1;32m    215\u001b[0m callbacks\u001b[38;5;241m.\u001b[39mon_batch_begin(batch_index, batch_logs)\n\u001b[0;32m--> 217\u001b[0m outs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_on_batch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    218\u001b[0m \u001b[43m                            \u001b[49m\u001b[43msample_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43msample_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    219\u001b[0m \u001b[43m                            \u001b[49m\u001b[43mclass_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mclass_weight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    220\u001b[0m \u001b[43m                            \u001b[49m\u001b[43mreset_metrics\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[1;32m    222\u001b[0m outs \u001b[38;5;241m=\u001b[39m to_list(outs)\n\u001b[1;32m    223\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m l, o \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(out_labels, outs):\n",
      "File \u001b[0;32m~/miniconda3/envs/cebra/lib/python3.8/site-packages/keras/engine/training.py:1514\u001b[0m, in \u001b[0;36mModel.train_on_batch\u001b[0;34m(self, x, y, sample_weight, class_weight, reset_metrics)\u001b[0m\n\u001b[1;32m   1512\u001b[0m     ins \u001b[38;5;241m=\u001b[39m x \u001b[38;5;241m+\u001b[39m y \u001b[38;5;241m+\u001b[39m sample_weights\n\u001b[1;32m   1513\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_make_train_function()\n\u001b[0;32m-> 1514\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43mins\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   1516\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m reset_metrics:\n\u001b[1;32m   1517\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreset_metrics()\n",
      "File \u001b[0;32m~/miniconda3/envs/cebra/lib/python3.8/site-packages/tensorflow/python/keras/backend.py:3792\u001b[0m, in \u001b[0;36mEagerExecutionFunction.__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m   3790\u001b[0m     value \u001b[38;5;241m=\u001b[39m math_ops\u001b[38;5;241m.\u001b[39mcast(value, tensor\u001b[38;5;241m.\u001b[39mdtype)\n\u001b[1;32m   3791\u001b[0m   converted_inputs\u001b[38;5;241m.\u001b[39mappend(value)\n\u001b[0;32m-> 3792\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_graph_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconverted_inputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m   3794\u001b[0m \u001b[38;5;66;03m# EagerTensor.numpy() will often make a copy to ensure memory safety.\u001b[39;00m\n\u001b[1;32m   3795\u001b[0m \u001b[38;5;66;03m# However in this case `outputs` is not directly returned, so it is always\u001b[39;00m\n\u001b[1;32m   3796\u001b[0m \u001b[38;5;66;03m# safe to reuse the underlying buffer without checking. In such a case the\u001b[39;00m\n\u001b[1;32m   3797\u001b[0m \u001b[38;5;66;03m# private numpy conversion method is preferred to guarantee performance.\u001b[39;00m\n\u001b[1;32m   3798\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m nest\u001b[38;5;241m.\u001b[39mpack_sequence_as(\n\u001b[1;32m   3799\u001b[0m     \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_outputs_structure,\n\u001b[1;32m   3800\u001b[0m     [x\u001b[38;5;241m.\u001b[39m_numpy() \u001b[38;5;28;01mfor\u001b[39;00m x \u001b[38;5;129;01min\u001b[39;00m outputs],  \u001b[38;5;66;03m# pylint: disable=protected-access\u001b[39;00m\n\u001b[1;32m   3801\u001b[0m     expand_composites\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
      "File \u001b[0;32m~/miniconda3/envs/cebra/lib/python3.8/site-packages/tensorflow/python/eager/function.py:1605\u001b[0m, in \u001b[0;36mConcreteFunction.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1582\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__call__\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m   1583\u001b[0m \u001b[38;5;250m  \u001b[39m\u001b[38;5;124;03m\"\"\"Executes the wrapped function.\u001b[39;00m\n\u001b[1;32m   1584\u001b[0m \n\u001b[1;32m   1585\u001b[0m \u001b[38;5;124;03m  Args:\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   1603\u001b[0m \u001b[38;5;124;03m    TypeError: For invalid positional/keyword argument combinations.\u001b[39;00m\n\u001b[1;32m   1604\u001b[0m \u001b[38;5;124;03m  \"\"\"\u001b[39;00m\n\u001b[0;32m-> 1605\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/cebra/lib/python3.8/site-packages/tensorflow/python/eager/function.py:1645\u001b[0m, in \u001b[0;36mConcreteFunction._call_impl\u001b[0;34m(self, args, kwargs, cancellation_manager)\u001b[0m\n\u001b[1;32m   1642\u001b[0m       \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGot two values for keyword \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(unused_key))\n\u001b[1;32m   1643\u001b[0m   \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mKeyword arguments \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m unknown. Expected \u001b[39m\u001b[38;5;132;01m{}\u001b[39;00m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;241m.\u001b[39mformat(\n\u001b[1;32m   1644\u001b[0m       \u001b[38;5;28mlist\u001b[39m(kwargs\u001b[38;5;241m.\u001b[39mkeys()), \u001b[38;5;28mlist\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_arg_keywords)))\n\u001b[0;32m-> 1645\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_call_flat\u001b[49m\u001b[43m(\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcaptured_inputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcancellation_manager\u001b[49m\u001b[43m)\u001b[49m\n",
      "File \u001b[0;32m~/miniconda3/envs/cebra/lib/python3.8/site-packages/tensorflow/python/eager/function.py:1745\u001b[0m, in \u001b[0;36mConcreteFunction._call_flat\u001b[0;34m(self, args, captured_inputs, cancellation_manager)\u001b[0m\n\u001b[1;32m   1740\u001b[0m possible_gradient_type \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m   1741\u001b[0m     pywrap_tfe\u001b[38;5;241m.\u001b[39mTFE_Py_TapeSetPossibleGradientTypes(args))\n\u001b[1;32m   1742\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (possible_gradient_type \u001b[38;5;241m==\u001b[39m _POSSIBLE_GRADIENT_TYPES_NONE\n\u001b[1;32m   1743\u001b[0m     \u001b[38;5;129;01mand\u001b[39;00m executing_eagerly):\n\u001b[1;32m   1744\u001b[0m   \u001b[38;5;66;03m# No tape is watching; skip to running the function.\u001b[39;00m\n\u001b[0;32m-> 1745\u001b[0m   \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_build_call_outputs(\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_inference_function\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcall\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m   1746\u001b[0m \u001b[43m      \u001b[49m\u001b[43mctx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcancellation_manager\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcancellation_manager\u001b[49m\u001b[43m)\u001b[49m)\n\u001b[1;32m   1747\u001b[0m forward_backward \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_select_forward_and_backward_functions(\n\u001b[1;32m   1748\u001b[0m     args,\n\u001b[1;32m   1749\u001b[0m     possible_gradient_type,\n\u001b[1;32m   1750\u001b[0m     executing_eagerly)\n\u001b[1;32m   1751\u001b[0m forward_function, args_with_tangents \u001b[38;5;241m=\u001b[39m forward_backward\u001b[38;5;241m.\u001b[39mforward()\n",
      "File \u001b[0;32m~/miniconda3/envs/cebra/lib/python3.8/site-packages/tensorflow/python/eager/function.py:593\u001b[0m, in \u001b[0;36m_EagerDefinedFunction.call\u001b[0;34m(self, ctx, args, cancellation_manager)\u001b[0m\n\u001b[1;32m    591\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m _InterpolateFunctionError(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m    592\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m cancellation_manager \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 593\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m \u001b[43mexecute\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexecute\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m    594\u001b[0m \u001b[43m        \u001b[49m\u001b[38;5;28;43mstr\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msignature\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    595\u001b[0m \u001b[43m        \u001b[49m\u001b[43mnum_outputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_num_outputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    596\u001b[0m \u001b[43m        \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    597\u001b[0m \u001b[43m        \u001b[49m\u001b[43mattrs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mattrs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m    598\u001b[0m \u001b[43m        \u001b[49m\u001b[43mctx\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mctx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m    599\u001b[0m   \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m    600\u001b[0m     outputs \u001b[38;5;241m=\u001b[39m execute\u001b[38;5;241m.\u001b[39mexecute_with_cancellation(\n\u001b[1;32m    601\u001b[0m         \u001b[38;5;28mstr\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msignature\u001b[38;5;241m.\u001b[39mname),\n\u001b[1;32m    602\u001b[0m         num_outputs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_outputs,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    605\u001b[0m         ctx\u001b[38;5;241m=\u001b[39mctx,\n\u001b[1;32m    606\u001b[0m         cancellation_manager\u001b[38;5;241m=\u001b[39mcancellation_manager)\n",
      "File \u001b[0;32m~/miniconda3/envs/cebra/lib/python3.8/site-packages/tensorflow/python/eager/execute.py:59\u001b[0m, in \u001b[0;36mquick_execute\u001b[0;34m(op_name, num_outputs, inputs, attrs, ctx, name)\u001b[0m\n\u001b[1;32m     57\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[1;32m     58\u001b[0m   ctx\u001b[38;5;241m.\u001b[39mensure_initialized()\n\u001b[0;32m---> 59\u001b[0m   tensors \u001b[38;5;241m=\u001b[39m \u001b[43mpywrap_tfe\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mTFE_Py_Execute\u001b[49m\u001b[43m(\u001b[49m\u001b[43mctx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_handle\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdevice_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mop_name\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m     60\u001b[0m \u001b[43m                                      \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mattrs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_outputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m     61\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m core\u001b[38;5;241m.\u001b[39m_NotOkStatusException \u001b[38;5;28;01mas\u001b[39;00m e:\n\u001b[1;32m     62\u001b[0m   \u001b[38;5;28;01mif\u001b[39;00m name \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "directory = \"./data/Fig5_6_Natural/RTGridSU\"\n",
    "files = os.listdir(directory)\n",
    "\n",
    "for file in files:\n",
    "    print(file)\n",
    "    mat_contents = sio.loadmat(os.path.join(directory, file))\n",
    "    neural = mat_contents['neural_whole']\n",
    "    continuous_index_XY = mat_contents['vel_xy_whole']\n",
    "    discrete_index = mat_contents['angle_deg_whole'] \n",
    "    L = neural.shape[0]\n",
    "    N_values_hist = round(L/5)\n",
    "    random_indices = np.random.choice(L, size=N_values_hist, replace=False)\n",
    "    indices_X = continuous_index_XY[random_indices, 0]\n",
    "    indices_Y = continuous_index_XY[random_indices, 1]\n",
    "    index_diffs_X = np.abs(indices_X[:, None] - indices_X[None, :]) \n",
    "    index_diffs_Y = np.abs(indices_Y[:, None] - indices_Y[None, :])\n",
    "    l_dist_XY = index_diffs_X + index_diffs_Y\n",
    "    l_dist_XY_1d = l_dist_XY[~np.eye(N_values_hist, dtype=bool)].flatten()\n",
    "    XY_scale = round(135/np.median(l_dist_XY_1d))\n",
    "    XY_scale = round(np.ceil(XY_scale/100)*100)\n",
    "    continuous_index_XY = continuous_index_XY*XY_scale\n",
    "    continuous_index = np.column_stack((continuous_index_XY, discrete_index))\n",
    "    print('continuous_index=', continuous_index[:10])\n",
    "    N_bins, N_neurons = neural.shape\n",
    "    train_end = int(N_bins * train_percent)// dur * dur\n",
    "    valid_end = train_end + int(N_bins * valid_percent)\n",
    "    valid_end = valid_end// dur * dur\n",
    "    train_neural = neural[:train_end, :]\n",
    "    Y_train = continuous_index[:train_end, :]\n",
    "    valid_neural = neural[train_end:valid_end, :]\n",
    "    Y_valid = continuous_index[train_end:valid_end, :]\n",
    "    test_neural = neural[valid_end:, :]\n",
    "    Y_test = continuous_index[valid_end:, :]\n",
    "    # print(np.unique(Y_train[:, 2]))\n",
    "    X_train = dataset_2D_to_3D(train_neural)\n",
    "    X_valid = dataset_2D_to_3D(valid_neural)\n",
    "    X_test = dataset_2D_to_3D(test_neural)\n",
    "\n",
    "    train_x, train_u = to_batch_list(X_train, Y_train, batch_size)\n",
    "    train_loader = custom_data_generator(train_x, train_u)\n",
    "\n",
    "    valid_x, valid_u = to_batch_list(X_valid, Y_valid, batch_size)\n",
    "    valid_loader = custom_data_generator(valid_x, valid_u)\n",
    "\n",
    "    test_x, test_u = to_batch_list(X_test, Y_test, batch_size)\n",
    "    test_loader  = custom_data_generator(test_x, test_u)\n",
    "    try:    \n",
    "        conv_pivae = pivae_code.conv_pi_vae.conv_vae_mdl(\n",
    "            dim_x = N_neurons,\n",
    "            dim_z = embed_dimension,\n",
    "            dim_u = 3,\n",
    "            time_window=10,\n",
    "            gen_nodes=60,\n",
    "            n_blk=2,\n",
    "            mdl=\"poisson\",\n",
    "            disc=False,\n",
    "            learning_rate=learning_rate)      \n",
    "        s_n = conv_pivae.fit_generator(\n",
    "            train_loader, ### will call \"def custom_data_generator\" \n",
    "            steps_per_epoch=len(train_x), ### 34\n",
    "            epochs=iterations, ### iterations\n",
    "            verbose=1,\n",
    "            validation_data = valid_loader,\n",
    "            validation_steps = len(valid_x))\n",
    "\n",
    "        start_time = time.time()\n",
    "        X = np.concatenate(train_x) ### (Xbins, 10=5ms-offset+5ms-offset, Xneurons)\n",
    "        labels = np.concatenate(train_u) ### (Xbins, position+direction)\n",
    "        outputs_train = conv_pivae.predict([X, labels])\n",
    "        X = np.concatenate(test_x) \n",
    "        labels = np.concatenate(test_u) \n",
    "        outputs_test = conv_pivae.predict([X, labels])\n",
    "        end_time = time.time()\n",
    "        execution_time = np.round((end_time - start_time), 2)\n",
    "        ### Outputs: post_mean, post_log_var, z_sample,fire_rate, lam_mean, lam_log_var, z_mean, z_log_var\n",
    "        cebra_veldir_train = outputs_train[0]\n",
    "        cebra_veldir_test = outputs_test[0]\n",
    "\n",
    "        fig = plt.figure(figsize=(6,5))\n",
    "        ax = plt.subplot(111)\n",
    "        val_loss = s_n.history['val_loss'][:]\n",
    "        loss = s_n.history['loss'][:]\n",
    "        loss = np.array(s_n.history['loss'])\n",
    "        loss_stable = loss[-10:]\n",
    "        # print('loss_stable=', loss_stable)\n",
    "        plt.plot(val_loss, c='deepskyblue', label='val-loss')\n",
    "        plt.plot(loss, c='blue', label='loss')\n",
    "        ax.spines['top'].set_visible(False)\n",
    "        ax.spines['right'].set_visible(False)\n",
    "        ax.set_xlabel('Iterations')\n",
    "        ax.set_ylabel('piVAE Loss')\n",
    "        plt.legend(bbox_to_anchor=(0.5,0.3), frameon = False )\n",
    "        plt.title('itr='+str(iterations)+' loss='+str(int(np.mean(loss_stable))))\n",
    "        filename = file[:16]+\"_loss.pdf\"\n",
    "        file_save = os.path.join('./data/Fig5_6_Natural/RTGridSU_emb_pivae/loss', filename)\n",
    "        plt.savefig(file_save)\n",
    "        plt.close(fig)\n",
    "\n",
    "        X = cebra_veldir_train\n",
    "        y = Y_train[:,0:2]\n",
    "        reg_3d = LinearRegression().fit(X, y)       #### 1st fit ####\n",
    "        pred_vel = reg_3d.predict(X)\n",
    "        vel_train_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "        pca = PCA(n_components=2)\n",
    "        pca_2d = pca.fit(X)                         #### 2nd fit ####\n",
    "        X_2d = pca_2d.transform(X)\n",
    "        reg_2d = LinearRegression().fit(X_2d, y)    #### 3rd fit ####\n",
    "        pred_vel = reg_2d.predict(X_2d)\n",
    "        vel_train_r2_pca = sklearn.metrics.r2_score(y, pred_vel)\n",
    "        vel_train_r2_pca = np.round(vel_train_r2_pca, 4)\n",
    "\n",
    "        print('80% Train LR=', str(learning_rate), \\\n",
    "              ' r2-3D=', str(np.round(vel_train_r2, 3)), ' r2-2D=', str(vel_train_r2_pca))\n",
    "        ###************* use previous trained \"reg_3d & pca_2d & reg_2d\" ###***************\n",
    "        ###************* use previous trained \"reg_3d & pca_2d & reg_2d\" ###***************\n",
    "        X = cebra_veldir_test\n",
    "        y = Y_test[:,0:2]\n",
    "        pred_vel = reg_3d.predict(X)\n",
    "        vel_test_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "        X_2d = pca_2d.transform(X)\n",
    "        pred_vel = reg_2d.predict(X_2d)\n",
    "        vel_test_r2_pca = sklearn.metrics.r2_score(y, pred_vel)\n",
    "        vel_test_r2_pca = np.round(vel_test_r2_pca, 4)\n",
    "\n",
    "        print('20% Test  LR=', str(learning_rate), \\\n",
    "              ' r2-3D=', str(np.round(vel_test_r2, 3)), ' r2-2D=', str(vel_test_r2_pca))\n",
    "\n",
    "        new_filename = file[:16] + \"_LR_\"+str(learning_rate)+ \\\n",
    "            \"_iterations_\"+str(iterations)+ \\\n",
    "            \"_80%train_\"+str(vel_train_r2_pca)+ \\\n",
    "            \"_20%test_\"+str(vel_test_r2_pca)+\".npz\"\n",
    "        file_save = os.path.join('./data/Fig5_6_Natural/RTGridSU_emb_pivae',new_filename)\n",
    "        np.savez(file_save,\n",
    "                 execution_time = execution_time,\n",
    "                 learning_rate = learning_rate,\n",
    "                 iterations = iterations, \n",
    "                 train_loss = loss,\n",
    "                 cebra_veldir_train=cebra_veldir_train,\n",
    "                 cebra_veldir_test=cebra_veldir_test,\n",
    "                 continuous_index_train=Y_train,\n",
    "                 continuous_index_test=Y_test,\n",
    "                 vel_train_r2 = vel_train_r2,\n",
    "                 vel_test_r2 = vel_test_r2,\n",
    "                 vel_train_r2_pca = vel_train_r2_pca,\n",
    "                 vel_test_r2_pca = vel_test_r2_pca)\n",
    "    except Exception as e:\n",
    "        print(' LR=', str(learning_rate), ' fail')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb1afb43",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  },
  "vscode": {
   "interpreter": {
    "hash": "dc327929684d2c13e929b2699e1b37518dbb61b921da51c352c926069002ee0e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
