{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2ebd2668",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "label of date: ['20140217' '20140218' '20140303' '20140304' '20140306' '20140307'\n",
      " '20160929' '20161005' '20161006' '20161007' '20161014' '20161021']\n"
     ]
    }
   ],
   "source": [
    "import os, re\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.linear_model import LinearRegression,LogisticRegression\n",
    "import sklearn.metrics\n",
    "from scipy.linalg import orthogonal_procrustes\n",
    "from itertools import permutations, combinations\n",
    "from sklearn.decomposition import PCA\n",
    "dur = 40\n",
    "emb_dim = 3\n",
    "N_angles = 8\n",
    "\n",
    "# directory = './data/Fig2_SU/nmr/emb_M1_lr001_itr5k_temp07/'\n",
    "# file_save = './data/Fig3_SU_Decode/nmr_M1.npz'\n",
    "# name_range = slice(33, 41)\n",
    "\n",
    "# directory = './data/Fig2_SU/ceb/emb_M1_10k/'\n",
    "# file_save = './data/Fig3_SU_Decode/ceb_M1.npz'\n",
    "# name_range = slice(33, 41)\n",
    "\n",
    "# directory = './data/Fig2_SU/piv/emb_M1_run3/'\n",
    "# file_save = './data/Fig3_SU_Decode/piv_M1.npz'\n",
    "# name_range = slice(38, 46)\n",
    "\n",
    "# directory = './data/Fig2_SU/nmr/emb_PMd/'\n",
    "# file_save = './data/Fig3_SU_Decode/nmr_PMd.npz'\n",
    "# name_range = slice(34, 42)\n",
    "\n",
    "# directory = './data/Fig2_SU/ceb/emb_PMd/'\n",
    "# file_save = './data/Fig3_SU_Decode/ceb_PMd.npz'\n",
    "# name_range = slice(34, 42)\n",
    "\n",
    "directory = './data/Fig2_SU/piv/emb_PMd_itr60/'\n",
    "file_save = './data/Fig3_SU_Decode/piv_PMd.npz'\n",
    "name_range = slice(40, 48)\n",
    "\n",
    "def get_best_R(R_all, emb_A, emb_A_8angle_align):\n",
    "    determinants = [np.linalg.det(R_all[:, :, i]) for i in range(R_all.shape[2])]\n",
    "    positive_dets = [det for det in determinants if det >= 0]\n",
    "    negative_dets = [det for det in determinants if det < 0]\n",
    "\n",
    "    if len(positive_dets)>0:\n",
    "        target_dets = positive_dets\n",
    "        differences = [abs(abs(det) - 1) for det in target_dets]\n",
    "        min_index = np.argmin(differences)\n",
    "        best_R_index_p = determinants.index(positive_dets[min_index])\n",
    "        best_R_p = R_all[:, :, best_R_index_p]\n",
    "        emb_A_whole_align_p = np.matmul(emb_A, best_R_p)\n",
    "        align_diff_p = np.sum(abs(emb_A_whole_align_p-emb_A_8angle_align))\n",
    "        ## print('diff positive detR=', align_diff_p)\n",
    "    elif len(positive_dets) == 0:\n",
    "        align_diff_p = 5000000 ### arbitory value\n",
    "        \n",
    "    if len(negative_dets)>0:\n",
    "        target_dets = negative_dets\n",
    "        differences = [abs(abs(det) - 1) for det in target_dets]\n",
    "        min_index = np.argmin(differences)\n",
    "        best_R_index_n = determinants.index(negative_dets[min_index])\n",
    "        best_R_n = R_all[:, :, best_R_index_n]\n",
    "        emb_A_whole_align_n = np.matmul(emb_A, best_R_n)\n",
    "        align_diff_n = np.sum(abs(emb_A_whole_align_n-emb_A_8angle_align))\n",
    "        ## print('diff negative detR=', align_diff_n)\n",
    "    elif len(negative_dets) == 0:\n",
    "        align_diff_n = 5000000\n",
    "        \n",
    "    if align_diff_p<align_diff_n:\n",
    "        best_R = best_R_p\n",
    "        ## print('Using positive R')\n",
    "    elif align_diff_p>align_diff_n:\n",
    "        best_R = best_R_n\n",
    "        ## print('Using negative R')\n",
    "    return best_R\n",
    "\n",
    "\n",
    "def cross_decode(file_path1, file_path2):\n",
    "    Monkey_A = np.load(file_path1)\n",
    "    XYTarget_A = np.concatenate((Monkey_A['continuous_index_train'][:, :3], \\\n",
    "                                 Monkey_A['continuous_index_test']), axis=0)\n",
    "    # print('XYTarget_A>>', XYTarget_A.shape)\n",
    "    emb_A = np.concatenate((Monkey_A['cebra_veldir_train'], Monkey_A['cebra_veldir_test']), axis=0)\n",
    "    if np.max(XYTarget_A[:, 2])>10: ### angles in 0-45-90-...315degrees\n",
    "        XYTarget_A[:, 2] = XYTarget_A[:, 2]/45\n",
    "        \n",
    "    Monkey_B = np.load(file_path2)\n",
    "    XYTarget_B = np.concatenate((Monkey_B['continuous_index_train'][:, :3], \\\n",
    "                                 Monkey_B['continuous_index_test']), axis=0)\n",
    "    # print('XYTarget_B>>', XYTarget_B.shape)\n",
    "    emb_B = np.concatenate((Monkey_B['cebra_veldir_train'], Monkey_B['cebra_veldir_test']), axis=0)\n",
    "    if np.max(XYTarget_B[:, 2])>10:\n",
    "        XYTarget_B[:, 2] = XYTarget_B[:, 2]/45\n",
    "    \n",
    "    train_trial_A = int(Monkey_A['continuous_index_train'][:, :3].shape[0]/dur)\n",
    "    test_trial_A = int(Monkey_A['continuous_index_test'].shape[0]/dur)\n",
    "    train_trial_B = int(Monkey_B['continuous_index_train'][:, :3].shape[0]/dur)\n",
    "    test_trial_B = int(Monkey_B['continuous_index_test'].shape[0]/dur)\n",
    "    \n",
    "    R_all = np.zeros((emb_dim, emb_dim, N_angles))\n",
    "    for a in range(N_angles):\n",
    "        direction_trial = (XYTarget_A[:, 2] == a)\n",
    "        trial_avg_A = emb_A[direction_trial, :].reshape(-1,dur,emb_dim).mean(axis=0)\n",
    "        direction_trial = (XYTarget_B[:, 2] == a)\n",
    "        trial_avg_B = emb_B[direction_trial, :].reshape(-1,dur,emb_dim).mean(axis=0)\n",
    "        R, sca = orthogonal_procrustes(trial_avg_A, trial_avg_B) ### both are (dur, 3emb-dim)\n",
    "        R_all[:,:, a] = R\n",
    "        det_R = np.linalg.det(R)\n",
    "    trial_arrays = []\n",
    "    for i in range(N_angles):\n",
    "        direction_trial = (XYTarget_A[:, 2] == i)\n",
    "        trial_A = emb_A[direction_trial, :].reshape(-1,dur,emb_dim)\n",
    "        trial_A = np.matmul(trial_A, R_all[:,:,i])\n",
    "        trial_arrays.append((direction_trial, trial_A))\n",
    "    emb_A_8angle_align = np.empty_like(emb_A)\n",
    "    for mask, trial_data in trial_arrays: ### loop-through 8 times=angles\n",
    "        flat_data = trial_data.reshape(-1, emb_dim) ### (n-trials*dur, 3emb-dim)\n",
    "        emb_A_8angle_align[mask, :] = flat_data \n",
    "     \n",
    "    emb_A_whole_align = np.matmul(emb_A, get_best_R(R_all, emb_A, emb_A_8angle_align))\n",
    "    \n",
    "    continuous_index_train = XYTarget_A[:train_trial_A*dur, :]\n",
    "    cebra_veldir_train = emb_A_whole_align[:train_trial_A*dur, :] ####***** three choices here *****####\n",
    "    continuous_index_test_B = XYTarget_B[-test_trial_B*dur:, :]\n",
    "    cebra_veldir_test_B = emb_B[-test_trial_B*dur:, :]\n",
    "    \n",
    "    X = cebra_veldir_train\n",
    "    y = continuous_index_train[:, 0:2]\n",
    "    reg_3d = LinearRegression().fit(X, y)       #### 1st fit ####\n",
    "    pred_vel = reg_3d.predict(X)\n",
    "    vel_train_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "    pca = PCA(n_components=2)\n",
    "    pca_2d = pca.fit(X)                         #### 2nd fit ####\n",
    "    X_2d = pca_2d.transform(X)\n",
    "    reg_2d = LinearRegression().fit(X_2d, y)    #### 3rd fit ####\n",
    "    \n",
    "    ###******** this part will use previous trained \"reg & LogisticReg\" ###********\n",
    "    ###******** this part will use previous trained \"reg & LogisticReg\" ###********\n",
    "    \n",
    "    X = cebra_veldir_test_B\n",
    "    y = continuous_index_test_B[:, 0:2]\n",
    "    pred_vel = reg_3d.predict(X)\n",
    "    vel_test_r2_3d = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "    X_2d = pca_2d.transform(X)\n",
    "    pred_vel = reg_2d.predict(X_2d)\n",
    "    vel_test_r2_2d = sklearn.metrics.r2_score(y, pred_vel)\n",
    "    # print(\"Cross vel 2d >>\", np.round(vel_test_r2_2d, 4))\n",
    "    return vel_test_r2_3d, vel_test_r2_2d\n",
    "\n",
    "def self_decode(file_path1):\n",
    "    Monkey_A = np.load(file_path1)\n",
    "    X = Monkey_A['cebra_veldir_train']\n",
    "    y = Monkey_A['continuous_index_train'][:,0:2]\n",
    "    # print('Self y>>', y.shape)\n",
    "    reg_3d = LinearRegression().fit(X, y)       #### 1st fit ####\n",
    "    pred_vel = reg_3d.predict(X)\n",
    "    vel_train_r2 = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "    pca = PCA(n_components=2)\n",
    "    pca_2d = pca.fit(X)                         #### 2nd fit ####\n",
    "    X_2d = pca_2d.transform(X)\n",
    "    reg_2d = LinearRegression().fit(X_2d, y)    #### 3rd fit ####\n",
    "   \n",
    "    ###************* use previous trained \"reg_3d & pca_2d & reg_2d\" ###***************\n",
    "    ###************* use previous trained \"reg_3d & pca_2d & reg_2d\" ###***************\n",
    "    X = Monkey_A['cebra_veldir_test']\n",
    "    y = Monkey_A['continuous_index_test'][:,0:2]\n",
    "    pred_vel = reg_3d.predict(X)\n",
    "    vel_test_r2_3d = sklearn.metrics.r2_score(y, pred_vel)\n",
    "\n",
    "    X_2d = pca_2d.transform(X)\n",
    "    pred_vel = reg_2d.predict(X_2d)\n",
    "    vel_test_r2_2d = sklearn.metrics.r2_score(y, pred_vel)\n",
    "    # print(\"Self vel 2d >>\", np.round(vel_test_r2_2d, 4))\n",
    "    return vel_test_r2_3d, vel_test_r2_2d\n",
    "            \n",
    "files = [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]\n",
    "n = len(files)\n",
    "vel_R_3D = np.zeros((n, n))\n",
    "vel_R_2D = np.zeros((n, n))\n",
    "date_subjects = []\n",
    "n_compare = 0\n",
    "\n",
    "def list_and_sort_files(directory):\n",
    "    files = [os.path.join(directory, f) for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]\n",
    "    def extract_date(filename):\n",
    "        match = re.search(r'(\\d{8})', os.path.basename(filename))\n",
    "        date = match.group(0) if match else '000000'  # Default to '000000' if no date is found\n",
    "        return int(date) if len(date) == 6 else int(date)\n",
    "    sorted_files = sorted(files, key=extract_date)\n",
    "    return sorted_files\n",
    "sorted_files=list_and_sort_files(directory)\n",
    "\n",
    "for i, file1 in enumerate(sorted_files):\n",
    "    # print(\"Reading file:\", file1[30:46])\n",
    "    for j, file2 in enumerate(sorted_files):\n",
    "        if i != j:    ### with-others\n",
    "            vel_test_3d, vel_test_2d = cross_decode(file1, file2)\n",
    "            # print('#'+str(n_compare+1)+' cross compare')\n",
    "        elif i == j:  ### with-itself\n",
    "            vel_test_3d, vel_test_2d = self_decode(file1)\n",
    "            # print('#'+str(n_compare+1)+' self compare')\n",
    "        vel_R_3D[i, j] = vel_test_3d\n",
    "        vel_R_2D[i, j] = vel_test_2d\n",
    "\n",
    "        if \"M1PMd\" in directory:\n",
    "            date = file1[-29:-23]\n",
    "            suffix = file1[-7:-5]\n",
    "            date_subjects.append(f\"{date}{suffix}\")  \n",
    "        elif \"M1PMd\" not in directory:\n",
    "            # print(file1[name_range])\n",
    "            date_subjects.append(file1[name_range])\n",
    "        n_compare = n_compare+1\n",
    "print('label of date:', np.unique(date_subjects))\n",
    "\n",
    "np.savez(file_save, date_subjects = date_subjects, vel_R_3D=vel_R_3D, vel_R_2D=vel_R_2D)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "28a597d3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:cebra] *",
   "language": "python",
   "name": "conda-env-cebra-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
