{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expected braintreebank data at: /PATH_TO_BTBANK/braintreebank/braintreebank/\n",
      "Sampling rate: 2048 Hz\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "os.environ['ROOT_DIR_BRAINTREEBANK'] = '/PATH_TO_BTBANK/braintreebank/braintreebank/' # change this to the path to the braintreebank data\n",
    "\n",
    "import torch\n",
    "import neuroprobe.config as neuroprobe_config\n",
    "\n",
    "# Make sure the config ROOT_DIR is set correctly\n",
    "print(\"Expected braintreebank data at:\", neuroprobe_config.ROOT_DIR)\n",
    "print(\"Sampling rate:\", neuroprobe_config.SAMPLING_RATE, \"Hz\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checking for NaN electrode coordinates across all subjects (1-10):\n",
      "\n",
      "Subject 1: 1 electrodes with NaN coordinates:\n",
      "  - F3cId10\n",
      "\n",
      "Subject 2: All electrodes have valid coordinates\n",
      "\n",
      "Subject 3: 12 electrodes with NaN coordinates:\n",
      "  - F3c9\n",
      "  - F3c10\n",
      "  - T1aIc1\n",
      "  - T1aIc2\n",
      "  - P2a10\n",
      "  - O1aIb2\n",
      "  - O1aIb3\n",
      "  - O1aIb4\n",
      "  - O1aIb5\n",
      "  - O1aIb6\n",
      "  - O1aIb7\n",
      "  - O1aIb8\n",
      "\n",
      "Subject 4: 2 electrodes with NaN coordinates:\n",
      "  - LT1aIb10\n",
      "  - LF3bIa12\n",
      "\n",
      "Subject 5: All electrodes have valid coordinates\n",
      "\n",
      "Subject 6: All electrodes have valid coordinates\n",
      "\n",
      "Subject 7: 2 electrodes with NaN coordinates:\n",
      "  - LF3aOFa16\n",
      "  - LF1cCb12\n",
      "\n",
      "Subject 8: 2 electrodes with NaN coordinates:\n",
      "  - F2bCb6\n",
      "  - F2bCb14\n",
      "\n",
      "Subject 9: 3 electrodes with NaN coordinates:\n",
      "  - P2a6\n",
      "  - P2a7\n",
      "  - P2a8\n",
      "\n",
      "Subject 10: 2 electrodes with NaN coordinates:\n",
      "  - T1aIa4\n",
      "  - P2cCc5\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from neuroprobe.braintreebank_subject import BrainTreebankSubject\n",
    "\n",
    "# Check for NaN electrode coordinates across all subjects\n",
    "print(\"Checking for NaN electrode coordinates across all subjects (1-10):\\n\")\n",
    "\n",
    "nan_electrodes_all_list = {}\n",
    "\n",
    "for subject_id in range(1, 11):\n",
    "    try:\n",
    "        # Create subject without subsetting electrodes\n",
    "        temp_subject = BrainTreebankSubject(subject_id, allow_corrupted=False, cache=False, dtype=torch.float32, allow_missing_coordinates=True)\n",
    "        \n",
    "        # Get electrode coordinates\n",
    "        coords = temp_subject.get_electrode_coordinates()\n",
    "        \n",
    "        # Check for NaN values\n",
    "        nan_mask = torch.isnan(coords).any(dim=1)\n",
    "        nan_electrodes = [temp_subject.electrode_labels[i] for i in range(len(temp_subject.electrode_labels)) if nan_mask[i]]\n",
    "        nan_electrodes_all_list[\"sub_\"+str(subject_id)] = nan_electrodes\n",
    "        \n",
    "        if len(nan_electrodes) > 0:\n",
    "            print(f\"Subject {subject_id}: {len(nan_electrodes)} electrodes with NaN coordinates:\")\n",
    "            for electrode in nan_electrodes:\n",
    "                print(f\"  - {electrode}\")\n",
    "        else:\n",
    "            print(f\"Subject {subject_id}: All electrodes have valid coordinates\")\n",
    "            \n",
    "    except Exception as e:\n",
    "        print(f\"Subject {subject_id}: Error loading subject - {e}\")\n",
    "    \n",
    "    print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"sub_1\": [\"F3cId10\"], \"sub_2\": [], \"sub_3\": [\"F3c9\", \"F3c10\", \"T1aIc1\", \"T1aIc2\", \"P2a10\", \"O1aIb2\", \"O1aIb3\", \"O1aIb4\", \"O1aIb5\", \"O1aIb6\", \"O1aIb7\", \"O1aIb8\"], \"sub_4\": [\"LT1aIb10\", \"LF3bIa12\"], \"sub_5\": [], \"sub_6\": [], \"sub_7\": [\"LF3aOFa16\", \"LF1cCb12\"], \"sub_8\": [\"F2bCb6\", \"F2bCb14\"], \"sub_9\": [\"P2a6\", \"P2a7\", \"P2a8\"], \"sub_10\": [\"T1aIa4\", \"P2cCc5\"]}\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "print(json.dumps(nan_electrodes_all_list))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now the same lines, but with allow_missing_coordinates=False (which is the default)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checking for NaN electrode coordinates across all subjects (1-10):\n",
      "\n",
      "Subject 1: All electrodes have valid coordinates\n",
      "\n",
      "Subject 2: All electrodes have valid coordinates\n",
      "\n",
      "Subject 3: All electrodes have valid coordinates\n",
      "\n",
      "Subject 4: All electrodes have valid coordinates\n",
      "\n",
      "Subject 5: All electrodes have valid coordinates\n",
      "\n",
      "Subject 6: All electrodes have valid coordinates\n",
      "\n",
      "Subject 7: All electrodes have valid coordinates\n",
      "\n",
      "Subject 8: All electrodes have valid coordinates\n",
      "\n",
      "Subject 9: All electrodes have valid coordinates\n",
      "\n",
      "Subject 10: All electrodes have valid coordinates\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from neuroprobe.braintreebank_subject import BrainTreebankSubject\n",
    "\n",
    "# Check for NaN electrode coordinates across all subjects\n",
    "print(\"Checking for NaN electrode coordinates across all subjects (1-10):\\n\")\n",
    "\n",
    "nan_electrodes_all_list = {}\n",
    "\n",
    "for subject_id in range(1, 11):\n",
    "    try:\n",
    "        # Create subject without subsetting electrodes\n",
    "        temp_subject = BrainTreebankSubject(subject_id, allow_corrupted=False, cache=False, dtype=torch.float32, allow_missing_coordinates=False)\n",
    "        \n",
    "        # Get electrode coordinates\n",
    "        coords = temp_subject.get_electrode_coordinates()\n",
    "        \n",
    "        # Check for NaN values\n",
    "        nan_mask = torch.isnan(coords).any(dim=1)\n",
    "        nan_electrodes = [temp_subject.electrode_labels[i] for i in range(len(temp_subject.electrode_labels)) if nan_mask[i]]\n",
    "        nan_electrodes_all_list[\"sub_\"+str(subject_id)] = nan_electrodes\n",
    "        \n",
    "        if len(nan_electrodes) > 0:\n",
    "            print(f\"Subject {subject_id}: {len(nan_electrodes)} electrodes with NaN coordinates:\")\n",
    "            for electrode in nan_electrodes:\n",
    "                print(f\"  - {electrode}\")\n",
    "        else:\n",
    "            print(f\"Subject {subject_id}: All electrodes have valid coordinates\")\n",
    "            \n",
    "    except Exception as e:\n",
    "        print(f\"Subject {subject_id}: Error loading subject - {e}\")\n",
    "    \n",
    "    print()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is to double check the Neuroprobe Benchmark (LITE electrode subset) coordinates"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Checking for NaN electrode coordinates across all subjects (1-10):\n",
      "\n",
      "Subject 1: All electrodes have valid coordinates\n",
      "\n",
      "Subject 2: All electrodes have valid coordinates\n",
      "\n",
      "Subject 3: All electrodes have valid coordinates\n",
      "\n",
      "Subject 4: All electrodes have valid coordinates\n",
      "\n",
      "Subject 5: All electrodes have valid coordinates\n",
      "\n",
      "Subject 6: All electrodes have valid coordinates\n",
      "\n",
      "Subject 7: All electrodes have valid coordinates\n",
      "\n",
      "Subject 8: All electrodes have valid coordinates\n",
      "\n",
      "Subject 9: All electrodes have valid coordinates\n",
      "\n",
      "Subject 10: All electrodes have valid coordinates\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from neuroprobe.config import NEUROPROBE_LITE_ELECTRODES\n",
    "from neuroprobe.braintreebank_subject import BrainTreebankSubject\n",
    "\n",
    "# Check for NaN electrode coordinates across all subjects\n",
    "print(\"Checking for NaN electrode coordinates across all subjects (1-10):\\n\")\n",
    "\n",
    "nan_electrodes_all_list = {}\n",
    "\n",
    "for subject_id in range(1, 11):\n",
    "    try:\n",
    "        # Create subject without subsetting electrodes\n",
    "        temp_subject = BrainTreebankSubject(subject_id, allow_corrupted=False, cache=False, dtype=torch.float32, allow_missing_coordinates=True)\n",
    "        temp_subject.set_electrode_subset(NEUROPROBE_LITE_ELECTRODES[temp_subject.subject_identifier])\n",
    "        \n",
    "        # Get electrode coordinates\n",
    "        coords = temp_subject.get_electrode_coordinates()\n",
    "        \n",
    "        # Check for NaN values\n",
    "        nan_mask = torch.isnan(coords).any(dim=1)\n",
    "        nan_electrodes = [temp_subject.electrode_labels[i] for i in range(len(temp_subject.electrode_labels)) if nan_mask[i]]\n",
    "        nan_electrodes_all_list[\"sub_\"+str(subject_id)] = nan_electrodes\n",
    "        \n",
    "        if len(nan_electrodes) > 0:\n",
    "            print(f\"Subject {subject_id}: {len(nan_electrodes)} electrodes with NaN coordinates:\")\n",
    "            for electrode in nan_electrodes:\n",
    "                print(f\"  - {electrode}\")\n",
    "        else:\n",
    "            print(f\"Subject {subject_id}: All electrodes have valid coordinates\")\n",
    "            \n",
    "    except Exception as e:\n",
    "        print(f\"Subject {subject_id}: Error loading subject - {e}\")\n",
    "    \n",
    "    print()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "BFM Local (.venv)",
   "language": "python",
   "name": "bfm_local"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
