{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wq7FnhDqnwzg",
        "outputId": "279eb5c1-e5bb-441e-b430-73c5e52c77d3"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Collecting braindecode\n",
            "  Using cached braindecode-0.8.1-py3-none-any.whl.metadata (8.1 kB)\n",
            "Requirement already satisfied: mne in /usr/local/lib/python3.11/dist-packages (from braindecode) (1.9.0)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from braindecode) (1.26.4)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.11/dist-packages (from braindecode) (2.2.2)\n",
            "Requirement already satisfied: scipy in /usr/local/lib/python3.11/dist-packages (from braindecode) (1.15.3)\n",
            "Requirement already satisfied: matplotlib in /usr/local/lib/python3.11/dist-packages (from braindecode) (3.10.0)\n",
            "Requirement already satisfied: h5py in /usr/local/lib/python3.11/dist-packages (from braindecode) (3.13.0)\n",
            "Collecting skorch (from braindecode)\n",
            "  Using cached skorch-1.1.0-py3-none-any.whl.metadata (11 kB)\n",
            "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (from braindecode) (2.6.0+cu124)\n",
            "Requirement already satisfied: einops in /usr/local/lib/python3.11/dist-packages (from braindecode) (0.8.1)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.11/dist-packages (from braindecode) (1.5.0)\n",
            "Collecting torchinfo (from braindecode)\n",
            "  Using cached torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)\n",
            "Collecting docstring-inheritance (from braindecode)\n",
            "  Using cached docstring_inheritance-2.2.2-py3-none-any.whl.metadata (11 kB)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->braindecode) (1.3.2)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib->braindecode) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->braindecode) (4.58.0)\n",
            "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->braindecode) (1.4.8)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib->braindecode) (24.2)\n",
            "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib->braindecode) (11.2.1)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib->braindecode) (3.2.3)\n",
            "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib->braindecode) (2.9.0.post0)\n",
            "Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from mne->braindecode) (4.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from mne->braindecode) (3.1.6)\n",
            "Requirement already satisfied: lazy-loader>=0.3 in /usr/local/lib/python3.11/dist-packages (from mne->braindecode) (0.4)\n",
            "Requirement already satisfied: pooch>=1.5 in /usr/local/lib/python3.11/dist-packages (from mne->braindecode) (1.8.2)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.11/dist-packages (from mne->braindecode) (4.67.1)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas->braindecode) (2025.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas->braindecode) (2025.2)\n",
            "Requirement already satisfied: scikit-learn>=0.22.0 in /usr/local/lib/python3.11/dist-packages (from skorch->braindecode) (1.5.2)\n",
            "Requirement already satisfied: tabulate>=0.7.7 in /usr/local/lib/python3.11/dist-packages (from skorch->braindecode) (0.9.0)\n",
            "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch->braindecode) (3.18.0)\n",
            "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch->braindecode) (4.13.2)\n",
            "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch->braindecode) (3.4.2)\n",
            "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch->braindecode) (2025.3.2)\n",
            "Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->braindecode)\n",
            "  Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->braindecode)\n",
            "  Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->braindecode)\n",
            "  Using cached nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->braindecode)\n",
            "  Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->braindecode)\n",
            "  Using cached nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->braindecode)\n",
            "  Using cached nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-curand-cu12==10.3.5.147 (from torch->braindecode)\n",
            "  Using cached nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch->braindecode)\n",
            "  Using cached nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch->braindecode)\n",
            "  Using cached nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
            "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch->braindecode) (0.6.2)\n",
            "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch->braindecode) (2.21.5)\n",
            "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch->braindecode) (12.4.127)\n",
            "Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch->braindecode)\n",
            "  Using cached nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
            "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch->braindecode) (3.2.0)\n",
            "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch->braindecode) (1.13.1)\n",
            "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch->braindecode) (1.3.0)\n",
            "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.11/dist-packages (from pooch>=1.5->mne->braindecode) (4.3.8)\n",
            "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.11/dist-packages (from pooch>=1.5->mne->braindecode) (2.32.3)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.7->matplotlib->braindecode) (1.17.0)\n",
            "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn>=0.22.0->skorch->braindecode) (3.6.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->mne->braindecode) (3.0.2)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->pooch>=1.5->mne->braindecode) (3.4.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->pooch>=1.5->mne->braindecode) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->pooch>=1.5->mne->braindecode) (1.26.20)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests>=2.19.0->pooch>=1.5->mne->braindecode) (2025.4.26)\n",
            "Using cached braindecode-0.8.1-py3-none-any.whl (165 kB)\n",
            "Using cached docstring_inheritance-2.2.2-py3-none-any.whl (24 kB)\n",
            "Using cached skorch-1.1.0-py3-none-any.whl (228 kB)\n",
            "Using cached nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n",
            "Using cached nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n",
            "Using cached nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n",
            "Using cached nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n",
            "Using cached nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
            "Using cached nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n",
            "Using cached nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n",
            "Using cached nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n",
            "Downloading nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n",
            "\u001b[2K   \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25hUsing cached nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
            "Using cached torchinfo-1.8.0-py3-none-any.whl (23 kB)\n",
            "Installing collected packages: torchinfo, nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, docstring-inheritance, nvidia-cusparse-cu12, nvidia-cudnn-cu12, skorch, nvidia-cusolver-cu12, braindecode\n",
            "  Attempting uninstall: nvidia-nvjitlink-cu12\n",
            "    Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n",
            "    Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n",
            "      Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n",
            "  Attempting uninstall: nvidia-curand-cu12\n",
            "    Found existing installation: nvidia-curand-cu12 10.3.6.82\n",
            "    Uninstalling nvidia-curand-cu12-10.3.6.82:\n",
            "      Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n",
            "  Attempting uninstall: nvidia-cufft-cu12\n",
            "    Found existing installation: nvidia-cufft-cu12 11.2.3.61\n",
            "    Uninstalling nvidia-cufft-cu12-11.2.3.61:\n",
            "      Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n",
            "  Attempting uninstall: nvidia-cuda-runtime-cu12\n",
            "    Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n",
            "    Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n",
            "      Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n",
            "  Attempting uninstall: nvidia-cuda-nvrtc-cu12\n",
            "    Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n",
            "    Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n",
            "      Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n",
            "  Attempting uninstall: nvidia-cuda-cupti-cu12\n",
            "    Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n",
            "    Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n",
            "      Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n",
            "  Attempting uninstall: nvidia-cublas-cu12\n",
            "    Found existing installation: nvidia-cublas-cu12 12.5.3.2\n",
            "    Uninstalling nvidia-cublas-cu12-12.5.3.2:\n",
            "      Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n",
            "  Attempting uninstall: nvidia-cusparse-cu12\n",
            "    Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n",
            "    Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n",
            "      Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n",
            "  Attempting uninstall: nvidia-cudnn-cu12\n",
            "    Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n",
            "    Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n",
            "      Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n",
            "  Attempting uninstall: nvidia-cusolver-cu12\n",
            "    Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n",
            "    Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n",
            "      Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n",
            "Successfully installed braindecode-0.8.1 docstring-inheritance-2.2.2 nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127 skorch-1.1.0 torchinfo-1.8.0\n",
            "Requirement already satisfied: moabb in /usr/local/lib/python3.11/dist-packages (1.2.0)\n",
            "Requirement already satisfied: PyYAML<7.0,>=6.0 in /usr/local/lib/python3.11/dist-packages (from moabb) (6.0.2)\n",
            "Requirement already satisfied: coverage<8.0.0,>=7.0.1 in /usr/local/lib/python3.11/dist-packages (from moabb) (7.8.1)\n",
            "Requirement already satisfied: edfio<0.5.0,>=0.4.2 in /usr/local/lib/python3.11/dist-packages (from moabb) (0.4.9)\n",
            "Requirement already satisfied: edflib-python<2.0.0,>=1.0.6 in /usr/local/lib/python3.11/dist-packages (from moabb) (1.0.8)\n",
            "Requirement already satisfied: h5py<4.0.0,>=3.10.0 in /usr/local/lib/python3.11/dist-packages (from moabb) (3.13.0)\n",
            "Requirement already satisfied: matplotlib<4.0.0,>=3.6.2 in /usr/local/lib/python3.11/dist-packages (from moabb) (3.10.0)\n",
            "Requirement already satisfied: memory-profiler<0.62.0,>=0.61.0 in /usr/local/lib/python3.11/dist-packages (from moabb) (0.61.0)\n",
            "Requirement already satisfied: mne<2.0.0,>=1.7.0 in /usr/local/lib/python3.11/dist-packages (from moabb) (1.9.0)\n",
            "Requirement already satisfied: mne-bids>=0.14 in /usr/local/lib/python3.11/dist-packages (from moabb) (0.16.0)\n",
            "Requirement already satisfied: numpy<2.0,>=1.22 in /usr/local/lib/python3.11/dist-packages (from moabb) (1.26.4)\n",
            "Requirement already satisfied: pandas>=1.5.2 in /usr/local/lib/python3.11/dist-packages (from moabb) (2.2.2)\n",
            "Requirement already satisfied: pooch<2.0.0,>=1.6.0 in /usr/local/lib/python3.11/dist-packages (from moabb) (1.8.2)\n",
            "Requirement already satisfied: pyriemann<0.8,>=0.7 in /usr/local/lib/python3.11/dist-packages (from moabb) (0.7)\n",
            "Requirement already satisfied: requests<3.0.0,>=2.28.1 in /usr/local/lib/python3.11/dist-packages (from moabb) (2.32.3)\n",
            "Requirement already satisfied: scikit-learn<1.6 in /usr/local/lib/python3.11/dist-packages (from moabb) (1.5.2)\n",
            "Requirement already satisfied: scipy<2.0.0,>=1.9.3 in /usr/local/lib/python3.11/dist-packages (from moabb) (1.15.3)\n",
            "Requirement already satisfied: seaborn<0.13.0,>=0.12.1 in /usr/local/lib/python3.11/dist-packages (from moabb) (0.12.2)\n",
            "Requirement already satisfied: tqdm<5.0.0,>=4.64.1 in /usr/local/lib/python3.11/dist-packages (from moabb) (4.67.1)\n",
            "Requirement already satisfied: urllib3<2.0.0,>=1.26.15 in /usr/local/lib/python3.11/dist-packages (from moabb) (1.26.20)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.6.2->moabb) (1.3.2)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.6.2->moabb) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.6.2->moabb) (4.58.0)\n",
            "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.6.2->moabb) (1.4.8)\n",
            "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.6.2->moabb) (24.2)\n",
            "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.6.2->moabb) (11.2.1)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.6.2->moabb) (3.2.3)\n",
            "Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.11/dist-packages (from matplotlib<4.0.0,>=3.6.2->moabb) (2.9.0.post0)\n",
            "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from memory-profiler<0.62.0,>=0.61.0->moabb) (5.9.5)\n",
            "Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from mne<2.0.0,>=1.7.0->moabb) (4.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from mne<2.0.0,>=1.7.0->moabb) (3.1.6)\n",
            "Requirement already satisfied: lazy-loader>=0.3 in /usr/local/lib/python3.11/dist-packages (from mne<2.0.0,>=1.7.0->moabb) (0.4)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.5.2->moabb) (2025.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas>=1.5.2->moabb) (2025.2)\n",
            "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.11/dist-packages (from pooch<2.0.0,>=1.6.0->moabb) (4.3.8)\n",
            "Requirement already satisfied: joblib in /usr/local/lib/python3.11/dist-packages (from pyriemann<0.8,>=0.7->moabb) (1.5.0)\n",
            "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0,>=2.28.1->moabb) (3.4.2)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0,>=2.28.1->moabb) (3.10)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0,>=2.28.1->moabb) (2025.4.26)\n",
            "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn<1.6->moabb) (3.6.0)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.7->matplotlib<4.0.0,>=3.6.2->moabb) (1.17.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->mne<2.0.0,>=1.7.0->moabb) (3.0.2)\n"
          ]
        }
      ],
      "source": [
        "!pip install braindecode\n",
        "!pip install moabb\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "from braindecode.datasets import MOABBDataset\n",
        "from braindecode.preprocessing import Preprocessor, exponential_moving_standardize, preprocess, create_windows_from_events\n",
        "from torch.utils.data import TensorDataset, DataLoader, ConcatDataset, random_split\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from sklearn.metrics import cohen_kappa_score\n",
        "import pandas as pd\n",
        "import random\n",
        "#CDNN/VNN\n",
        "\n",
        "# ---------------------\n",
        "# Reproducibility\n",
        "# ---------------------\n",
        "def set_seed(seed: int):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "# Usage: pick any integer\n",
        "torch.manual_seed(42)\n",
        "set_seed(42)\n",
        "\n",
        "# ---------------------\n",
        "# Utility functions\n",
        "# ---------------------\n",
        "def convert_to_tensors(windows_dataset):\n",
        "    X, y = [], []\n",
        "    for win in windows_dataset:\n",
        "        X.append(win[0])\n",
        "        y.append(win[1])\n",
        "    X = np.stack(X)\n",
        "    y = np.array(y)\n",
        "    return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.long)\n",
        "\n",
        "\n",
        "def compute_covariance_matrix(loader, device):\n",
        "    all_batches = []\n",
        "    for X, _ in loader:\n",
        "        all_batches.append(X.to(device))\n",
        "    X = torch.cat(all_batches, dim=0)\n",
        "    mean_t = X.mean(dim=2, keepdim=True)\n",
        "    std_t = X.std(dim=2, keepdim=True, unbiased=False) + 1e-15\n",
        "    Xz = (X - mean_t) / std_t\n",
        "    N, C, T = Xz.shape\n",
        "    X_flat = Xz.permute(0, 2, 1).reshape(N * T, C)\n",
        "    cov = (X_flat.T @ X_flat) / (N * T)\n",
        "    return cov\n",
        "\n",
        "# ---------------------\n",
        "# Model definition\n",
        "# ---------------------\n",
        "#Simple approach, add filter coefficients as needed.\n",
        "class VNNFilterBank(nn.Module):\n",
        "    def __init__(self, C, T, cov_matrix, betas, hidden_dim=128, num_classes=4, dropout_p=0.5):\n",
        "        super().__init__()\n",
        "        self.C, self.T = C, T\n",
        "        self.L = cov_matrix\n",
        "        self.betas = betas\n",
        "        in_dim = C * T * len(betas)\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Linear(in_dim, hidden_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Dropout(p=dropout_p),\n",
        "            nn.Linear(hidden_dim, num_classes)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        B, C, T = x.shape\n",
        "\n",
        "        feats = []\n",
        "        for beta in self.betas:\n",
        "            filt = torch.matrix_exp(-beta * self.L) # Just pass self.L with beta=1 for VNN ablation\n",
        "            filt = filt / torch.trace(torch.matrix_exp(-beta * self.L))\n",
        "            y = torch.matmul(filt.unsqueeze(0).expand(B, -1, -1), xz)\n",
        "            feats.append(y.reshape(B, C * T))\n",
        "        out_feats = torch.cat(feats, dim=1)\n",
        "        return self.classifier(out_feats)\n",
        "\n",
        "# ---------------------\n",
        "# Training with validation\n",
        "# ---------------------\n",
        "def train_and_evaluate(model, train_loader, val_loader, test_loader,\n",
        "                       optimizer, criterion, num_epochs, device):\n",
        "    model.to(device)\n",
        "    best_state = None\n",
        "    best_val_loss = float('inf')\n",
        "\n",
        "    for epoch in range(1, num_epochs + 1):\n",
        "        # Training step\n",
        "        model.train()\n",
        "        for X, y in train_loader:\n",
        "            X, y = X.to(device), y.to(device)\n",
        "            optimizer.zero_grad()\n",
        "            out = model(X)\n",
        "            loss = criterion(out, y)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "\n",
        "        # Validation step\n",
        "        model.eval()\n",
        "        val_loss = 0.0\n",
        "        val_count = 0\n",
        "        with torch.no_grad():\n",
        "            for X, y in val_loader:\n",
        "                X, y = X.to(device), y.to(device)\n",
        "                out = model(X)\n",
        "                loss = criterion(out, y)\n",
        "                val_loss += loss.item() * y.size(0)\n",
        "                val_count += y.size(0)\n",
        "        val_loss /= val_count\n",
        "        print(f\"Epoch {epoch}/{num_epochs} | Val Loss: {val_loss:.4f}\")\n",
        "\n",
        "        # Checkpoint best\n",
        "        if val_loss < best_val_loss:\n",
        "            best_val_loss = val_loss\n",
        "            best_state = model.state_dict()\n",
        "\n",
        "    # Load best model\n",
        "    model.load_state_dict(best_state)\n",
        "\n",
        "    # Final test evaluation\n",
        "    model.eval()\n",
        "    preds, targs = [], []\n",
        "    with torch.no_grad():\n",
        "        for X, y in test_loader:\n",
        "            X, y = X.to(device), y.to(device)\n",
        "            p = model(X).argmax(dim=1)\n",
        "            preds.extend(p.cpu().tolist())\n",
        "            targs.extend(y.cpu().tolist())\n",
        "    final_acc = 100.0 * np.mean(np.array(preds) == np.array(targs))\n",
        "    kappa = cohen_kappa_score(targs, preds)\n",
        "    print(f\"Final Test Acc: {final_acc:.2f}% | Kappa: {kappa:.4f}\")\n",
        "    return final_acc, kappa\n",
        "\n",
        "# ---------------------\n",
        "# Main LO-SO CV with val split\n",
        "# ---------------------\n",
        "if __name__ == \"__main__\":\n",
        "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "    print(\"Device:\", device)\n",
        "    num_subjects = 9\n",
        "    num_epochs = 50\n",
        "    batch_size = 64\n",
        "    betas = [0.1, 5, 15.1]\n",
        "    results = []\n",
        "\n",
        "    for test_id in range(1, num_subjects + 1):\n",
        "        set_seed(42 + test_id)\n",
        "        print(f\"\\n--- Test subj {test_id} ---\")\n",
        "        train_sets, test_set = [], None\n",
        "\n",
        "        for sid in range(1, num_subjects + 1):\n",
        "            ds = MOABBDataset('BNCI2014_001', [sid])\n",
        "            preprocess(ds, preprocessors, n_jobs=-1)\n",
        "            sf = ds.datasets[0].raw.info['sfreq']\n",
        "            off = int(trial_offset * sf)\n",
        "            wins = create_windows_from_events(\n",
        "                ds,\n",
        "                trial_start_offset_samples=off,\n",
        "                trial_stop_offset_samples=0,\n",
        "                preload=True\n",
        "            )\n",
        "            parts = wins.split('session')\n",
        "            Xt, Yt = convert_to_tensors(parts['0train'])\n",
        "            Xv, Yv = convert_to_tensors(parts['1test'])\n",
        "\n",
        "            if sid == test_id:\n",
        "                test_set = TensorDataset(torch.cat((Xt, Xv), 0), torch.cat((Yt, Yv), 0))\n",
        "            else:\n",
        "                train_sets += [TensorDataset(Xt, Yt), TensorDataset(Xv, Yv)]\n",
        "\n",
        "        full_train = ConcatDataset(train_sets)\n",
        "        # 5% validation split\n",
        "        total_train = len(full_train)\n",
        "        val_size = int(0.05 * total_train)\n",
        "        train_size = total_train - val_size\n",
        "        g = torch.Generator()\n",
        "        g.manual_seed(42 + test_id)\n",
        "        train_ds, val_ds = random_split(full_train, [train_size, val_size], generator=g)\n",
        "\n",
        "        tr_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)\n",
        "        val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)\n",
        "        te_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False)\n",
        "\n",
        "        cov = compute_covariance_matrix(tr_loader, device)\n",
        "        C, T = cov.shape[0], Xt.shape[2]\n",
        "\n",
        "        model = VNNFilterBank(C, T, cov, betas, hidden_dim=128, num_classes=4, dropout_p=0.5).to(device)\n",
        "        optimizer = optim.Adam(model.parameters(), lr=1e-4)\n",
        "        criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "        acc, kappa = train_and_evaluate(\n",
        "            model, tr_loader, val_loader, te_loader,\n",
        "            optimizer, criterion, num_epochs, device\n",
        "        )\n",
        "        results.append((test_id, acc, kappa))\n",
        "\n",
        "    df = pd.DataFrame(results, columns=['Subject', 'Acc(%)', 'Kappa'])\n",
        "    df.loc['Avg'] = ['Avg', df['Acc(%)'].mean(), df['Kappa'].mean()]\n",
        "    df.loc['Std'] = ['Std', df['Acc(%)'].std(), df['Kappa'].std()]\n",
        "    print(\"\\n=== Summary ===\")\n",
        "    print(df)\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2EFfK-xNo3k1",
        "outputId": "a8f7ed3c-6137-447c-ec34-feac6e08769d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/moabb/datasets/download.py:56: RuntimeWarning: Setting non-standard config type: \"MNE_DATASETS_BNCI_PATH\"\n",
            "  set_config(key, get_config(\"MNE_DATA\"))\n",
            "Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A01T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A01T.mat'.\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Device: cuda\n",
            "\n",
            "--- Test subj 1 ---\n",
            "MNE_DATA is not already configured. It will be set to default location in the home directory - /root/mne_data\n",
            "All datasets will be downloaded to this location, if anything is already downloaded, please move manually to this location\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/urllib3/connectionpool.py:1064: InsecureRequestWarning: Unverified HTTPS request is being made to host 'lampx.tugraz.at'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html#ssl-warnings\n",
            "  warnings.warn(\n",
            "100%|█████████████████████████████████████| 42.8M/42.8M [00:00<00:00, 4.52GB/s]\n",
            "SHA256 hash of downloaded file: 054f02e70cf9c4ada1517e9b9864f45407939c1062c6793516585c6f511d0325\n",
            "Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.\n",
            "Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A01E.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A01E.mat'.\n",
            "/usr/local/lib/python3.11/dist-packages/urllib3/connectionpool.py:1064: InsecureRequestWarning: Unverified HTTPS request is being made to host 'lampx.tugraz.at'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/1.26.x/advanced-usage.html#ssl-warnings\n",
            "  warnings.warn(\n",
            "100%|█████████████████████████████████████| 43.8M/43.8M [00:00<00:00, 22.8GB/s]\n",
            "SHA256 hash of downloaded file: 53d415f39c3d7b0c88b894d7b08d99bcdfe855ede63831d3691af1a45607fb62\n",
            "Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "import random\n",
        "import pandas as pd\n",
        "\n",
        "from braindecode.datasets import MOABBDataset\n",
        "from braindecode.preprocessing import (\n",
        "    Preprocessor,\n",
        "    exponential_moving_standardize,\n",
        "    preprocess,\n",
        "    create_windows_from_events,\n",
        ")\n",
        "from torch.utils.data import TensorDataset, DataLoader, ConcatDataset\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from sklearn.metrics import cohen_kappa_score\n",
        "\n",
        "# -----------------------------------------------------------------------------\n",
        "# 0) Reproducibility helpers\n",
        "# -----------------------------------------------------------------------------\n",
        "#GIN AND GAT\n",
        "def set_seed(seed: int):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "    torch.backends.cudnn.deterministic = True\n",
        "    torch.backends.cudnn.benchmark = False\n",
        "\n",
        "\n",
        "set_seed(42)\n",
        "\n",
        "# -----------------------------------------------------------------------------\n",
        "# 1) Data utilities\n",
        "# -----------------------------------------------------------------------------\n",
        "\n",
        "def convert_to_tensors(windows_dataset):\n",
        "    \"\"\"Braindecode WindowsDataset → channel‑major tensors.\"\"\"\n",
        "    X, y = [], []\n",
        "    for win in windows_dataset:\n",
        "        X.append(win[0])  # [C, T]\n",
        "        y.append(win[1])\n",
        "    X = np.stack(X).astype(np.float32)  # [N, C, T]\n",
        "    y = np.asarray(y, dtype=np.int64)\n",
        "    return torch.from_numpy(X), torch.from_numpy(y)\n",
        "\n",
        "\n",
        "# -----------------------------------------------------------------------------\n",
        "# 2) Pre‑processing configuration (unchanged)\n",
        "# -----------------------------------------------------------------------------\n",
        "\n",
        "low_cut_hz, high_cut_hz = 0.01, 38.0\n",
        "factor_new, init_block_size = 1e-3, 1000\n",
        "scale_factor = 1e6\n",
        "trial_offset = -0.5  # seconds\n",
        "\n",
        "preprocessors = [\n",
        "    Preprocessor(\"pick_types\", eeg=True, meg=False, stim=False),\n",
        "    Preprocessor(lambda x: x * scale_factor),\n",
        "    Preprocessor(\"filter\", l_freq=low_cut_hz, h_freq=high_cut_hz),\n",
        "    Preprocessor(\n",
        "        exponential_moving_standardize,\n",
        "        factor_new=factor_new,\n",
        "        init_block_size=init_block_size,\n",
        "    ),\n",
        "]\n",
        "\n",
        "# -----------------------------------------------------------------------------\n",
        "# 3) Normalisation helper (z‑score per trial, per channel)\n",
        "# -----------------------------------------------------------------------------\n",
        "\n",
        "def zscore_batch(x: torch.Tensor) -> torch.Tensor:\n",
        "    \"\"\"z‑score each trial individually across the *time* dimension.\"\"\"\n",
        "    # x: [B, C, T]\n",
        "    mean = x.mean(dim=1, keepdim=True)\n",
        "    std = x.std(dim=1, keepdim=True) + 1e-8\n",
        "    return (x - mean) / std\n",
        "\n",
        "\n",
        "# -----------------------------------------------------------------------------\n",
        "# 4) Vectorised GIN & GAT classifiers (no Python loops over T)\n",
        "# -----------------------------------------------------------------------------\n",
        "\n",
        "# ---------------------------------------------------------------------\n",
        "# GIN – time-step nodes, fully-connected in T\n",
        "# ---------------------------------------------------------------------\n",
        "class GINClassifier(nn.Module):\n",
        "    def __init__(self, num_channels, input_window, num_classes=4,\n",
        "                 hidden_dim=128, dropout_p=0.5):\n",
        "        super().__init__()\n",
        "        self.eps = nn.Parameter(torch.zeros(1))\n",
        "        self.mlp_node = nn.Sequential(\n",
        "            nn.Linear(num_channels, num_channels),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(num_channels, num_channels),\n",
        "        )\n",
        "        self.act = nn.ELU()\n",
        "        self.drop = nn.Dropout(dropout_p)\n",
        "        self.head = nn.Sequential(\n",
        "            nn.Linear(input_window * num_channels, hidden_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Dropout(dropout_p),\n",
        "            nn.Linear(hidden_dim, num_classes),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):  # x: [B, C, T]\n",
        "        #x = zscore_batch(x)                    # per‑trial z‑score\n",
        "        x = x.permute(0, 2, 1)                 # [B, T, C]\n",
        "        tot = x.sum(dim=2, keepdim=True)       # [B, T, 1]\n",
        "        h_prime = (1 + self.eps) * x + (tot - x)\n",
        "        # apply mlp_node to every time‑step: reshape to 2‑D, run, reshape back\n",
        "        B, T, C = h_prime.shape\n",
        "        h_flat = h_prime.reshape(B * T, C)\n",
        "        out_flat = self.mlp_node(h_flat)       # [B*T, C]\n",
        "        z = out_flat.view(B, T, C)\n",
        "        z = self.act(z).reshape(B, -1)\n",
        "        z = self.drop(z)\n",
        "        return self.head(z)\n",
        "\n",
        "\n",
        "class GATClassifier(nn.Module):\n",
        "    def __init__(self, num_channels, input_window, num_classes=4,\n",
        "                 hidden_dim=128, dropout_p=0.5):\n",
        "        super().__init__()\n",
        "        self.attn = nn.MultiheadAttention(embed_dim=1, num_heads=1,\n",
        "                                          batch_first=True)\n",
        "        self.act = nn.ELU()\n",
        "        self.drop = nn.Dropout(dropout_p)\n",
        "        self.head = nn.Sequential(\n",
        "            nn.Linear(input_window * num_channels, hidden_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Dropout(dropout_p),\n",
        "            nn.Linear(hidden_dim, num_classes),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):  # x: [B, C, T]\n",
        "        #x = zscore_batch(x)\n",
        "        x = x.permute(0, 2, 1)  # [B, T, C]\n",
        "        B, T, C = x.shape\n",
        "        # treat each time‑step as separate batch element to vectorise attention\n",
        "        x_flat = x.reshape(B * T, C, 1)        # [B*T, C, 1]\n",
        "        h_attn, _ = self.attn(x_flat, x_flat, x_flat)  # same shape\n",
        "        h_res = h_attn + x_flat\n",
        "        z = h_res.squeeze(-1).reshape(B, T, C)\n",
        "        z = self.act(z).reshape(B, -1)\n",
        "        z = self.drop(z)\n",
        "        return self.head(z)\n",
        "\n",
        "# -----------------------------------------------------------------------------\n",
        "# 5) AMP‑enabled training/evaluation helpers\n",
        "# -----------------------------------------------------------------------------\n",
        "\n",
        "def train_and_evaluate(model, train_loader, test_loader,\n",
        "                       optimizer, criterion, epochs, device, name=\"model\"):\n",
        "    scaler = torch.cuda.amp.GradScaler(enabled=device.type == \"cuda\")\n",
        "    model.to(device)\n",
        "\n",
        "    for ep in range(1, epochs + 1):\n",
        "        model.train()\n",
        "        for X, y in train_loader:\n",
        "            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
        "            optimizer.zero_grad()\n",
        "            with torch.cuda.amp.autocast(enabled=device.type == \"cuda\"):\n",
        "                loss = criterion(model(X), y)\n",
        "            scaler.scale(loss).backward()\n",
        "            scaler.step(optimizer)\n",
        "            scaler.update()\n",
        "\n",
        "        # quick test accuracy each epoch\n",
        "        model.eval()\n",
        "        correct = total = 0\n",
        "        with torch.no_grad():\n",
        "            for X, y in test_loader:\n",
        "                X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
        "                preds = model(X).argmax(dim=1)\n",
        "                correct += (preds == y).sum().item()\n",
        "                total += y.size(0)\n",
        "        print(f\"[{name}] Ep {ep}/{epochs}  acc {100*correct/total:.2f}%\")\n",
        "\n",
        "    # final metrics\n",
        "    preds, targs = [], []\n",
        "    model.eval()\n",
        "    with torch.no_grad():\n",
        "        for X, y in test_loader:\n",
        "            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)\n",
        "            p = model(X).argmax(dim=1)\n",
        "            preds.extend(p.cpu().tolist())\n",
        "            targs.extend(y.cpu().tolist())\n",
        "    acc = 100 * np.mean(np.array(preds) == np.array(targs))\n",
        "    kappa = cohen_kappa_score(targs, preds)\n",
        "    print(f\"[{name}] Final  acc {acc:.2f}%   κ {kappa:.4f}\")\n",
        "    return acc, kappa\n",
        "\n",
        "\n",
        "# -----------------------------------------------------------------------------\n",
        "# 6) LOSO cross‑validation with accelerated DataLoaders\n",
        "# -----------------------------------------------------------------------------\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "    print(\"Device:\", device)\n",
        "\n",
        "    num_subjects = 9\n",
        "    epochs = 50\n",
        "    batch_size = 256        # ← bigger batch; fits on modern GPUs\n",
        "    num_workers = 4         # dataloader threads\n",
        "\n",
        "    models = {\n",
        "        \"GIN\": GINClassifier,\n",
        "        \"GAT\": GATClassifier,\n",
        "    }\n",
        "    results = {k: [] for k in models}\n",
        "\n",
        "    for test_id in range(1, num_subjects + 1):\n",
        "        set_seed(42 + test_id)\n",
        "        print(f\"\\n=== Subject {test_id} held out ===\")\n",
        "\n",
        "        train_sets, test_set = [], None\n",
        "        for sid in range(1, num_subjects + 1):\n",
        "            ds = MOABBDataset(\"BNCI2014_001\", [sid])\n",
        "            preprocess(ds, preprocessors, n_jobs=-1)\n",
        "            sfreq = ds.datasets[0].raw.info[\"sfreq\"]\n",
        "            off = int(trial_offset * sfreq)\n",
        "            wins = create_windows_from_events(\n",
        "                ds, trial_start_offset_samples=off,\n",
        "                trial_stop_offset_samples=0, preload=True)\n",
        "            parts = wins.split(\"session\")\n",
        "            Xt, Yt = convert_to_tensors(parts[\"0train\"])\n",
        "            Xv, Yv = convert_to_tensors(parts[\"1test\"])\n",
        "            if sid == test_id:\n",
        "                test_set = TensorDataset(torch.cat((Xt, Xv)), torch.cat((Yt, Yv)))\n",
        "            else:\n",
        "                train_sets.extend([TensorDataset(Xt, Yt), TensorDataset(Xv, Yv)])\n",
        "\n",
        "        tr_loader = DataLoader(ConcatDataset(train_sets), batch_size=batch_size,\n",
        "                               shuffle=True, drop_last=True, pin_memory=True,\n",
        "                               num_workers=num_workers, persistent_workers=True)\n",
        "        te_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False,\n",
        "                               pin_memory=True, num_workers=num_workers)\n",
        "\n",
        "        # derive channel (C) and time-points (T) dimensions correctly\n",
        "        C = Xt.shape[1]   # channels\n",
        "        T = Xt.shape[2]   # timepoints  # timepoints\n",
        "\n",
        "        for name, ctor in models.items():\n",
        "            print(f\"\\n-- {name} --\")\n",
        "            model = ctor(num_channels=C, input_window=T, num_classes=4,\n",
        "                          hidden_dim=128, dropout_p=0.5)\n",
        "            opt = optim.AdamW(model.parameters(), lr=3e-4)\n",
        "            acc, kappa = train_and_evaluate(model, tr_loader, te_loader,\n",
        "                                            opt, nn.CrossEntropyLoss(), epochs,\n",
        "                                            device, name)\n",
        "            results[name].append((test_id, acc, kappa))\n",
        "\n",
        "    print(\"\\n================ SUMMARY ================\")\n",
        "    for name in models:\n",
        "        df = pd.DataFrame(results[name], columns=[\"Subject\", \"Acc(%)\", \"Kappa\"])\n",
        "        df.loc[\"Avg\"] = [\"Avg\", df[\"Acc(%)\"].mean(), df[\"Kappa\"].mean()]\n",
        "        df.loc[\"Std\"] = [\"Std\", df[\"Acc(%)\"].std(), df[\"Kappa\"].std()]\n",
        "        print(f\"\\n{name}:\")\n",
        "\n",
        "        print(df)\n"
      ],
      "metadata": {
        "id": "ozJOVf_cpa50"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import torch\n",
        "import numpy as np\n",
        "from numpy import multiply\n",
        "from braindecode.datasets import MOABBDataset\n",
        "from braindecode.preprocessing import (Preprocessor, exponential_moving_standardize, preprocess)\n",
        "from braindecode.preprocessing import create_windows_from_events\n",
        "from torch.utils.data import TensorDataset, DataLoader, random_split\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from sklearn.metrics import cohen_kappa_score\n",
        "import pandas as pd\n",
        "import torch.nn.functional as F\n",
        "# EEGNET SUBJECT INDEPENDANT BCI2A\n",
        "# Parameters for preprocessing and data\n",
        "low_cut_hz = 0.01  # Low cut frequency for filtering\n",
        "high_cut_hz = 38.0  # High cut frequency for filtering\n",
        "factor_new = 1e-3  # Parameters for exponential moving standardization\n",
        "init_block_size = 1000\n",
        "factor = 1e6  # Factor to convert from V to uV\n",
        "trial_start_offset_seconds = -0.5  # Trial start offset in seconds\n",
        "\n",
        "# Preprocessors list\n",
        "preprocessors = [\n",
        "    Preprocessor('pick_types', eeg=True, meg=False, stim=False),  # Keep EEG sensors\n",
        "    Preprocessor(lambda data: multiply(data, factor)),  # Convert from V to uV\n",
        "    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter\n",
        "    Preprocessor(\n",
        "        exponential_moving_standardize,  # Exponential moving standardization\n",
        "        factor_new=factor_new, init_block_size=init_block_size\n",
        "    )\n",
        "]\n",
        "\n",
        "# Set the device to CUDA if available\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "print(f\"Using device: {device}\")\n",
        "\n",
        "# Define the EEGNet model\n",
        "class EEGNet(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(EEGNet, self).__init__()\n",
        "        self.T = 1125  # Number of time steps per sample\n",
        "\n",
        "        # Layer 1\n",
        "        self.conv1 = nn.Conv2d(1, 16, (1, 64), padding=(0, 0))  # Input shape: [batch, 1, channels, time_steps]\n",
        "        self.batchnorm1 = nn.BatchNorm2d(16, affine=False)  # Batch normalization after convolution\n",
        "\n",
        "        # Layer 2\n",
        "        self.padding1 = nn.ZeroPad2d((0, 0, 16, 17))  # Zero-padding: (left, right, top, bottom)\n",
        "        self.conv2 = nn.Conv2d(16, 4, (22, 32))  # Updated in_channels to match Layer 1 output\n",
        "        self.batchnorm2 = nn.BatchNorm2d(4, affine=False)  # Batch normalization\n",
        "        self.pooling2 = nn.MaxPool2d((2, 4))  # Max pooling with kernel size (2, 4)\n",
        "\n",
        "        # Layer 3\n",
        "        self.padding2 = nn.ZeroPad2d((2, 1, 4, 3))  # Zero-padding: (left, right, top, bottom)\n",
        "        self.conv3 = nn.Conv2d(4, 4, (8, 4))  # Conv layer with filter size (8, 4)\n",
        "        self.batchnorm3 = nn.BatchNorm2d(4, affine=False)  # Batch normalization\n",
        "        self.pooling3 = nn.MaxPool2d((2, 4))  # Max pooling with kernel size (2, 4)\n",
        "\n",
        "        # Fully Connected Layer\n",
        "        # Adjust the input size after seeing the actual shape before the fully connected layer\n",
        "        self.fc1 = nn.Linear(4 * 8 * 64, 4)  # Adjusted output size for 4-class classification\n",
        "\n",
        "    def forward(self, x):\n",
        "        # Layer 1\n",
        "        x = F.elu(self.conv1(x))  # Output shape: [batch_size, 16, 22, time_steps - 63]\n",
        "        x = self.batchnorm1(x)\n",
        "        x = F.dropout(x, 0.25)\n",
        "\n",
        "        # Layer 2\n",
        "        x = self.padding1(x)  # Padding applied\n",
        "        x = F.elu(self.conv2(x))  # Output shape: [batch_size, 4, 1, time_steps - 31 + padding]\n",
        "        x = self.batchnorm2(x)\n",
        "        x = F.dropout(x, 0.25)\n",
        "        x = self.pooling2(x)  # Output shape after pooling\n",
        "\n",
        "        # Layer 3\n",
        "        x = self.padding2(x)  # Padding applied\n",
        "        x = F.elu(self.conv3(x))  # Convolution output\n",
        "        x = self.batchnorm3(x)\n",
        "        x = F.dropout(x, 0.25)\n",
        "        x = self.pooling3(x)  # Pooling output\n",
        "\n",
        "        # Flatten before the fully connected layer\n",
        "\n",
        "        x = x.view(x.size(0), -1)  # Flatten to match the input size for the fully connected layer\n",
        "        x = self.fc1(x)  # Output through FC layer\n",
        "        x = F.softmax(x, dim=1)  # Apply softmax for multi-class classification\n",
        "        return x\n",
        "\n",
        "# Training and evaluation function with validation\n",
        "def train_and_evaluate(model, train_loader, valid_loader, test_loader, optimizer, criterion, num_epochs, device):\n",
        "    model.to(device)  # Move the model to the specified device\n",
        "    all_predictions = []\n",
        "    all_targets = []\n",
        "\n",
        "    best_valid_loss = float('inf')\n",
        "    patience, patience_counter = 5, 0\n",
        "    best_model_weights = None\n",
        "\n",
        "    import time\n",
        "    for epoch in range(num_epochs):\n",
        "        start_time = time.time()\n",
        "        # Training phase\n",
        "        model.train()\n",
        "        train_loss = 0\n",
        "        for data, target in train_loader:\n",
        "            data = data.to(device)  # Move data to device\n",
        "\n",
        "            data = data.unsqueeze(1)  # Move data to device\n",
        "            target = target.to(device).long()  # Move target to device and convert to LongTensor for CrossEntropyLoss\n",
        "            optimizer.zero_grad()\n",
        "            output = model(data)  # Output shape: [batch_size, 1]\n",
        "            loss = criterion(output, target)\n",
        "            loss.backward()\n",
        "            optimizer.step()\n",
        "            train_loss += loss.item()\n",
        "        train_loss /= len(train_loader)\n",
        "\n",
        "        # Validation phase\n",
        "        model.eval()\n",
        "        valid_loss = 0\n",
        "        with torch.no_grad():\n",
        "            for data, target in valid_loader:\n",
        "                data = data.to(device)\n",
        "                data = data.unsqueeze(1)  # Reshape to [batch, 1, channels, time_steps]\n",
        "                data = data.to(device)\n",
        "                target = target.to(device).long()  # Move target to device and convert to LongTensor for CrossEntropyLoss\n",
        "                output = model(data)\n",
        "                loss = criterion(output, target)\n",
        "                valid_loss += loss.item()\n",
        "        valid_loss /= len(valid_loader)\n",
        "\n",
        "        epoch_time = time.time() - start_time\n",
        "        print(f\"Epoch {epoch + 1}/{num_epochs}, Training Loss: {train_loss:.4f}, Validation Loss: {valid_loss:.4f}, Time: {epoch_time:.2f} seconds\")\n",
        "\n",
        "\n",
        "    # Test phase\n",
        "    model.eval()\n",
        "    correct = 0\n",
        "    total = 0\n",
        "    with torch.no_grad():\n",
        "        for data, target in test_loader:\n",
        "            data = data.to(device)\n",
        "            data = data.unsqueeze(1)  # Reshape to [batch, 1, channels, time_steps]\n",
        "            data = data.to(device)  # Move data to device\n",
        "            target = target.to(device)  # Move target to device\n",
        "            output = model(data)\n",
        "            predictions = output.argmax(dim=1)\n",
        "            correct += (predictions == target).sum().item()\n",
        "            total += target.size(0)\n",
        "            all_predictions.extend(predictions.cpu().numpy())\n",
        "            all_targets.extend(target.cpu().numpy())\n",
        "\n",
        "    test_accuracy = correct / total * 100\n",
        "    kappa = cohen_kappa_score(all_targets, all_predictions)\n",
        "\n",
        "    print(f\"Test Accuracy: {test_accuracy:.2f}%\")\n",
        "    return test_accuracy, kappa\n",
        "\n",
        "# Iterate over all subjects, keeping one subject hidden for testing\n",
        "num_subjects = 9\n",
        "num_epochs = 200\n",
        "accuracies = []\n",
        "kappa_scores = []\n",
        "\n",
        "for test_subject_id in range(1, num_subjects + 1):\n",
        "    print(f\"Testing subject {test_subject_id}, training on all others\")\n",
        "\n",
        "    all_train_sets = []\n",
        "    test_set = None\n",
        "\n",
        "    for subject_id in range(1, num_subjects + 1):\n",
        "        # Load dataset for the current subject\n",
        "        dataset = MOABBDataset(dataset_name=\"BNCI2014_001\", subject_ids=[subject_id])\n",
        "\n",
        "        # Preprocess the data\n",
        "        preprocess(dataset, preprocessors, n_jobs=-1)\n",
        "\n",
        "        # Extract sampling frequency, check that they are same in all datasets\n",
        "        sfreq = dataset.datasets[0].raw.info['sfreq']\n",
        "        assert all([ds.raw.info['sfreq'] == sfreq for ds in dataset.datasets])\n",
        "\n",
        "        # Calculate the trial start offset in samples\n",
        "        trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)\n",
        "\n",
        "        # Create windows using braindecode function\n",
        "        windows_dataset = create_windows_from_events(\n",
        "            dataset,\n",
        "            trial_start_offset_samples=trial_start_offset_samples,\n",
        "            trial_stop_offset_samples=0,\n",
        "            preload=True,\n",
        "        )\n",
        "\n",
        "        # Split dataset into train and test sets by session\n",
        "        splitted = windows_dataset.split('session')\n",
        "        train_set = splitted['0train']  # Training session\n",
        "        valid_set = splitted['1test']  # Evaluation session\n",
        "\n",
        "        # Convert to tensors\n",
        "        x_train_tensor, y_train_tensor = convert_to_tensors(train_set)\n",
        "        x_train_tensor = x_train_tensor  # Correctly reshape to [batch, 1, channels, time_steps]\n",
        "        x_valid_tensor, y_valid_tensor = convert_to_tensors(valid_set)\n",
        "        x_valid_tensor = x_valid_tensor  # Correctly reshape to [batch, 1, channels, time_steps]\n",
        "\n",
        "        # Create training dataset\n",
        "        train_dataset = TensorDataset(x_train_tensor, y_train_tensor)\n",
        "\n",
        "        if subject_id == test_subject_id:\n",
        "            # Use this subject's data for testing\n",
        "            test_set = TensorDataset(torch.cat((x_train_tensor, x_valid_tensor), dim=0),\n",
        "                                     torch.cat((y_train_tensor, y_valid_tensor), dim=0))\n",
        "        else:\n",
        "            # Add this subject's data to the training datasets\n",
        "            all_train_sets.append(train_dataset)\n",
        "\n",
        "    # Combine all training datasets\n",
        "    combined_train_dataset = torch.utils.data.ConcatDataset(all_train_sets)\n",
        "\n",
        "    # Use a portion of the training set for validation (20% validation from training data)\n",
        "    train_size = int(0.9 * len(combined_train_dataset))\n",
        "    valid_size = len(combined_train_dataset) - train_size\n",
        "    train_dataset, valid_dataset = random_split(combined_train_dataset, [train_size, valid_size])\n",
        "\n",
        "    # Create DataLoaders\n",
        "    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\n",
        "    valid_loader = DataLoader(valid_dataset, batch_size=128)  # Use validation set from training data for early stopping\n",
        "    test_loader = DataLoader(test_set, batch_size=128)\n",
        "\n",
        "    # Initialize model, optimizer, and loss function\n",
        "    model = EEGNet()\n",
        "    criterion = nn.CrossEntropyLoss()  # Use cross-entropy loss for multi-class classification\n",
        "    optimizer = optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)  # Added weight decay\n",
        "\n",
        "    # Train and evaluate the model\n",
        "    accuracy, kappa = train_and_evaluate(model, train_loader, valid_loader, test_loader, optimizer, criterion, num_epochs, device)\n",
        "    accuracies.append(accuracy)\n",
        "    kappa_scores.append(kappa)\n",
        "\n",
        "# Calculate and print the average accuracy and kappa score\n",
        "average_accuracy = sum(accuracies) / len(accuracies)\n",
        "std_accuracy = np.std(accuracies)\n",
        "average_kappa = sum(kappa_scores) / len(kappa_scores)\n",
        "std_kappa = np.std(kappa_scores)\n",
        "\n",
        "# Create a summary table\n",
        "summary_table = pd.DataFrame({\n",
        "    'Subject': list(range(1, num_subjects + 1)),\n",
        "    'Accuracy': [acc for acc in accuracies],\n",
        "    'Kappa Score': kappa_scores\n",
        "})\n",
        "summary_table.loc['Average'] = ['Average', average_accuracy, average_kappa]\n",
        "summary_table.loc['Std Dev'] = ['Std Dev', std_accuracy, std_kappa]\n",
        "\n",
        "print(summary_table)"
      ],
      "metadata": {
        "id": "ib2FJVmHp4fK"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import random\n",
        "import time\n",
        "\n",
        "from braindecode.datasets import MOABBDataset\n",
        "from braindecode.preprocessing import (\n",
        "    Preprocessor,\n",
        "    exponential_moving_standardize,\n",
        "    preprocess,\n",
        "    create_windows_from_events,\n",
        ")\n",
        "from mne.decoding import CSP\n",
        "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n",
        "from sklearn.metrics import accuracy_score\n",
        "\n",
        "\n",
        "#CSP+LDA\n",
        "# -----------------------------------------------------------------------------\n",
        "# 0) Reproducibility\n",
        "# -----------------------------------------------------------------------------\n",
        "def set_seed(seed: int):\n",
        "    random.seed(seed)\n",
        "    np.random.seed(seed)\n",
        "\n",
        "set_seed(42)\n",
        "\n",
        "# -----------------------------------------------------------------------------\n",
        "# 1) Pre-processing configuration\n",
        "# -----------------------------------------------------------------------------\n",
        "low_cut_hz, high_cut_hz = 0.01, 38.0\n",
        "factor_new, init_block_size = 1e-3, 1000\n",
        "scale_factor = 1e6\n",
        "trial_offset = -0.5  # seconds\n",
        "\n",
        "preprocessors = [\n",
        "    Preprocessor(\"pick_types\", eeg=True, meg=False, stim=False),\n",
        "    Preprocessor(lambda x: x * scale_factor),\n",
        "    Preprocessor(\"filter\", l_freq=low_cut_hz, h_freq=high_cut_hz),\n",
        "    Preprocessor(\n",
        "        exponential_moving_standardize,\n",
        "        factor_new=factor_new,\n",
        "        init_block_size=init_block_size,\n",
        "    ),\n",
        "]\n",
        "\n",
        "# -----------------------------------------------------------------------------\n",
        "# 2) Helpers: window → numpy, z-score\n",
        "# -----------------------------------------------------------------------------\n",
        "def windows_to_numpy(windows_ds):\n",
        "    \"\"\"\n",
        "    Convert a Braindecode WindowsDataset into numpy arrays:\n",
        "      X shape [n_trials, n_channels, n_times],\n",
        "      y shape [n_trials].\n",
        "    Matches your original iteration to avoid unpacking errors.\n",
        "    \"\"\"\n",
        "    X_list, y_list = [], []\n",
        "    for win in windows_ds:\n",
        "        X_list.append(win[0])  # [C, T]\n",
        "        y_list.append(win[1])\n",
        "    X = np.stack(X_list)\n",
        "    y = np.array(y_list, dtype=np.int64)\n",
        "    return X, y\n",
        "\n",
        "\n",
        "# -----------------------------------------------------------------------------\n",
        "# 3) LOSO + timed CSP + LDA\n",
        "# -----------------------------------------------------------------------------\n",
        "def run_loso_csp_lda_timed(filter_labels=None):\n",
        "    \"\"\"\n",
        "    Leave-One-Subject-Out CV with CSP→LDA, timing each step.\n",
        "    If filter_labels is (1,2), keeps only those and remaps to {0,1}.\n",
        "    \"\"\"\n",
        "    n_subjects = 9\n",
        "    accs = []\n",
        "    timings = {\n",
        "        \"csp_fit_transform\": [],\n",
        "        \"csp_transform\":       [],\n",
        "        \"lda_fit\":             [],\n",
        "        \"lda_predict\":         [],\n",
        "    }\n",
        "\n",
        "    for test_subj in range(1, n_subjects + 1):\n",
        "        X_train_list, y_train_list = [], []\n",
        "        X_test, y_test = None, None\n",
        "\n",
        "        # Load and split each subject\n",
        "        for subj in range(1, n_subjects + 1):\n",
        "            ds = MOABBDataset(\"BNCI2014_001\", [subj])\n",
        "            preprocess(ds, preprocessors, n_jobs=-1)\n",
        "            sfreq = ds.datasets[0].raw.info[\"sfreq\"]\n",
        "            off = int(trial_offset * sfreq)\n",
        "            wins = create_windows_from_events(\n",
        "                ds,\n",
        "                trial_start_offset_samples=off,\n",
        "                trial_stop_offset_samples=0,\n",
        "                preload=True,\n",
        "            )\n",
        "            parts = wins.split(\"session\")\n",
        "\n",
        "            # Convert windows → numpy\n",
        "            X_tr, y_tr = windows_to_numpy(parts[\"0train\"])\n",
        "            X_te, y_te = windows_to_numpy(parts[\"1test\"])\n",
        "\n",
        "            # Optional binary filter/remap\n",
        "            if filter_labels is not None:\n",
        "                mask_tr = np.isin(y_tr, filter_labels)\n",
        "                mask_te = np.isin(y_te, filter_labels)\n",
        "                X_tr, y_tr = X_tr[mask_tr], y_tr[mask_tr]\n",
        "                X_te, y_te = X_te[mask_te], y_te[mask_te]\n",
        "                y_tr = np.array([filter_labels.index(lbl) for lbl in y_tr])\n",
        "                y_te = np.array([filter_labels.index(lbl) for lbl in y_te])\n",
        "\n",
        "            if subj == test_subj:\n",
        "                X_test, y_test = X_te, y_te\n",
        "            else:\n",
        "                X_train_list.append(X_tr)\n",
        "                y_train_list.append(y_tr)\n",
        "\n",
        "        # Concatenate training data\n",
        "        X_train = np.concatenate(X_train_list, axis=0)\n",
        "        y_train = np.concatenate(y_train_list, axis=0)\n",
        "\n",
        "\n",
        "\n",
        "        # Initialize CSP & LDA\n",
        "        csp = CSP(n_components=min(8, X_train.shape[1]), reg=None, log=True)\n",
        "        lda = LinearDiscriminantAnalysis()\n",
        "\n",
        "        # Time CSP fit_transform on training\n",
        "        t0 = time.perf_counter()\n",
        "        X_train_feat = csp.fit_transform(X_train, y_train)\n",
        "        t1 = time.perf_counter()\n",
        "        timings[\"csp_fit_transform\"].append(t1 - t0)\n",
        "\n",
        "        # Time CSP transform on test\n",
        "        t0 = time.perf_counter()\n",
        "        X_test_feat = csp.transform(X_test)\n",
        "        t1 = time.perf_counter()\n",
        "        timings[\"csp_transform\"].append(t1 - t0)\n",
        "\n",
        "        # Time LDA fit\n",
        "        t0 = time.perf_counter()\n",
        "        lda.fit(X_train_feat, y_train)\n",
        "        t1 = time.perf_counter()\n",
        "        timings[\"lda_fit\"].append(t1 - t0)\n",
        "\n",
        "        # Time LDA predict\n",
        "        t0 = time.perf_counter()\n",
        "        y_pred = lda.predict(X_test_feat)\n",
        "        t1 = time.perf_counter()\n",
        "        timings[\"lda_predict\"].append(t1 - t0)\n",
        "\n",
        "        # Compute accuracy\n",
        "        acc = accuracy_score(y_test, y_pred) * 100\n",
        "        print(f\"[Held-out subj {test_subj}] Acc: {acc:.2f}%\")\n",
        "        accs.append(acc)\n",
        "\n",
        "    # Summary\n",
        "    print(\"\\n=== Accuracy ===\")\n",
        "    print(f\"Mean Acc = {np.mean(accs):.2f}%  Std = {np.std(accs):.2f}%\")\n",
        "\n",
        "    print(\"\\n=== Timings (per fold) ===\")\n",
        "    for step, times in timings.items():\n",
        "        arr = np.array(times)\n",
        "        print(f\"{step:20s}: {arr.tolist()}  → avg {arr.mean():.4f}s\")\n",
        "\n",
        "    return accs, timings\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    print(\"=== 4-class CSP + LDA ===\")\n",
        "    run_loso_csp_lda_timed(filter_labels=None)\n",
        "\n",
        "    print(\"\\n=== 2-class CSP + LDA (Right vs Left) ===\")\n",
        "    run_loso_csp_lda_timed(filter_labels=(1, 2))\n"
      ],
      "metadata": {
        "id": "QnDZ7-GkqQr_"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "\n",
        "#PHYSIONET LOAD\n",
        "#DOWNLOAD PHYSIONET DATA AND REPLACE\n",
        "\n",
        "# 1) Mount Google Drive\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# 2) Install necessary packages (if not already installed)\n",
        "!pip install mne pandas\n",
        "\n",
        "import os\n",
        "import mne\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "# 3) Define the main path where all participant folders are stored\n",
        "main_folder_path = '/content/drive/My Drive/Physionet MI/files/'\n",
        "\n",
        "# 4) Generate participant IDs (001 to 105) and exclude faulty ones\n",
        "excluded_participants = {'088', '089', '092', '100'}\n",
        "participants = [f\"{i:03d}\" for i in range(1, 110) if f\"{i:03d}\" not in excluded_participants]\n",
        "\n",
        "print(f\"✅ Loading data for {len(participants)} valid participants.\")\n",
        "\n",
        "# 5) Dictionary to store all data\n",
        "data_dict = {}\n",
        "\n",
        "# 6) Loop through each participant\n",
        "for participant in participants:\n",
        "    participant_path = os.path.join(main_folder_path, f\"S{participant}\")\n",
        "\n",
        "    # Ensure participant folder exists\n",
        "    if not os.path.exists(participant_path):\n",
        "        print(f\"⚠️ Missing folder for S{participant}, skipping...\")\n",
        "        continue\n",
        "\n",
        "    # List EDF files (make sure the order is correct)\n",
        "    edf_files = sorted([f for f in os.listdir(participant_path) if f.lower().endswith('.edf')])\n",
        "\n",
        "    # Only select relevant motor imagery runs (R04, R08, R12)\n",
        "    imagery_runs = [f for f in edf_files if 'R04' in f or 'R08' in f or 'R12' in f]\n",
        "\n",
        "    if not imagery_runs:\n",
        "        print(f\"⚠️ No valid files for S{participant}, skipping...\")\n",
        "        continue\n",
        "\n",
        "    print(f\"\\n📥 Processing participant S{participant} ({len(imagery_runs)} files)\")\n",
        "\n",
        "    # Lists to store participant's trials\n",
        "    X_data = []\n",
        "    y_labels = []\n",
        "\n",
        "    # 7) Process each valid file\n",
        "    for edf_file in imagery_runs:\n",
        "        edf_path = os.path.join(participant_path, edf_file)\n",
        "\n",
        "        # Load EDF data\n",
        "        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)\n",
        "\n",
        "        # Extract events\n",
        "        events, event_id = mne.events_from_annotations(raw, verbose=False)\n",
        "\n",
        "        # Dynamically find T1 and T2 event IDs\n",
        "        t1_event_id = next((code for label, code in event_id.items() if 'T1' in label), None)\n",
        "        t2_event_id = next((code for label, code in event_id.items() if 'T2' in label), None)\n",
        "\n",
        "        if t1_event_id is None or t2_event_id is None:\n",
        "            print(f\"⚠️ No T1/T2 events in {edf_file}, skipping...\")\n",
        "            continue\n",
        "\n",
        "        # Extract only relevant events\n",
        "        target_events = np.array([e for e in events if e[2] in [t1_event_id, t2_event_id]])\n",
        "\n",
        "        if len(target_events) == 0:\n",
        "            print(f\"⚠️ No T1/T2 trials found in {edf_file}, skipping...\")\n",
        "            continue\n",
        "\n",
        "        print(f\"✔️ Extracting {len(target_events)} T1/T2 events from {edf_file}\")\n",
        "\n",
        "        # Define trial duration (3.1s * 160Hz = 496 samples)\n",
        "        epochs = mne.Epochs(raw, target_events, event_id={'T1': t1_event_id, 'T2': t2_event_id},\n",
        "                             tmin=0, tmax=3.1, baseline=None, preload=True)\n",
        "\n",
        "        # Convert EEG data to NumPy array\n",
        "        data = epochs.get_data()  # Shape: (num_trials, channels, 496)\n",
        "        labels = np.array([0 if e[2] == t1_event_id else 1 for e in target_events])\n",
        "\n",
        "        # Store data\n",
        "        X_data.append(data)\n",
        "        y_labels.append(labels)\n",
        "\n",
        "    # 8) Convert participant data to NumPy arrays\n",
        "    if X_data:\n",
        "        X_data = np.concatenate(X_data, axis=0)  # Shape: (num_trials, channels, 496)\n",
        "        y_labels = np.concatenate(y_labels, axis=0)  # Shape: (num_trials,)\n",
        "\n",
        "        # Store in dictionary\n",
        "        data_dict[f\"S{participant}\"] = {\"X\": X_data, \"y\": y_labels}\n",
        "\n",
        "        print(f\"✅ Stored {X_data.shape[0]} trials for S{participant}\")\n",
        "\n",
        "# 9) Convert to a DataFrame\n",
        "df_list = []\n",
        "for participant, data in data_dict.items():\n",
        "    trials = data[\"X\"].shape[0]\n",
        "    df_list.append({\"Participant\": participant, \"Trials\": trials, \"Data Shape\": data[\"X\"].shape})\n",
        "\n",
        "df = pd.DataFrame(df_list)\n",
        "\n",
        "save_path = \"/content/drive/My Drive/Physionet MI/all_participants_data.npz\"\n",
        "np.savez(save_path, **data_dict)\n",
        "print(f\"\\n📁 Saved all data to {save_path}\")\n",
        "\n"
      ],
      "metadata": {
        "id": "E-ZV9OyBqRBs"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#GAT PHYSIONET\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "import random\n",
        "import copy\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(\"▶ Using device:\", device)\n",
        "\n",
        "SEED = 321\n",
        "random.seed(SEED)\n",
        "np.random.seed(SEED)\n",
        "torch.manual_seed(SEED)\n",
        "if device.type == \"cuda\":\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "\n",
        "############################################################################\n",
        "# 1) Dataset helpers  (shape:  (N , C , T)  throughout)\n",
        "############################################################################\n",
        "class EEGDataset(torch.utils.data.Dataset):\n",
        "    \"\"\"\n",
        "    Expects\n",
        "      X : (num_samples , C , T)  numpy or torch.float\n",
        "      y : (num_samples,)         numpy or torch.long\n",
        "    \"\"\"\n",
        "    def __init__(self, X, y):\n",
        "        if isinstance(X, np.ndarray):\n",
        "            X = torch.tensor(X, dtype=torch.float)\n",
        "        if isinstance(y, np.ndarray):\n",
        "            y = torch.tensor(y, dtype=torch.long)\n",
        "        self.X, self.y = X, y\n",
        "\n",
        "    def __len__(self):\n",
        "        return self.X.shape[0]\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return self.X[idx], self.y[idx]\n",
        "\n",
        "\n",
        "def build_dataset_for_subjects(subject_list, data_dict):\n",
        "    \"\"\"Concatenate many subjects → single (X, y).\"\"\"\n",
        "    X = np.concatenate([data_dict[s][\"X\"] for s in subject_list], axis=0)\n",
        "    y = np.concatenate([data_dict[s][\"y\"] for s in subject_list], axis=0)\n",
        "    return X, y\n",
        "\n",
        "############################################################################\n",
        "# 2) GAT layer & model on a fully‐connected graph\n",
        "############################################################################\n",
        "class GATLayer(nn.Module):\n",
        "    \"\"\"\n",
        "    One GAT layer: for each node i,\n",
        "      h_i' = ELU( ∑_j α_{ij} W h_j )\n",
        "    where α_{ij} = softmax_j(LeakyReLU(a^T [W h_i || W h_j])).\n",
        "    \"\"\"\n",
        "    def __init__(self, in_feats, out_feats, dropout=0.5, alpha=0.2, concat=True):\n",
        "        super().__init__()\n",
        "        self.in_feats = in_feats\n",
        "        self.out_feats = out_feats\n",
        "        self.dropout = dropout\n",
        "        self.concat = concat\n",
        "\n",
        "        # linear transformation\n",
        "        self.W = nn.Linear(in_feats, out_feats, bias=False)\n",
        "        # attention mechanism\n",
        "        self.a = nn.Parameter(torch.empty(size=(2*out_feats, 1)))\n",
        "        nn.init.xavier_uniform_(self.a.data, gain=1.414)\n",
        "\n",
        "        self.leakyrelu = nn.LeakyReLU(alpha)\n",
        "\n",
        "    def forward(self, h, adj):\n",
        "        \"\"\"\n",
        "        h: (B, C, in_feats)\n",
        "        adj: (C, C) with 1s where edges exist (here all ones)\n",
        "        returns: (B, C, out_feats)\n",
        "        \"\"\"\n",
        "        B, C, _ = h.size()\n",
        "        Wh = self.W(h)                     # (B, C, out_feats)\n",
        "\n",
        "        # prepare for attention: Wh_i and Wh_j\n",
        "        Wh_i = Wh.unsqueeze(2).repeat(1, 1, C, 1)   # (B, C, C, out_feats)\n",
        "        Wh_j = Wh.unsqueeze(1).repeat(1, C, 1, 1)   # (B, C, C, out_feats)\n",
        "        # compute e_ij\n",
        "        e = self.leakyrelu(\n",
        "            torch.matmul(torch.cat([Wh_i, Wh_j], dim=-1), self.a).squeeze(-1)\n",
        "        )  # (B, C, C)\n",
        "\n",
        "        # mask with adjacency: set e_ij to -inf where adj=0\n",
        "        zero_vec = -9e15 * torch.ones_like(e)\n",
        "        attention = torch.where(adj.unsqueeze(0) > 0, e, zero_vec)\n",
        "        # softmax\n",
        "        attention = F.softmax(attention, dim=-1)\n",
        "        attention = F.dropout(attention, self.dropout, training=self.training)\n",
        "\n",
        "        # aggregate\n",
        "        h_prime = torch.matmul(attention, Wh)  # (B, C, out_feats)\n",
        "\n",
        "        if self.concat:\n",
        "            return F.elu(h_prime)\n",
        "        else:\n",
        "            return h_prime\n",
        "\n",
        "class GATNetwork(nn.Module):\n",
        "    \"\"\"\n",
        "    Multi‐layer GAT on a fully‐connected C‐node graph, node‐features of dim T,\n",
        "    followed by an MLP classifier.\n",
        "    \"\"\"\n",
        "    def __init__(self, C, T,\n",
        "                 hidden_dim=64,\n",
        "                 num_layers=2,\n",
        "                 num_classes=2,\n",
        "                 dropout_p=0.5,\n",
        "                 alpha=0.2):\n",
        "        super().__init__()\n",
        "        self.C, self.T = C, T\n",
        "\n",
        "        # fully‐connected adjacency (including self‐loops)\n",
        "        A = torch.ones(C, C)\n",
        "        self.register_buffer(\"A\", A)\n",
        "\n",
        "        # build GAT layers\n",
        "        layers = []\n",
        "        # first layer: in_feats=T → hidden_dim\n",
        "        layers.append(GATLayer(in_feats=T,\n",
        "                               out_feats=hidden_dim,\n",
        "                               dropout=dropout_p,\n",
        "                               alpha=alpha,\n",
        "                               concat=True))\n",
        "        # subsequent layers: hidden_dim → hidden_dim\n",
        "        for _ in range(num_layers - 1):\n",
        "            # last layer: we can set concat=False if desired, but we'll keep concat=True\n",
        "            layers.append(GATLayer(in_feats=hidden_dim,\n",
        "                                   out_feats=hidden_dim,\n",
        "                                   dropout=dropout_p,\n",
        "                                   alpha=alpha,\n",
        "                                   concat=True))\n",
        "        self.layers = nn.ModuleList(layers)\n",
        "\n",
        "        # classifier: flatten C nodes × hidden_dim each → C*hidden_dim\n",
        "        clf_in = C * hidden_dim\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Linear(clf_in, 128),\n",
        "            nn.ReLU(),\n",
        "            nn.Dropout(dropout_p),\n",
        "            nn.Linear(128, num_classes)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        # x: (B, C, T)\n",
        "        h = x\n",
        "        for layer in self.layers:\n",
        "            h = layer(h, self.A)   # (B, C, hidden_dim)\n",
        "\n",
        "        # flatten node embeddings\n",
        "        h = h.view(h.size(0), -1)  # (B, C*hidden_dim)\n",
        "        return self.classifier(h)   # (B, num_classes)\n",
        "\n",
        "\n",
        "############################################################################\n",
        "# 3) Training / evaluation helpers\n",
        "############################################################################\n",
        "def train_one_epoch(model, loader, optimizer, criterion):\n",
        "    model.train()\n",
        "    running_loss = 0.0\n",
        "    for Xb, yb in loader:\n",
        "        Xb, yb = Xb.to(device), yb.to(device)\n",
        "        optimizer.zero_grad()\n",
        "        logits = model(Xb)\n",
        "        loss = criterion(logits, yb)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        running_loss += loss.item() * Xb.size(0)\n",
        "    return running_loss / len(loader.dataset)\n",
        "\n",
        "@torch.no_grad()\n",
        "def evaluate_loss(model, loader, criterion):\n",
        "    model.eval()\n",
        "    total_loss = 0.0\n",
        "    for Xb, yb in loader:\n",
        "        Xb, yb = Xb.to(device), yb.to(device)\n",
        "        logits = model(Xb)\n",
        "        loss = criterion(logits, yb)\n",
        "        total_loss += loss.item() * Xb.size(0)\n",
        "    return total_loss / len(loader.dataset)\n",
        "\n",
        "@torch.no_grad()\n",
        "def evaluate_accuracy(model, loader):\n",
        "    model.eval()\n",
        "    correct = total = 0\n",
        "    for Xb, yb in loader:\n",
        "        Xb, yb = Xb.to(device), yb.to(device)\n",
        "        preds = model(Xb).argmax(dim=1)\n",
        "        correct += (preds == yb).sum().item()\n",
        "        total   += yb.size(0)\n",
        "    return correct / total if total else 0.0\n",
        "\n",
        "############################################################################\n",
        "# 4) k‐fold cross‐validation driver\n",
        "############################################################################\n",
        "def run_cross_validation_gat(\n",
        "    data_dict,\n",
        "    num_folds      = 10,\n",
        "    val_split      = 0.2,\n",
        "    epochs         = 20,\n",
        "    batch_size     = 64,\n",
        "    learning_rate  = 1e-3,\n",
        "    hidden_dim     = 64,\n",
        "    num_layers     = 2,\n",
        "    num_classes    = 2,\n",
        "    dropout_p      = 0.5,\n",
        "    alpha          = 0.2\n",
        "):\n",
        "    participants = list(data_dict.keys())\n",
        "    random.shuffle(participants)\n",
        "    total_subj   = len(participants)\n",
        "    fold_sz      = total_subj // num_folds\n",
        "\n",
        "    # split into folds\n",
        "    folds = [participants[i*fold_sz:(i+1)*fold_sz]\n",
        "             for i in range(num_folds-1)]\n",
        "    folds.append(participants[(num_folds-1)*fold_sz:])\n",
        "\n",
        "    all_test_acc = []\n",
        "\n",
        "    for k, test_subj in enumerate(folds, 1):\n",
        "        print(f\"\\n=== Fold {k}/{num_folds} ===\")\n",
        "        train_subj = [s for s in participants if s not in test_subj]\n",
        "\n",
        "        # assemble train/val\n",
        "        X_train, y_train = build_dataset_for_subjects(train_subj, data_dict)\n",
        "        full_ds = EEGDataset(X_train, y_train)\n",
        "        n_tot   = len(full_ds)\n",
        "        n_val   = int(val_split * n_tot)\n",
        "        train_ds, val_ds = torch.utils.data.random_split(\n",
        "            full_ds, [n_tot - n_val, n_val]\n",
        "        )\n",
        "        train_ld = torch.utils.data.DataLoader(\n",
        "            train_ds, batch_size=batch_size, shuffle=True, drop_last=True\n",
        "        )\n",
        "        val_ld   = torch.utils.data.DataLoader(\n",
        "            val_ds,   batch_size=batch_size, shuffle=False\n",
        "        )\n",
        "\n",
        "        # assemble test\n",
        "        X_test, y_test = build_dataset_for_subjects(test_subj, data_dict)\n",
        "        test_ld = torch.utils.data.DataLoader(\n",
        "            EEGDataset(X_test, y_test),\n",
        "            batch_size=batch_size, shuffle=False\n",
        "        )\n",
        "\n",
        "        C, T = X_train.shape[1], X_train.shape[2]\n",
        "        model = GATNetwork(\n",
        "            C=C, T=T,\n",
        "            hidden_dim=hidden_dim,\n",
        "            num_layers=num_layers,\n",
        "            num_classes=num_classes,\n",
        "            dropout_p=dropout_p,\n",
        "            alpha=alpha\n",
        "        ).to(device)\n",
        "\n",
        "        optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
        "        criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "        best_val_loss = float(\"inf\")\n",
        "        best_state    = None\n",
        "\n",
        "        for ep in range(1, epochs + 1):\n",
        "            tr_loss = train_one_epoch(model, train_ld, optimizer, criterion)\n",
        "            vl_loss = evaluate_loss(model, val_ld, criterion)\n",
        "            print(f\"[Fold {k}] Epoch {ep}/{epochs} | \"\n",
        "                  f\"train {tr_loss:.4f} | val {vl_loss:.4f}\")\n",
        "            if vl_loss < best_val_loss:\n",
        "                best_val_loss = vl_loss\n",
        "                best_state    = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        # test\n",
        "        model.load_state_dict(best_state)\n",
        "        tst_acc = evaluate_accuracy(model, test_ld)\n",
        "        print(f\"[Fold {k}] best val loss {best_val_loss:.4f} | \"\n",
        "              f\"test acc {100*tst_acc:.2f}%\")\n",
        "        all_test_acc.append(tst_acc)\n",
        "\n",
        "    mean_acc = 100 * np.mean(all_test_acc)\n",
        "    std_acc  = 100 * np.std(all_test_acc)\n",
        "    print(f\"\\n✅ {num_folds}-fold average accuracy: {mean_acc:.2f}%  (±{std_acc:.2f})\")\n",
        "    return np.mean(all_test_acc)\n",
        "\n",
        "# Example usage:\n",
        "final_acc = run_cross_validation_gat(\n",
        "     data_dict      = data_dict,\n",
        "     num_folds      = 10,\n",
        "     val_split      = 0.1,\n",
        "     epochs         = 20,\n",
        "     batch_size     = 64,\n",
        "     learning_rate  = 1e-3,\n",
        "     hidden_dim     = 64,\n",
        "     num_layers     = 2,\n",
        "     num_classes    = 2,\n",
        "     dropout_p      = 0.5,\n",
        "     alpha          = 0.2\n",
        " )\n"
      ],
      "metadata": {
        "id": "Xh7sFyK9tgKx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "   ############################################################################\n",
        "# 0) Imports & reproducibility\n",
        "############################################################################\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "import random\n",
        "import copy\n",
        "\n",
        "#GIN PHYSIONET\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(\"▶ Using device:\", device)\n",
        "\n",
        "SEED = 321\n",
        "random.seed(SEED)\n",
        "np.random.seed(SEED)\n",
        "torch.manual_seed(SEED)\n",
        "if device.type == \"cuda\":\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "\n",
        "############################################################################\n",
        "# 1) Dataset helpers  (shape:  (N , C , T)  throughout)\n",
        "############################################################################\n",
        "class EEGDataset(torch.utils.data.Dataset):\n",
        "    \"\"\"\n",
        "    Expects\n",
        "      X : (num_samples , C , T)  numpy or torch.float\n",
        "      y : (num_samples,)         numpy or torch.long\n",
        "    \"\"\"\n",
        "    def __init__(self, X, y):\n",
        "        if isinstance(X, np.ndarray):\n",
        "            X = torch.tensor(X, dtype=torch.float)\n",
        "        if isinstance(y, np.ndarray):\n",
        "            y = torch.tensor(y, dtype=torch.long)\n",
        "        self.X, self.y = X, y\n",
        "\n",
        "    def __len__(self):\n",
        "        return self.X.shape[0]\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return self.X[idx], self.y[idx]\n",
        "\n",
        "\n",
        "def build_dataset_for_subjects(subject_list, data_dict):\n",
        "    \"\"\"Concatenate many subjects → single (X, y).\"\"\"\n",
        "    X = np.concatenate([data_dict[s][\"X\"] for s in subject_list], axis=0)\n",
        "    y = np.concatenate([data_dict[s][\"y\"] for s in subject_list], axis=0)\n",
        "    return X, y\n",
        "\n",
        "############################################################################\n",
        "# 2) GIN layer & model on a fully‐connected graph\n",
        "############################################################################\n",
        "class GINLayer(nn.Module):\n",
        "    def __init__(self, in_feats, out_feats, eps=0.0):\n",
        "        super().__init__()\n",
        "        # initialize eps so that (1+eps)=1 at start\n",
        "        self.eps = nn.Parameter(torch.tensor(eps))\n",
        "        self.mlp = nn.Sequential(\n",
        "            nn.Linear(in_feats, out_feats),\n",
        "            nn.ReLU(),\n",
        "            nn.Linear(out_feats, out_feats),\n",
        "        )\n",
        "\n",
        "    def forward(self, h, A):\n",
        "        # h: (B, C, F); A: (C, C)\n",
        "        # OPTION 1: normalize A so rows sum to 1 -> average‐neighbor aggregator\n",
        "        #     D = A.sum(dim=1, keepdim=True)            # (C,1)\n",
        "        #     A_norm = A / D                             # (C,C)\n",
        "        # OPTION 2: stick with sum but be aware of scale\n",
        "        A_norm = A / A.size(0)                         # simple global normalization\n",
        "\n",
        "        # aggregate neighbors\n",
        "        agg = torch.matmul(A_norm.unsqueeze(0), h)     # (B, C, F)\n",
        "\n",
        "        # include the identity skip properly\n",
        "        out = (1.0 + self.eps) * h  + agg                # (B, C, F)\n",
        "\n",
        "        return self.mlp(out)                           # (B, C, out_feats)\n",
        "                  # (B, C, out_feats)\n",
        "\n",
        "class GINNetwork(nn.Module):\n",
        "    \"\"\"\n",
        "    Multi‐layer GIN on a fully‐connected C‐node graph, node‐features of dim T,\n",
        "    followed by an MLP classifier.\n",
        "    \"\"\"\n",
        "    def __init__(self, C, T,\n",
        "                 hidden_dim=64,\n",
        "                 num_layers=3,\n",
        "                 num_classes=2,\n",
        "                 dropout_p=0.5):\n",
        "        super().__init__()\n",
        "        self.C, self.T = C, T\n",
        "\n",
        "        # fully‐connected adjacency\n",
        "        A = torch.ones(C, C)\n",
        "        self.register_buffer(\"A\", A)\n",
        "\n",
        "        # build GIN layers\n",
        "        layers = []\n",
        "        # first layer: in_feats=T → hidden_dim\n",
        "        layers.append(GINLayer(in_feats=T, out_feats=hidden_dim))\n",
        "        # subsequent layers: hidden_dim → hidden_dim\n",
        "        for _ in range(num_layers - 1):\n",
        "            layers.append(GINLayer(in_feats=hidden_dim, out_feats=hidden_dim))\n",
        "        self.layers = nn.ModuleList(layers)\n",
        "\n",
        "        # classifier: aggregate node embeddings and classify\n",
        "        # flatten C nodes × hidden_dim each → C*hidden_dim\n",
        "        clf_in = C * hidden_dim\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Linear(clf_in, 128),\n",
        "            nn.ReLU(),\n",
        "            nn.Dropout(dropout_p),\n",
        "            nn.Linear(128, num_classes)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        # x: (B, C, T) node‐features\n",
        "        h = x\n",
        "        for layer in self.layers:\n",
        "            h = layer(h, self.A)   # (B, C, hidden_dim)\n",
        "\n",
        "        # flatten all node embeddings\n",
        "        h = h.view(h.size(0), -1)  # (B, C*hidden_dim)\n",
        "        return self.classifier(h)   # (B, num_classes)\n",
        "\n",
        "\n",
        "############################################################################\n",
        "# 3) Training / evaluation helpers\n",
        "############################################################################\n",
        "def train_one_epoch(model, loader, optimizer, criterion):\n",
        "    model.train()\n",
        "    running_loss = 0.0\n",
        "    for Xb, yb in loader:\n",
        "        Xb, yb = Xb.to(device), yb.to(device)\n",
        "        optimizer.zero_grad()\n",
        "        logits = model(Xb)\n",
        "        loss = criterion(logits, yb)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        running_loss += loss.item() * Xb.size(0)\n",
        "    return running_loss / len(loader.dataset)\n",
        "\n",
        "@torch.no_grad()\n",
        "def evaluate_loss(model, loader, criterion):\n",
        "    model.eval()\n",
        "    total_loss = 0.0\n",
        "    for Xb, yb in loader:\n",
        "        Xb, yb = Xb.to(device), yb.to(device)\n",
        "        logits = model(Xb)\n",
        "        loss = criterion(logits, yb)\n",
        "        total_loss += loss.item() * Xb.size(0)\n",
        "    return total_loss / len(loader.dataset)\n",
        "\n",
        "@torch.no_grad()\n",
        "def evaluate_accuracy(model, loader):\n",
        "    model.eval()\n",
        "    correct = total = 0\n",
        "    for Xb, yb in loader:\n",
        "        Xb, yb = Xb.to(device), yb.to(device)\n",
        "        preds = model(Xb).argmax(dim=1)\n",
        "        correct += (preds == yb).sum().item()\n",
        "        total   += yb.size(0)\n",
        "    return correct / total if total else 0.0\n",
        "\n",
        "############################################################################\n",
        "# 4) k‐fold cross‐validation driver\n",
        "############################################################################\n",
        "def run_cross_validation_gin(\n",
        "    data_dict,\n",
        "    num_folds      = 10,\n",
        "    val_split      = 0.2,\n",
        "    epochs         = 20,\n",
        "    batch_size     = 64,\n",
        "    learning_rate  = 1e-3,\n",
        "    hidden_dim     = 64,\n",
        "    num_layers     = 3,\n",
        "    num_classes    = 2\n",
        "):\n",
        "    participants = list(data_dict.keys())\n",
        "    random.shuffle(participants)\n",
        "    total_subj   = len(participants)\n",
        "    fold_sz      = total_subj // num_folds\n",
        "\n",
        "    # split into folds\n",
        "    folds = [participants[i*fold_sz:(i+1)*fold_sz]\n",
        "             for i in range(num_folds-1)]\n",
        "    folds.append(participants[(num_folds-1)*fold_sz:])\n",
        "\n",
        "    all_test_acc = []\n",
        "\n",
        "    for k, test_subj in enumerate(folds, 1):\n",
        "        print(f\"\\n=== Fold {k}/{num_folds} ===\")\n",
        "        train_subj = [s for s in participants if s not in test_subj]\n",
        "\n",
        "        # build train / val sets\n",
        "        X_train, y_train = build_dataset_for_subjects(train_subj, data_dict)\n",
        "        full_ds = EEGDataset(X_train, y_train)\n",
        "        n_tot   = len(full_ds)\n",
        "        n_val   = int(val_split * n_tot)\n",
        "        train_ds, val_ds = torch.utils.data.random_split(\n",
        "            full_ds, [n_tot - n_val, n_val]\n",
        "        )\n",
        "        train_ld = torch.utils.data.DataLoader(\n",
        "            train_ds, batch_size=batch_size, shuffle=True, drop_last=True\n",
        "        )\n",
        "        val_ld = torch.utils.data.DataLoader(\n",
        "            val_ds, batch_size=batch_size, shuffle=False\n",
        "        )\n",
        "\n",
        "        # build test set\n",
        "        X_test, y_test = build_dataset_for_subjects(test_subj, data_dict)\n",
        "        test_ld = torch.utils.data.DataLoader(\n",
        "            EEGDataset(X_test, y_test),\n",
        "            batch_size=batch_size, shuffle=False\n",
        "        )\n",
        "\n",
        "        C, T = X_train.shape[1], X_train.shape[2]\n",
        "        model = GINNetwork(\n",
        "            C=C, T=T,\n",
        "            hidden_dim=hidden_dim,\n",
        "            num_layers=num_layers,\n",
        "            num_classes=num_classes\n",
        "        ).to(device)\n",
        "\n",
        "        optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
        "        criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "        best_val_loss = float(\"inf\")\n",
        "        best_state    = None\n",
        "\n",
        "        for ep in range(1, epochs + 1):\n",
        "            tr_loss = train_one_epoch(model, train_ld, optimizer, criterion)\n",
        "            vl_loss = evaluate_loss(model, val_ld, criterion)\n",
        "            print(f\"[Fold {k}] Epoch {ep}/{epochs} | \"\n",
        "                  f\"train {tr_loss:.4f} | val {vl_loss:.4f}\")\n",
        "            if vl_loss < best_val_loss:\n",
        "                best_val_loss = vl_loss\n",
        "                best_state    = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        # test\n",
        "        model.load_state_dict(best_state)\n",
        "        tst_acc = evaluate_accuracy(model, test_ld)\n",
        "        print(f\"[Fold {k}] best val loss {best_val_loss:.4f} | \"\n",
        "              f\"test acc {100*tst_acc:.2f}%\")\n",
        "        all_test_acc.append(tst_acc)\n",
        "\n",
        "    mean_acc = 100*np.mean(all_test_acc)\n",
        "    std_acc  = 100*np.std(all_test_acc)\n",
        "    print(f\"\\n✅ {num_folds}-fold average accuracy: {mean_acc:.2f}%  (±{std_acc:.2f})\")\n",
        "    return np.mean(all_test_acc)\n",
        "\n",
        "# Example usage:\n",
        "final_acc = run_cross_validation_gin(\n",
        "     data_dict      = data_dict,\n",
        "     num_folds      = 10,\n",
        "     val_split      = 0.1,\n",
        "     epochs         = 20,\n",
        "     batch_size     = 64,\n",
        "     learning_rate  = 1e-3,\n",
        "     hidden_dim     = 64,\n",
        "     num_layers     = 1,\n",
        "     num_classes    = 2\n",
        " )\n"
      ],
      "metadata": {
        "id": "t0tSu4Act65f"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#CDNN Physionet\n",
        "\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.nn.functional as F\n",
        "import numpy as np\n",
        "import random\n",
        "import copy\n",
        "from typing import Sequence, Dict, Any\n",
        "\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "print(\"▶ Using device:\", device)\n",
        "\n",
        "SEED = 321\n",
        "random.seed(SEED)\n",
        "np.random.seed(SEED)\n",
        "torch.manual_seed(SEED)\n",
        "if device.type == \"cuda\":\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "\n",
        "############################################################################\n",
        "# 1) Dataset helpers  (shape:  (N , C , T)  throughout)\n",
        "############################################################################\n",
        "class EEGDataset(torch.utils.data.Dataset):\n",
        "    \"\"\"\n",
        "    Expects\n",
        "      X : (num_samples , C , T)  numpy or torch.float\n",
        "      y : (num_samples,)         numpy or torch.long\n",
        "    \"\"\"\n",
        "    def __init__(self, X, y):\n",
        "        if isinstance(X, np.ndarray):\n",
        "            X = torch.tensor(X, dtype=torch.float)\n",
        "        if isinstance(y, np.ndarray):\n",
        "            y = torch.tensor(y, dtype=torch.long)\n",
        "        self.X, self.y = X, y\n",
        "\n",
        "    def __len__(self) -> int:\n",
        "        return self.X.shape[0]\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "        return self.X[idx], self.y[idx]\n",
        "\n",
        "\n",
        "def build_dataset_for_subjects(subject_list: Sequence[Any],\n",
        "                               data_dict: Dict[Any, Dict[str, np.ndarray]]):\n",
        "    \"\"\"Concatenate many subjects -> single (X , y).\"\"\"\n",
        "    X = np.concatenate([data_dict[s][\"X\"] for s in subject_list], axis=0)\n",
        "    y = np.concatenate([data_dict[s][\"y\"] for s in subject_list], axis=0)\n",
        "    return X, y\n",
        "\n",
        "############################################################################\n",
        "# 2) Global covariance (graph Laplacian-style) utility\n",
        "############################################################################\n",
        "@torch.no_grad()\n",
        "def compute_global_covariance(x_tensor: torch.Tensor) -> torch.Tensor:\n",
        "    \"\"\"\n",
        "    x_tensor : (N, C, T) – training windows.\n",
        "    Returns   : (C, C)   – unit-norm covariance matrix Σ̂.\n",
        "    \"\"\"\n",
        "    N, C, T = x_tensor.shape\n",
        "    # time-centre every window\n",
        "    x_centered = x_tensor - x_tensor.mean(dim=2, keepdim=True)      # (N,C,T)\n",
        "    # accumulate   Σ = Σₙ Σₜ X̃ₙ[:,t] X̃ₙ[:,t]ᵀ\n",
        "    cov = torch.einsum(\"nct,nkt->ck\", x_centered, x_centered) / (N * T)\n",
        "    return cov #/ torch.linalg.norm(cov)\n",
        "\n",
        "############################################################################\n",
        "# 3) VNN filter-bank model\n",
        "############################################################################\n",
        "class VNNFilterBank(nn.Module):\n",
        "    \"\"\"\n",
        "    VNN with several heat-kernel filters exp(−βL), each β trainable,\n",
        "    plus trainable γ_skip[k], γ_conv[k] for skip vs. conv paths.\n",
        "    \"\"\"\n",
        "    def __init__(self, C, T, cov_matrix, init_betas,\n",
        "                 hidden_dim=128, num_classes=2, dropout_p=0.5):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"L\", cov_matrix)   # fixed graph Laplacian\n",
        "        self.C, self.T = C, T\n",
        "\n",
        "        init_betas = torch.tensor(init_betas, dtype=torch.float)\n",
        "        self.beta_params = nn.Parameter(init_betas)\n",
        "\n",
        "        # --- skip / conv scalars per filter ---\n",
        "        nF = init_betas.numel()\n",
        "        self.gamma_conv = nn.Parameter(torch.ones(nF))\n",
        "\n",
        "        # --- classifier input dim = C*T per filter × nF filters ---\n",
        "        in_dim = C * T * nF\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Linear(in_dim, hidden_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Dropout(dropout_p),\n",
        "            nn.Linear(hidden_dim, num_classes)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        # x: (B, C, T)\n",
        "        B, C, T = x.shape\n",
        "        feats = []\n",
        "\n",
        "        # single normalisation (optional – leave as identity for now)\n",
        "        x_norm = x\n",
        "\n",
        "        # for each filter\n",
        "        for i in range(self.beta_params.numel()):\n",
        "            # ensure β>0\n",
        "            beta = (self.beta_params[i])\n",
        "\n",
        "            # build heat-kernel filter exp(−βL)\n",
        "            filt =torch.matrix_exp(-beta * self.L)          # (C, C)\n",
        "            filt = filt / torch.trace(filt)                  # unit-trace\n",
        "\n",
        "            # skip + conv, each weighted\n",
        "                       # (B, C, T)\n",
        "            y_conv = self.gamma_conv[i] * (filt.unsqueeze(0) @ x_norm)\n",
        "            y = y_conv                              # (B, C, T)\n",
        "\n",
        "            feats.append(y.reshape(B, -1))                   # flatten\n",
        "\n",
        "        z = torch.cat(feats, dim=1)                          # (B, C*T*nF)\n",
        "        return self.classifier(z)                            # raw logits\n",
        "\n",
        "############################################################################\n",
        "# 4) Training / evaluation helpers\n",
        "############################################################################\n",
        "def train_one_epoch(model, loader, optimizer, criterion):\n",
        "    model.train()\n",
        "    running_loss = 0.0\n",
        "    for Xb, yb in loader:\n",
        "        Xb, yb = Xb.to(device), yb.to(device)\n",
        "        optimizer.zero_grad()\n",
        "        logits = model(Xb)\n",
        "        loss = criterion(logits, yb)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        running_loss += loss.item() * Xb.size(0)\n",
        "    return running_loss / len(loader.dataset)\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def evaluate_loss(model, loader, criterion):\n",
        "    model.eval()\n",
        "    total_loss = 0.0\n",
        "    for Xb, yb in loader:\n",
        "        Xb, yb = Xb.to(device), yb.to(device)\n",
        "        logits = model(Xb)\n",
        "        loss = criterion(logits, yb)\n",
        "        total_loss += loss.item() * Xb.size(0)\n",
        "    return total_loss / len(loader.dataset)\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def evaluate_accuracy(model, loader):\n",
        "    model.eval()\n",
        "    correct = total = 0\n",
        "    for Xb, yb in loader:\n",
        "        Xb, yb = Xb.to(device), yb.to(device)\n",
        "        preds = model(Xb).argmax(dim=1)\n",
        "        correct += (preds == yb).sum().item()\n",
        "        total   += yb.size(0)\n",
        "    return correct / total if total else 0.0\n",
        "\n",
        "############################################################################\n",
        "# 5) Leave-one-subject-out cross-validation driver\n",
        "############################################################################\n",
        "def run_leave_one_out_vnn(\n",
        "    data_dict: Dict[Any, Dict[str, np.ndarray]],\n",
        "    val_split      = 0.2,\n",
        "    epochs         = 20,\n",
        "    batch_size     = 64,\n",
        "    learning_rate  = 1e-3,\n",
        "    betas          = (0.1, 1.0, 10.0),\n",
        "    num_classes    = 2\n",
        "):\n",
        "    \"\"\"\n",
        "    Leave-one-subject-out CV:\n",
        "        for each subject S, train on (all − S) and test on S.\n",
        "    \"\"\"\n",
        "    participants = sorted(list(data_dict.keys()))\n",
        "    total_subj   = len(participants)\n",
        "    print(f\"▶ Running LOSO on {total_subj} participants\")\n",
        "\n",
        "    all_test_acc = []\n",
        "\n",
        "    for idx, test_subj in enumerate(participants, 1):\n",
        "        train_subj = [s for s in participants if s != test_subj]\n",
        "        print(f\"\\n=== Fold {idx}/{total_subj} – held-out subject: {test_subj} ===\")\n",
        "        print(\"Train subjects:\", len(train_subj))\n",
        "\n",
        "        # ------------ assemble train / val / test datasets -------------\n",
        "        X_train_full, y_train_full = build_dataset_for_subjects(train_subj, data_dict)\n",
        "        full_ds   = EEGDataset(X_train_full, y_train_full)\n",
        "        n_tot     = len(full_ds)\n",
        "        n_val     = int(val_split * n_tot)\n",
        "        n_train   = n_tot - n_val\n",
        "        train_ds, val_ds = torch.utils.data.random_split(full_ds, [n_train, n_val])\n",
        "        train_ld = torch.utils.data.DataLoader(train_ds, batch_size=batch_size,\n",
        "                                               shuffle=True,  drop_last=True)\n",
        "        val_ld   = torch.utils.data.DataLoader(val_ds,   batch_size=batch_size,\n",
        "                                               shuffle=False)\n",
        "\n",
        "        X_test, y_test = build_dataset_for_subjects([test_subj], data_dict)\n",
        "        test_ld  = torch.utils.data.DataLoader(EEGDataset(X_test, y_test),\n",
        "                                               batch_size=batch_size,\n",
        "                                               shuffle=False)\n",
        "\n",
        "        C, T = X_train_full.shape[1], X_train_full.shape[2]\n",
        "\n",
        "        cov = compute_global_covariance(\n",
        "            torch.tensor(X_train_full, dtype=torch.float, device=device)\n",
        "        )\n",
        "\n",
        "        # ------------ build / train model ------------------------------\n",
        "        model = VNNFilterBank(C=C, T=T, cov_matrix=cov,\n",
        "                              init_betas=betas, num_classes=num_classes).to(device)\n",
        "        optimiser = optim.Adam(model.parameters(), lr=learning_rate)\n",
        "        criterion = nn.CrossEntropyLoss()\n",
        "\n",
        "        best_val_loss = float(\"inf\")\n",
        "        best_state    = None\n",
        "\n",
        "        for ep in range(1, epochs + 1):\n",
        "            train_loss = train_one_epoch(model, train_ld, optimiser, criterion)\n",
        "            val_loss   = evaluate_loss(model,  val_ld, criterion)\n",
        "            print(f\"[Subject {idx}] Epoch {ep}/{epochs} | \"\n",
        "                  f\"train {train_loss:.4f} | val {val_loss:.4f}\")\n",
        "\n",
        "            if val_loss < best_val_loss:\n",
        "                best_val_loss = val_loss\n",
        "                best_state    = copy.deepcopy(model.state_dict())\n",
        "\n",
        "        # ------------ test ------------------------------------------------\n",
        "        model.load_state_dict(best_state)\n",
        "        test_acc = evaluate_accuracy(model, test_ld)\n",
        "        print(f\"[Subject {idx}] best val loss {best_val_loss:.4f} | \"\n",
        "              f\"test acc {100*test_acc:.2f}%\")\n",
        "        all_test_acc.append(test_acc)\n",
        "\n",
        "    mean_acc = 100 * np.mean(all_test_acc)\n",
        "    std_acc  = 100 * np.std(all_test_acc)\n",
        "    print(f\"\\n✅ LOSO average accuracy over {total_subj} subjects: \"\n",
        "          f\"{mean_acc:.2f}%  (±{std_acc:.2f})\")\n",
        "    return np.mean(all_test_acc)\n",
        "\n",
        "############################################################################\n",
        "# 6) Entry-point\n",
        "############################################################################\n",
        "if __name__ == \"__main__\":\n",
        "    # Make sure `data_dict` is already loaded into memory before this point!\n",
        "    final_acc = run_leave_one_out_vnn(\n",
        "        data_dict     = data_dict,\n",
        "        val_split     = 0.1,\n",
        "        epochs        = 20,\n",
        "        batch_size    = 64,\n",
        "        learning_rate = 1e-3,\n",
        "        betas         = (2.0, 2.0),\n",
        "        num_classes   = 2\n",
        "    )\n"
      ],
      "metadata": {
        "id": "nv2La5gJt7bz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#!/usr/bin/env python3\n",
        "#CSP+VNN PHYSIONET\n",
        "\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "# 0) Imports & reproducibility\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "import copy, random\n",
        "from typing import Sequence, Dict, Any\n",
        "\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n",
        "from sklearn.metrics import cohen_kappa_score\n",
        "\n",
        "SEED = 321\n",
        "random.seed(SEED)\n",
        "np.random.seed(SEED)\n",
        "torch.manual_seed(SEED)\n",
        "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
        "if device.type == \"cuda\":\n",
        "    torch.cuda.manual_seed_all(SEED)\n",
        "print(\"▶ Using device:\", device)\n",
        "\n",
        "NUM_CLASSES = 2          # ← set to 4 if you keep all classes\n",
        "N_CSP_COMPONENTS = 6     # components kept per CSP projection\n",
        "HIDDEN_DIM_VNN   = 128   # MLP width in SimpleVNN\n",
        "\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "# 1) Dataset helpers\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "class EEGDataset(torch.utils.data.Dataset):\n",
        "    def __init__(self, X, y):\n",
        "        if isinstance(X, np.ndarray):\n",
        "            X = torch.tensor(X, dtype=torch.float32)\n",
        "        if isinstance(y, np.ndarray):\n",
        "            y = torch.tensor(y, dtype=torch.long)\n",
        "        self.X, self.y = X, y\n",
        "\n",
        "    def __len__(self):                return self.X.shape[0]\n",
        "    def __getitem__(self, idx):       return self.X[idx], self.y[idx]\n",
        "\n",
        "\n",
        "def build_dataset_for_subjects(subj_list: Sequence[Any],\n",
        "                               ddict: Dict[Any, Dict[str, np.ndarray]]):\n",
        "    X = np.concatenate([ddict[s][\"X\"] for s in subj_list], axis=0)\n",
        "    y = np.concatenate([ddict[s][\"y\"] for s in subj_list], axis=0)\n",
        "    return X, y\n",
        "\n",
        "\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "# 2) Graph-covariance  (used by SimpleVNN)\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "@torch.no_grad()\n",
        "def compute_global_covariance(x: torch.Tensor) -> torch.Tensor:\n",
        "    # x: (N,C,T) on DEVICE\n",
        "    x = x - x.mean(dim=2, keepdim=True)\n",
        "    N, C, T = x.shape\n",
        "    cov = torch.einsum(\"nct,nkt->ck\", x, x) / (N * T)\n",
        "    return cov\n",
        "\n",
        "\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "# 3) Common Spatial Patterns (binary)\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "def fit_csp(x: np.ndarray, y: np.ndarray, n_comp: int = 6):\n",
        "    \"\"\"\n",
        "    x : (N,C,T) – training epochs\n",
        "    y : (N,)    – binary labels {0,1}\n",
        "    returns projection matrix W (C → n_comp*2).\n",
        "    \"\"\"\n",
        "    assert len(np.unique(y)) == 2, \"CSP implemented for binary only\"\n",
        "    C = x.shape[1]\n",
        "\n",
        "    # class-wise covariance matrices\n",
        "    covs = []\n",
        "    for cls in [0, 1]:\n",
        "        x_cls = x[y == cls]\n",
        "        x_cls = x_cls - x_cls.mean(axis=2, keepdims=True)\n",
        "        cov = np.einsum(\"nct,nkt->ck\", x_cls, x_cls) / (x_cls.shape[0] * x_cls.shape[2])\n",
        "        covs.append(cov)\n",
        "    R1, R2 = covs\n",
        "    # solve generalised eigenproblem R1 w = λ (R1+R2) w\n",
        "    evals, evecs = np.linalg.eig(np.linalg.pinv(R1 + R2) @ R1)\n",
        "    idx = np.argsort(evals)[::-1]                       # descending\n",
        "    evecs = evecs[:, idx]\n",
        "\n",
        "    # select top/bottom n_comp vectors\n",
        "    W = np.concatenate([evecs[:, :n_comp], evecs[:, -n_comp:]], axis=1)  # (C, 2n)\n",
        "    return W.astype(np.float32)\n",
        "\n",
        "\n",
        "def csp_transform(x: np.ndarray, W: np.ndarray) -> np.ndarray:\n",
        "    \"\"\"\n",
        "    x : (N,C,T); W : (C,M) → returns (N,M) log-variance features.\n",
        "    \"\"\"\n",
        "    Z = np.einsum(\"ck,nkt->nct\", W.T, x)          # (N,M,T)\n",
        "    var = np.var(Z, axis=2)                       # (N,M)\n",
        "    feats = np.log(var / var.sum(axis=1, keepdims=True))\n",
        "    return feats.astype(np.float32)\n",
        "\n",
        "\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "# 4) Simple VNN  (single graph filter)\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "class SimpleVNN(nn.Module):\n",
        "    def __init__(self, C, T, L, hidden_dim=128, num_classes=2, dropout_p=0.5):\n",
        "        super().__init__()\n",
        "        self.register_buffer(\"L\", L)              # (C,C)\n",
        "        self.classifier = nn.Sequential(\n",
        "            nn.Linear(C * T, hidden_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Dropout(dropout_p),\n",
        "            nn.Linear(hidden_dim, num_classes),\n",
        "        )\n",
        "\n",
        "    def forward(self, x):                         # x: (B,C,T)\n",
        "        y = torch.matmul(self.L.unsqueeze(0), x)  # (B,C,T)\n",
        "        return self.classifier(y.reshape(x.size(0), -1))\n",
        "\n",
        "\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "# 5) Training / evaluation helpers for SimpleVNN\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "def train_one_epoch(model, loader, opt, crit):\n",
        "    model.train()\n",
        "    tot = 0.0\n",
        "    for X, y in loader:\n",
        "        X, y = X.to(device), y.to(device)\n",
        "        opt.zero_grad()\n",
        "        loss = crit(model(X), y)\n",
        "        loss.backward()\n",
        "        opt.step()\n",
        "        tot += loss.item() * X.size(0)\n",
        "    return tot / len(loader.dataset)\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def eval_loss(model, loader, crit):\n",
        "    model.eval(); tot = 0.0\n",
        "    for X, y in loader:\n",
        "        X, y = X.to(device), y.to(device)\n",
        "        tot += crit(model(X), y).item() * X.size(0)\n",
        "    return tot / len(loader.dataset)\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def eval_accuracy(model, loader):\n",
        "    model.eval(); correct = total = 0\n",
        "    for X, y in loader:\n",
        "        X, y = X.to(device), y.to(device)\n",
        "        preds = model(X).argmax(dim=1)\n",
        "        correct += (preds == y).sum().item()\n",
        "        total   += y.size(0)\n",
        "    return correct / total if total else 0.0\n",
        "\n",
        "\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "# 6) Leave-one-subject-out driver (CSP+LDA & SimpleVNN)\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "def run_loso(data_dict: Dict[Any, Dict[str, np.ndarray]],\n",
        "             epochs=20, batch_size=64, lr=1e-3, val_split=0.1):\n",
        "\n",
        "    subjects = sorted(data_dict.keys())\n",
        "    res_csp, res_vnn = [], []\n",
        "\n",
        "    for idx, test_subj in enumerate(subjects, 1):\n",
        "        print(f\"\\n=== Fold {idx}/{len(subjects)}  (held-out: {test_subj}) ===\")\n",
        "\n",
        "        train_subj = [s for s in subjects if s != test_subj]\n",
        "        X_tr, y_tr = build_dataset_for_subjects(train_subj, data_dict)\n",
        "        X_te, y_te = build_dataset_for_subjects([test_subj], data_dict)\n",
        "\n",
        "        # ──── ░░░░  CSP + LDA  ░░░░ ──────────────────────────────────────\n",
        "        if NUM_CLASSES == 2:\n",
        "            W = fit_csp(X_tr, y_tr, N_CSP_COMPONENTS)\n",
        "            X_tr_csp = csp_transform(X_tr, W)\n",
        "            X_te_csp = csp_transform(X_te, W)\n",
        "\n",
        "            lda = LinearDiscriminantAnalysis()\n",
        "            lda.fit(X_tr_csp, y_tr)\n",
        "            acc_csp = lda.score(X_te_csp, y_te)\n",
        "            res_csp.append(acc_csp)\n",
        "            print(f\"CSP+LDA  acc {100*acc_csp:.2f}%\")\n",
        "        else:\n",
        "            res_csp.append(np.nan)\n",
        "\n",
        "        # ──── ░░░░  Simple VNN  ░░░░ ─────────────────────────────────────\n",
        "        full_ds = EEGDataset(X_tr, y_tr)\n",
        "        n_val   = int(val_split * len(full_ds))\n",
        "        n_train = len(full_ds) - n_val\n",
        "        tr_ds, val_ds = torch.utils.data.random_split(full_ds, [n_train, n_val])\n",
        "\n",
        "        ld_tr  = torch.utils.data.DataLoader(tr_ds,  batch_size=batch_size,\n",
        "                                             shuffle=True, drop_last=True)\n",
        "        ld_val = torch.utils.data.DataLoader(val_ds, batch_size=batch_size)\n",
        "        ld_te  = torch.utils.data.DataLoader(EEGDataset(X_te, y_te),\n",
        "                                             batch_size=batch_size)\n",
        "\n",
        "        C, T = X_tr.shape[1], X_tr.shape[2]\n",
        "        L = compute_global_covariance(torch.tensor(X_tr, device=device))\n",
        "\n",
        "        vnn   = SimpleVNN(C, T, L, hidden_dim=HIDDEN_DIM_VNN,\n",
        "                          num_classes=NUM_CLASSES).to(device)\n",
        "        opt   = optim.Adam(vnn.parameters(), lr=lr)\n",
        "        crit  = nn.CrossEntropyLoss()\n",
        "\n",
        "        best_val, best_state = float(\"inf\"), None\n",
        "        for ep in range(1, epochs + 1):\n",
        "            tr_loss = train_one_epoch(vnn, ld_tr, opt, crit)\n",
        "            val_loss = eval_loss(vnn, ld_val, crit)\n",
        "            if val_loss < best_val:\n",
        "                best_val, best_state = val_loss, copy.deepcopy(vnn.state_dict())\n",
        "            print(f\"VNN ep {ep:02d}/{epochs}  tr {tr_loss:.4f}  val {val_loss:.4f}\",\n",
        "                  end=\"\\r\")\n",
        "        vnn.load_state_dict(best_state)\n",
        "        acc_vnn = eval_accuracy(vnn, ld_te)\n",
        "        res_vnn.append(acc_vnn)\n",
        "        print(f\"SimpleVNN acc {100*acc_vnn:.2f}%   best val {best_val:.4f}\")\n",
        "\n",
        "    # ──── Summary ──────────────────────────────────────────────────────────\n",
        "    print(\"\\n================ SUMMARY ================\")\n",
        "    def pr(name, arr):\n",
        "        arr = np.asarray(arr, dtype=float)\n",
        "        print(f\"{name:9s}:  mean {100*arr.mean():5.2f}%   ±{100*arr.std():4.2f}\")\n",
        "    if NUM_CLASSES == 2:\n",
        "        pr(\"CSP+LDA\", res_csp)\n",
        "    pr(\"SimpleVNN\", res_vnn)\n",
        "\n",
        "\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "# 7) Entry-point\n",
        "# ─────────────────────────────────────────────────────────────────────────────\n",
        "if __name__ == \"__main__\":\n",
        "    # Make sure `data_dict` is available in the current namespace.\n",
        "    # Example loading (adapt as needed):\n",
        "    #   with open(\"windows.pkl\", \"rb\") as f:\n",
        "    #       data_dict = pickle.load(f)\n",
        "    run_loso(data_dict,\n",
        "             epochs=20,\n",
        "             batch_size=64,\n",
        "             lr=1e-3,\n",
        "             val_split=0.1)\n"
      ],
      "metadata": {
        "id": "jlLzX_qIujoF"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}