{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "INCLUDED_CHANNELS = [\n",
    "    'EEG FP1',\n",
    "    'EEG FP2',\n",
    "    'EEG F3',\n",
    "    'EEG F4',\n",
    "    'EEG C3',\n",
    "    'EEG C4',\n",
    "    'EEG P3',\n",
    "    'EEG P4',\n",
    "    'EEG O1',\n",
    "    'EEG O2',\n",
    "    'EEG F7',\n",
    "    'EEG F8',\n",
    "    'EEG T3',\n",
    "    'EEG T4',\n",
    "    'EEG T5',\n",
    "    'EEG T6',\n",
    "    'EEG FZ',\n",
    "    'EEG CZ',\n",
    "    'EEG PZ'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# load distances between sensors according to 10-20 system\n",
    "dist_df = pd.read_csv('./distances_3d.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>from</th>\n",
       "      <th>to</th>\n",
       "      <th>distance</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>EEG FP1</td>\n",
       "      <td>EEG FP1</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>EEG FP1</td>\n",
       "      <td>EEG FP2</td>\n",
       "      <td>0.618000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>EEG FP1</td>\n",
       "      <td>EEG F3</td>\n",
       "      <td>0.618969</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>EEG FP1</td>\n",
       "      <td>EEG F4</td>\n",
       "      <td>1.030322</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>EEG FP1</td>\n",
       "      <td>EEG C3</td>\n",
       "      <td>1.250226</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>356</th>\n",
       "      <td>EEG PZ</td>\n",
       "      <td>EEG T5</td>\n",
       "      <td>1.081066</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>357</th>\n",
       "      <td>EEG PZ</td>\n",
       "      <td>EEG T6</td>\n",
       "      <td>1.081066</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>358</th>\n",
       "      <td>EEG PZ</td>\n",
       "      <td>EEG FZ</td>\n",
       "      <td>1.414200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>359</th>\n",
       "      <td>EEG PZ</td>\n",
       "      <td>EEG CZ</td>\n",
       "      <td>0.765363</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>360</th>\n",
       "      <td>EEG PZ</td>\n",
       "      <td>EEG PZ</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>361 rows × 3 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        from       to  distance\n",
       "0    EEG FP1  EEG FP1  0.000000\n",
       "1    EEG FP1  EEG FP2  0.618000\n",
       "2    EEG FP1   EEG F3  0.618969\n",
       "3    EEG FP1   EEG F4  1.030322\n",
       "4    EEG FP1   EEG C3  1.250226\n",
       "..       ...      ...       ...\n",
       "356   EEG PZ   EEG T5  1.081066\n",
       "357   EEG PZ   EEG T6  1.081066\n",
       "358   EEG PZ   EEG FZ  1.414200\n",
       "359   EEG PZ   EEG CZ  0.765363\n",
       "360   EEG PZ   EEG PZ  0.000000\n",
       "\n",
       "[361 rows x 3 columns]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dist_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_adjacency_matrix(distance_df, sensor_ids, dist_k=0.9):\n",
    "    \"\"\"\n",
    "    Args:\n",
    "        distance_df: data frame with three columns: [from, to, distance].\n",
    "        sensor_ids: list of sensor ids.\n",
    "        dist_k: threshold for graph sparsity\n",
    "    Returns:\n",
    "        adj_mx: adj\n",
    "    \"\"\"\n",
    "    num_sensors = len(sensor_ids)\n",
    "    dist_mx = np.zeros((num_sensors, num_sensors), dtype=np.float32)\n",
    "    dist_mx[:] = np.inf\n",
    "    # Builds sensor id to index map.\n",
    "    sensor_id_to_ind = {}\n",
    "    for i, sensor_id in enumerate(sensor_ids):\n",
    "        sensor_id_to_ind[sensor_id] = i\n",
    "\n",
    "    # Fills cells in the matrix with distances.\n",
    "    for row in distance_df.values:\n",
    "        if row[0] not in sensor_id_to_ind or row[1] not in sensor_id_to_ind:\n",
    "            continue\n",
    "        dist_mx[sensor_id_to_ind[row[0]], sensor_id_to_ind[row[1]]] = row[2]\n",
    "\n",
    "    # Calculates the standard deviation as theta.\n",
    "    distances = dist_mx[~np.isinf(dist_mx)].flatten()\n",
    "    std = distances.std()\n",
    "\n",
    "    # Sets entries that lower than a threshold, i.e., k, to zero for sparsity.    \n",
    "    adj_mx = np.exp(-np.square(dist_mx / std))\n",
    "    adj_mx[dist_mx > dist_k] = 0\n",
    "   \n",
    "    return adj_mx, sensor_id_to_ind"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "thresh = 0.9\n",
    "adj_mat, sensor_id_to_ind = get_adjacency_matrix(dist_df, INCLUDED_CHANNELS, dist_k=thresh)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'EEG FP1': 0,\n",
       " 'EEG FP2': 1,\n",
       " 'EEG F3': 2,\n",
       " 'EEG F4': 3,\n",
       " 'EEG C3': 4,\n",
       " 'EEG C4': 5,\n",
       " 'EEG P3': 6,\n",
       " 'EEG P4': 7,\n",
       " 'EEG O1': 8,\n",
       " 'EEG O2': 9,\n",
       " 'EEG F7': 10,\n",
       " 'EEG F8': 11,\n",
       " 'EEG T3': 12,\n",
       " 'EEG T4': 13,\n",
       " 'EEG T5': 14,\n",
       " 'EEG T6': 15,\n",
       " 'EEG FZ': 16,\n",
       " 'EEG CZ': 17,\n",
       " 'EEG PZ': 18}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sensor_id_to_ind"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(19, 19)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adj_mat.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.12 ('gnn_signals')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "vscode": {
   "interpreter": {
    "hash": "20d9da74d825dd6fc05ff42b529813048c7d512b4dca72bccb167942525bb9e7"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
