{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f617ecc5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<frozen importlib._bootstrap>:219: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import joblib as jl\n",
    "import cebra.datasets\n",
    "from cebra import CEBRA\n",
    "import torch\n",
    "import scipy.io as sio\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.linear_model import LinearRegression\n",
    "import sklearn.metrics\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "122b7983",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./data/LFP_12PMd/Mihili_20140218_Beta.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.43  r2-2D= 0.3364\n",
      "20% Test  Data Temp= 0.1  r2-3D= -0.38  r2-2D= -0.3511\n",
      "./data/LFP_12PMd/Mihili_20140304_Beta.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.504  r2-2D= 0.452\n",
      "20% Test  Data Temp= 0.1  r2-3D= -0.391  r2-2D= -0.3705\n",
      "./data/LFP_12PMd/Chewie_20161007_Beta.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.45  r2-2D= 0.4043\n",
      "20% Test  Data Temp= 0.1  r2-3D= -0.241  r2-2D= -0.25\n",
      "./data/LFP_12PMd/Chewie_20161007_Gamma.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.513  r2-2D= 0.4816\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.436  r2-2D= 0.4439\n",
      "./data/LFP_12PMd/Chewie_20161014_Gamma.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.528  r2-2D= 0.5235\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.467  r2-2D= 0.4684\n",
      "./data/LFP_12PMd/Chewie_20160929_LMP.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.483  r2-2D= 0.4759\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.405  r2-2D= 0.3954\n",
      "./data/LFP_12PMd/Mihili_20140303_Gamma.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.481  r2-2D= 0.4641\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.248  r2-2D= 0.2289\n",
      "./data/LFP_12PMd/Mihili_20140303_Beta.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.387  r2-2D= 0.1563\n",
      "20% Test  Data Temp= 0.1  r2-3D= -0.323  r2-2D= -0.1069\n",
      "./data/LFP_12PMd/Chewie_20161021_Beta.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.34  r2-2D= 0.2411\n",
      "20% Test  Data Temp= 0.1  r2-3D= -0.181  r2-2D= -0.1035\n",
      "./data/LFP_12PMd/Chewie_20160929_Gamma.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.485  r2-2D= 0.4806\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.433  r2-2D= 0.4356\n",
      "./data/LFP_12PMd/Chewie_20161006_Beta.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.397  r2-2D= 0.3448\n",
      "20% Test  Data Temp= 0.1  r2-3D= -0.254  r2-2D= -0.2077\n",
      "./data/LFP_12PMd/Mihili_20140307_Gamma.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.51  r2-2D= 0.4861\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.317  r2-2D= 0.3212\n",
      "./data/LFP_12PMd/Chewie_20161021_LMP.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.419  r2-2D= 0.3962\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.334  r2-2D= 0.3365\n",
      "./data/LFP_12PMd/Chewie_20161006_Gamma.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.491  r2-2D= 0.48\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.435  r2-2D= 0.4328\n",
      "./data/LFP_12PMd/Mihili_20140217_Gamma.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.485  r2-2D= 0.446\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.327  r2-2D= 0.3151\n",
      "./data/LFP_12PMd/Mihili_20140307_LMP.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.543  r2-2D= 0.5333\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.145  r2-2D= 0.1523\n",
      "./data/LFP_12PMd/Mihili_20140217_Beta.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.356  r2-2D= 0.1431\n",
      "20% Test  Data Temp= 0.1  r2-3D= -0.231  r2-2D= -0.107\n",
      "./data/LFP_12PMd/Mihili_20140307_Beta.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.46  r2-2D= 0.4234\n",
      "20% Test  Data Temp= 0.1  r2-3D= -0.261  r2-2D= -0.225\n",
      "./data/LFP_12PMd/Mihili_20140218_LMP.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.553  r2-2D= 0.5313\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.257  r2-2D= 0.2749\n",
      "./data/LFP_12PMd/Chewie_20161006_LMP.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.471  r2-2D= 0.4511\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.361  r2-2D= 0.3495\n",
      "./data/LFP_12PMd/Mihili_20140306_Beta.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.394  r2-2D= 0.3092\n",
      "20% Test  Data Temp= 0.1  r2-3D= -0.178  r2-2D= -0.1365\n",
      "./data/LFP_12PMd/Mihili_20140217_LMP.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.478  r2-2D= 0.4531\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.127  r2-2D= 0.1497\n",
      "./data/LFP_12PMd/Chewie_20161007_LMP.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.511  r2-2D= 0.4947\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.34  r2-2D= 0.3527\n",
      "./data/LFP_12PMd/Chewie_20161021_Gamma.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.449  r2-2D= 0.4337\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.375  r2-2D= 0.3504\n",
      "./data/LFP_12PMd/Chewie_20161005_Gamma.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.465  r2-2D= 0.435\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.4  r2-2D= 0.3823\n",
      "./data/LFP_12PMd/Chewie_20161005_Beta.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.352  r2-2D= 0.1257\n",
      "20% Test  Data Temp= 0.1  r2-3D= -0.285  r2-2D= -0.0895\n",
      "./data/LFP_12PMd/Mihili_20140304_Gamma.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.609  r2-2D= 0.606\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.198  r2-2D= 0.1993\n",
      "./data/LFP_12PMd/Chewie_20161014_LMP.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.535  r2-2D= 0.5293\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.486  r2-2D= 0.4861\n",
      "./data/LFP_12PMd/Chewie_20160929_Beta.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.415  r2-2D= 0.3849\n",
      "20% Test  Data Temp= 0.1  r2-3D= -0.159  r2-2D= -0.1347\n",
      "./data/LFP_12PMd/Mihili_20140306_Gamma.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.56  r2-2D= 0.5465\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.345  r2-2D= 0.3439\n",
      "./data/LFP_12PMd/Mihili_20140306_LMP.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.628  r2-2D= 0.6245\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.248  r2-2D= 0.2484\n",
      "./data/LFP_12PMd/Chewie_20161014_Beta.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.213  r2-2D= 0.1634\n",
      "20% Test  Data Temp= 0.1  r2-3D= -0.104  r2-2D= -0.0734\n",
      "./data/LFP_12PMd/Chewie_20161005_LMP.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.47  r2-2D= 0.4366\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.358  r2-2D= 0.3488\n",
      "./data/LFP_12PMd/Mihili_20140304_LMP.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.644  r2-2D= 0.6407\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.148  r2-2D= 0.1515\n",
      "./data/LFP_12PMd/Mihili_20140303_LMP.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.494  r2-2D= 0.4744\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.153  r2-2D= 0.1631\n",
      "./data/LFP_12PMd/Mihili_20140218_Gamma.mat_embed_5000itr_PMd.npz\n",
      "80% Train Data Temp= 0.1  r2-3D= 0.476  r2-2D= 0.4334\n",
      "20% Test  Data Temp= 0.1  r2-3D= 0.28  r2-2D= 0.2586\n"
     ]
    }
   ],
   "source": [
    "dur = 40\n",
    "iterations = 5*1000\n",
    "batch_size=512\n",
    "learning_rate = 0.001\n",
    "output_dimension = 3\n",
    "# LFP_ch = \"Gamma\" ### \"LMP\" or \"Gamma\" or \"Beta\"\n",
    "# Temp_para = [0.0001, 0.001, 0.01, 0.1, 1, 10] # for hyperparameter search \n",
    "Temp_para = [0.1]\n",
    "def split_data(neural, continuous_index):\n",
    "            L = neural.shape[0]\n",
    "            split_idx = round(L*0.8) \n",
    "            neural_train = neural[:split_idx]\n",
    "            neural_test = neural[split_idx:]\n",
    "            continuous_index_train = continuous_index[:split_idx]\n",
    "            continuous_index_test = continuous_index[split_idx:]\n",
    "            return neural_train,neural_test,continuous_index_train,continuous_index_test\n",
    "        \n",
    "angle_to_new_value = {-180: 4,-135: 5,-90: 6,-45: 7,0: 0,45: 1,90: 2,135: 3,180: 4}\n",
    "\n",
    "directory = \"./data/LFP_12PMd\"\n",
    "files = os.listdir(directory)\n",
    "for file in files:\n",
    "#     if LFP_ch in file:\n",
    "    mat_contents = sio.loadmat(os.path.join(directory, file))\n",
    "    filename_parts = file.split(\"_neural_con_dis_index\")\n",
    "    new_filename = filename_parts[0] + \"_embed_\"+str(iterations)+\"itr_PMd.npz\"\n",
    "    file_save = os.path.join(directory, new_filename)\n",
    "    print(file_save)\n",
    "\n",
    "    neural = mat_contents['lfp_PMd']\n",
    "    continuous_index_XY = mat_contents['continuous_index']\n",
    "    discrete_index = mat_contents['discrete_index'] ## angles range from -180deg to +180deg\n",
    "    vectorized_map = np.vectorize(lambda x: angle_to_new_value[x])\n",
    "    discrete_index = 45*vectorized_map(discrete_index)\n",
    "\n",
    "    L = neural.shape[0]\n",
    "    N_values_hist = round(L/5)\n",
    "    random_indices = np.random.choice(L, size=N_values_hist, replace=False)\n",
    "    indices_X = continuous_index_XY[random_indices, 0]\n",
    "    indices_Y = continuous_index_XY[random_indices, 1]\n",
    "    index_diffs_X = np.abs(indices_X[:, None] - indices_X[None, :]) \n",
    "    index_diffs_Y = np.abs(indices_Y[:, None] - indices_Y[None, :])\n",
    "    l_dist_XY = index_diffs_X + index_diffs_Y\n",
    "    l_dist_XY_1d = l_dist_XY[~np.eye(N_values_hist, dtype=bool)].flatten()\n",
    "    # print('XY distance>>', np.median(l_dist_XY_1d))\n",
    "    angles = np.squeeze(discrete_index[random_indices])\n",
    "    angle_diffs = np.abs(angles[:, None] - angles[None, :]) \n",
    "    circular_angle_diffs = np.minimum(angle_diffs, 360 - angle_diffs)\n",
    "    l_dist_Z_1d = circular_angle_diffs[~np.eye(N_values_hist, dtype=bool)].flatten()\n",
    "    # print('Z distance>>', np.median(l_dist_Z_1d))\n",
    "\n",
    "    XY_scale = 10\n",
    "    continuous_index_XY = continuous_index_XY*XY_scale\n",
    "    #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "#         XYZ_threshold = np.median(l_dist_XY_1d)*XY_scale + np.median(l_dist_Z_1d)\n",
    "    #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "\n",
    "    for temp in range(len(Temp_para)): \n",
    "\n",
    "        continuous_index = np.column_stack((continuous_index_XY, discrete_index))\n",
    "\n",
    "        neural_train, neural_test, continuous_index_train, \\\n",
    "                    continuous_index_test = split_data(neural, continuous_index)\n",
    "\n",
    "        #####>>>>>>>>>>>>>>>>>>>>>>>>  Only for NMR\n",
    "#             L_train = neural_train.shape[0]\n",
    "#             conr_2para = np.full((L_train,), 0.001)\n",
    "#             conr_2para[:2] = [XYZ_threshold, Temp_para[temp]]\n",
    "#             continuous_index_train = np.column_stack((continuous_index_train, conr_2para))\n",
    "        #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "        try: \n",
    "            cebra_veldir_model = CEBRA(model_architecture='offset10-model',\n",
    "                       batch_size = batch_size,\n",
    "                       learning_rate = learning_rate,\n",
    "                       temperature = Temp_para[temp],\n",
    "                       output_dimension = output_dimension,\n",
    "                       max_iterations=iterations,\n",
    "                       distance='cosine',\n",
    "                       conditional='time_delta',\n",
    "                       verbose=False,\n",
    "                       time_offsets=10)\n",
    "            start_time = time.time()\n",
    "            cebra_veldir_model.fit(neural_train, continuous_index_train)\n",
    "            end_time = time.time()\n",
    "            execution_time = np.round((end_time - start_time), 2)\n",
    "\n",
    "            cebra_veldir_train = cebra_veldir_model.transform(neural_train)\n",
    "            cebra_veldir_test  = cebra_veldir_model.transform(neural_test)\n",
    "\n",
    "            train_loss = cebra_veldir_model.state_dict_['loss']\n",
    "            X = cebra_veldir_train\n",
    "            y = continuous_index_train[:,0:2]\n",
    "            reg_3d = LinearRegression().fit(X, y)       #### 1st fit ####\n",
    "            pred_vel = reg_3d.predict(X)\n",
    "            vel_train_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "            pca = PCA(n_components=2)\n",
    "            pca_2d = pca.fit(X)                         #### 2nd fit ####\n",
    "            X_2d = pca_2d.transform(X)\n",
    "            reg_2d = LinearRegression().fit(X_2d, y)    #### 3rd fit ####\n",
    "            pred_vel = reg_2d.predict(X_2d)\n",
    "            vel_train_r2_pca = sklearn.metrics.r2_score(y, pred_vel)\n",
    "            vel_train_r2_pca = np.round(vel_train_r2_pca, 4)\n",
    "\n",
    "            print('80% Train Data Temp=', str(Temp_para[temp]), \\\n",
    "                  ' r2-3D=', str(np.round(vel_train_r2, 3)), ' r2-2D=', str(vel_train_r2_pca))\n",
    "            ###************* use previous trained \"reg_3d & pca_2d & reg_2d\" ###***************\n",
    "            ###************* use previous trained \"reg_3d & pca_2d & reg_2d\" ###***************\n",
    "            X = cebra_veldir_test\n",
    "            y = continuous_index_test[:,0:2]\n",
    "            pred_vel = reg_3d.predict(X)\n",
    "            vel_test_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "            X_2d = pca_2d.transform(X)\n",
    "            pred_vel = reg_2d.predict(X_2d)\n",
    "            vel_test_r2_pca = sklearn.metrics.r2_score(y, pred_vel)\n",
    "            vel_test_r2_pca = np.round(vel_test_r2_pca, 4)\n",
    "\n",
    "            print('20% Test  Data Temp=', str(Temp_para[temp]), \\\n",
    "                  ' r2-3D=', str(np.round(vel_test_r2, 3)), ' r2-2D=', str(vel_test_r2_pca))\n",
    "\n",
    "            new_filename = file[:19] + \"_Temp_\"+str(Temp_para[temp])+ \\\n",
    "                \"_iterations_\"+str(iterations)+ \\\n",
    "                \"_80%train_\"+str(vel_train_r2_pca)+ \\\n",
    "                \"_20%test_\"+str(vel_test_r2_pca)+\".npz\"\n",
    "            # file_save = os.path.join('./data/Fig5_6_Natural/RTGridSU_emb_nmr_5000itr',new_filename)\n",
    "            # file_save = os.path.join('./data/Fig5_6_Natural/RTGridSU_emb_nmr_lr0.0001',new_filename)\n",
    "            file_save = os.path.join('./data/Fig3_4_LFP/cebra/emb_PMd',new_filename)\n",
    "            # file_save = os.path.join('./data/Fig5_6_Natural/RTGridSU_emb_nmr_lr0.005',new_filename)\n",
    "            np.savez(file_save,\n",
    "                     execution_time = execution_time,\n",
    "                     temperature = Temp_para[temp],\n",
    "                     iterations = iterations, \n",
    "                     train_loss = train_loss,\n",
    "                     cebra_veldir_train=cebra_veldir_train,\n",
    "                     cebra_veldir_test=cebra_veldir_test,\n",
    "                     continuous_index_train=continuous_index_train,\n",
    "                     continuous_index_test=continuous_index_test,\n",
    "                     vel_train_r2 = vel_train_r2,\n",
    "                     vel_test_r2 = vel_test_r2,\n",
    "                     vel_train_r2_pca = vel_train_r2_pca,\n",
    "                     vel_test_r2_pca = vel_test_r2_pca)\n",
    "        except Exception as e:\n",
    "            print(' Temp=', str(Temp_para[temp]), ' fail')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c09b28a6",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  },
  "vscode": {
   "interpreter": {
    "hash": "dc327929684d2c13e929b2699e1b37518dbb61b921da51c352c926069002ee0e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
