{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loading the package"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This notebook will walk you through a simple usecase of Neuroprobe and evaluation of the logistic regression baseline. It can be easily adapted to evaluate any foundation model of neural activity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expected braintreebank data at: /PATH_TO_BTBANK/braintreebank/braintreebank/\n",
      "Sampling rate: 2048 Hz\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "# NOTE: Change this to your own path, or define an environment variable elsewhere\n",
    "os.environ['ROOT_DIR_BRAINTREEBANK'] = '/PATH_TO_BTBANK/braintreebank/braintreebank/' \n",
    "\n",
    "import torch\n",
    "import neuroprobe.config as neuroprobe_config\n",
    "\n",
    "# Make sure the config ROOT_DIR is set correctly\n",
    "print(\"Expected braintreebank data at:\", neuroprobe_config.ROOT_DIR)\n",
    "print(\"Sampling rate:\", neuroprobe_config.SAMPLING_RATE, \"Hz\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The BrainTreebankSubject Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loaded subject 1\n",
      "Electrode labels (first 10): ['F3aOFa2', 'F3aOFa3', 'F3aOFa4', 'F3aOFa7', 'F3aOFa8', 'F3aOFa9', 'F3aOFa10', 'F3aOFa11', 'F3aOFa12', 'F3aOFa13']\n",
      "\n",
      "Electrode coordinates (MNI space) of the first 10 electrodes:\n",
      "tensor([[ 76.0103, -49.9502, -25.1740],\n",
      "        [ 75.4765, -50.8993, -22.9590],\n",
      "        [ 81.5899, -49.6018, -13.3198],\n",
      "        [ 81.3702, -47.3542,  -6.0947],\n",
      "        [ 83.1155, -43.3788,   0.5507],\n",
      "        [ 79.8622, -41.9135,   3.7532],\n",
      "        [ 79.1331, -41.2117,   4.8066],\n",
      "        [ 67.2942, -28.6666,  14.7228],\n",
      "        [ 68.9201, -28.2619,  15.2759],\n",
      "        [ 77.7627, -21.2303,  19.6270]])\n"
     ]
    }
   ],
   "source": [
    "from neuroprobe import BrainTreebankSubject\n",
    "\n",
    "subject_id = 1\n",
    "\n",
    "# use cache=True to load this trial's neural data into RAM, if you have enough memory!\n",
    "# It will make the loading process faster.\n",
    "subject = BrainTreebankSubject(subject_id, allow_corrupted=False, cache=True, dtype=torch.float32)\n",
    "print(\"Loaded subject\", subject_id)\n",
    "print(\"Electrode labels (first 10):\", subject.electrode_labels[:10]) # list of electrode labels\n",
    "\n",
    "print(\"\\nElectrode coordinates (MNI space) of the first 10 electrodes:\")\n",
    "print(subject.get_electrode_coordinates()[:10]) # L, P, I coordinates of the electrodes\n",
    "\n",
    "# Optionally, subset the electrodes to a specific set of electrodes. NOTE: you should not do this if you are using the neuroprobe as a standardized benchmark.\n",
    "# subject.set_electrode_subset(['F3aOFa2', 'F3aOFa3', 'F3aOFa4', 'F3aOFa7']) # if you change this line when using cache=True, you need to clear the cache after: subject.clear_neural_data_cache()\n",
    "# print(\"Electrode labels after subsetting:\", subject.electrode_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Loading the electrode data from a specific trial:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "All neural data shape:\n",
      "torch.Size([129, 21401009])\n",
      "\n",
      "First 50 samples of the first electrode (data is in microvolts):\n",
      "tensor([-24.7234, -25.7867, -25.5209, -28.9769, -33.2303, -34.5595, -38.8130,\n",
      "        -36.4204, -33.7620, -33.2303, -26.8501, -21.7991, -17.8115, -25.7867,\n",
      "        -27.3818, -17.8115, -13.2921,  -5.3169,   2.9243,   1.8609,   7.4436,\n",
      "         14.8872,  14.3555,  15.4189,  16.4822,  16.4822,  16.2164,  20.2040,\n",
      "         19.6724,  15.4189,  18.6090,  17.8115,   4.7852,  -4.7852, -16.2164,\n",
      "        -24.7234, -26.3184, -34.8254, -36.6863, -39.3447, -45.4591, -45.7249,\n",
      "        -48.3834, -48.3834, -45.1933, -40.6739, -40.1422, -49.7126, -53.7002,\n",
      "        -44.6616])\n"
     ]
    }
   ],
   "source": [
    "trial_id = 1\n",
    "\n",
    "subject.load_neural_data(trial_id)\n",
    "window_from = None # This is the index into the neural data array from where to start loading the data.\n",
    "window_to = None # if None, the whole trial will be loaded\n",
    "\n",
    "all_neural_data = subject.get_all_electrode_data(trial_id, window_from=window_from, window_to=window_to)\n",
    "\n",
    "print(\"All neural data shape:\")\n",
    "print(all_neural_data.shape) # (n_electrodes, n_samples). To get the data for a specific electrode, use subject.get_electrode_data(trial_id, electrode_label)\n",
    "\n",
    "print(\"\\nFirst 50 samples of the first electrode (data is in microvolts):\")\n",
    "print(all_neural_data[0, :50])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The BrainTreebankSubjectTrialBenchmarkDataset Class"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "NOTE: In the dataset below, there will be fewer electrodes than in the full subject data. This is because the Neuroprobe benchmark only uses a subset of the electrodes for standardized and quick benchmarking. The electrode labels below are subset to the `neuroprobe_config.NEUROPROBE_LITE_ELECTRODES` list.\n",
    "\n",
    "Accordingly, when using the `BrainTreebankSubjectTrialBenchmarkDataset` with `lite=True` (which is the default Neuroprobe benchmark option), make sure that you use the `dataset.electrode_labels` and `dataset.electrode_coordinates` properties, which give the electrode labels and the electrode coordinates in MNI space, respectively, in the exact order that the `dataset` will output the data tensors in."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Items in the dataset: 3500 \n",
      "\n",
      "The first item: (shape = torch.Size([120, 2048]))\n",
      "tensor([[ 32.9645,  27.3818,  20.4699,  ...,   2.1267,  -0.5317,   6.3802],\n",
      "        [ 67.2582,  62.2072,  56.0928,  ...,  44.9274,  42.2690,  49.1809],\n",
      "        [120.9584, 115.3757, 106.6029,  ...,  58.2195,  53.9661,  59.8146],\n",
      "        ...,\n",
      "        [ 15.1530,   7.4436,   3.7218,  ...,  -2.1267,  -5.8485,   2.9243],\n",
      "        [ 26.3184,  19.4065,  14.8872,  ...,  16.4822,  14.3555,  19.9382],\n",
      "        [-13.0263, -17.5456, -23.9258,  ...,   3.4560,   2.6584,  11.1654]])\n",
      "label = 1\n",
      "\n",
      "Electrode labels in the data above in the following order (120 electrodes): ['T1bIc1', 'T1bIc2', 'T1bIc3', 'T1bIc4', 'T1bIc5', 'T1bIc6', 'T1bIc7', 'T1bIc8', 'T1cIf10', 'T1cIf11', 'T1cIf12', 'T1cIf13', 'T1cIf14', 'T1cIf15', 'T1cIf16', 'T1aIb1', 'T1aIb2', 'T1aIb3', 'T1aIb4', 'T1aIb5', 'T1aIb6', 'T1aIb7', 'T1aIb8', 'T3aHb9', 'T3aHb10', 'T1cIf1', 'T1cIf2', 'T1cIf3', 'T1cIf4', 'T1cIf5', 'T1cIf6', 'T1cIf7', 'T1cIf8', 'T2bHa7', 'T2bHa8', 'T2bHa9', 'T2bHa10', 'T2bHa11', 'T2bHa12', 'T2bHa13', 'T2bHa14', 'T3bOT8', 'T3bOT9', 'T3bOT10', 'F3cId1', 'F3cId2', 'F3cId3', 'F3cId4', 'F3cId5', 'F3cId6', 'F3cId7', 'F3cId8', 'F3cId9', 'T2c4', 'T2c5', 'T2c6', 'T2c7', 'T2c8', 'F3bIaOFb1', 'F3bIaOFb2', 'F3bIaOFb3', 'F3bIaOFb4', 'F3bIaOFb5', 'F3bIaOFb6', 'F3bIaOFb7', 'F3bIaOFb8', 'F3bIaOFb9', 'F3bIaOFb10', 'F3bIaOFb11', 'F3bIaOFb12', 'F3bIaOFb13', 'F3bIaOFb14', 'F3bIaOFb15', 'F3bIaOFb16', 'T2d1', 'T2d2', 'T2d3', 'T2d4', 'T2d5', 'T2d6', 'F3aOFa7', 'F3aOFa8', 'F3aOFa9', 'F3aOFa10', 'F3aOFa11', 'F3aOFa12', 'F3aOFa13', 'F3aOFa14', 'F3aOFa15', 'F3aOFa16', 'F3aOFa2', 'F3aOFa3', 'F3aOFa4', 'T3bOT1', 'T3bOT2', 'T3bOT3', 'T3bOT4', 'T3bOT5', 'T3bOT6', 'T2bHa3', 'T2bHa4', 'T2bHa5', 'F3dIe1', 'F3dIe2', 'F3dIe3', 'F3dIe4', 'F3dIe5', 'F3dIe6', 'F3dIe7', 'F3dIe8', 'F3dIe9', 'F3dIe10', 'T2aA1', 'T2aA2', 'T2aA3', 'T2aA4', 'T2aA5', 'T2aA6', 'T2aA7', 'T2aA8']\n",
      "Electrode coordinates in the data above in the following order (120 electrodes): tensor([[ 3.5688e+01, -3.2862e+01,  1.4370e+01],\n",
      "        [ 3.3857e+01, -3.1721e+01,  1.4702e+01],\n",
      "        [ 2.1214e+01, -3.9701e+01,  2.7060e+01],\n",
      "        [ 1.8357e+01, -3.7100e+01,  2.8049e+01],\n",
      "        [ 1.6504e+01, -3.6198e+01,  2.9402e+01],\n",
      "        [ 9.2748e+00, -3.6987e+01,  3.6375e+01],\n",
      "        [ 1.1030e+01, -3.6524e+01,  3.5046e+01],\n",
      "        [ 5.6288e+00, -3.0839e+01,  3.5297e+01],\n",
      "        [-1.5418e+01, -2.5516e+01,  3.9162e+01],\n",
      "        [-1.4611e+01, -2.5276e+01,  3.9169e+01],\n",
      "        [-1.3832e+01, -2.5003e+01,  3.9149e+01],\n",
      "        [-1.4268e+01, -2.4497e+01,  3.9120e+01],\n",
      "        [-1.2808e+01, -2.3856e+01,  3.8992e+01],\n",
      "        [-1.6940e+01, -2.0067e+01,  3.8905e+01],\n",
      "        [-1.6138e+01, -1.9288e+01,  3.8801e+01],\n",
      "        [ 4.3981e+01, -5.2199e+01,  5.0444e+00],\n",
      "        [ 4.4119e+01, -4.9762e+01,  8.0147e+00],\n",
      "        [ 4.3524e+01, -4.4496e+01,  1.1088e+01],\n",
      "        [ 2.9769e+01, -4.9792e+01,  2.3592e+01],\n",
      "        [ 2.8218e+01, -4.8540e+01,  2.4914e+01],\n",
      "        [ 2.6664e+01, -4.6319e+01,  2.5555e+01],\n",
      "        [ 2.2501e+01, -4.6543e+01,  3.0526e+01],\n",
      "        [ 2.1519e+01, -4.5360e+01,  3.0920e+01],\n",
      "        [-1.9956e+01, -5.1310e+01,  3.8305e+01],\n",
      "        [-2.8050e+01, -4.4495e+01,  3.9209e+01],\n",
      "        [ 9.7728e+00, -1.7981e+01,  1.9746e+01],\n",
      "        [ 8.8209e+00, -1.8369e+01,  2.0713e+01],\n",
      "        [ 9.1638e+00, -1.8820e+01,  2.0726e+01],\n",
      "        [ 7.5831e+00, -1.8791e+01,  2.2079e+01],\n",
      "        [ 7.5900e+00, -1.9737e+01,  2.2818e+01],\n",
      "        [ 8.0179e+00, -2.0710e+01,  2.3214e+01],\n",
      "        [-1.6155e+01, -2.7714e+01,  3.9218e+01],\n",
      "        [-1.5899e+01, -2.7392e+01,  3.9218e+01],\n",
      "        [-5.2294e+00, -6.7200e+01,  1.7792e+01],\n",
      "        [ 2.6117e+00, -6.5365e+01,  3.1598e+01],\n",
      "        [ 3.3227e+00, -6.4008e+01,  3.3221e+01],\n",
      "        [ 5.7811e-01, -6.1947e+01,  3.4764e+01],\n",
      "        [-1.1291e+00, -6.1682e+01,  3.4763e+01],\n",
      "        [-5.5462e+00, -6.1975e+01,  3.3464e+01],\n",
      "        [-4.5639e+00, -5.8576e+01,  3.6672e+01],\n",
      "        [-4.7606e+00, -5.6935e+01,  3.7647e+01],\n",
      "        [-5.0306e+01, -4.5854e+01,  2.8588e+01],\n",
      "        [-5.2744e+01, -4.3943e+01,  2.9051e+01],\n",
      "        [-4.8433e+01, -3.7171e+01,  3.5046e+01],\n",
      "        [ 2.5868e+01, -1.8515e+01,  1.3860e+01],\n",
      "        [ 2.4671e+01, -1.9236e+01,  1.4057e+01],\n",
      "        [ 2.4778e+01, -2.0001e+01,  1.4121e+01],\n",
      "        [ 2.4855e+01, -7.3436e+00,  2.3682e+01],\n",
      "        [ 2.0462e+01, -2.3742e+00,  3.2575e+01],\n",
      "        [ 2.2050e+01, -2.9270e+00,  3.1934e+01],\n",
      "        [ 2.3861e+01, -2.2128e+00,  3.2741e+01],\n",
      "        [ 2.5427e+01,  2.4566e+00,  3.5344e+01],\n",
      "        [ 2.4621e+01,  2.9320e+00,  3.5772e+01],\n",
      "        [-3.7809e+01, -2.0112e+01,  3.5483e+01],\n",
      "        [-4.3560e+01, -1.9838e+01,  3.4603e+01],\n",
      "        [-4.2958e+01, -2.0811e+01,  3.4940e+01],\n",
      "        [-3.8015e+01, -2.6971e+01,  3.7116e+01],\n",
      "        [-4.0045e+01, -2.9522e+01,  3.7371e+01],\n",
      "        [ 6.1675e+01, -4.9603e+01, -2.7883e+01],\n",
      "        [ 6.0438e+01, -5.0313e+01, -2.6863e+01],\n",
      "        [ 5.7864e+01, -5.1754e+01, -2.3795e+01],\n",
      "        [ 5.6942e+01, -5.2614e+01, -1.9569e+01],\n",
      "        [ 6.2888e+01, -5.1523e+01, -1.0141e+01],\n",
      "        [ 6.4392e+01, -5.0827e+01, -8.1029e+00],\n",
      "        [ 7.0470e+01, -4.7858e+01, -2.8855e+00],\n",
      "        [ 5.3645e+01, -4.5186e+01,  5.8706e+00],\n",
      "        [ 5.9259e+01, -3.7708e+01,  9.3889e+00],\n",
      "        [ 4.6932e+01, -3.5280e+01,  1.1702e+01],\n",
      "        [ 4.4175e+01, -3.3280e+01,  1.2400e+01],\n",
      "        [ 5.0203e+01, -1.2528e+01,  2.1968e+01],\n",
      "        [ 5.5504e+01, -9.1545e+00,  2.5024e+01],\n",
      "        [ 4.3970e+01, -5.6686e+00,  2.7234e+01],\n",
      "        [ 4.3975e+01, -4.7420e+00,  2.7625e+01],\n",
      "        [ 4.3948e+01, -3.7537e+00,  2.7930e+01],\n",
      "        [-5.8000e+01, -5.6650e+00,  2.8337e+01],\n",
      "        [-5.7414e+01, -5.6700e+00,  2.8533e+01],\n",
      "        [-5.5427e+01, -3.4217e+00,  2.8904e+01],\n",
      "        [-5.3208e+01, -3.5895e+00,  2.9620e+01],\n",
      "        [-5.2123e+01, -3.6946e+00,  2.9962e+01],\n",
      "        [-5.0881e+01, -4.9191e+00,  3.0437e+01],\n",
      "        [ 8.1370e+01, -4.7354e+01, -6.0947e+00],\n",
      "        [ 8.3115e+01, -4.3379e+01,  5.5073e-01],\n",
      "        [ 7.9862e+01, -4.1914e+01,  3.7532e+00],\n",
      "        [ 7.9133e+01, -4.1212e+01,  4.8066e+00],\n",
      "        [ 6.7294e+01, -2.8667e+01,  1.4723e+01],\n",
      "        [ 6.8920e+01, -2.8262e+01,  1.5276e+01],\n",
      "        [ 7.7763e+01, -2.1230e+01,  1.9627e+01],\n",
      "        [ 7.5848e+01, -2.0794e+01,  2.0100e+01],\n",
      "        [ 7.3083e+01, -1.7761e+01,  2.1541e+01],\n",
      "        [ 7.2167e+01, -1.6950e+01,  2.1905e+01],\n",
      "        [ 7.6010e+01, -4.9950e+01, -2.5174e+01],\n",
      "        [ 7.5476e+01, -5.0899e+01, -2.2959e+01],\n",
      "        [ 8.1590e+01, -4.9602e+01, -1.3320e+01],\n",
      "        [-3.2831e+01, -5.3966e+01,  7.2357e+00],\n",
      "        [-3.4091e+01, -5.3775e+01,  7.7509e+00],\n",
      "        [-3.0457e+01, -5.5838e+01,  1.0462e+01],\n",
      "        [-3.0019e+01, -5.6338e+01,  1.1748e+01],\n",
      "        [-3.0071e+01, -5.6513e+01,  1.2553e+01],\n",
      "        [-4.9939e+01, -4.9410e+01,  2.2477e+01],\n",
      "        [ 9.2548e+00, -5.9425e+01, -6.0441e-01],\n",
      "        [ 1.0768e+01, -5.8612e+01, -1.3432e+00],\n",
      "        [ 1.0489e+01, -5.7181e+01, -2.2552e+00],\n",
      "        [ 1.2787e+01, -1.6326e+01,  1.7247e+01],\n",
      "        [ 1.3337e+01, -1.6393e+01,  1.7007e+01],\n",
      "        [ 1.5250e+01, -1.6987e+01,  1.6323e+01],\n",
      "        [ 1.5826e+01, -1.7065e+01,  1.6113e+01],\n",
      "        [ 1.6446e+01, -1.6144e+01,  1.5754e+01],\n",
      "        [ 7.3291e+00, -1.6885e+00,  2.9599e+01],\n",
      "        [ 7.9934e+00, -1.7086e+00,  2.9877e+01],\n",
      "        [ 1.0769e+01, -2.2268e+00,  3.0185e+01],\n",
      "        [ 9.2705e+00, -1.2622e+00,  3.1261e+01],\n",
      "        [ 8.5029e+00,  5.5362e-02,  3.3005e+01],\n",
      "        [ 2.1483e+01, -6.4978e+01,  2.2774e+00],\n",
      "        [ 2.2974e+01, -6.4407e+01,  2.0725e+00],\n",
      "        [ 2.5924e+01, -6.3072e+01,  1.6925e+00],\n",
      "        [ 2.7156e+01, -6.2745e+01,  1.7896e+00],\n",
      "        [ 3.0966e+01, -6.0486e+01,  1.2967e+00],\n",
      "        [ 3.2188e+01, -5.9466e+01,  8.3429e-01],\n",
      "        [ 3.7181e+01, -5.4710e+01,  1.4268e+01],\n",
      "        [ 1.8107e+01, -5.3561e+01,  3.5146e+01]])\n"
     ]
    }
   ],
   "source": [
    "from neuroprobe import BrainTreebankSubjectTrialBenchmarkDataset\n",
    "\n",
    "# Options for eval_name (from the Neuroprobe paper): neuroprobe_config.EVAL_NAMES\n",
    "#   frame_brightness, global_flow, local_flow, face_num, volume, pitch, delta_volume, \n",
    "#   speech, onset, gpt2_surprisal, word_length, word_gap, word_index, word_head_pos, word_part_speech, speaker\n",
    "eval_name = \"volume\"\n",
    "\n",
    "# if True, the dataset will output the indices of the samples in the neural data in a tuple: (index_from, index_to); \n",
    "# if False, the dataset will output the neural data directly\n",
    "output_indices = False\n",
    "\n",
    "start_neural_data_before_word_onset = 0 # the number of samples to start the neural data before each word onset\n",
    "end_neural_data_after_word_onset = neuroprobe_config.SAMPLING_RATE * 1 # the number of samples to end the neural data after each word onset -- here we use 1 second\n",
    "\n",
    "dataset = BrainTreebankSubjectTrialBenchmarkDataset(subject, trial_id, dtype=torch.float32, eval_name=eval_name, output_indices=output_indices, \n",
    "                                                    start_neural_data_before_word_onset=start_neural_data_before_word_onset, end_neural_data_after_word_onset=end_neural_data_after_word_onset,\n",
    "                                                    lite=True) # the default is Neuroprobe Lite for standardized and quick benchmarking. Feel free to set lite=false if trying to access the Full dataset.\n",
    "# P.S. Allow partial cache -- whether to allow partial caching of the neural data, if only part of it is needed for this particular dataset. Better set to False when doing multiple evals back to back, but better set to True when doing a single eval.\n",
    "\n",
    "data_electrode_labels = dataset.electrode_labels # NOTE: this is different from the subject.electrode_labels! Neuroprobe uses a special subset of electrodes in this exact order.\n",
    "data_electrode_coordinates = dataset.electrode_coordinates \n",
    "\n",
    "print(\"Items in the dataset:\", len(dataset), \"\\n\")\n",
    "print(f\"The first item: (shape = {dataset[0][0].shape})\", dataset[0][0], f\"label = {dataset[0][1]}\", sep=\"\\n\")\n",
    "print(\"\")\n",
    "print(f\"Electrode labels in the data above in the following order ({len(data_electrode_labels)} electrodes):\", data_electrode_labels)\n",
    "print(f\"Electrode coordinates in the data above in the following order ({len(data_electrode_coordinates)} electrodes):\", data_electrode_coordinates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'data': tensor([[ 32.9645,  27.3818,  20.4699,  ...,   2.1267,  -0.5317,   6.3802],\n",
      "        [ 67.2582,  62.2072,  56.0928,  ...,  44.9274,  42.2690,  49.1809],\n",
      "        [120.9584, 115.3757, 106.6029,  ...,  58.2195,  53.9661,  59.8146],\n",
      "        ...,\n",
      "        [ 15.1530,   7.4436,   3.7218,  ...,  -2.1267,  -5.8485,   2.9243],\n",
      "        [ 26.3184,  19.4065,  14.8872,  ...,  16.4822,  14.3555,  19.9382],\n",
      "        [-13.0263, -17.5456, -23.9258,  ...,   3.4560,   2.6584,  11.1654]]), 'label': 1, 'electrode_labels': ['T1bIc1', 'T1bIc2', 'T1bIc3', 'T1bIc4', 'T1bIc5', 'T1bIc6', 'T1bIc7', 'T1bIc8', 'T1cIf10', 'T1cIf11', 'T1cIf12', 'T1cIf13', 'T1cIf14', 'T1cIf15', 'T1cIf16', 'T1aIb1', 'T1aIb2', 'T1aIb3', 'T1aIb4', 'T1aIb5', 'T1aIb6', 'T1aIb7', 'T1aIb8', 'T3aHb9', 'T3aHb10', 'T1cIf1', 'T1cIf2', 'T1cIf3', 'T1cIf4', 'T1cIf5', 'T1cIf6', 'T1cIf7', 'T1cIf8', 'T2bHa7', 'T2bHa8', 'T2bHa9', 'T2bHa10', 'T2bHa11', 'T2bHa12', 'T2bHa13', 'T2bHa14', 'T3bOT8', 'T3bOT9', 'T3bOT10', 'F3cId1', 'F3cId2', 'F3cId3', 'F3cId4', 'F3cId5', 'F3cId6', 'F3cId7', 'F3cId8', 'F3cId9', 'T2c4', 'T2c5', 'T2c6', 'T2c7', 'T2c8', 'F3bIaOFb1', 'F3bIaOFb2', 'F3bIaOFb3', 'F3bIaOFb4', 'F3bIaOFb5', 'F3bIaOFb6', 'F3bIaOFb7', 'F3bIaOFb8', 'F3bIaOFb9', 'F3bIaOFb10', 'F3bIaOFb11', 'F3bIaOFb12', 'F3bIaOFb13', 'F3bIaOFb14', 'F3bIaOFb15', 'F3bIaOFb16', 'T2d1', 'T2d2', 'T2d3', 'T2d4', 'T2d5', 'T2d6', 'F3aOFa7', 'F3aOFa8', 'F3aOFa9', 'F3aOFa10', 'F3aOFa11', 'F3aOFa12', 'F3aOFa13', 'F3aOFa14', 'F3aOFa15', 'F3aOFa16', 'F3aOFa2', 'F3aOFa3', 'F3aOFa4', 'T3bOT1', 'T3bOT2', 'T3bOT3', 'T3bOT4', 'T3bOT5', 'T3bOT6', 'T2bHa3', 'T2bHa4', 'T2bHa5', 'F3dIe1', 'F3dIe2', 'F3dIe3', 'F3dIe4', 'F3dIe5', 'F3dIe6', 'F3dIe7', 'F3dIe8', 'F3dIe9', 'F3dIe10', 'T2aA1', 'T2aA2', 'T2aA3', 'T2aA4', 'T2aA5', 'T2aA6', 'T2aA7', 'T2aA8'], 'electrode_coordinates': tensor([[ 3.5688e+01, -3.2862e+01,  1.4370e+01],\n",
      "        [ 3.3857e+01, -3.1721e+01,  1.4702e+01],\n",
      "        [ 2.1214e+01, -3.9701e+01,  2.7060e+01],\n",
      "        [ 1.8357e+01, -3.7100e+01,  2.8049e+01],\n",
      "        [ 1.6504e+01, -3.6198e+01,  2.9402e+01],\n",
      "        [ 9.2748e+00, -3.6987e+01,  3.6375e+01],\n",
      "        [ 1.1030e+01, -3.6524e+01,  3.5046e+01],\n",
      "        [ 5.6288e+00, -3.0839e+01,  3.5297e+01],\n",
      "        [-1.5418e+01, -2.5516e+01,  3.9162e+01],\n",
      "        [-1.4611e+01, -2.5276e+01,  3.9169e+01],\n",
      "        [-1.3832e+01, -2.5003e+01,  3.9149e+01],\n",
      "        [-1.4268e+01, -2.4497e+01,  3.9120e+01],\n",
      "        [-1.2808e+01, -2.3856e+01,  3.8992e+01],\n",
      "        [-1.6940e+01, -2.0067e+01,  3.8905e+01],\n",
      "        [-1.6138e+01, -1.9288e+01,  3.8801e+01],\n",
      "        [ 4.3981e+01, -5.2199e+01,  5.0444e+00],\n",
      "        [ 4.4119e+01, -4.9762e+01,  8.0147e+00],\n",
      "        [ 4.3524e+01, -4.4496e+01,  1.1088e+01],\n",
      "        [ 2.9769e+01, -4.9792e+01,  2.3592e+01],\n",
      "        [ 2.8218e+01, -4.8540e+01,  2.4914e+01],\n",
      "        [ 2.6664e+01, -4.6319e+01,  2.5555e+01],\n",
      "        [ 2.2501e+01, -4.6543e+01,  3.0526e+01],\n",
      "        [ 2.1519e+01, -4.5360e+01,  3.0920e+01],\n",
      "        [-1.9956e+01, -5.1310e+01,  3.8305e+01],\n",
      "        [-2.8050e+01, -4.4495e+01,  3.9209e+01],\n",
      "        [ 9.7728e+00, -1.7981e+01,  1.9746e+01],\n",
      "        [ 8.8209e+00, -1.8369e+01,  2.0713e+01],\n",
      "        [ 9.1638e+00, -1.8820e+01,  2.0726e+01],\n",
      "        [ 7.5831e+00, -1.8791e+01,  2.2079e+01],\n",
      "        [ 7.5900e+00, -1.9737e+01,  2.2818e+01],\n",
      "        [ 8.0179e+00, -2.0710e+01,  2.3214e+01],\n",
      "        [-1.6155e+01, -2.7714e+01,  3.9218e+01],\n",
      "        [-1.5899e+01, -2.7392e+01,  3.9218e+01],\n",
      "        [-5.2294e+00, -6.7200e+01,  1.7792e+01],\n",
      "        [ 2.6117e+00, -6.5365e+01,  3.1598e+01],\n",
      "        [ 3.3227e+00, -6.4008e+01,  3.3221e+01],\n",
      "        [ 5.7811e-01, -6.1947e+01,  3.4764e+01],\n",
      "        [-1.1291e+00, -6.1682e+01,  3.4763e+01],\n",
      "        [-5.5462e+00, -6.1975e+01,  3.3464e+01],\n",
      "        [-4.5639e+00, -5.8576e+01,  3.6672e+01],\n",
      "        [-4.7606e+00, -5.6935e+01,  3.7647e+01],\n",
      "        [-5.0306e+01, -4.5854e+01,  2.8588e+01],\n",
      "        [-5.2744e+01, -4.3943e+01,  2.9051e+01],\n",
      "        [-4.8433e+01, -3.7171e+01,  3.5046e+01],\n",
      "        [ 2.5868e+01, -1.8515e+01,  1.3860e+01],\n",
      "        [ 2.4671e+01, -1.9236e+01,  1.4057e+01],\n",
      "        [ 2.4778e+01, -2.0001e+01,  1.4121e+01],\n",
      "        [ 2.4855e+01, -7.3436e+00,  2.3682e+01],\n",
      "        [ 2.0462e+01, -2.3742e+00,  3.2575e+01],\n",
      "        [ 2.2050e+01, -2.9270e+00,  3.1934e+01],\n",
      "        [ 2.3861e+01, -2.2128e+00,  3.2741e+01],\n",
      "        [ 2.5427e+01,  2.4566e+00,  3.5344e+01],\n",
      "        [ 2.4621e+01,  2.9320e+00,  3.5772e+01],\n",
      "        [-3.7809e+01, -2.0112e+01,  3.5483e+01],\n",
      "        [-4.3560e+01, -1.9838e+01,  3.4603e+01],\n",
      "        [-4.2958e+01, -2.0811e+01,  3.4940e+01],\n",
      "        [-3.8015e+01, -2.6971e+01,  3.7116e+01],\n",
      "        [-4.0045e+01, -2.9522e+01,  3.7371e+01],\n",
      "        [ 6.1675e+01, -4.9603e+01, -2.7883e+01],\n",
      "        [ 6.0438e+01, -5.0313e+01, -2.6863e+01],\n",
      "        [ 5.7864e+01, -5.1754e+01, -2.3795e+01],\n",
      "        [ 5.6942e+01, -5.2614e+01, -1.9569e+01],\n",
      "        [ 6.2888e+01, -5.1523e+01, -1.0141e+01],\n",
      "        [ 6.4392e+01, -5.0827e+01, -8.1029e+00],\n",
      "        [ 7.0470e+01, -4.7858e+01, -2.8855e+00],\n",
      "        [ 5.3645e+01, -4.5186e+01,  5.8706e+00],\n",
      "        [ 5.9259e+01, -3.7708e+01,  9.3889e+00],\n",
      "        [ 4.6932e+01, -3.5280e+01,  1.1702e+01],\n",
      "        [ 4.4175e+01, -3.3280e+01,  1.2400e+01],\n",
      "        [ 5.0203e+01, -1.2528e+01,  2.1968e+01],\n",
      "        [ 5.5504e+01, -9.1545e+00,  2.5024e+01],\n",
      "        [ 4.3970e+01, -5.6686e+00,  2.7234e+01],\n",
      "        [ 4.3975e+01, -4.7420e+00,  2.7625e+01],\n",
      "        [ 4.3948e+01, -3.7537e+00,  2.7930e+01],\n",
      "        [-5.8000e+01, -5.6650e+00,  2.8337e+01],\n",
      "        [-5.7414e+01, -5.6700e+00,  2.8533e+01],\n",
      "        [-5.5427e+01, -3.4217e+00,  2.8904e+01],\n",
      "        [-5.3208e+01, -3.5895e+00,  2.9620e+01],\n",
      "        [-5.2123e+01, -3.6946e+00,  2.9962e+01],\n",
      "        [-5.0881e+01, -4.9191e+00,  3.0437e+01],\n",
      "        [ 8.1370e+01, -4.7354e+01, -6.0947e+00],\n",
      "        [ 8.3115e+01, -4.3379e+01,  5.5073e-01],\n",
      "        [ 7.9862e+01, -4.1914e+01,  3.7532e+00],\n",
      "        [ 7.9133e+01, -4.1212e+01,  4.8066e+00],\n",
      "        [ 6.7294e+01, -2.8667e+01,  1.4723e+01],\n",
      "        [ 6.8920e+01, -2.8262e+01,  1.5276e+01],\n",
      "        [ 7.7763e+01, -2.1230e+01,  1.9627e+01],\n",
      "        [ 7.5848e+01, -2.0794e+01,  2.0100e+01],\n",
      "        [ 7.3083e+01, -1.7761e+01,  2.1541e+01],\n",
      "        [ 7.2167e+01, -1.6950e+01,  2.1905e+01],\n",
      "        [ 7.6010e+01, -4.9950e+01, -2.5174e+01],\n",
      "        [ 7.5476e+01, -5.0899e+01, -2.2959e+01],\n",
      "        [ 8.1590e+01, -4.9602e+01, -1.3320e+01],\n",
      "        [-3.2831e+01, -5.3966e+01,  7.2357e+00],\n",
      "        [-3.4091e+01, -5.3775e+01,  7.7509e+00],\n",
      "        [-3.0457e+01, -5.5838e+01,  1.0462e+01],\n",
      "        [-3.0019e+01, -5.6338e+01,  1.1748e+01],\n",
      "        [-3.0071e+01, -5.6513e+01,  1.2553e+01],\n",
      "        [-4.9939e+01, -4.9410e+01,  2.2477e+01],\n",
      "        [ 9.2548e+00, -5.9425e+01, -6.0441e-01],\n",
      "        [ 1.0768e+01, -5.8612e+01, -1.3432e+00],\n",
      "        [ 1.0489e+01, -5.7181e+01, -2.2552e+00],\n",
      "        [ 1.2787e+01, -1.6326e+01,  1.7247e+01],\n",
      "        [ 1.3337e+01, -1.6393e+01,  1.7007e+01],\n",
      "        [ 1.5250e+01, -1.6987e+01,  1.6323e+01],\n",
      "        [ 1.5826e+01, -1.7065e+01,  1.6113e+01],\n",
      "        [ 1.6446e+01, -1.6144e+01,  1.5754e+01],\n",
      "        [ 7.3291e+00, -1.6885e+00,  2.9599e+01],\n",
      "        [ 7.9934e+00, -1.7086e+00,  2.9877e+01],\n",
      "        [ 1.0769e+01, -2.2268e+00,  3.0185e+01],\n",
      "        [ 9.2705e+00, -1.2622e+00,  3.1261e+01],\n",
      "        [ 8.5029e+00,  5.5362e-02,  3.3005e+01],\n",
      "        [ 2.1483e+01, -6.4978e+01,  2.2774e+00],\n",
      "        [ 2.2974e+01, -6.4407e+01,  2.0725e+00],\n",
      "        [ 2.5924e+01, -6.3072e+01,  1.6925e+00],\n",
      "        [ 2.7156e+01, -6.2745e+01,  1.7896e+00],\n",
      "        [ 3.0966e+01, -6.0486e+01,  1.2967e+00],\n",
      "        [ 3.2188e+01, -5.9466e+01,  8.3429e-01],\n",
      "        [ 3.7181e+01, -5.4710e+01,  1.4268e+01],\n",
      "        [ 1.8107e+01, -5.3561e+01,  3.5146e+01]]), 'metadata': {'subject_identifier': 'btbank1', 'trial_id': 1, 'sampling_rate': 2048}}\n"
     ]
    }
   ],
   "source": [
    "# Optionally, you can request the output_dict=True to get the data as a dictionary with a bunch of metadata.\n",
    "dataset.output_dict = True\n",
    "print(dataset[0])\n",
    "\n",
    "dataset.output_dict = False # set it back"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "((443667, 445715), 1)\n"
     ]
    }
   ],
   "source": [
    "# Also, you can request only the indices into the neural data array, instead of the actual data.\n",
    "# NOTE: These are the indices into the data as in the raw h5 files in the braintreebank dataset.\n",
    "\n",
    "dataset.output_indices = True\n",
    "print(dataset[0]) # Data format: (index_from, index_to), label\n",
    "\n",
    "dataset.output_indices = False # set it back"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train/Test Splits\n",
    "\n",
    "In this example, we generate train/test splits for the WithinSession evaluation.\n",
    "\n",
    "All options: generate_splits_within_session, generate_splits_cross_session, generate_splits_cross_subject"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len(folds) = k_folds = 2\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'train_dataset': <torch.utils.data.dataset.Subset at 0x14c75cf21fa0>,\n",
       "  'test_dataset': <torch.utils.data.dataset.Subset at 0x14c75cf21250>},\n",
       " {'train_dataset': <torch.utils.data.dataset.Subset at 0x14c75cf21430>,\n",
       "  'test_dataset': <torch.utils.data.dataset.Subset at 0x14c75cf21400>}]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import neuroprobe.train_test_splits as neuroprobe_train_test_splits\n",
    "\n",
    "folds = neuroprobe_train_test_splits.generate_splits_within_session(subject, trial_id, eval_name, dtype=torch.float32, \n",
    "                                                                                # Put the dataset parameters here\n",
    "                                                                                output_indices=output_indices, start_neural_data_before_word_onset=start_neural_data_before_word_onset, end_neural_data_after_word_onset=end_neural_data_after_word_onset,\n",
    "                                                                                lite=True)\n",
    "print(\"len(folds) = k_folds =\", len(folds))\n",
    "folds"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Example Linear Regression on SS_SM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fold 1 of 2\n",
      "\t Train accuracy: 1.000 | Test accuracy: 0.598\n",
      "Fold 2 of 2\n",
      "\t Train accuracy: 1.000 | Test accuracy: 0.570\n"
     ]
    }
   ],
   "source": [
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import numpy as np\n",
    "\n",
    "for fold_idx, fold in enumerate(folds):\n",
    "    print(f\"Fold {fold_idx+1} of {len(folds)}\")\n",
    "    train_dataset = fold[\"train_dataset\"]\n",
    "    test_dataset = fold[\"test_dataset\"]\n",
    "\n",
    "    # Convert PyTorch dataset to numpy arrays for scikit-learn\n",
    "    X_train = np.array([item[0].flatten() for item in train_dataset])\n",
    "    y_train = np.array([item[1] for item in train_dataset])\n",
    "    X_test = np.array([item[0].flatten() for item in test_dataset])\n",
    "    y_test = np.array([item[1] for item in test_dataset])\n",
    "\n",
    "    # Standardize the data\n",
    "    scaler = StandardScaler()\n",
    "    X_train = scaler.fit_transform(X_train)\n",
    "    X_test = scaler.transform(X_test)\n",
    "\n",
    "    # Train logistic regression\n",
    "    clf = LogisticRegression(random_state=42, max_iter=1000, tol=1e-3)\n",
    "    clf.fit(X_train, y_train)\n",
    "\n",
    "    # Evaluate model\n",
    "    train_score = clf.score(X_train, y_train)\n",
    "    test_score = clf.score(X_test, y_test)\n",
    "    print(f\"\\t Train accuracy: {train_score:.3f} | Test accuracy: {test_score:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Neuroprobe",
   "language": "python",
   "name": ".venv"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
