{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "110dcbef",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "100% Data Temp= 0.045  r2-3D= 0.965712  r2-2D= 0.954928\n",
      "100% Data Temp= 0.045  r2-3D= 0.959714  r2-2D= 0.9506\n",
      "100% Data Temp= 0.045  r2-3D= 0.969608  r2-2D= 0.959536\n",
      "100% Data Temp= 0.045  r2-3D= 0.971782  r2-2D= 0.967086\n",
      "100% Data Temp= 0.045  r2-3D= 0.968603  r2-2D= 0.955486\n"
     ]
    }
   ],
   "source": [
    "import sys\n",
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import joblib as jl\n",
    "import cebra.datasets\n",
    "from cebra import CEBRA\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.linear_model import LinearRegression,LogisticRegression\n",
    "import sklearn.metrics\n",
    "import time\n",
    "### Loaded data has already been convolved with 40 ms Gaussian kernal\n",
    "monkey_pos = cebra.datasets.init('area2-bump-pos-active') \n",
    "monkey_target = cebra.datasets.init('area2-bump-target-active')\n",
    "dur = 600\n",
    "iterations = 20 * 1000\n",
    "learning_rate = 0.001\n",
    "XY_scale = 50\n",
    "batch_size = 512\n",
    "output_dimension = 3\n",
    "Temp_para = [0.045, 0.045, 0.045, 0.045, 0.045]\n",
    "\n",
    "neural = monkey_pos.neural \n",
    "continuous_index_XY = monkey_pos.continuous_index.numpy() \n",
    "discrete_index_raw = monkey_target.discrete_index.numpy().reshape(-1, 1)\n",
    "discrete_index = discrete_index_raw*45\n",
    "# continuous_index = np.hstack((continuous_index_XY, discrete_index))\n",
    "L = neural.shape[0]\n",
    "N_values_hist = round(L/20)\n",
    "random_indices = np.random.choice(L, size=N_values_hist, replace=False)\n",
    "indices_X = continuous_index_XY[random_indices, 0]\n",
    "indices_Y = continuous_index_XY[random_indices, 1]\n",
    "index_diffs_X = np.abs(indices_X[:, None] - indices_X[None, :]) \n",
    "index_diffs_Y = np.abs(indices_Y[:, None] - indices_Y[None, :])\n",
    "l_dist_XY = index_diffs_X + index_diffs_Y\n",
    "l_dist_XY_1d = l_dist_XY[~np.eye(N_values_hist, dtype=bool)].flatten()\n",
    "# print('XY distance>>', np.median(l_dist_XY_1d))\n",
    "angles = np.squeeze(discrete_index[random_indices])\n",
    "angle_diffs = np.abs(angles[:, None] - angles[None, :]) \n",
    "circular_angle_diffs = np.minimum(angle_diffs, 360 - angle_diffs)\n",
    "l_dist_Z_1d = circular_angle_diffs[~np.eye(N_values_hist, dtype=bool)].flatten()\n",
    "# print('Z distance>>', np.median(l_dist_Z_1d))\n",
    "continuous_index_XY = continuous_index_XY*XY_scale\n",
    "#####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "XYZ_threshold = np.median(l_dist_XY_1d)*XY_scale + np.median(l_dist_Z_1d)\n",
    "#####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "\n",
    "for temp in range(len(Temp_para)):\n",
    "    continuous_index = np.column_stack((continuous_index_XY, discrete_index))\n",
    "    #####>>>>>>>>>>>>>>>>>>>>>>>>  Only for NMR\n",
    "    conr_2para = np.full((L,), 0.001)\n",
    "    conr_2para[:2] = [XYZ_threshold, Temp_para[temp]]\n",
    "    continuous_index = np.column_stack((continuous_index, conr_2para))\n",
    "    #####>>>>>>>>>>>>>>>>>>>>>>>> Only for NMR\n",
    "    try:\n",
    "        cebra_veldir_model = CEBRA(model_architecture='offset10-model',\n",
    "                                batch_size=batch_size,\n",
    "                                learning_rate = learning_rate,\n",
    "                                output_dimension = output_dimension ,\n",
    "                                max_iterations=iterations,\n",
    "                                distance='cosine',\n",
    "                                conditional='time_delta',\n",
    "                                device='cuda_if_available',\n",
    "                                verbose=False,\n",
    "                                time_offsets=10)\n",
    "        start_time = time.time()\n",
    "        cebra_veldir_model.fit(neural, continuous_index)\n",
    "        end_time = time.time()\n",
    "        execution_time = np.round((end_time - start_time), 2)\n",
    "        cebra_veldir = cebra_veldir_model.transform(neural) \n",
    "        train_loss = cebra_veldir_model.state_dict_['loss']\n",
    "\n",
    "        X = cebra_veldir\n",
    "        y = continuous_index[:,0:2]\n",
    "        reg = LinearRegression().fit(X, y)\n",
    "        pred_vel = reg.predict(X)\n",
    "        vel_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "        vel_r2_3d = np.round(vel_r2, 6)\n",
    "\n",
    "        pca = PCA(n_components=2)\n",
    "        embedding_2d = pca.fit_transform(cebra_veldir)\n",
    "        X = embedding_2d\n",
    "        reg = LinearRegression().fit(X, y)\n",
    "        pred_vel = reg.predict(X)\n",
    "        vel_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "        vel_r2_2d = np.round(vel_r2, 6)\n",
    "\n",
    "        print('100% Data Temp=', str(Temp_para[temp]), \\\n",
    "                      ' r2-3D=', str(vel_r2_3d), ' r2-2D=', str(vel_r2_2d))\n",
    "        new_filename = \"Temp_\"+str(Temp_para[temp])+ \\\n",
    "                    \"_iterations_\"+str(iterations)+ \\\n",
    "                    \"_3D_\"+str(vel_r2_3d)+ \\\n",
    "                    \"_2D_\"+str(vel_r2_2d)+\".npz\"\n",
    "        file_save = os.path.join('./NMR_Figs/Fig1_S1_Active',new_filename)\n",
    "        np.savez(file_save,\n",
    "                 execution_time = execution_time,\n",
    "                 temperature = Temp_para[temp],\n",
    "                 iterations = iterations, \n",
    "                 train_loss = train_loss,\n",
    "                 cebra_veldir=cebra_veldir,\n",
    "                 continuous_index=continuous_index,\n",
    "                 vel_r2_3d = vel_r2_3d,\n",
    "                 vel_r2_2d = vel_r2_2d)\n",
    "\n",
    "        fig = plt.figure(figsize=(3, 3), dpi=150)\n",
    "        change_points = np.where(discrete_index[:-1] != discrete_index[1:])[0] + 1\n",
    "        trial_boundaries = np.concatenate(([0], change_points, [len(discrete_index)]))\n",
    "        for i in range(len(trial_boundaries) - 1):\n",
    "            start_idx = trial_boundaries[i]\n",
    "            end_idx = trial_boundaries[i + 1]\n",
    "            trial_embeddings = embedding_2d[start_idx:end_idx, :]\n",
    "            trial_color = plt.cm.hsv(1 / 360 * discrete_index[start_idx])\n",
    "            plt.plot(trial_embeddings[:, 0], trial_embeddings[:, 1], color=trial_color, alpha=0.75, linewidth=0.5)\n",
    "            plt.axis('off')\n",
    "        plt.title('r2-2D = '+str(vel_r2_2d))\n",
    "        new_filename = \"Temp_\"+str(Temp_para[temp])+ \\\n",
    "                    \"_iterations_\"+str(iterations)+ \\\n",
    "                    \"_3D_\"+str(vel_r2_3d)+ \\\n",
    "                    \"_2D_\"+str(vel_r2_2d)+\"_raw.pdf\"\n",
    "        file_save = os.path.join('./NMR_Figs/Fig1_S1_Active',new_filename)\n",
    "        plt.savefig(file_save)\n",
    "        plt.close(fig)\n",
    "\n",
    "        fig = plt.figure(figsize=(3, 3), dpi=150)\n",
    "        for i in range(8):\n",
    "            direction_trial = (discrete_index_raw[:, 0] == i)\n",
    "            trial_avg = embedding_2d[direction_trial, :].reshape(-1,dur,2).mean(axis=0) ## nD output\n",
    "            plt.scatter(trial_avg[:, 0],trial_avg[:, 1],color=plt.cm.hsv(1 / 8 * i),\n",
    "                       edgecolors='none',linewidth=1,alpha=0.5,s=2)\n",
    "        plt.axis('off')\n",
    "        new_filename = \"Temp_\"+str(Temp_para[temp])+ \\\n",
    "                    \"_iterations_\"+str(iterations)+ \\\n",
    "                    \"_3D_\"+str(vel_r2_3d)+ \\\n",
    "                    \"_2D_\"+str(vel_r2_2d)+\"_average.pdf\"\n",
    "        file_save = os.path.join('./NMR_Figs/Fig1_S1_Active',new_filename)\n",
    "        plt.savefig(file_save)\n",
    "        plt.close(fig)\n",
    "    except Exception as e:\n",
    "            print(' Temp=', str(Temp_para[temp]), ' fail')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  },
  "vscode": {
   "interpreter": {
    "hash": "dc327929684d2c13e929b2699e1b37518dbb61b921da51c352c926069002ee0e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
