{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f617ecc5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<frozen importlib._bootstrap>:219: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import joblib as jl\n",
    "import cebra.datasets\n",
    "from cebra import CEBRA\n",
    "import torch\n",
    "import scipy.io as sio\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.linear_model import LinearRegression\n",
    "import sklearn.metrics\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "122b7983",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./data/SU_12PMd/Mihili_20140303_embed_10000itr_PMd.npz\n",
      "XY distance>> 12.972804060568384\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.519  r2-2D= 0.5116\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.359  r2-2D= 0.3691\n",
      "./data/SU_12PMd/Chewie_20161007_embed_10000itr_PMd.npz\n",
      "XY distance>> 12.48179326745967\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.519  r2-2D= 0.5058\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.489  r2-2D= 0.4815\n",
      "./data/SU_12PMd/Mihili_20140306_embed_10000itr_PMd.npz\n",
      "XY distance>> 12.399065844538201\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.618  r2-2D= 0.6141\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.48  r2-2D= 0.4834\n",
      "./data/SU_12PMd/Mihili_20140218_embed_10000itr_PMd.npz\n",
      "XY distance>> 12.73637998329521\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.574  r2-2D= 0.5539\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.478  r2-2D= 0.4629\n",
      "./data/SU_12PMd/Chewie_20161014_embed_10000itr_PMd.npz\n",
      "XY distance>> 12.35571413923275\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.584  r2-2D= 0.5814\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.556  r2-2D= 0.554\n",
      "./data/SU_12PMd/Chewie_20161005_embed_10000itr_PMd.npz\n",
      "XY distance>> 11.363516511473463\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.477  r2-2D= 0.4481\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.433  r2-2D= 0.412\n",
      "./data/SU_12PMd/Chewie_20161021_embed_10000itr_PMd.npz\n",
      "XY distance>> 11.483921575933898\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.469  r2-2D= 0.4582\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.453  r2-2D= 0.4347\n",
      "./data/SU_12PMd/Chewie_20161006_embed_10000itr_PMd.npz\n",
      "XY distance>> 12.795374393980183\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.496  r2-2D= 0.4851\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.49  r2-2D= 0.4779\n",
      "./data/SU_12PMd/Chewie_20160929_embed_10000itr_PMd.npz\n",
      "XY distance>> 11.98002433782431\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.5  r2-2D= 0.4982\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.495  r2-2D= 0.4934\n",
      "./data/SU_12PMd/Mihili_20140217_embed_10000itr_PMd.npz\n",
      "XY distance>> 11.719148111645575\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.51  r2-2D= 0.4827\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.402  r2-2D= 0.383\n",
      "./data/SU_12PMd/Mihili_20140304_embed_10000itr_PMd.npz\n",
      "XY distance>> 12.843095300680805\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.637  r2-2D= 0.632\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.473  r2-2D= 0.4648\n",
      "./data/SU_12PMd/Mihili_20140307_embed_10000itr_PMd.npz\n",
      "XY distance>> 12.149602657290968\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.56  r2-2D= 0.5563\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.382  r2-2D= 0.3802\n"
     ]
    }
   ],
   "source": [
    "dur = 40\n",
    "iterations = 10*1000\n",
    "batch_size = 512\n",
    "learning_rate = 0.001\n",
    "output_dimension = 3\n",
    "Temp_para = [0.08]\n",
    "def split_data(neural, continuous_index):\n",
    "            L = neural.shape[0]\n",
    "            split_idx = round(L*0.8) \n",
    "            neural_train = neural[:split_idx]\n",
    "            neural_test = neural[split_idx:]\n",
    "            continuous_index_train = continuous_index[:split_idx]\n",
    "            continuous_index_test = continuous_index[split_idx:]\n",
    "            return neural_train,neural_test,continuous_index_train,continuous_index_test\n",
    "        \n",
    "angle_to_new_value = {-180: 4,-135: 5,-90: 6,-45: 7,0: 0,45: 1,90: 2,135: 3,180: 4}\n",
    "\n",
    "directory = \"./data/SU_12PMd/\"\n",
    "files = os.listdir(directory)\n",
    "for file in files:\n",
    "#     if f_type in file:\n",
    "    mat_contents = sio.loadmat(os.path.join(directory, file))\n",
    "    filename_parts = file.split(\"_neural_con_dis_index\")\n",
    "    new_filename = filename_parts[0] + \"_embed_\"+str(iterations)+\"itr_PMd.npz\"\n",
    "    file_save = os.path.join(directory, new_filename)\n",
    "    print(file_save)\n",
    "\n",
    "    neural = mat_contents['neural_PMd']\n",
    "    continuous_index_XY = mat_contents['continuous_index']\n",
    "    discrete_index = mat_contents['discrete_index'] ## angles range from -180deg to +180deg\n",
    "    vectorized_map = np.vectorize(lambda x: angle_to_new_value[x])\n",
    "    discrete_index = 45*vectorized_map(discrete_index)\n",
    "\n",
    "    L = neural.shape[0]\n",
    "    N_values_hist = round(L/5)\n",
    "    random_indices = np.random.choice(L, size=N_values_hist, replace=False)\n",
    "    indices_X = continuous_index_XY[random_indices, 0]\n",
    "    indices_Y = continuous_index_XY[random_indices, 1]\n",
    "    index_diffs_X = np.abs(indices_X[:, None] - indices_X[None, :]) \n",
    "    index_diffs_Y = np.abs(indices_Y[:, None] - indices_Y[None, :])\n",
    "    l_dist_XY = index_diffs_X + index_diffs_Y\n",
    "    l_dist_XY_1d = l_dist_XY[~np.eye(N_values_hist, dtype=bool)].flatten()\n",
    "    print('XY distance>>', np.median(l_dist_XY_1d))\n",
    "    angles = np.squeeze(discrete_index[random_indices])\n",
    "    angle_diffs = np.abs(angles[:, None] - angles[None, :]) \n",
    "    circular_angle_diffs = np.minimum(angle_diffs, 360 - angle_diffs)\n",
    "    l_dist_Z_1d = circular_angle_diffs[~np.eye(N_values_hist, dtype=bool)].flatten()\n",
    "    print('Z distance>>', np.median(l_dist_Z_1d))\n",
    "\n",
    "    XY_scale = 10\n",
    "    continuous_index_XY = continuous_index_XY*XY_scale\n",
    "    #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "#     XYZ_threshold = np.median(l_dist_XY_1d)*XY_scale + np.median(l_dist_Z_1d)\n",
    "    #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "\n",
    "    for temp in range(len(Temp_para)): \n",
    "\n",
    "        continuous_index = np.column_stack((continuous_index_XY, discrete_index))\n",
    "\n",
    "        neural_train, neural_test, continuous_index_train, \\\n",
    "                    continuous_index_test = split_data(neural, continuous_index)\n",
    "\n",
    "        #####>>>>>>>>>>>>>>>>>>>>>>>>  Only for NMR\n",
    "#         L_train = neural_train.shape[0]\n",
    "#         conr_2para = np.full((L_train,), 0.001)\n",
    "#         conr_2para[:2] = [XYZ_threshold, Temp_para[temp]]\n",
    "#         continuous_index_train = np.column_stack((continuous_index_train, conr_2para))\n",
    "        #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "        try: \n",
    "            cebra_veldir_model = CEBRA(model_architecture='offset10-model',\n",
    "                       batch_size = batch_size,\n",
    "                       learning_rate = learning_rate,\n",
    "                       temperature = Temp_para[temp],\n",
    "                       output_dimension = output_dimension,\n",
    "                       max_iterations=iterations,\n",
    "                       distance='cosine',\n",
    "                       conditional='time_delta',\n",
    "                       verbose=False,\n",
    "                       time_offsets=10)\n",
    "            start_time = time.time()\n",
    "            cebra_veldir_model.fit(neural_train, continuous_index_train)\n",
    "            end_time = time.time()\n",
    "            execution_time = np.round((end_time - start_time), 2)\n",
    "\n",
    "            cebra_veldir_train = cebra_veldir_model.transform(neural_train)\n",
    "            cebra_veldir_test  = cebra_veldir_model.transform(neural_test)\n",
    "\n",
    "            train_loss = cebra_veldir_model.state_dict_['loss']\n",
    "            X = cebra_veldir_train\n",
    "            y = continuous_index_train[:,0:2]\n",
    "            reg_3d = LinearRegression().fit(X, y)       #### 1st fit ####\n",
    "            pred_vel = reg_3d.predict(X)\n",
    "            vel_train_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "            pca = PCA(n_components=2)\n",
    "            pca_2d = pca.fit(X)                         #### 2nd fit ####\n",
    "            X_2d = pca_2d.transform(X)\n",
    "            reg_2d = LinearRegression().fit(X_2d, y)    #### 3rd fit ####\n",
    "            pred_vel = reg_2d.predict(X_2d)\n",
    "            vel_train_r2_pca = sklearn.metrics.r2_score(y, pred_vel)\n",
    "            vel_train_r2_pca = np.round(vel_train_r2_pca, 4)\n",
    "\n",
    "            print('80% Train Data Temp=', str(Temp_para[temp]), \\\n",
    "                  ' r2-3D=', str(np.round(vel_train_r2, 3)), ' r2-2D=', str(vel_train_r2_pca))\n",
    "            ###************* use previous trained \"reg_3d & pca_2d & reg_2d\" ###***************\n",
    "            ###************* use previous trained \"reg_3d & pca_2d & reg_2d\" ###***************\n",
    "            X = cebra_veldir_test\n",
    "            y = continuous_index_test[:,0:2]\n",
    "            pred_vel = reg_3d.predict(X)\n",
    "            vel_test_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "            X_2d = pca_2d.transform(X)\n",
    "            pred_vel = reg_2d.predict(X_2d)\n",
    "            vel_test_r2_pca = sklearn.metrics.r2_score(y, pred_vel)\n",
    "            vel_test_r2_pca = np.round(vel_test_r2_pca, 4)\n",
    "\n",
    "            print('20% Test  Data Temp=', str(Temp_para[temp]), \\\n",
    "                  ' r2-3D=', str(np.round(vel_test_r2, 3)), ' r2-2D=', str(vel_test_r2_pca))\n",
    "\n",
    "            new_filename = file[:19] + \"_Temp_\"+str(Temp_para[temp])+ \\\n",
    "                \"_iterations_\"+str(iterations)+ \\\n",
    "                \"_80%train_\"+str(vel_train_r2_pca)+ \\\n",
    "                \"_20%test_\"+str(vel_test_r2_pca)+\".npz\"\n",
    "            file_save = os.path.join('./data/Fig2_SU/cebra/emb_PMd',new_filename)\n",
    "            np.savez(file_save,\n",
    "                     execution_time = execution_time,\n",
    "                     temperature = Temp_para[temp],\n",
    "                     iterations = iterations, \n",
    "                     train_loss = train_loss,\n",
    "                     cebra_veldir_train=cebra_veldir_train,\n",
    "                     cebra_veldir_test=cebra_veldir_test,\n",
    "                     continuous_index_train=continuous_index_train,\n",
    "                     continuous_index_test=continuous_index_test,\n",
    "                     vel_train_r2 = vel_train_r2,\n",
    "                     vel_test_r2 = vel_test_r2,\n",
    "                     vel_train_r2_pca = vel_train_r2_pca,\n",
    "                     vel_test_r2_pca = vel_test_r2_pca)\n",
    "        except Exception as e:\n",
    "            print(' Temp=', str(Temp_para[temp]), ' fail')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2a0d76b9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  },
  "vscode": {
   "interpreter": {
    "hash": "dc327929684d2c13e929b2699e1b37518dbb61b921da51c352c926069002ee0e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
