{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "110dcbef",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<frozen importlib._bootstrap>:219: RuntimeWarning: scipy._lib.messagestream.MessageStream size changed, may indicate binary incompatibility. Expected 56 from C header, got 64 from PyObject\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "XY distance>> 105.63815932197089\n",
      "Z distance>> 90.0\n",
      "100% Data Temp= 0.045  r2-3D= 0.968704  r2-2D= 0.966347\n",
      "100% Data Temp= 0.045  r2-3D= 0.95473  r2-2D= 0.954635\n",
      "100% Data Temp= 0.045  r2-3D= 0.960515  r2-2D= 0.960341\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import joblib as jl\n",
    "import cebra.datasets\n",
    "from cebra import CEBRA\n",
    "import scipy.io as sio\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.linear_model import LinearRegression,LogisticRegression\n",
    "import sklearn.metrics\n",
    "import time\n",
    "\n",
    "mat_contents=sio.loadmat(\"./data/Fig7_Handwriting/StraightLines_sigmoid_long_pos.mat\") \n",
    "dur = 100\n",
    "neural = mat_contents['neural']\n",
    "continuous_index_XY = mat_contents['location_XY']\n",
    "discrete_index = mat_contents['direction_angle'] \n",
    "\n",
    "iterations = 20 * 1000\n",
    "learning_rate = 0.001\n",
    "XY_scale = 3\n",
    "batch_size = 512\n",
    "output_dimension = 3\n",
    "# Temp_para = [0.035, 0.04, 0.045]\n",
    "Temp_para = [0.045, 0.045, 0.045]\n",
    "# Temp_para = [0.03]\n",
    "\n",
    "L = neural.shape[0]\n",
    "N_values_hist = round(L/20)\n",
    "random_indices = np.random.choice(L, size=N_values_hist, replace=False)\n",
    "indices_X = continuous_index_XY[random_indices, 0]\n",
    "indices_Y = continuous_index_XY[random_indices, 1]\n",
    "index_diffs_X = np.abs(indices_X[:, None] - indices_X[None, :]) \n",
    "index_diffs_Y = np.abs(indices_Y[:, None] - indices_Y[None, :])\n",
    "l_dist_XY = index_diffs_X + index_diffs_Y\n",
    "l_dist_XY_1d = l_dist_XY[~np.eye(N_values_hist, dtype=bool)].flatten()\n",
    "print('XY distance>>', np.median(l_dist_XY_1d))\n",
    "angles = np.squeeze(discrete_index[random_indices])\n",
    "angle_diffs = np.abs(angles[:, None] - angles[None, :]) \n",
    "circular_angle_diffs = np.minimum(angle_diffs, 360 - angle_diffs)\n",
    "l_dist_Z_1d = circular_angle_diffs[~np.eye(N_values_hist, dtype=bool)].flatten()\n",
    "print('Z distance>>', np.median(l_dist_Z_1d))\n",
    "continuous_index_XY = continuous_index_XY*XY_scale\n",
    "#####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "XYZ_threshold = np.median(l_dist_XY_1d)*XY_scale + np.median(l_dist_Z_1d)\n",
    "#####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "\n",
    "for temp in range(len(Temp_para)):\n",
    "    continuous_index = np.column_stack((continuous_index_XY, discrete_index))\n",
    "    #####>>>>>>>>>>>>>>>>>>>>>>>>  Only for NMR\n",
    "    conr_2para = np.full((L,), 0.001)\n",
    "    conr_2para[:2] = [XYZ_threshold, Temp_para[temp]]\n",
    "    continuous_index = np.column_stack((continuous_index, conr_2para))\n",
    "    #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "    try:\n",
    "        cebra_veldir_model = CEBRA(model_architecture='offset10-model',\n",
    "                                batch_size=batch_size,\n",
    "                                learning_rate = learning_rate,\n",
    "                                output_dimension = output_dimension ,\n",
    "                                max_iterations=iterations,\n",
    "                                distance='cosine',\n",
    "                                conditional='time_delta',\n",
    "                                device='cuda_if_available',\n",
    "                                verbose=False,\n",
    "                                time_offsets=10)\n",
    "        start_time = time.time()\n",
    "        cebra_veldir_model.fit(neural, continuous_index)\n",
    "        end_time = time.time()\n",
    "        execution_time = np.round((end_time - start_time), 2)\n",
    "        cebra_veldir = cebra_veldir_model.transform(neural) \n",
    "        train_loss = cebra_veldir_model.state_dict_['loss']\n",
    "\n",
    "        X = cebra_veldir\n",
    "        y = continuous_index[:,0:2]\n",
    "        reg = LinearRegression().fit(X, y)\n",
    "        pred_vel = reg.predict(X)\n",
    "        vel_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "        vel_r2_3d = np.round(vel_r2, 6)\n",
    "\n",
    "        pca = PCA(n_components=2)\n",
    "        embedding_2d = pca.fit_transform(cebra_veldir)\n",
    "        X = embedding_2d\n",
    "        reg = LinearRegression().fit(X, y)\n",
    "        pred_vel = reg.predict(X)\n",
    "        vel_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "        vel_r2_2d = np.round(vel_r2, 6)\n",
    "\n",
    "        print('100% Data Temp=', str(Temp_para[temp]), \\\n",
    "                      ' r2-3D=', str(vel_r2_3d), ' r2-2D=', str(vel_r2_2d))\n",
    "        new_filename = \"Temp_\"+str(Temp_para[temp])+ \\\n",
    "                    \"_iterations_\"+str(iterations)+ \\\n",
    "                    \"_3D_\"+str(vel_r2_3d)+ \\\n",
    "                    \"_2D_\"+str(vel_r2_2d)+\".npz\"\n",
    "        file_save = os.path.join('./NMR_Figs/Fig7_Handwriting',new_filename)\n",
    "        np.savez(file_save,\n",
    "                 execution_time = execution_time,\n",
    "                 temperature = Temp_para[temp],\n",
    "                 iterations = iterations, \n",
    "                 train_loss = train_loss,\n",
    "                 cebra_veldir=cebra_veldir,\n",
    "                 continuous_index=continuous_index,\n",
    "                 vel_r2_3d = vel_r2_3d,\n",
    "                 vel_r2_2d = vel_r2_2d)\n",
    "\n",
    "        fig = plt.figure(figsize=(3, 3), dpi=250)\n",
    "        change_points = np.where(discrete_index[:-1] != discrete_index[1:])[0] + 1\n",
    "        trial_boundaries = np.concatenate(([0], change_points, [len(discrete_index)]))\n",
    "        for i in range(len(trial_boundaries) - 1):\n",
    "            start_idx = trial_boundaries[i] + 5\n",
    "            end_idx = trial_boundaries[i + 1] - 1\n",
    "            trial_embeddings = embedding_2d[start_idx:end_idx, :]\n",
    "            trial_color = plt.cm.hsv(1 / 360 * discrete_index[start_idx])\n",
    "            plt.plot(trial_embeddings[:, 0], trial_embeddings[:, 1], color=trial_color, alpha=0.75, linewidth=0.5)\n",
    "            plt.axis('off')\n",
    "        plt.title('r2-2D = '+str(vel_r2_2d))\n",
    "        new_filename = \"Temp_\"+str(Temp_para[temp])+ \\\n",
    "                    \"_iterations_\"+str(iterations)+ \\\n",
    "                    \"_3D_\"+str(vel_r2_3d)+ \\\n",
    "                    \"_2D_\"+str(vel_r2_2d)+\"_raw.pdf\"\n",
    "        file_save = os.path.join('./NMR_Figs/Fig7_Handwriting',new_filename)\n",
    "        plt.savefig(file_save)\n",
    "        plt.close(fig)\n",
    "\n",
    "        fig = plt.figure(figsize=(3, 3), dpi=250)\n",
    "        val_range = slice(5, 98)\n",
    "        for i in range(16):\n",
    "            i = i + 1\n",
    "            direction_trial = (discrete_index[:, 0]//22.5 == i)\n",
    "            trial_avg = embedding_2d[direction_trial, :].reshape(-1,dur,2).mean(axis=0)\n",
    "            plt.scatter(trial_avg[val_range, 0],trial_avg[val_range, 1],color=plt.cm.hsv(1 / 16 * i),\n",
    "                       edgecolors='none',linewidth=1,alpha=1,s=2)\n",
    "            plt.plot(trial_avg[val_range, 0],trial_avg[val_range, 1],color=plt.cm.hsv(1 / 16 * i),\n",
    "                    linewidth=1, alpha=0.75)\n",
    "            plt.axis('off')\n",
    "        new_filename = \"Temp_\"+str(Temp_para[temp])+ \\\n",
    "                    \"_iterations_\"+str(iterations)+ \\\n",
    "                    \"_3D_\"+str(vel_r2_3d)+ \\\n",
    "                    \"_2D_\"+str(vel_r2_2d)+\"_average.pdf\"\n",
    "        file_save = os.path.join('./NMR_Figs/Fig7_Handwriting',new_filename)\n",
    "        plt.savefig(file_save)\n",
    "        plt.close(fig)\n",
    "    except Exception as e:\n",
    "            print(' Temp=', str(Temp_para[temp]), ' fail')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33b499bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "iterations = 10 * 1000\n",
    "learning_rate = 0.001\n",
    "XY_scale = 3\n",
    "100% Data Temp= 0.045  r2-3D= 0.956653  r2-2D= 0.956553\n",
    "100% Data Temp= 0.045  r2-3D= 0.959142  r2-2D= 0.958717\n",
    "100% Data Temp= 0.045  r2-3D= 0.965923  r2-2D= 0.965404\n",
    "\n",
    "iterations = 5 * 1000\n",
    "learning_rate = 0.001\n",
    "XY_scale = 3\n",
    "100% Data Temp= 0.035  r2-3D= 0.942361  r2-2D= 0.941979\n",
    "100% Data Temp= 0.04  r2-3D= 0.937881  r2-2D= 0.93692\n",
    "100% Data Temp= 0.045  r2-3D= 0.945715  r2-2D= 0.944785\n",
    "\n",
    "iterations = 5 * 1000\n",
    "learning_rate = 0.001\n",
    "XY_scale = 3\n",
    "100% Data Temp= 0.02  r2-3D= 0.842585  r2-2D= 0.842327\n",
    "100% Data Temp= 0.03  r2-3D= 0.932551  r2-2D= 0.932234\n",
    "100% Data Temp= 0.04  r2-3D= 0.945591  r2-2D= 0.945422 ***\n",
    "100% Data Temp= 0.05  r2-3D= 0.938769  r2-2D= 0.938645\n",
    "100% Data Temp= 0.06  r2-3D= 0.94072  r2-2D= 0.940644\n",
    "\n",
    "iterations = 5 * 1000\n",
    "learning_rate = 0.001\n",
    "XY_scale = 4\n",
    "100% Data Temp= 0.02  r2-3D= 0.906801  r2-2D= 0.906323\n",
    "100% Data Temp= 0.03  r2-3D= 0.932493  r2-2D= 0.931185\n",
    "100% Data Temp= 0.04  r2-3D= 0.942825  r2-2D= 0.941007\n",
    "100% Data Temp= 0.05  r2-3D= 0.935364  r2-2D= 0.934572\n",
    "100% Data Temp= 0.06  r2-3D= 0.936747  r2-2D= 0.936447\n",
    "\n",
    "iterations = 5 * 1000\n",
    "learning_rate = 0.001\n",
    "XY_scale = 5\n",
    "100% Data Temp= 0.02  r2-3D= 0.93845  r2-2D= 0.938123\n",
    "100% Data Temp= 0.03  r2-3D= 0.930577  r2-2D= 0.928602\n",
    "100% Data Temp= 0.04  r2-3D= 0.89428  r2-2D= 0.891301\n",
    "100% Data Temp= 0.05  r2-3D= 0.935591  r2-2D= 0.935499\n",
    "\n",
    "iterations = 2 * 1000\n",
    "learning_rate = 0.001\n",
    "XY_scale = 6\n",
    "100% Data Temp= 0.02  r2-3D= 0.888519  r2-2D= 0.888298\n",
    "100% Data Temp= 0.03  r2-3D= 0.906486  r2-2D= 0.905676\n",
    "100% Data Temp= 0.04  r2-3D= 0.901956  r2-2D= 0.901919\n",
    "100% Data Temp= 0.05  r2-3D= 0.892884  r2-2D= 0.892418\n",
    "\n",
    "iterations = 5 * 1000\n",
    "learning_rate = 0.001\n",
    "XY_scale = 8\n",
    "100% Data Temp= 0.02  r2-3D= 0.934609  r2-2D= 0.934607\n",
    "100% Data Temp= 0.03  r2-3D= 0.931288  r2-2D= 0.931048"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  },
  "vscode": {
   "interpreter": {
    "hash": "dc327929684d2c13e929b2699e1b37518dbb61b921da51c352c926069002ee0e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
