{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "c84d239c-e0fa-4745-9b32-aba6c455cfd8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "import glob\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import csv\n",
    "import random\n",
    "import h5py\n",
    "\n",
    "from torchmetrics import AUROC\n",
    "\n",
    "from torch_geometric.nn import GraphConv \n",
    "\n",
    "import torch.nn.functional as F\n",
    "import scipy.io as sio\n",
    "\n",
    "from torch_geometric.nn import GCNConv , ChebConv\n",
    "from torch_geometric.nn import global_mean_pool\n",
    "from torch_geometric.nn import global_max_pool\n",
    "from torch_geometric.data import Data\n",
    "from torch_geometric.loader import DataLoader\n",
    "\n",
    "import torch.optim as optim\n",
    "from torch.optim.lr_scheduler import CosineAnnealingLR\n",
    "from torch.nn import Linear\n",
    "\n",
    "import wandb\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import time\n",
    "\n",
    "import lightning as L\n",
    "\n",
    "\n",
    "\n",
    "import math\n",
    "import torch\n",
    "from torch_geometric.utils import to_dense_adj, dense_to_sparse\n",
    "from torch_geometric.nn.conv import MessagePassing\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d7acb226-3d39-4c38-93e2-7826fd1cb343",
   "metadata": {},
   "source": [
    "For processing data we used the original DCRNN code at: https://github.com/tsy935/eeg-gnn-ssl/blob/main/utils.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "729567e4-1d2f-4784-a104-013b88389421",
   "metadata": {},
   "outputs": [],
   "source": [
    "adj = sio.loadmat('adj_mat')['adj_mat']\n",
    "ind = (adj != 0) & (adj != 1)\n",
    "edge_index = np.argwhere(ind == True).T \n",
    "\n",
    "edge_weight = np.zeros((1,edge_index.shape[1]))\n",
    "\n",
    "for i , e in enumerate(edge_index.T):  \n",
    "    \n",
    "    edge_weight[0,i] = adj [e[0] , e[1]]\n",
    "        \n",
    "\n",
    "edge_index =torch.tensor(edge_index )\n",
    "edge_weight = torch.tensor(edge_weight)\n",
    "\n",
    "pos = sio.loadmat('position.mat')['pos'].T\n",
    "# A = adj\n",
    "colors = ['w' , '#F601FF'] # first color is black, last is red\n",
    "cm = LinearSegmentedColormap.from_list(\n",
    "        \"Custom\", colors, N=200)\n",
    "\n",
    "# adj.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ac6c4ad8-f46a-4f56-b44b-9323733cf722",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(38497,)\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.data import Dataset\n",
    "from torch_geometric.loader import DataLoader\n",
    "\n",
    "\n",
    "# This section is only for run if you have data pre processed\n",
    "\n",
    "pos = sio.loadmat('position.mat')['pos'].T\n",
    "\n",
    "with h5py.File( 'clip_data' +'.h5' , \"r\") as f:\n",
    "    \n",
    "    EEG = np.array(f[list(f.keys())[0]])\n",
    "\n",
    "        \n",
    "with h5py.File(  'label' +'.h5', \"r\") as f: \n",
    "    \n",
    "    Label = np.array(f[list(f.keys())[0]])\n",
    "    \n",
    "    \n",
    "print(Label.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "24713ce1-5bc4-4a76-a02e-7a5e39114b3d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████| 38497/38497 [00:34<00:00, 1115.39it/s]\n"
     ]
    }
   ],
   "source": [
    "dataset = []\n",
    "A = torch.tensor(adj)\n",
    "batch = np.ones((1,19))\n",
    "from scipy.fft import fft\n",
    "is_fft = True\n",
    "\n",
    "for idx in tqdm(range(EEG.shape[0])):\n",
    "    \n",
    "    eeg_clip = EEG[idx,:,:,:]\n",
    "    \n",
    "    if is_fft:\n",
    "        eeg_clip = np.log(np.abs( fft(eeg_clip,axis=2)[:,:,0:100]) +1e-30) # Just real part might also be useful\n",
    "        \n",
    "    label = Label[idx]\n",
    "    \n",
    "    dataset.append( (  torch.tensor(eeg_clip).transpose(1,0) , \n",
    "                      torch.tensor((label)  , dtype=torch.long) ,  torch.tensor(pos)  )  \n",
    ")\n",
    "    \n",
    "    \n",
    "del Label , EEG  # Added new\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b4c40d44-35ea-48d2-adeb-f3c56e4e0a3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_idx = [0 , int(0.75*len(dataset))]\n",
    "test_idx = [int(0.75*len(dataset)) , len(dataset)]\n",
    "\n",
    "train_dataloader = DataLoader(dataset[train_idx[0]:train_idx[1]] , batch_size = 128   )\n",
    "test_dataloader = DataLoader(dataset[test_idx[0]:test_idx[1]] , batch_size = 16   )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "26b5f54b-fe66-4d04-825c-93362303c1de",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TimePositionalEncoding(torch.nn.Module):\n",
    "\n",
    "    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):\n",
    "        super().__init__()\n",
    "        self.dropout = torch.nn.Dropout(p=dropout)\n",
    "        \n",
    "        position = torch.arange(max_len).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))\n",
    "        pe = torch.zeros(max_len, 1, 1, d_model)\n",
    "        pe[:, 0,0, 0::2] = torch.sin(position * div_term)\n",
    "        pe[:, 0,0, 1::2] = torch.cos(position * div_term)\n",
    "        self.register_buffer('pe', pe)\n",
    "\n",
    "    def forward(self, x) :\n",
    "        \"\"\"\n",
    "        Arguments:\n",
    "            x: Tensor, shape ``[ batch_size, nodes , seq_len, embedding_dim]``\n",
    "        \"\"\"\n",
    "        x = x + self.pe[:x.size(0)]\n",
    "        return self.dropout(x) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "955ddd18-c0c9-49f3-b55b-ed734bb1fcb2",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransGraph (L.LightningModule):\n",
    "\n",
    "    def __init__(self , d_model: int, num_heads, dropout = 0.1 ):\n",
    "        \n",
    "        super().__init__()\n",
    "\n",
    "        torch.set_float32_matmul_precision('high')\n",
    "        \n",
    "        self.num_heads = num_heads\n",
    "        self.d_model = d_model\n",
    "        \n",
    "        self.dropout = torch.nn.Dropout(p = dropout)\n",
    "      \n",
    "        \n",
    "\n",
    "        self.emb = torch.nn.Linear(100,self.d_model)\n",
    "        self.time_enc = TimePositionalEncoding(self.d_model)\n",
    "        self.node_enc = torch.nn.Linear(2,self.d_model)\n",
    "        \n",
    "               \n",
    "        self.qkv1 = torch.nn.Linear(self.d_model , 3*self.d_model )\n",
    "        self.dropout = torch.nn.Dropout(p=dropout)\n",
    "        self.Matt1 = torch.nn.MultiheadAttention(self.d_model , self.num_heads , batch_first=True )\n",
    "        self.fc11 = torch.nn.Linear(self.d_model , self.d_model )\n",
    "        self.fc12 = torch.nn.Linear(self.d_model , self.d_model )\n",
    "\n",
    "        \n",
    "        self.qkv2 = torch.nn.Linear(self.d_model , 3*self.d_model )\n",
    "        self.Matt2 = torch.nn.MultiheadAttention(self.d_model , self.num_heads , batch_first=True )\n",
    "        self.fc2 = torch.nn.Linear(self.d_model , self.d_model )\n",
    "\n",
    "        self.proj = torch.nn.Linear(self.d_model , 1 )\n",
    "\n",
    "        # self.node_time_proj = torch.nn.Linear(19*12 , 1 )\n",
    "\n",
    "\n",
    "    def forward(self, data , pos ) :\n",
    "        \n",
    "        x = data.float()\n",
    "        \n",
    "        _ , num_ch , time_step , num_features = x.size()\n",
    "                \n",
    "        out =  self.time_enc( self.emb(x) )  + self.node_enc( pos.float() ).unsqueeze(2).repeat(1, 1,time_step ,1 )\n",
    "        \n",
    "        out = out.transpose(1,2).reshape(-1,num_ch* time_step,self.d_model)\n",
    "        \n",
    "        qkv = self.qkv1(out)\n",
    "        q,k,v = torch.split(qkv , self.d_model , dim=-1)\n",
    "\n",
    "        \n",
    "        attn_output , att_weight = self.Matt1(q , k , v   )\n",
    "        out = out + attn_output\n",
    "        out = out + self.fc12( F.relu(self.fc11(out)) )\n",
    "        out = self.dropout (out)\n",
    "\n",
    "\n",
    "        qkv = self.qkv2(out)\n",
    "        q,k,v = torch.split(qkv , self.d_model , dim=-1)\n",
    "        \n",
    "        attn_output , att_weight = self.Matt2(q , k , v  )\n",
    "        out = out + self.fc2( (attn_output) )\n",
    "        \n",
    "        out = self.dropout(F.relu( out) )\n",
    "\n",
    "        \n",
    "       \n",
    "\n",
    "        out = self.proj(out[:, -torch.arange(19) - 1,:])\n",
    "\n",
    "\n",
    "        return self.dropout( torch.mean(out, dim=1) ) , att_weight\n",
    "\n",
    "    def training_step(self, data ):\n",
    "\n",
    "        inputs, labels , pos  = data\n",
    "\n",
    "        out ,_  = self(inputs , pos )\n",
    "        \n",
    "        loss = F.binary_cross_entropy_with_logits (out.reshape(-1,1) , labels.type(torch.float32).reshape(-1,1)   )\n",
    "        \n",
    "        return loss\n",
    "\n",
    "    def configure_optimizers(self):\n",
    "        \n",
    "        optimizer = optim.Adam(params=self.parameters(),\n",
    "                           lr = 1e-3)\n",
    "        \n",
    "        scheduler = CosineAnnealingLR(optimizer, T_max = 1000)\n",
    "        \n",
    "        return [optimizer], [{\"scheduler\": scheduler, \"interval\": \"epoch\"}]\n",
    "        \n",
    "    \n",
    "model = TransGraph( 32 , 16 , 0.1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a930e467-fa2c-4010-9eac-1194f70a5ad7",
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.manual_seed(123)  # Seed which can be removed or added in another part\n",
    "\n",
    "torch.set_float32_matmul_precision('high')\n",
    "trainer = L.Trainer(max_epochs= 30  ,devices = [0,1,2,3] ,  accelerator=\"gpu\" , precision=\"bf16-mixed\" )\n",
    "\n",
    "trainer.fit(model, train_dataloader  )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "fb72628e-964c-49ed-bcef-35eb789748b4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# torch.save(model.state_dict() , 'transformergraph_eeg_12s_2.pt' )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "255f5700-d44f-4eb7-83ba-6935d4b07dfe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# model.load_state_dict(torch.load('transformergraph_eeg_12s.pt'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8439e27b-a9c0-43c4-868d-4f55a6381a52",
   "metadata": {},
   "source": [
    "## For Pre-Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "93aeb650-2883-40be-8e14-ad79544011cf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(32075,)\n"
     ]
    }
   ],
   "source": [
    "from torch_geometric.data import Dataset\n",
    "from torch_geometric.loader import DataLoader\n",
    "pos = sio.loadmat('position.mat')['pos'].T\n",
    "\n",
    "with h5py.File( 'clip_data_eval' +'.h5' , \"r\") as f:\n",
    "    \n",
    "    EEG = np.array(f[list(f.keys())[0]])\n",
    "\n",
    "        \n",
    "with h5py.File(  'label_eval' +'.h5', \"r\") as f: \n",
    "    \n",
    "    Label = np.array(f[list(f.keys())[0]])\n",
    "    \n",
    "    \n",
    "print(Label.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "3994b317-cb94-437a-95d2-23baefbe6fc5",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████| 1714/1714 [00:42<00:00, 40.10it/s]\n"
     ]
    }
   ],
   "source": [
    "pos = sio.loadmat('position.mat')['pos'].T\n",
    "\n",
    "dataset = []\n",
    "batch = np.ones((1,19))\n",
    "from scipy.fft import fft\n",
    "is_fft = True\n",
    "\n",
    "for idx in tqdm(range(EEG.shape[0])):\n",
    "    \n",
    "    eeg_clip = EEG[idx,:,:,:]\n",
    "    \n",
    "    if is_fft:\n",
    "        eeg_clip = np.log(np.abs( fft(eeg_clip,axis=2)[:,:,0:100]) +1e-30) # Just real part might also be useful\n",
    "        \n",
    "    label = Label[idx]\n",
    "    \n",
    "    dataset.append( (  torch.tensor(eeg_clip).transpose(1,0) , \n",
    "                      torch.tensor((label)  , dtype=torch.long) ,  torch.tensor(pos) )  \n",
    ")\n",
    "    \n",
    "    \n",
    "del Label , EEG  # Added new"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "910347d3-edac-4b1a-b0c6-57850728da76",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_dataloader = DataLoader(dataset  , batch_size = 16  )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d6d4149-d006-471e-8ce7-edf7645aa0b7",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "l = []\n",
    "gt = []\n",
    "# model.to('cpu')\n",
    "model.cuda()\n",
    "for batch_num, data in enumerate(tqdm(test_dataloader)):\n",
    "\n",
    "    inputs, labels , pos = data\n",
    "\n",
    "    \n",
    "    out , _ = model(inputs.cuda()  , pos.cuda())\n",
    "    \n",
    "    l.extend ( (  (torch.sigmoid(out.reshape(-1,1)) )).to('cpu').detach().numpy()  )  \n",
    "\n",
    "    gt.extend( ( data[1].type(torch.float32).reshape(-1,1).to('cpu') ).detach().numpy() )\n",
    "\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "f1_score(gt,np.round(l), average='binary')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fe296ec6-2f2b-48bb-bf5b-6f6082149525",
   "metadata": {},
   "outputs": [],
   "source": [
    "from imblearn.metrics import sensitivity_score , specificity_score\n",
    "sensitivity_score(gt,np.round(l), average='binary') , specificity_score(gt,np.round(l), average='binary')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "e1c44c34-651b-41c3-a3a1-6739c26fdf95",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.840244682690645"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn import metrics\n",
    "\n",
    "fpr, tpr, thresholds = metrics.roc_curve(np.ravel( np.array(gt).reshape(-1,1)) , \n",
    "                                         np.ravel( np.array(l).reshape(-1,1)   ) )\n",
    "metrics.auc(fpr,tpr)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "673b7d8e-1842-4460-8e5f-02541627e5f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from scipy.io import savemat\n",
    "import numpy as np\n",
    "\n",
    "mdic = {\"a\": a, \"label\": \"experiment\"}\n",
    "\n",
    "savemat(\"atten_w.mat\", mdic)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
