{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "f617ecc5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<frozen importlib._bootstrap>:219: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import joblib as jl\n",
    "import cebra.datasets\n",
    "from cebra import CEBRA\n",
    "import torch\n",
    "import scipy.io as sio\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.linear_model import LinearRegression\n",
    "import sklearn.metrics\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "122b7983",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "./data/SU_12PMd/Mihili_20140303_embed_5000itr_PMd.npz\n",
      "XY distance>> 12.922589645771321\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.923  r2-2D= 0.8987\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.596  r2-2D= 0.5986\n",
      "./data/SU_12PMd/Chewie_20161007_embed_5000itr_PMd.npz\n",
      "XY distance>> 12.953917261568805\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.928  r2-2D= 0.6213\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.826  r2-2D= 0.579\n",
      "./data/SU_12PMd/Mihili_20140306_embed_5000itr_PMd.npz\n",
      "XY distance>> 12.93645954490035\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.906  r2-2D= 0.9026\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.633  r2-2D= 0.6315\n",
      "./data/SU_12PMd/Mihili_20140218_embed_5000itr_PMd.npz\n",
      "XY distance>> 12.962815113496596\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.93  r2-2D= 0.9294\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.665  r2-2D= 0.6671\n",
      "./data/SU_12PMd/Chewie_20161014_embed_5000itr_PMd.npz\n",
      "XY distance>> 11.80898540773402\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.94  r2-2D= 0.9173\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.901  r2-2D= 0.8849\n",
      "./data/SU_12PMd/Chewie_20161005_embed_5000itr_PMd.npz\n",
      "XY distance>> 12.395154003419698\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.937  r2-2D= 0.9313\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.802  r2-2D= 0.8078\n",
      "./data/SU_12PMd/Chewie_20161021_embed_5000itr_PMd.npz\n",
      "XY distance>> 11.55901397800378\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.905  r2-2D= 0.9028\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.845  r2-2D= 0.842\n",
      "./data/SU_12PMd/Chewie_20161006_embed_5000itr_PMd.npz\n",
      "XY distance>> 13.219618781222147\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.944  r2-2D= 0.6707\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.821  r2-2D= 0.6183\n",
      "./data/SU_12PMd/Chewie_20160929_embed_5000itr_PMd.npz\n",
      "XY distance>> 12.338508184542736\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.938  r2-2D= 0.9157\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.858  r2-2D= 0.8371\n",
      "./data/SU_12PMd/Mihili_20140217_embed_5000itr_PMd.npz\n",
      "XY distance>> 11.970214785006945\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.932  r2-2D= 0.8779\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.631  r2-2D= 0.6057\n",
      "./data/SU_12PMd/Mihili_20140304_embed_5000itr_PMd.npz\n",
      "XY distance>> 12.634467380663594\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.929  r2-2D= 0.9008\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.727  r2-2D= 0.7124\n",
      "./data/SU_12PMd/Mihili_20140307_embed_5000itr_PMd.npz\n",
      "XY distance>> 12.588874058972547\n",
      "Z distance>> 90.0\n",
      "80% Train Data Temp= 0.08  r2-3D= 0.92  r2-2D= 0.8942\n",
      "20% Test  Data Temp= 0.08  r2-3D= 0.618  r2-2D= 0.6148\n"
     ]
    }
   ],
   "source": [
    "dur = 40\n",
    "iterations = 5*1000\n",
    "batch_size = 512\n",
    "learning_rate = 0.001\n",
    "output_dimension = 3\n",
    "# f_type = \".mat\" \n",
    "# Temp_para = [0.02, 0.03, 0.04, 0.05, 0.055, 0.06, 0.065, 0.07, 0.075, 0.08, 0.09, 0.1]\n",
    "Temp_para = [0.08]\n",
    "def split_data(neural, continuous_index):\n",
    "            L = neural.shape[0]\n",
    "            split_idx = round(L*0.8) \n",
    "            neural_train = neural[:split_idx]\n",
    "            neural_test = neural[split_idx:]\n",
    "            continuous_index_train = continuous_index[:split_idx]\n",
    "            continuous_index_test = continuous_index[split_idx:]\n",
    "            return neural_train,neural_test,continuous_index_train,continuous_index_test\n",
    "        \n",
    "angle_to_new_value = {-180: 4,-135: 5,-90: 6,-45: 7,0: 0,45: 1,90: 2,135: 3,180: 4}\n",
    "\n",
    "directory = \"./data/SU_12PMd/\"\n",
    "files = os.listdir(directory)\n",
    "for file in files:\n",
    "#     if f_type in file:\n",
    "    mat_contents = sio.loadmat(os.path.join(directory, file))\n",
    "    filename_parts = file.split(\"_neural_con_dis_index\")\n",
    "    new_filename = filename_parts[0] + \"_embed_\"+str(iterations)+\"itr_PMd.npz\"\n",
    "    file_save = os.path.join(directory, new_filename)\n",
    "    print(file_save)\n",
    "\n",
    "    neural = mat_contents['neural_PMd']\n",
    "    continuous_index_XY = mat_contents['continuous_index']\n",
    "    discrete_index = mat_contents['discrete_index'] ## angles range from -180deg to +180deg\n",
    "    vectorized_map = np.vectorize(lambda x: angle_to_new_value[x])\n",
    "    discrete_index = 45*vectorized_map(discrete_index)\n",
    "\n",
    "    L = neural.shape[0]\n",
    "    N_values_hist = round(L/5)\n",
    "    random_indices = np.random.choice(L, size=N_values_hist, replace=False)\n",
    "    indices_X = continuous_index_XY[random_indices, 0]\n",
    "    indices_Y = continuous_index_XY[random_indices, 1]\n",
    "    index_diffs_X = np.abs(indices_X[:, None] - indices_X[None, :]) \n",
    "    index_diffs_Y = np.abs(indices_Y[:, None] - indices_Y[None, :])\n",
    "    l_dist_XY = index_diffs_X + index_diffs_Y\n",
    "    l_dist_XY_1d = l_dist_XY[~np.eye(N_values_hist, dtype=bool)].flatten()\n",
    "    print('XY distance>>', np.median(l_dist_XY_1d))\n",
    "    angles = np.squeeze(discrete_index[random_indices])\n",
    "    angle_diffs = np.abs(angles[:, None] - angles[None, :]) \n",
    "    circular_angle_diffs = np.minimum(angle_diffs, 360 - angle_diffs)\n",
    "    l_dist_Z_1d = circular_angle_diffs[~np.eye(N_values_hist, dtype=bool)].flatten()\n",
    "    print('Z distance>>', np.median(l_dist_Z_1d))\n",
    "\n",
    "    XY_scale = 10\n",
    "    continuous_index_XY = continuous_index_XY*XY_scale\n",
    "    #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "    XYZ_threshold = np.median(l_dist_XY_1d)*XY_scale + np.median(l_dist_Z_1d)\n",
    "    #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "\n",
    "    for temp in range(len(Temp_para)): \n",
    "\n",
    "        continuous_index = np.column_stack((continuous_index_XY, discrete_index))\n",
    "\n",
    "        neural_train, neural_test, continuous_index_train, \\\n",
    "                    continuous_index_test = split_data(neural, continuous_index)\n",
    "\n",
    "        #####>>>>>>>>>>>>>>>>>>>>>>>>  Only for NMR\n",
    "        L_train = neural_train.shape[0]\n",
    "        conr_2para = np.full((L_train,), 0.001)\n",
    "        conr_2para[:2] = [XYZ_threshold, Temp_para[temp]]\n",
    "        continuous_index_train = np.column_stack((continuous_index_train, conr_2para))\n",
    "        #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "        try: \n",
    "            cebra_veldir_model = CEBRA(model_architecture='offset10-model',\n",
    "                       batch_size = batch_size,\n",
    "                       learning_rate = learning_rate,\n",
    "                       output_dimension = output_dimension,\n",
    "                       max_iterations=iterations,\n",
    "                       distance='cosine',\n",
    "                       conditional='time_delta',\n",
    "                       verbose=False,\n",
    "                       time_offsets=10)\n",
    "            start_time = time.time()\n",
    "            cebra_veldir_model.fit(neural_train, continuous_index_train)\n",
    "            end_time = time.time()\n",
    "            execution_time = np.round((end_time - start_time), 2)\n",
    "\n",
    "            cebra_veldir_train = cebra_veldir_model.transform(neural_train)\n",
    "            cebra_veldir_test  = cebra_veldir_model.transform(neural_test)\n",
    "\n",
    "            train_loss = cebra_veldir_model.state_dict_['loss']\n",
    "            X = cebra_veldir_train\n",
    "            y = continuous_index_train[:,0:2]\n",
    "            reg_3d = LinearRegression().fit(X, y)       #### 1st fit ####\n",
    "            pred_vel = reg_3d.predict(X)\n",
    "            vel_train_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "            pca = PCA(n_components=2)\n",
    "            pca_2d = pca.fit(X)                         #### 2nd fit ####\n",
    "            X_2d = pca_2d.transform(X)\n",
    "            reg_2d = LinearRegression().fit(X_2d, y)    #### 3rd fit ####\n",
    "            pred_vel = reg_2d.predict(X_2d)\n",
    "            vel_train_r2_pca = sklearn.metrics.r2_score(y, pred_vel)\n",
    "            vel_train_r2_pca = np.round(vel_train_r2_pca, 4)\n",
    "\n",
    "            print('80% Train Data Temp=', str(Temp_para[temp]), \\\n",
    "                  ' r2-3D=', str(np.round(vel_train_r2, 3)), ' r2-2D=', str(vel_train_r2_pca))\n",
    "            ###************* use previous trained \"reg_3d & pca_2d & reg_2d\" ###***************\n",
    "            ###************* use previous trained \"reg_3d & pca_2d & reg_2d\" ###***************\n",
    "            X = cebra_veldir_test\n",
    "            y = continuous_index_test[:,0:2]\n",
    "            pred_vel = reg_3d.predict(X)\n",
    "            vel_test_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "            X_2d = pca_2d.transform(X)\n",
    "            pred_vel = reg_2d.predict(X_2d)\n",
    "            vel_test_r2_pca = sklearn.metrics.r2_score(y, pred_vel)\n",
    "            vel_test_r2_pca = np.round(vel_test_r2_pca, 4)\n",
    "\n",
    "            print('20% Test  Data Temp=', str(Temp_para[temp]), \\\n",
    "                  ' r2-3D=', str(np.round(vel_test_r2, 3)), ' r2-2D=', str(vel_test_r2_pca))\n",
    "\n",
    "            new_filename = file[:19] + \"_Temp_\"+str(Temp_para[temp])+ \\\n",
    "                \"_iterations_\"+str(iterations)+ \\\n",
    "                \"_80%train_\"+str(vel_train_r2_pca)+ \\\n",
    "                \"_20%test_\"+str(vel_test_r2_pca)+\".npz\"\n",
    "            file_save = os.path.join('./data/Fig2_SU/nmr/emb_PMd',new_filename)\n",
    "            np.savez(file_save,\n",
    "                     execution_time = execution_time,\n",
    "                     temperature = Temp_para[temp],\n",
    "                     iterations = iterations, \n",
    "                     train_loss = train_loss,\n",
    "                     cebra_veldir_train=cebra_veldir_train,\n",
    "                     cebra_veldir_test=cebra_veldir_test,\n",
    "                     continuous_index_train=continuous_index_train,\n",
    "                     continuous_index_test=continuous_index_test,\n",
    "                     vel_train_r2 = vel_train_r2,\n",
    "                     vel_test_r2 = vel_test_r2,\n",
    "                     vel_train_r2_pca = vel_train_r2_pca,\n",
    "                     vel_test_r2_pca = vel_test_r2_pca)\n",
    "        except Exception as e:\n",
    "            print(' Temp=', str(Temp_para[temp]), ' fail')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  },
  "vscode": {
   "interpreter": {
    "hash": "dc327929684d2c13e929b2699e1b37518dbb61b921da51c352c926069002ee0e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
