{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 575
        },
        "id": "_FMSsWFKI8_3",
        "outputId": "3c138537-4dfb-438e-f54a-11773585594c"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "=== Horizon=3 | Seed=124 | Model=GraphVar ===\n",
            "Ep 001 | Train MSE=0.9704 | Val MSE=1.1609\n",
            "Ep 002 | Train MSE=0.9255 | Val MSE=1.1106\n",
            "Ep 003 | Train MSE=0.8836 | Val MSE=1.0612\n",
            "Ep 004 | Train MSE=0.8421 | Val MSE=1.0119\n",
            "Ep 005 | Train MSE=0.8010 | Val MSE=0.9636\n",
            "Ep 006 | Train MSE=0.7621 | Val MSE=0.9190\n",
            "Ep 007 | Train MSE=0.7279 | Val MSE=0.8803\n",
            "Ep 008 | Train MSE=0.6998 | Val MSE=0.8503\n",
            "Ep 009 | Train MSE=0.6784 | Val MSE=0.8281\n",
            "Ep 010 | Train MSE=0.6627 | Val MSE=0.8109\n",
            "Ep 011 | Train MSE=0.6512 | Val MSE=0.7986\n",
            "Ep 012 | Train MSE=0.6427 | Val MSE=0.7893\n",
            "Ep 013 | Train MSE=0.6362 | Val MSE=0.7827\n",
            "Ep 014 | Train MSE=0.6311 | Val MSE=0.7765\n",
            "Ep 015 | Train MSE=0.6268 | Val MSE=0.7710\n",
            "Ep 016 | Train MSE=0.6231 | Val MSE=0.7677\n",
            "Ep 017 | Train MSE=0.6197 | Val MSE=0.7634\n",
            "Ep 018 | Train MSE=0.6165 | Val MSE=0.7596\n",
            "Ep 019 | Train MSE=0.6135 | Val MSE=0.7563\n"
          ]
        },
        {
          "output_type": "error",
          "ename": "KeyboardInterrupt",
          "evalue": "",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-790074751.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m    720\u001b[0m                     \u001b[0;32mfor\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mloaders\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'val'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    721\u001b[0m                         \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 722\u001b[0;31m                         \u001b[0macc\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mloss_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    723\u001b[0m                 \u001b[0mval_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0macc\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mloaders\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'val'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    724\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ],
      "source": [
        "\n",
        "#GRAPH BASED FORECASTING TASKS\n",
        "import os, time, random, copy\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.utils.data import TensorDataset, DataLoader\n",
        "from google.colab import drive\n",
        "\n",
        "\n",
        "# 2) Hyperparameters\n",
        "T = 6\n",
        "horizons = [3, 6, 12]\n",
        "seeds = [124,14, 124235]\n",
        "batch_size =1280\n",
        "num_epochs = 200\n",
        "learning_rate = 1e-4\n",
        "hidden_dim = 128\n",
        "train_val_split = 0.8\n",
        "EPS = 1e-5\n",
        "epsilon = 1e-5\n",
        "\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "\n",
        "# 3) Load data\n",
        "#df = pd.read_csv('LorenzCoupled_coarse.csv')\n",
        "#df = df.drop(columns=['time'], errors='ignore')\n",
        "#data = df.values.astype(np.float32)\n",
        "#data = pd.read_csv('PEMS-BAY.csv').select_dtypes(include=[np.number]).dropna().values  # shape: (N_raw, C)\n",
        "\n",
        "#means = data.mean(axis=0, keepdims=True)\n",
        "#stds = data.std(axis=0, keepdims=True)\n",
        "#data_z = (data - means) / (stds + 1e-8)\n",
        "#mu=data_z.mean(axis=1,keepdims=True)\n",
        "#sigma=data_z.std(axis=1,keepdims=True)\n",
        "#data_z = (data_z - mu) / (sigma + 1e-8)\n",
        "\n",
        "#data_z=data\n",
        "#N_raw, C = data_z.shape\n",
        "#W_C_Corr =torch.abs( torch.corrcoef(torch.tensor(data_z.T, dtype=torch.float32)))\n",
        "#W_C_Stable=torch.tensor(P)\n",
        "#W=W_C_Corr\n",
        "# 4) Graph utilities\n",
        "def graph_conv(x, A):\n",
        "    # x: (B, C, T)\n",
        "    # A: (B, C, C, T)  where A[:, i, j, t] is edge j->i at time t\n",
        "    return torch.einsum('bckt,bkt->bct', A, x)  # sum over k (neighbors)\n",
        "def graph_conv_bmm(x, A):\n",
        "    \"\"\"\n",
        "    x: (B, N, T)\n",
        "    A: (B, N, N, T)\n",
        "    \"\"\"\n",
        "    B, N, T = x.shape\n",
        "    # (B*T, N, N) @ (B*T, N, 1) -> (B*T, N, 1)\n",
        "    A_bt = A.permute(0, 3, 1, 2).contiguous().view(B*T, N, N)\n",
        "    x_bt = x.permute(0, 2, 1).contiguous().view(B*T, N, 1)\n",
        "    y_bt = torch.bmm(A_bt, x_bt).view(B, T, N)      # (B, T, N)\n",
        "    return y_bt.permute(0, 2, 1).contiguous()       # (B, N, T)\n",
        "\n",
        "def graph_variate(x, fun='sqd', Zave=True, eps=1e-5, alpha=None):\n",
        "    B, C_, T_ = x.shape\n",
        "    if Zave:\n",
        "        mu = x.mean(0, keepdim=True)\n",
        "        sig = x.std(0, keepdim=True, unbiased=True)\n",
        "        x = (x - mu) / (sig + eps)\n",
        "\n",
        "    if fun == 'sqd':\n",
        "        D = x-x.mean(1,keepdim=True)\n",
        "        corr_term = torch.abs(D.unsqueeze(2) * D.unsqueeze(1))\n",
        "        Om = ((x.unsqueeze(2) - x.unsqueeze(1)).pow(2))\n",
        "\n",
        "    elif fun == 'corr':\n",
        "\n",
        "        D = x#-x.mean(1,keepdim=True)\n",
        "        corr_term = torch.abs(D.unsqueeze(2) * D.unsqueeze(1))\n",
        "        Om= corr_term\n",
        "\n",
        "    elif fun == 'rbf':\n",
        "        diff2 = (x.unsqueeze(2) - x.unsqueeze(1)).pow(2)\n",
        "        Om = torch.exp(-1 * diff2)\n",
        "\n",
        "    elif fun == 'full':\n",
        "        Om = torch.ones(B, C_, C_, T_, device=x.device)\n",
        "\n",
        "    else:\n",
        "        raise ValueError(fun)\n",
        "\n",
        "    return Om\n",
        "def renormalize_dynamic(A, eps=EPS):\n",
        "    # A: (B, C, C, T)\n",
        "    I = torch.eye(A.size(1), device=A.device)[None, :, :, None]\n",
        "    At = A + I\n",
        "    deg = At.sum(2, keepdim=True)\n",
        "    inv = deg.clamp(min=eps).pow(-0.5)\n",
        "    return inv * At * inv.transpose(1, 2)\n",
        "# ----------------- long-term stats (per-sample, over time) -----------------\n",
        "def long_term_sqd(x):\n",
        "    \"\"\"\n",
        "    Long-term average squared difference per sample (over time), per pair (i,j):\n",
        "      mean_t ( (x_i(t) - x_j(t))^2 ), after demeaning each channel over time.\n",
        "    x: (B, C, T)\n",
        "    returns: (B, C, C, 1) nonnegative\n",
        "    \"\"\"\n",
        "    x_d = x - x.mean(dim=2, keepdim=True)                         # (B, C, T)\n",
        "    diff2 = (x_d.unsqueeze(2) - x_d.unsqueeze(1)).pow(2)          # (B, C, C, T)\n",
        "    Om = diff2.mean(dim=3, keepdim=True)                           # (B, C, C, 1)\n",
        "    return Om\n",
        "\n",
        "def long_term_corr_abs(x):\n",
        "    \"\"\"\n",
        "    Long-term (over time) Pearson correlation per sample, abs-valued:\n",
        "      corr[i,j] = | <z_i, z_j> / T |, where z_i is per-channel z-score over time.\n",
        "    x: (B, C, T)\n",
        "    returns: (B, C, C, 1) in [0,1], diagonal ~1 (we zero it before renorm)\n",
        "    \"\"\"\n",
        "    B, C, T = x.shape\n",
        "    mu_t = x.mean(dim=2, keepdim=True)\n",
        "    sd_t = x.std(dim=2, keepdim=True) + EPS\n",
        "    z = (x - mu_t) / sd_t                                         # (B, C, T)\n",
        "    corr = torch.einsum('bct,bdt->bcd', z, z) / float(T)          # (B, C, C)\n",
        "    corr = corr.abs()\n",
        "    eye = torch.eye(C, device=x.device).unsqueeze(0)\n",
        "    corr = corr * (1.0 - eye)\n",
        "    return corr.unsqueeze(-1)                                      # (B, C, C, 1)\n",
        "\n",
        "\n",
        "# 5) Models — all produce (B, C, H)\n",
        "class GraphVarForecastModel(nn.Module):\n",
        "    def __init__(self, C, T, W_C, H, hidden_dim=128, trainable_W_C=False,\n",
        "                 dynamic_W: str = 'const',\n",
        "                 ltcorr_from_z: bool = True\n",
        "                 ):\n",
        "        super().__init__()\n",
        "        self.C, self.T, self.H = C, T, H\n",
        "        self.dynamic_W = dynamic_W\n",
        "        self.ltcorr_from_z = ltcorr_from_z\n",
        "\n",
        "\n",
        "        self.register_buffer('W_C', W_C)\n",
        "\n",
        "        self.g_conv   = nn.Parameter(torch.ones(T))\n",
        "        self.g_skip   = nn.Parameter(torch.ones(T))\n",
        "        self.g_skip_2 = nn.Parameter(torch.ones(T))\n",
        "        self.g_conv_2 = nn.Parameter(torch.ones(T))\n",
        "\n",
        "        self.theta = nn.Linear(T, T, bias=False)\n",
        "        self.beta  = nn.Linear(T, T, bias=False)\n",
        "\n",
        "        self.fun = 'corr'\n",
        "        self.alpha = torch.tensor(0.5)\n",
        "\n",
        "        self.mlp = nn.Sequential(\n",
        "            nn.Linear(T, hidden_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Linear(hidden_dim, H)\n",
        "        )\n",
        "\n",
        "    def _build_W_eff(self, x, x0):\n",
        "        \"\"\"\n",
        "        Returns W_eff with shape:\n",
        "          - (1, C, C, 1) if dynamic_W == 'const'\n",
        "          - (B, C, C, 1) if dynamic_W == 'ltcorr'\n",
        "        \"\"\"\n",
        "\n",
        "        return self.W_C.unsqueeze(0).unsqueeze(-1)\n",
        "\n",
        "    def forward(self, x):\n",
        "        mu = x.mean(1, keepdim=True)\n",
        "        sig = x.std(1, keepdim=True)\n",
        "        x0  = (x - mu) / (sig + EPS)\n",
        "\n",
        "        W_eff = self._build_W_eff(x, x0)  # (B or 1, C, C, 1)\n",
        "\n",
        "        Om1 = graph_variate(x0, fun='corr', Zave=False, eps=EPS)           # (B, C, C, T)\n",
        "        A1  = renormalize_dynamic(Om1 * W_eff)                               # broadcast over T\n",
        "\n",
        "        h1 = self.theta(\n",
        "            self.g_skip.view(1, 1, self.T) * x0 +\n",
        "            self.g_conv.view(1, 1, self.T) * graph_conv_bmm(x0, A1)\n",
        "        )\n",
        "\n",
        "        Om2 = graph_variate(h1, fun='sqd', Zave=False, eps=EPS)             # (B, C, C, T)\n",
        "        A2  = renormalize_dynamic(Om2 * W_eff)\n",
        "\n",
        "        h2 = self.beta(\n",
        "            self.g_skip_2.view(1, 1, self.T) * h1 +\n",
        "            self.g_conv_2.view(1, 1, self.T) * graph_conv(h1, A2)\n",
        "        )\n",
        "\n",
        "\n",
        "\n",
        "        return self.mlp(h2)\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "class CPGraphSTForecastModel(nn.Module):\n",
        "    \"\"\"\n",
        "    3-layer CP-style space-time graph model (Kronecker-sum diffusion each layer):\n",
        "      y1 = theta( g_skip1 ⊙ x0 + g_conv1 ⊙ G(x0) )\n",
        "      y2 =  beta( g_skip2 ⊙ y1 + g_conv2 ⊙ G(y1) )\n",
        "      y3 = gamma( g_skip3 ⊙ y2 + g_conv3 ⊙ G(y2) )\n",
        "      out = MLP(y3)\n",
        "    \"\"\"\n",
        "    def __init__(self, C, T, W_C, H, hidden_dim=128):\n",
        "        super().__init__()\n",
        "        self.C, self.T, self.H = C, T, H\n",
        "\n",
        "        self.register_buffer('W_C', W_C)\n",
        "        self.register_buffer('_I_C', torch.eye(C))\n",
        "        self.register_buffer('_I_T', torch.eye(T))\n",
        "\n",
        "        A_time = torch.zeros(T, T)\n",
        "        for i in range(T - 1):\n",
        "            A_time[i, i + 1] = 1.0\n",
        "            A_time[i + 1, i] = 1.0\n",
        "        self.register_buffer('A_time', A_time)\n",
        "\n",
        "        # Layer 1\n",
        "        self.g_skip1 = nn.Parameter(torch.ones(T))\n",
        "        self.g_conv1 = nn.Parameter(torch.ones(T))\n",
        "        self.theta   = nn.Linear(T, T, bias=False)\n",
        "\n",
        "        # Layer 2\n",
        "        self.g_skip2 = nn.Parameter(torch.ones(T))\n",
        "        self.g_conv2 = nn.Parameter(torch.ones(T))\n",
        "        self.beta    = nn.Linear(T, T, bias=False)\n",
        "\n",
        "        # Layer 3\n",
        "        self.g_skip3 = nn.Parameter(torch.ones(T))\n",
        "        self.g_conv3 = nn.Parameter(torch.ones(T))\n",
        "        self.gamma   = nn.Linear(T, T, bias=False)\n",
        "\n",
        "        self.mlp = nn.Sequential(\n",
        "            nn.Linear(T, hidden_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Linear(hidden_dim, H),\n",
        "        )\n",
        "\n",
        "    def _A_norm(self, device):\n",
        "        # (CT x CT) Kronecker-sum adjacency, add I, symmetric degree norm\n",
        "        A_ts = torch.kron(self.A_time, self._I_C)   # time ⊗ I_C\n",
        "        A_st = torch.kron(self._I_T, self.W_C)      # I_T ⊗ space\n",
        "        A = (A_ts + A_st).to(device)\n",
        "        At = A + torch.eye(A.size(0), device=device)\n",
        "        inv = At.sum(1).clamp(min=EPS).pow(-0.5)\n",
        "        return inv.unsqueeze(1) * At * inv.unsqueeze(0)  # (CT, CT)\n",
        "\n",
        "    def _diffuse(self, X, A_norm):\n",
        "        # X: (B, C, T) → (B, CT) @ (CT, CT) → (B, CT) → (B, C, T)\n",
        "        B, C, T = X.shape\n",
        "        flat = X.reshape(B, C * T)\n",
        "        out = (flat @ A_norm.T).reshape(B, C, T)\n",
        "        return out\n",
        "\n",
        "    def forward(self, x):\n",
        "        mu = x.mean(1, keepdim=True); sig = x.std(1, keepdim=True)\n",
        "        x0 = (x - mu) / (sig + epsilon)\n",
        "\n",
        "        A_norm = self._A_norm(x.device)\n",
        "\n",
        "        # Layer 1\n",
        "        conv1 = self._diffuse(x0, A_norm)\n",
        "        y1 = self.theta(self.g_skip1.view(1,1,self.T) * x0 +\n",
        "                        self.g_conv1.view(1,1,self.T) * conv1)\n",
        "\n",
        "        # Layer 2\n",
        "        conv2 = self._diffuse(y1, A_norm)\n",
        "        y2 = self.beta(self.g_skip2.view(1,1,self.T) * y1 +\n",
        "                       self.g_conv2.view(1,1,self.T) * conv2)\n",
        "\n",
        "\n",
        "        return self.mlp(y2)  # (B, C, H)\n",
        "\n",
        "\n",
        "class GVARMAForecastModel(nn.Module):\n",
        "\n",
        "    def __init__(self, C, T, W_C, H, hidden_dim=128, trainable_W_C=False):\n",
        "        super().__init__()\n",
        "        self.C, self.T, self.H = C, T, H\n",
        "        self.P=1\n",
        "        self.Q=1\n",
        "        self.K=2\n",
        "        P=self.P\n",
        "        Q=self.Q\n",
        "        K=self.K\n",
        "        self.register_buffer('W_C', W_C)\n",
        "\n",
        "        W_init = W_C.detach().float()\n",
        "        if trainable_W_C:\n",
        "            self.W_C = nn.Parameter(W_init)\n",
        "        else:\n",
        "            self.register_buffer(\"W_C\", W_init)\n",
        "\n",
        "        # AR coefficients a_{p,k}, small init for stability\n",
        "        self.a = nn.Parameter(0.05 * torch.randn(P, K + 1)) if P > 0 else nn.Parameter(torch.zeros(0, K + 1))\n",
        "        # MA coefficients b_{q,k}\n",
        "        self.b = nn.Parameter(torch.zeros(Q, K + 1)) if Q > 0 else nn.Parameter(torch.zeros(0, K + 1))\n",
        "\n",
        "        self.ar_scale = nn.Parameter(torch.tensor(1.0))\n",
        "        self.ma_scale = nn.Parameter(torch.tensor(1.0))\n",
        "\n",
        "    def _norm_shift(self, device):\n",
        "        \"\"\"Symmetric degree normalization: S = D^{-1/2}(W + I)D^{-1/2}.\"\"\"\n",
        "        W = self.W_C\n",
        "\n",
        "        At = W.to(device) + torch.eye(self.C, device=device, dtype=W.dtype)\n",
        "        deg = At.sum(1).clamp(min=EPS)\n",
        "        inv = deg.pow(-0.5)\n",
        "        return inv.unsqueeze(1) * At * inv.unsqueeze(0)  # (C, C)\n",
        "\n",
        "    def _S_powers(self, S):\n",
        "        \"\"\"Return [I, S, S^2, ..., S^K].\"\"\"\n",
        "        C = S.size(0)\n",
        "        S_list = [torch.eye(C, device=S.device, dtype=S.dtype)]\n",
        "        for _ in range(self.K):\n",
        "            S_list.append(S @ S_list[-1])\n",
        "        return S_list  # length K+1\n",
        "\n",
        "    def _build_poly_mats(self, S_pows):\n",
        "        \"\"\"\n",
        "        A_mats[p] = sum_k a_{p,k} S^k, shape (P, C, C)\n",
        "        B_mats[q] = sum_k b_{q,k} S^k, shape (Q, C, C)\n",
        "        \"\"\"\n",
        "        if len(S_pows) == 1:\n",
        "            Sp = S_pows[0].unsqueeze(0)                # (1, C, C)\n",
        "        else:\n",
        "            Sp = torch.stack(S_pows, dim=0)            # (K+1, C, C)\n",
        "\n",
        "        if self.P > 0:\n",
        "            A_mats = torch.einsum('pk,kij->pij', self.a, Sp)   # (P, C, C)\n",
        "        else:\n",
        "            A_mats = torch.zeros(0, self.C, self.C, device=Sp.device, dtype=Sp.dtype)\n",
        "\n",
        "        if self.Q > 0:\n",
        "            B_mats = torch.einsum('qk,kij->qij', self.b, Sp)   # (Q, C, C)\n",
        "        else:\n",
        "            B_mats = torch.zeros(0, self.C, self.C, device=Sp.device, dtype=Sp.dtype)\n",
        "\n",
        "        return A_mats, B_mats\n",
        "\n",
        "    def forward(self, x: torch.Tensor) -> torch.Tensor:\n",
        "        B, C, T = x.shape\n",
        "        assert C == self.C and T == self.T\n",
        "\n",
        "    # z-score over time\n",
        "        EPS = 1e-5\n",
        "\n",
        "        mu  = x.mean(1, keepdim=True)\n",
        "        sig = x.std(1, keepdim=True, unbiased=False)  # avoid NaNs when T==1\n",
        "        x0  = (x - mu) / (sig + EPS)\n",
        "\n",
        "\n",
        "    # graph stuff\n",
        "        S       = self._norm_shift(x.device)\n",
        "        S_pows  = self._S_powers(S)\n",
        "        A_mats, B_mats = self._build_poly_mats(S_pows)\n",
        "\n",
        "        Y_list = [x0[:, :, t] for t in range(T)]      # each (B, C)\n",
        "        E_list = []\n",
        "\n",
        "        for t in range(T):\n",
        "            yhat = x.new_zeros(B, C)\n",
        "            for p in range(1, self.P + 1):\n",
        "                tp = t - p\n",
        "                if tp >= 0:\n",
        "                    yhat = yhat + (Y_list[tp] @ A_mats[p-1].T)\n",
        "            for q in range(1, self.Q + 1):\n",
        "                tq = t - q\n",
        "                if tq >= 0:\n",
        "                    yhat = yhat + (E_list[tq] @ B_mats[q-1].T)\n",
        "            e_t = Y_list[t] - (self.ar_scale * yhat)\n",
        "            E_list.append(e_t)\n",
        "\n",
        "    # forecast y_T..y_{T+H-1} (future innovations = 0)\n",
        "        for t in range(T, T + self.H):\n",
        "            yhat = x.new_zeros(B, C)\n",
        "            for p in range(1, self.P + 1):\n",
        "                tp = t - p\n",
        "                if tp >= 0:\n",
        "                    yhat = yhat + (Y_list[tp] @ A_mats[p-1].T)\n",
        "            for q in range(1, self.Q + 1):\n",
        "                tq = t - q\n",
        "                if 0 <= tq < T:\n",
        "                    yhat = yhat + (E_list[tq] @ B_mats[q-1].T)\n",
        "            y_t = self.ar_scale * yhat\n",
        "            Y_list.append(y_t)\n",
        "\n",
        "        Yf = torch.stack(Y_list[T:T + self.H], dim=-1)  # (B, C, H)\n",
        "        return Yf\n",
        "\n",
        "class GGRNNForecastModel(nn.Module):\n",
        "    \"\"\"\n",
        "    SAME MODEL AS YOURS, but we run the EXACT SAME per-time GRNN-like\n",
        "    convolution TWICE (stacked). No extra gates; we keep the original\n",
        "    skip/conv fusion: skip = g_skip ⊙ x0, conv = g_conv ⊙ h2_seq.\n",
        "    \"\"\"\n",
        "    def __init__(self, C, T, W_C, H, hidden_dim=128, trainable_W_C=False):\n",
        "        super().__init__()\n",
        "        self.C, self.T, self.H = C, T, H\n",
        "        if trainable_W_C:\n",
        "            self.W_C = nn.Parameter(W_C)\n",
        "        else:\n",
        "            self.register_buffer('W_C', W_C)\n",
        "        self.g_skip = nn.Parameter(torch.ones(T))\n",
        "        self.g_conv = nn.Parameter(torch.ones(T))\n",
        "        self.mlp = nn.Sequential(nn.Linear(T, hidden_dim), nn.Tanh(),\n",
        "                                 nn.Linear(hidden_dim, H))\n",
        "\n",
        "    def _A_norm(self, x):\n",
        "        At_mat = self.W_C + torch.eye(self.C, device=x.device)\n",
        "        inv = At_mat.sum(1).clamp(min=EPS).pow(-0.5)\n",
        "        return inv.unsqueeze(1)*At_mat*inv.unsqueeze(0)\n",
        "\n",
        "    def _grnn_block(self, x_seq, A_norm):\n",
        "        \"\"\"\n",
        "        EXACT block you wrote:\n",
        "            for t in range(self.T):\n",
        "                x_t   = x0[:,:,t]\n",
        "                x_conv= x_t @ A_norm.T\n",
        "                h_conv= h   @ A_norm.T\n",
        "                z     = sigmoid(x_conv + h_conv)\n",
        "                r     = sigmoid(x_conv + h_conv)\n",
        "                h     = (1-z)*h + z*tanh(x_conv + r*h_conv)\n",
        "                seq.append(h.unsqueeze(-1))\n",
        "        \"\"\"\n",
        "        B, C, T = x_seq.shape\n",
        "        h = torch.zeros(B, C, device=x_seq.device)\n",
        "        seq = []\n",
        "        for t in range(T):\n",
        "            x_t   = x_seq[:,:,t]\n",
        "            x_conv= x_t @ A_norm.T\n",
        "            h_conv= h   @ A_norm.T\n",
        "            z = torch.sigmoid(x_conv + h_conv)\n",
        "            r = torch.sigmoid(x_conv + h_conv)\n",
        "            h = (1 - z) * h + z * torch.tanh(x_conv + r * h_conv)\n",
        "            seq.append(h.unsqueeze(-1))\n",
        "        return torch.cat(seq, dim=-1)  # (B, C, T)\n",
        "\n",
        "    def forward(self, x):\n",
        "        B = x.size(0)\n",
        "        mu = x.mean(1, keepdim=True); sig = x.std(1, keepdim=True)\n",
        "        x0 = (x - mu) / (sig + epsilon)\n",
        "\n",
        "        A_norm = self._A_norm(x)\n",
        "        skip = self.g_skip.view(1,1,self.T) * x0\n",
        "        h1_seq = skip+ self._grnn_block(x0, A_norm)\n",
        "\n",
        "        h2_seq =h1_seq+ self._grnn_block(h1_seq, A_norm)\n",
        "\n",
        "\n",
        "        conv = self.g_conv.view(1,1,self.T) * h2_seq\n",
        "        feats =  conv\n",
        "        return self.mlp(feats)  # (B, C, H)\n",
        "\n",
        "# 6) Training and evaluation\n",
        "results = {h:{} for h in horizons}\n",
        "\n",
        "for H in horizons:\n",
        "    Ns = N_raw - T - H + 1\n",
        "\n",
        "    idx_tv = int(0.8 * Ns)\n",
        "    idx_tr = int(train_val_split * idx_tv)\n",
        "    end_train_raw_ix = idx_tr + T + H - 1\n",
        "    train_slice = data[:end_train_raw_ix + 1]\n",
        "    means = train_slice.mean(axis=0, keepdims=True)\n",
        "    stds  = train_slice.std(axis=0, keepdims=True)\n",
        "    data_z = (data - means) / (stds + 1e-8)\n",
        "\n",
        "    X = np.stack([data_z[i:i+T].T for i in range(Ns)])          # (N, C, T)\n",
        "    y = np.stack([data_z[i+T:i+T+H].T for i in range(Ns)])      # (N, C, H)\n",
        "\n",
        "    idx_tv = int(0.8 * Ns)\n",
        "    idx_tr = int(train_val_split * idx_tv)\n",
        "\n",
        "    splits = {\n",
        "        'train': (X[:idx_tr], y[:idx_tr]),\n",
        "        'val':   (X[idx_tr:idx_tv], y[idx_tr:idx_tv]),\n",
        "        'test':  (X[idx_tv:], y[idx_tv:])\n",
        "    }\n",
        "\n",
        "    W = torch.corrcoef(torch.tensor(data_z.T, dtype=torch.float32))\n",
        "\n",
        "    for seed in seeds:\n",
        "        torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)\n",
        "        torch.backends.cudnn.deterministic=True; torch.backends.cudnn.benchmark=False\n",
        "        gen = torch.Generator().manual_seed(seed)\n",
        "\n",
        "        loaders = {}\n",
        "        for sp in ['train','val','test']:\n",
        "            Xi, yi = splits[sp]\n",
        "            loaders[sp] = DataLoader(\n",
        "                TensorDataset(torch.from_numpy(Xi).float(), torch.from_numpy(yi).float()),\n",
        "                batch_size=batch_size, shuffle=(sp=='train'), generator=gen if sp=='train' else None\n",
        "            )\n",
        "\n",
        "        model_configs = [\n",
        "            (\"GraphVar\", GraphVarForecastModel, W),\n",
        "            #(\"CPGraphST\", CPGraphSTForecastModel, W),\n",
        "            #(\"GVARMA\", GVARMAForecastModel, W),\n",
        "            #(\"GGRNN\", GGRNNForecastModel, W),\n",
        "        ]\n",
        "\n",
        "        for name, Model, Wc in model_configs:\n",
        "            if 'trainable_W_C' in Model.__init__.__code__.co_varnames:\n",
        "                model = Model(C, T, Wc, H, hidden_dim, trainable_W_C=False).to(device)\n",
        "            else:\n",
        "                model = Model(C, T, Wc, H, hidden_dim).to(device)\n",
        "\n",
        "            opt = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
        "            loss_fn = nn.MSELoss()\n",
        "            best_val, best_state, best_ep = float('inf'), None, 0\n",
        "            t0 = time.time()\n",
        "\n",
        "            print(f\"\\n=== Horizon={H} | Seed={seed} | Model={name} ===\")\n",
        "            for ep in range(1, num_epochs+1):\n",
        "                model.train()\n",
        "                acc = 0\n",
        "                for xb, yb in loaders['train']:\n",
        "                    xb, yb = xb.to(device), yb.to(device)\n",
        "                    opt.zero_grad()\n",
        "                    pred = model(xb)\n",
        "                    loss = loss_fn(pred, yb)\n",
        "                    loss.backward(); opt.step()\n",
        "                    acc += loss.item() * xb.size(0)\n",
        "                train_loss = acc / len(loaders['train'].dataset)\n",
        "\n",
        "                model.eval()\n",
        "                acc = 0\n",
        "                with torch.no_grad():\n",
        "                    for xb, yb in loaders['val']:\n",
        "                        xb, yb = xb.to(device), yb.to(device)\n",
        "                        acc += loss_fn(model(xb), yb).item() * xb.size(0)\n",
        "                val_loss = acc / len(loaders['val'].dataset)\n",
        "\n",
        "                print(f\"Ep {ep:03d} | Train MSE={train_loss:.4f} | Val MSE={val_loss:.4f}\")\n",
        "                if val_loss < best_val:\n",
        "                    best_val, best_state, best_ep = val_loss, copy.deepcopy(model.state_dict()), ep\n",
        "\n",
        "            model.load_state_dict(best_state)\n",
        "            acc = 0\n",
        "            with torch.no_grad():\n",
        "                for xb, yb in loaders['test']:\n",
        "                    xb, yb = xb.to(device), yb.to(device)\n",
        "                    acc += loss_fn(model(xb), yb).item() * xb.size(0)\n",
        "            test_loss = acc / len(loaders['test'].dataset)\n",
        "            duration = time.time() - t0\n",
        "\n",
        "            print(f\"--> Best Val Ep {best_ep} | Val={best_val:.4f} | Test MSE={test_loss:.4f} | Time={duration:.1f}s\")\n",
        "            results[H].setdefault(name, []).append(test_loss)\n",
        "\n",
        "# 7) Summary Table\n",
        "print(\"\\n===== Final Test MSE Summary =====\")\n",
        "for H in sorted(results):\n",
        "    print(f\"\\nHorizon {H}:\")\n",
        "    for name, vals in results[H].items():\n",
        "        mean = np.mean(vals); sd = np.std(vals, ddof=1)\n",
        "        print(f\"  {name}: {mean:.4f} ± {sd:.4f}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "# Sequence Models\n",
        "\n",
        "import os, time, random, copy\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import TensorDataset, DataLoader\n",
        "\n",
        "try:\n",
        "    from google.colab import drive  # type: ignore\n",
        "    drive.mount('/content/drive', force_remount=False)\n",
        "except Exception:\n",
        "    pass\n",
        "\n",
        "# -----------------------------\n",
        "# 1) Hyperparameters\n",
        "# -----------------------------\n",
        "T = 6\n",
        "horizons = [3, 6, 12]\n",
        "seeds = [124, 14, 124235]\n",
        "batch_size = 1280\n",
        "num_epochs = 200\n",
        "learning_rate = 1e-3\n",
        "hidden_dim = 128\n",
        "train_val_split = 0.8\n",
        "EPS = 1e-5\n",
        "epsilon = 1e-5\n",
        "\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "\n",
        "# -----------------------------\n",
        "# 2) Load data\n",
        "# -----------------------------\n",
        "#df = pd.read_csv('MacArthur_coarse.csv')\n",
        "#df = df.drop(columns=['time'], errors='ignore')\n",
        "#data = df.values.astype(np.float32)\n",
        "data = pd.read_csv('PEMS-BAY.csv').select_dtypes(include=[np.number]).dropna().values  # shape: (N_raw, C)\n",
        "\n",
        "#means = data.mean(axis=0, keepdims=True)\n",
        "#stds = data.std(axis=0, keepdims=True)\n",
        "#data_z = (data - means) / (stds + 1e-8)\n",
        "N_raw, C = data.shape\n",
        "\n",
        "# Static channel similarity graph (C x C) from correlations (used by GVNN only)\n",
        "#W = torch.corrcoef(torch.tensor(data_z.T, dtype=torch.float32))\n",
        "\n",
        "# -----------------------------\n",
        "# 3) Graph utilities (for GVNN)\n",
        "# -----------------------------\n",
        "def renormalize_dynamic(A, eps=EPS):\n",
        "    \"\"\"\n",
        "    A: (B, C, C, T) dynamic affinity\n",
        "    Returns symmetric renorm:  D^{-1/2} (A + I) D^{-1/2}\n",
        "    \"\"\"\n",
        "    I = torch.eye(A.size(1), device=A.device)[None, :, :, None]   # (1, C, C, 1)\n",
        "    At = A + I\n",
        "    deg = At.sum(2, keepdim=True)                                 # (B, C, 1, T)\n",
        "    inv = deg.clamp(min=eps).pow(-0.5)\n",
        "    S = inv * At * inv.transpose(1, 2)                            # symmetric renorm\n",
        "    return S\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "def graph_variate(x, fun='corr', Zave=True, eps=EPS):\n",
        "    \"\"\"\n",
        "    x: (B, C, T)\n",
        "    returns normalized dynamic adjacency Om: (B, C, C, T)\n",
        "    \"\"\"\n",
        "    B, C_, T_ = x.shape\n",
        "    if Zave:\n",
        "        mu = x.mean(1, keepdim=True)\n",
        "        sig = x.std(1, keepdim=True, unbiased=True)\n",
        "        x = (x - mu) / (sig + eps)\n",
        "\n",
        "    if fun == 'sqd':\n",
        "        D = x-x.mean(1,keepdim=True)\n",
        "        Om = (x.unsqueeze(2) - x.unsqueeze(1)).pow(2)# + (D.unsqueeze(2) * D.unsqueeze(1))\n",
        "    elif fun == 'corr':\n",
        "        D = x#-x.mean(2,keepdim=True)\n",
        "        Om = torch.abs(D.unsqueeze(2) * D.unsqueeze(1)) #+ (x.unsqueeze(2) - x.unsqueeze(1)).pow(2)\n",
        "    elif fun == 'rbf':\n",
        "        diff2 = (x.unsqueeze(2) - x.unsqueeze(1)).pow(2)\n",
        "        Om = torch.exp(-0.1 * diff2)\n",
        "    elif fun == 'full':\n",
        "        Om = torch.ones(B, C_, C_, T_, device=x.device)\n",
        "    else:\n",
        "        raise ValueError(fun)\n",
        "\n",
        "    return (Om)\n",
        "\n",
        "def graph_conv(x, Om):\n",
        "    \"\"\"\n",
        "    x:  (B, C, T)\n",
        "    Om: (B, C, C, T) dynamic normalized adjacency\n",
        "    returns: (B, C, T)\n",
        "    \"\"\"\n",
        "    Om_t = Om.permute(0, 3, 1, 2)               # (B, T, C, C)\n",
        "    sig_t = x.permute(0, 2, 1).unsqueeze(-1)    # (B, T, C, 1)\n",
        "    out = torch.matmul(Om_t, sig_t).squeeze(-1) # (B, T, C)\n",
        "    return out.permute(0, 2, 1)                 # (B, C, T)\n",
        "\n",
        "def graph_conv_bmm(x, A):\n",
        "    \"\"\"\n",
        "    x: (B, N, T)\n",
        "    A: (B, N, N, T)\n",
        "    \"\"\"\n",
        "    B, N, T = x.shape\n",
        "    # (B*T, N, N) @ (B*T, N, 1) -> (B*T, N, 1)\n",
        "    A_bt = A.permute(0, 3, 1, 2).contiguous().view(B*T, N, N)\n",
        "    x_bt = x.permute(0, 2, 1).contiguous().view(B*T, N, 1)\n",
        "    y_bt = torch.bmm(A_bt, x_bt).view(B, T, N)      # (B, T, N)\n",
        "    return y_bt.permute(0, 2, 1).contiguous()       # (B, N, T)\n",
        "\n",
        "# -----------------------------\n",
        "# 4) Models — all produce (B, C, H)\n",
        "# -----------------------------\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "\n",
        "class GVNN(nn.Module):\n",
        "    \"\"\"\n",
        "    Graph-Variate Neural Network (GVNN).\n",
        "    Positivity-only supports via softplus; optional symmetry + zero diagonal.\n",
        "    If return_features=True, returns the pre-MLP features (B, C, T) instead of final predictions.\n",
        "    \"\"\"\n",
        "    def __init__(self, C, T, W_C, H, hidden_dim=128, trainable_W_C=True,\n",
        "                 fun='corr', return_features=False, pos_eps=1e-6,\n",
        "                 symmetrize=True, zero_diag=True):\n",
        "        super().__init__()\n",
        "        self.C, self.T, self.H = C, T, H\n",
        "        self.return_features = return_features\n",
        "        self.fun = fun\n",
        "        self.pos_eps = pos_eps\n",
        "        self.symmetrize = symmetrize\n",
        "        self.zero_diag = zero_diag\n",
        "\n",
        "        if trainable_W_C:\n",
        "            self.W_raw  = nn.Parameter(W_C.clone().detach())    # stage 1\n",
        "            self.W1_raw = nn.Parameter(W_C.clone().detach())    # stage 2\n",
        "        else:\n",
        "            self.register_buffer('W_raw',  W_C.clone().detach())\n",
        "            self.register_buffer('W1_raw', W_C.clone().detach())\n",
        "\n",
        "        eye = torch.eye(C, dtype=torch.bool)\n",
        "        self.register_buffer('offdiag_mask', ~eye)\n",
        "\n",
        "        self.g_conv  = nn.Parameter(torch.ones(T))\n",
        "        self.g_conv1 = nn.Parameter(torch.ones(T))\n",
        "        self.g_skip  = nn.Parameter(torch.ones(T))\n",
        "        self.g_skip1 = nn.Parameter(torch.ones(T))\n",
        "\n",
        "        self.theta  = nn.Linear(T, T)\n",
        "        self.theta1 = nn.Linear(T, T)\n",
        "        self.act = nn.LeakyReLU(0.2)\n",
        "\n",
        "        self.mlp = nn.Sequential(\n",
        "            nn.Linear(T, hidden_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Linear(hidden_dim, H)\n",
        "        )\n",
        "\n",
        "    @staticmethod\n",
        "    def _softplus_pos(X, eps=1e-6):\n",
        "        return F.softplus(X) + eps\n",
        "\n",
        "    def _make_positive_support(self, W_raw):\n",
        "        W = self._softplus_pos(W_raw, self.pos_eps)\n",
        "\n",
        "        if self.symmetrize:\n",
        "            W = 0.5 * (W + W.transpose(0, 1))\n",
        "\n",
        "        if self.zero_diag:\n",
        "            W = W * self.offdiag_mask\n",
        "\n",
        "        return W\n",
        "\n",
        "    def forward(self, x):\n",
        "        # x: (B, C, T)\n",
        "        mu = x.mean(1, keepdim=True)\n",
        "        sig = x.std(1, keepdim=True)\n",
        "        x0 = (x - mu) / (sig + epsilon)\n",
        "\n",
        "        # --- Stage 1: corr-based operator with positive support ---\n",
        "        Wpos = self._make_positive_support(self.W_raw)  # (C, C), positive\n",
        "        Om   = graph_variate(x0, fun='sqd', Zave=False, eps=EPS)  # (B, C, C, T)\n",
        "        Om   = renormalize_dynamic(Om * Wpos.unsqueeze(0).unsqueeze(-1))\n",
        "\n",
        "        skip = self.g_skip.view(1, 1, self.T) * x0\n",
        "        conv = self.g_conv.view(1, 1, self.T) * graph_conv(x0, Om)\n",
        "        feats = self.act(self.theta(skip + conv))  # (B, C, T)\n",
        "\n",
        "        # --- Stage 2: sqd-based operator with positive support ---\n",
        "        Wpos1 = self._make_positive_support(self.W1_raw)\n",
        "        Om1   = graph_variate(feats, fun='sqd', Zave=False, eps=EPS)\n",
        "        Om1   = renormalize_dynamic(Om1 * Wpos1.unsqueeze(0).unsqueeze(-1))\n",
        "\n",
        "        final = self.act(\n",
        "            self.theta1(\n",
        "                self.g_skip1.view(1, 1, self.T) * feats +\n",
        "                self.g_conv1.view(1, 1, self.T) * graph_conv(feats, Om1)\n",
        "            )\n",
        "        )\n",
        "\n",
        "        if self.return_features:\n",
        "            return feats\n",
        "        return self.mlp(final)  # (B, C, H)\n",
        "\n",
        "class LSTMForecaster(nn.Module):\n",
        "\n",
        "    def __init__(self, C, T, H, hidden_dim=128, num_layers=2, bidir=False):\n",
        "        super().__init__()\n",
        "        self.C, self.T, self.H = C, T, H\n",
        "        self.lstm = nn.LSTM(\n",
        "            input_size=C,\n",
        "            hidden_size=hidden_dim,\n",
        "            num_layers=num_layers,\n",
        "            batch_first=True,\n",
        "            bidirectional=bidir\n",
        "        )\n",
        "        lstm_out_dim = hidden_dim * (2 if bidir else 1)\n",
        "        self.head = nn.Sequential(\n",
        "            nn.Linear(lstm_out_dim, lstm_out_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Linear(lstm_out_dim, C * H)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        # x: (B, C, T) -> LSTM over (B, T, C)\n",
        "        mu = x.mean(1, keepdim=True); sig = x.std(1, keepdim=True)\n",
        "        x0 = (x - mu) / (sig + epsilon)\n",
        "        x_seq = x0.permute(0, 2, 1)        # (B, T, C)\n",
        "        out, _ = self.lstm(x_seq)          # (B, T, hidden)\n",
        "        h_last = out[:, -1, :]             # (B, hidden)\n",
        "        y = self.head(h_last)              # (B, C*H)\n",
        "        return y.view(-1, self.C, self.H)  # (B, C, H)\n",
        "\n",
        "# -----------------------------\n",
        "# 5) GAT (Fully Connected, Time-as-Features)\n",
        "# -----------------------------\n",
        "\n",
        "# -----------------------------\n",
        "# 5c) Transformer (single layer, 1 head)\n",
        "# -----------------------------\n",
        "class SinusoidalPositionalEncoding(nn.Module):\n",
        "\n",
        "    def __init__(self, d_model, max_len=10000):\n",
        "        super().__init__()\n",
        "        pe = torch.zeros(max_len, d_model)\n",
        "        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
        "        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))\n",
        "        pe[:, 0::2] = torch.sin(position * div_term)\n",
        "        pe[:, 1::2] = torch.cos(position * div_term)\n",
        "        pe = pe.unsqueeze(0)  # (1, max_len, d_model)\n",
        "        self.register_buffer('pe', pe, persistent=False)\n",
        "\n",
        "    def forward(self, x):\n",
        "        # x: (B, T, d_model)\n",
        "        T = x.size(1)\n",
        "        return self.pe[:, :T, :]\n",
        "\n",
        "class TransformerForecaster(nn.Module):\n",
        "\n",
        "    def __init__(self, C, T, H, d_model=128, dim_ff=256, dropout=0.0):\n",
        "        super().__init__()\n",
        "        self.C, self.T, self.H = C, T, H\n",
        "        self.embed = nn.Linear(C, d_model, bias=False)\n",
        "        self.posenc = SinusoidalPositionalEncoding(d_model, max_len=10000)\n",
        "        self.encoder = nn.TransformerEncoder(\n",
        "            nn.TransformerEncoderLayer(\n",
        "                d_model=d_model, nhead=1, dim_feedforward=dim_ff,\n",
        "                dropout=dropout, activation='gelu', batch_first=True, norm_first=True\n",
        "            ),\n",
        "            num_layers=2\n",
        "        )\n",
        "        self.act = nn.LeakyReLU(0.2)\n",
        "        self.head = nn.Sequential(\n",
        "            nn.Linear(d_model, d_model),\n",
        "            nn.Tanh(),\n",
        "            nn.Linear(d_model, C * H)\n",
        "        )\n",
        "\n",
        "    def forward(self, x):\n",
        "        # Normalize like the others\n",
        "        mu = x.mean(1, keepdim=True); sig = x.std(1, keepdim=True)\n",
        "        x0 = (x - mu) / (sig + epsilon)\n",
        "\n",
        "        # (B, C, T) -> (B, T, C)\n",
        "        seq = x0.permute(0, 2, 1)                 # (B, T, C)\n",
        "        z = self.embed(seq)                       # (B, T, d_model)\n",
        "        z = z + self.posenc(z)                    # add positional encoding\n",
        "        z = self.encoder(z)                       # (B, T, d_model)\n",
        "        z = self.act(z[:, -1, :])                 # take last token (horizon readout)\n",
        "        y = self.head(z).view(-1, self.C, self.H) # (B, C, H)\n",
        "        return y\n",
        "\n",
        "# -----------------------------\n",
        "# 6) Training and evaluation\n",
        "# -----------------------------\n",
        "results = {h: {} for h in horizons}\n",
        "\n",
        "for H in horizons:\n",
        "    Ns = N_raw - T - H + 1\n",
        "\n",
        "    idx_tv = int(0.8 * Ns)\n",
        "    idx_tr = int(train_val_split * idx_tv)\n",
        "    end_train_raw_ix = idx_tr + T + H - 1                      # inclusive raw index used by train samples\n",
        "    train_slice = data[:end_train_raw_ix + 1]                   # rows [0 .. end_train_raw_ix]\n",
        "    means = train_slice.mean(axis=0, keepdims=True)\n",
        "    stds  = train_slice.std(axis=0, keepdims=True)\n",
        "    data_z = (data - means) / (stds + 1e-8)\n",
        "\n",
        "    X = np.stack([data_z[i:i+T].T for i in range(Ns)])          # (N, C, T)\n",
        "    y = np.stack([data_z[i+T:i+T+H].T for i in range(Ns)])      # (N, C, H)\n",
        "\n",
        "    idx_tv = int(0.8 * Ns)\n",
        "    idx_tr = int(train_val_split * idx_tv)\n",
        "\n",
        "    splits = {\n",
        "        'train': (X[:idx_tr], y[:idx_tr]),\n",
        "        'val':   (X[idx_tr:idx_tv], y[idx_tr:idx_tv]),\n",
        "        'test':  (X[idx_tv:], y[idx_tv:])\n",
        "    }\n",
        "\n",
        "    W = torch.corrcoef(torch.tensor(data_z.T, dtype=torch.float32))\n",
        "    for seed in seeds:\n",
        "        torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)\n",
        "        torch.backends.cudnn.deterministic = True\n",
        "        torch.backends.cudnn.benchmark = False\n",
        "        gen = torch.Generator().manual_seed(seed)\n",
        "\n",
        "        loaders = {}\n",
        "        for sp in ['train', 'val', 'test']:\n",
        "            Xi, yi = splits[sp]\n",
        "            loaders[sp] = DataLoader(\n",
        "                TensorDataset(torch.from_numpy(Xi).float(), torch.from_numpy(yi).float()),\n",
        "                batch_size=batch_size,\n",
        "                shuffle=(sp == 'train'),\n",
        "                generator=gen if sp == 'train' else None\n",
        "            )\n",
        "\n",
        "        # ----- Models to train -----\n",
        "        model_constructors = [\n",
        "            (\"Transformer\", lambda: TransformerForecaster(C, T, H, d_model=hidden_dim, dim_ff=hidden_dim*2, dropout=0.0)),\n",
        "\n",
        "            (\"GVNN\",        lambda: GVNN(C, T, W, H, hidden_dim, trainable_W_C=True, return_features=False)),\n",
        "\n",
        "            (\"LSTM\",        lambda: LSTMForecaster(C, T, H, hidden_dim)),\n",
        "        ]\n",
        "\n",
        "        for name, ctor in model_constructors:\n",
        "            model = ctor().to(device)\n",
        "            opt = torch.optim.Adam(model.parameters(), lr=learning_rate)\n",
        "            loss_fn = nn.MSELoss()\n",
        "            best_val, best_state, best_ep = float('inf'), None, 0\n",
        "            t0 = time.time()\n",
        "\n",
        "            print(f\"\\n=== Horizon={H} | Seed={seed} | Model={name} ===\")\n",
        "            for ep in range(1, num_epochs + 1):\n",
        "                # Train\n",
        "                model.train()\n",
        "                acc = 0.0\n",
        "                for xb, yb in loaders['train']:\n",
        "                    xb, yb = xb.to(device), yb.to(device)\n",
        "                    opt.zero_grad()\n",
        "                    pred = model(xb)\n",
        "                    loss = loss_fn(pred, yb)\n",
        "                    loss.backward()\n",
        "                    opt.step()\n",
        "                    acc += loss.item() * xb.size(0)\n",
        "                train_loss = acc / len(loaders['train'].dataset)\n",
        "                # Validate\n",
        "                model.eval()\n",
        "                acc = 0.0\n",
        "                mae_acc = 0.0\n",
        "                smape_acc = 0.0\n",
        "                with torch.no_grad():\n",
        "                    for xb, yb in loaders['val']:\n",
        "                        xb, yb = xb.to(device), yb.to(device)\n",
        "                        out = model(xb)\n",
        "                        acc += loss_fn(out, yb).item() * xb.size(0)\n",
        "                        mae_acc += torch.mean(torch.abs(out - yb)).item() * xb.size(0)\n",
        "                        denom = (out.abs() + yb.abs()).clamp(min=1e-8)\n",
        "                        smape_acc += torch.mean(2.0 * (out - yb).abs() / denom).item() * xb.size(0)\n",
        "                val_loss = acc / len(loaders['val'].dataset)\n",
        "                val_mae = mae_acc / len(loaders['val'].dataset)\n",
        "                val_smape = smape_acc / len(loaders['val'].dataset)\n",
        "\n",
        "                print(f\"Ep {ep:03d} | Train MSE={train_loss:.4f} | Val MSE={val_loss:.4f} | Val MAE={val_mae:.4f} | Val sMAPE={val_smape:.4f}\")\n",
        "                if val_loss < best_val:\n",
        "                    best_val, best_state, best_ep = val_loss, copy.deepcopy(model.state_dict()), ep\n",
        "\n",
        "            # Test best checkpoint\n",
        "            model.load_state_dict(best_state)\n",
        "            acc = 0.0\n",
        "            mae_acc = 0.0\n",
        "            smape_acc = 0.0\n",
        "            with torch.no_grad():\n",
        "                for xb, yb in loaders['test']:\n",
        "                    xb, yb = xb.to(device), yb.to(device)\n",
        "                    out = model(xb)\n",
        "                    acc += loss_fn(out, yb).item() * xb.size(0)\n",
        "                    mae_acc += torch.mean(torch.abs(out - yb)).item() * xb.size(0)\n",
        "                    denom = (out.abs() + yb.abs()).clamp(min=1e-8)\n",
        "                    smape_acc += torch.mean(2.0 * (out - yb).abs() / denom).item() * xb.size(0)\n",
        "            test_loss = acc / len(loaders['test'].dataset)\n",
        "            test_mae = mae_acc / len(loaders['test'].dataset)\n",
        "            test_smape = smape_acc / len(loaders['test'].dataset)\n",
        "            duration = time.time() - t0\n",
        "\n",
        "            print(f\"--> Best Val Ep {best_ep} | Val={best_val:.4f} | Test MSE={test_loss:.4f} | Test MAE={test_mae:.4f} | Test sMAPE={test_smape:.4f} | Time={duration:.1f}s\")\n",
        "            results[H].setdefault(name, []).append(test_loss)\n",
        "\n",
        "\n",
        "# -----------------------------\n",
        "# 7) Summary Table\n",
        "# -----------------------------\n",
        "print(\"\\n===== Final Test MSE Summary =====\")\n",
        "for H in sorted(results):\n",
        "    print(f\"\\nHorizon {H}:\")\n",
        "    for name, vals in results[H].items():\n",
        "        mean = np.mean(vals); sd = np.std(vals, ddof=1) if len(vals) > 1 else 0.0\n",
        "        print(f\"  {name}: {mean:.4f} ± {sd:.4f}\")\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 793
        },
        "id": "e3j56_05tubf",
        "outputId": "8c9d5c27-b6ae-431e-d165-e90f35863e90"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.norm_first was True\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "=== Horizon=3 | Seed=124 | Model=Transformer ===\n",
            "Ep 001 | Train MSE=0.6837 | Val MSE=0.6633 | Val MAE=0.4295 | Val sMAPE=0.8294\n",
            "Ep 002 | Train MSE=0.3897 | Val MSE=0.5650 | Val MAE=0.3851 | Val sMAPE=0.7604\n",
            "Ep 003 | Train MSE=0.3389 | Val MSE=0.5336 | Val MAE=0.3639 | Val sMAPE=0.7110\n",
            "Ep 004 | Train MSE=0.3087 | Val MSE=0.5021 | Val MAE=0.3536 | Val sMAPE=0.6952\n",
            "Ep 005 | Train MSE=0.2870 | Val MSE=0.4799 | Val MAE=0.3474 | Val sMAPE=0.6917\n",
            "Ep 006 | Train MSE=0.2715 | Val MSE=0.4613 | Val MAE=0.3440 | Val sMAPE=0.6904\n",
            "Ep 007 | Train MSE=0.2576 | Val MSE=0.4394 | Val MAE=0.3430 | Val sMAPE=0.7023\n",
            "Ep 008 | Train MSE=0.2455 | Val MSE=0.4240 | Val MAE=0.3382 | Val sMAPE=0.6819\n",
            "Ep 009 | Train MSE=0.2340 | Val MSE=0.4086 | Val MAE=0.3350 | Val sMAPE=0.6898\n",
            "Ep 010 | Train MSE=0.2241 | Val MSE=0.3998 | Val MAE=0.3339 | Val sMAPE=0.6885\n",
            "Ep 011 | Train MSE=0.2162 | Val MSE=0.3880 | Val MAE=0.3301 | Val sMAPE=0.6803\n",
            "Ep 012 | Train MSE=0.2080 | Val MSE=0.3839 | Val MAE=0.3275 | Val sMAPE=0.6647\n",
            "Ep 013 | Train MSE=0.2015 | Val MSE=0.3758 | Val MAE=0.3272 | Val sMAPE=0.6782\n",
            "Ep 014 | Train MSE=0.1952 | Val MSE=0.3677 | Val MAE=0.3258 | Val sMAPE=0.6771\n",
            "Ep 015 | Train MSE=0.1913 | Val MSE=0.3635 | Val MAE=0.3216 | Val sMAPE=0.6666\n",
            "Ep 016 | Train MSE=0.1853 | Val MSE=0.3563 | Val MAE=0.3180 | Val sMAPE=0.6644\n",
            "Ep 017 | Train MSE=0.1807 | Val MSE=0.3583 | Val MAE=0.3196 | Val sMAPE=0.6624\n",
            "Ep 018 | Train MSE=0.1770 | Val MSE=0.3488 | Val MAE=0.3164 | Val sMAPE=0.6671\n",
            "Ep 019 | Train MSE=0.1724 | Val MSE=0.3434 | Val MAE=0.3164 | Val sMAPE=0.6677\n",
            "Ep 020 | Train MSE=0.1690 | Val MSE=0.3426 | Val MAE=0.3164 | Val sMAPE=0.6593\n"
          ]
        },
        {
          "output_type": "error",
          "ename": "KeyboardInterrupt",
          "evalue": "",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-2524085365.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m    410\u001b[0m                     \u001b[0mxb\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mxb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    411\u001b[0m                     \u001b[0mopt\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 412\u001b[0;31m                     \u001b[0mpred\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mxb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    413\u001b[0m                     \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mloss_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpred\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0myb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    414\u001b[0m                     \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1771\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1772\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1773\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1774\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1775\u001b[0m     \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1782\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1783\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1785\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1786\u001b[0m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/tmp/ipython-input-2524085365.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    332\u001b[0m         \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0membed\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mseq\u001b[0m\u001b[0;34m)\u001b[0m                       \u001b[0;31m# (B, T, d_model)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    333\u001b[0m         \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mz\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mposenc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m                    \u001b[0;31m# add positional encoding\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 334\u001b[0;31m         \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m                       \u001b[0;31m# (B, T, d_model)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    335\u001b[0m         \u001b[0mz\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mact\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m                 \u001b[0;31m# take last token (horizon readout)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    336\u001b[0m         \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mz\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mC\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mH\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# (B, C, H)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1771\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1772\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1773\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1774\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1775\u001b[0m     \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1782\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1783\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1785\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1786\u001b[0m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, src, mask, src_key_padding_mask, is_causal)\u001b[0m\n\u001b[1;32m    522\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    523\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmod\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlayers\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 524\u001b[0;31m             output = mod(\n\u001b[0m\u001b[1;32m    525\u001b[0m                 \u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    526\u001b[0m                 \u001b[0msrc_mask\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1771\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1772\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1773\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1774\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1775\u001b[0m     \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1782\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1783\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1785\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1786\u001b[0m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, src, src_mask, src_key_padding_mask, is_causal)\u001b[0m\n\u001b[1;32m    929\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msrc_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msrc_key_padding_mask\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_causal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mis_causal\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    930\u001b[0m             )\n\u001b[0;32m--> 931\u001b[0;31m             \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ff_block\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnorm2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    932\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    933\u001b[0m             x = self.norm1(\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py\u001b[0m in \u001b[0;36m_ff_block\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    960\u001b[0m     \u001b[0;31m# feed forward block\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    961\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_ff_block\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 962\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mactivation\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    963\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdropout2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    964\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_wrapped_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1771\u001b[0m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_compiled_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# type: ignore[misc]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1772\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1773\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_impl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1774\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1775\u001b[0m     \u001b[0;31m# torchrec tests the code consistency with the following code\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1782\u001b[0m                 \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_pre_hooks\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0m_global_backward_hooks\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1783\u001b[0m                 or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1784\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1785\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1786\u001b[0m         \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/torch/nn/modules/linear.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    124\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 125\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbias\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    126\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    127\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mstr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#!/usr/bin/env python\n",
        "# -*- coding: utf-8 -*-\n",
        "\n",
        "\n",
        "import time, math\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import TensorDataset, DataLoader\n",
        "from sklearn.metrics import cohen_kappa_score\n",
        "from sklearn.model_selection import StratifiedKFold\n",
        "\n",
        "from braindecode.datasets.moabb import MOABBDataset\n",
        "from braindecode.preprocessing import (\n",
        "    preprocess, create_windows_from_events, exponential_moving_standardize, Preprocessor\n",
        ")\n",
        "\n",
        "# -----------------------\n",
        "# Hyperparams & settings\n",
        "# -----------------------\n",
        "SEED = 42\n",
        "torch.manual_seed(SEED); np.random.seed(SEED)\n",
        "torch.backends.cudnn.deterministic = True\n",
        "torch.backends.cudnn.benchmark = False\n",
        "\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "print(\"Device:\", device)\n",
        "\n",
        "low_cut_hz = 0.01\n",
        "high_cut_hz = 20.0\n",
        "factor_new = 1e-3\n",
        "init_block_size = 1000\n",
        "factor = 1e6\n",
        "trial_start_offset_seconds = -0.5\n",
        "\n",
        "epsilon = 1e-15   # normalization\n",
        "EPS = 1e-6        # dynamic renormalization\n",
        "\n",
        "num_epochs   = 100\n",
        "batch_size   = 64\n",
        "num_classes  = 4       # BNCI2014_001 is 4-class MI\n",
        "H_embed      = 64      # per-node embedding dim for graph models\n",
        "N_SPLITS     = 5\n",
        "\n",
        "# --------------------------------\n",
        "# Braindecode/MOABB preprocessing\n",
        "# --------------------------------\n",
        "preprocessors = [\n",
        "    Preprocessor('pick_types', eeg=True, meg=False, stim=False),\n",
        "    Preprocessor(lambda data: data * factor),\n",
        "    Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),\n",
        "    Preprocessor(exponential_moving_standardize,\n",
        "                 factor_new=factor_new, init_block_size=init_block_size)\n",
        "]\n",
        "\n",
        "def convert_to_tensors(windows_dataset):\n",
        "    X, y = [], []\n",
        "    for w in windows_dataset:\n",
        "        X.append(w[0]); y.append(w[1])\n",
        "    X = np.stack(X); y = np.array(y)\n",
        "    y = y - y.min()  # 0..K-1\n",
        "    return torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.long)\n",
        "\n",
        "# -------------------------\n",
        "# Build global adjacency W\n",
        "# -------------------------\n",
        "def compute_global_adjacency(x):\n",
        "    \"\"\"\n",
        "    Compute |corr| support W from TRAIN data in a fold:\n",
        "      x: (N, C, T) → (C, C)\n",
        "    \"\"\"\n",
        "    mu = x.mean(dim=2, keepdim=True)           # (N, C, 1)\n",
        "    Z  = x - mu                                # (N, C, T)\n",
        "    N, C, T = Z.shape\n",
        "    cov = torch.einsum('nct,ndt->cd', Z, Z) / (N * T + 1e-9)  # (C,C)\n",
        "    std = torch.sqrt(torch.clamp(torch.diag(cov), min=1e-9))  # (C,)\n",
        "    corr = cov / (std.unsqueeze(1) * std.unsqueeze(0) + 1e-9)\n",
        "    corr = corr.abs()\n",
        "    corr.fill_diagonal_(0.0)\n",
        "    return corr\n",
        "\n",
        "# ----------------\n",
        "# Graph operations\n",
        "# ----------------\n",
        "def graph_variate(x, fun='sqd', Zave=True, eps=EPS):\n",
        "    \"\"\"\n",
        "    x: (B, C, T) → Om: (B, C, C, T)\n",
        "    \"\"\"\n",
        "    B, C, T = x.shape\n",
        "    x_ = x\n",
        "    if Zave:\n",
        "        mu = x_.mean(dim=0, keepdim=True)\n",
        "        sig = x_.std(dim=0, keepdim=True, unbiased=True)\n",
        "        x_ = (x_ - mu) / (sig + eps)\n",
        "\n",
        "    if fun == 'sqd':\n",
        "        Om = (x_.unsqueeze(2) - x_.unsqueeze(1)).pow(2)\n",
        "    elif fun == 'corr':\n",
        "        D = x_ - x_.mean(dim=2, keepdim=True)\n",
        "        Om = torch.abs(D.unsqueeze(2) * D.unsqueeze(1))\n",
        "    elif fun == 'full':\n",
        "        Om = torch.ones(B, C, C, T, device=x.device, dtype=x.dtype)\n",
        "    else:\n",
        "        raise ValueError(fun)\n",
        "    return Om\n",
        "\n",
        "def renormalize_dynamic(A, eps=EPS):\n",
        "    \"\"\"\n",
        "    A: (B, C, C, T) → symmetric degree normalization per time-slice.\n",
        "    \"\"\"\n",
        "    I = torch.eye(A.size(1), device=A.device).view(1, A.size(1), A.size(2), 1)\n",
        "    At = A + I\n",
        "    deg = At.sum(2, keepdim=True)\n",
        "    inv = deg.clamp(min=eps).pow(-0.5)\n",
        "    return inv * At * inv.transpose(1, 2)\n",
        "\n",
        "def graph_conv_bmm(x, A):\n",
        "    \"\"\"\n",
        "    x: (B, C, T), A: (B, C, C, T) → (B, C, T)\n",
        "    \"\"\"\n",
        "    B, C, T = x.shape\n",
        "    A_bt = A.permute(0, 3, 1, 2).contiguous().view(B*T, C, C)  # (B*T,C,C)\n",
        "    x_bt = x.permute(0, 2, 1).contiguous().view(B*T, C, 1)     # (B*T,C,1)\n",
        "    y_bt = torch.bmm(A_bt, x_bt).view(B, T, C)                 # (B,T,C)\n",
        "    return y_bt.permute(0, 2, 1).contiguous()\n",
        "\n",
        "# -------------\n",
        "# Graph-based classifiers (B, C, T) → (B, K)\n",
        "# -------------\n",
        "class GraphVarClassifier(nn.Module):\n",
        "    \"\"\" Dynamic per-time adjacency Om(x) ⊙ W_C → renorm → gated mix → MLP(T→H) → Linear(C*H→K) \"\"\"\n",
        "    def __init__(self, C, T, W_C, num_classes, H=64, hidden_dim=128, fun='corr', ZAVE=True):\n",
        "        super().__init__()\n",
        "        self.C, self.T, self.H = C, T, H\n",
        "        self.fun, self.ZAVE = fun, ZAVE\n",
        "        self.register_buffer('W_C', W_C)\n",
        "\n",
        "        self.g_conv = nn.Parameter(torch.ones(T))\n",
        "        self.g_skip = nn.Parameter(torch.ones(T))\n",
        "        self.theta  = nn.Linear(T, T, bias=False)\n",
        "\n",
        "        self.feat = nn.Sequential(\n",
        "            nn.Linear(T, hidden_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Linear(hidden_dim, H)\n",
        "        )\n",
        "        self.cls = nn.Linear(C * H, num_classes)\n",
        "\n",
        "    def forward(self, x):\n",
        "        mu = x.mean(1, keepdim=True); sig = x.std(1, keepdim=True)\n",
        "        x0 = (x - mu) / (sig + EPS)\n",
        "\n",
        "        Om = graph_variate(x0, fun='sqd', Zave=self.ZAVE, eps=EPS)\n",
        "        A  = renormalize_dynamic(Om * self.W_C.unsqueeze(0).unsqueeze(-1))\n",
        "\n",
        "        h  = self.theta(self.g_skip.view(1,1,self.T)*x0\n",
        "                        + self.g_conv.view(1,1,self.T)*graph_conv_bmm(x0, A))\n",
        "        z  = self.feat(graph_conv_bmm(x0, A))                               # (B, C, H)\n",
        "        logits = self.cls(z.reshape(z.size(0), -1))\n",
        "        return logits\n",
        "\n",
        "class CPGraphSTClassifier(nn.Module):\n",
        "    \"\"\" Kronecker-sum space-time diffusion + MLP(T→H) + Linear(C*H→K) \"\"\"\n",
        "    def __init__(self, C, T, W_C, num_classes, H=64, hidden_dim=128):\n",
        "        super().__init__()\n",
        "        self.C, self.T, self.H = C, T, H\n",
        "        self.register_buffer('W_C', W_C)\n",
        "        self.register_buffer('_I_C', torch.eye(C))\n",
        "        self.register_buffer('_I_T', torch.eye(T))\n",
        "\n",
        "        A_time = torch.zeros(T, T)\n",
        "        for i in range(T-1):\n",
        "            A_time[i, i+1] = 1.0\n",
        "            A_time[i+1, i] = 1.0\n",
        "        self.register_buffer('A_time', A_time)\n",
        "\n",
        "        self.g_skip = nn.Parameter(torch.ones(T))\n",
        "        self.g_conv = nn.Parameter(torch.ones(T))\n",
        "        self.theta  = nn.Linear(T, T, bias=False)\n",
        "\n",
        "        self.feat = nn.Sequential(\n",
        "            nn.Linear(T, hidden_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Linear(hidden_dim, H)\n",
        "        )\n",
        "        self.cls = nn.Linear(C * H, num_classes)\n",
        "\n",
        "    def _A_norm(self, device):\n",
        "        # Use W_C (from corr) for spatial part and A_time for temporal part\n",
        "        A_ts = torch.kron(self.A_time, self._I_C)\n",
        "        A_st = torch.kron(self._I_T, self.W_C)\n",
        "        A = (A_ts + A_st).to(device)\n",
        "        At = A + torch.eye(A.size(0), device=device)\n",
        "        inv = At.sum(1).clamp(min=EPS).pow(-0.5)\n",
        "        return inv.unsqueeze(1) * At * inv.unsqueeze(0)\n",
        "\n",
        "    def _diffuse(self, X, A_norm):\n",
        "        B, C, T = X.shape\n",
        "        flat = X.reshape(B, C*T)\n",
        "        out = (flat @ A_norm.T).reshape(B, C, T)\n",
        "        return out\n",
        "\n",
        "    def forward(self, x):\n",
        "        mu = x.mean(1, keepdim=True); sig = x.std(1, keepdim=True)\n",
        "        x0 = (x - mu) / (sig + EPS)\n",
        "        A_norm = self._A_norm(x.device)\n",
        "        conv = self._diffuse(x0, A_norm)\n",
        "        h = self.theta(self.g_skip.view(1,1,self.T)*x0\n",
        "                       + self.g_conv.view(1,1,self.T)*conv)\n",
        "        z = self.feat(h)\n",
        "        logits = self.cls(z.reshape(z.size(0), -1))\n",
        "        return logits\n",
        "\n",
        "class GVARMAClassifier(nn.Module):\n",
        "    \"\"\" ARMA-style spatial filtering + MLP(T→H) + Linear(C*H→K) \"\"\"\n",
        "    def __init__(self, C, T, W_C, num_classes, H=64, hidden_dim=128):\n",
        "        super().__init__()\n",
        "        self.C, self.T, self.H = C, T, H\n",
        "        self.register_buffer('W_C', W_C)\n",
        "        self.g_skip = nn.Parameter(torch.ones(T))\n",
        "        self.g_conv = nn.Parameter(torch.ones(T))\n",
        "\n",
        "        self.feat = nn.Sequential(\n",
        "            nn.Linear(T, hidden_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Linear(hidden_dim, H)\n",
        "        )\n",
        "        self.cls = nn.Linear(C * H, num_classes)\n",
        "\n",
        "    def _A_norm(self, device):\n",
        "        At = self.W_C + torch.eye(self.C, device=device)\n",
        "        inv = At.sum(1).clamp(min=EPS).pow(-0.5)\n",
        "        return inv.unsqueeze(1) * At * inv.unsqueeze(0)\n",
        "\n",
        "    @staticmethod\n",
        "    def _arma_pass_vec(x_in, A_norm):\n",
        "        h1 = torch.einsum('ij,bjt->bit', A_norm, x_in)\n",
        "        h2 = torch.einsum('ij,bjt->bit', A_norm, h1)\n",
        "        return h1 + h2\n",
        "\n",
        "    def forward(self, x):\n",
        "        mu = x.mean(1, keepdim=True); sig = x.std(1, keepdim=True)\n",
        "        x0 = (x - mu) / (sig + EPS)\n",
        "        A_norm = self._A_norm(x.device)\n",
        "        conv = self._arma_pass_vec(x0, A_norm)\n",
        "        h = self.g_skip.view(1,1,self.T)*x0 + self.g_conv.view(1,1,self.T)*conv\n",
        "        z = self.feat(h)\n",
        "        logits = self.cls(z.reshape(z.size(0), -1))\n",
        "        return logits\n",
        "\n",
        "class GGRNNClassifier(nn.Module):\n",
        "    \"\"\" GRNN-like gated graph conv over time + MLP(T→H) + Linear(C*H→K) \"\"\"\n",
        "    def __init__(self, C, T, W_C, num_classes, H=64, hidden_dim=128):\n",
        "        super().__init__()\n",
        "        self.C, self.T, self.H = C, T, H\n",
        "        self.register_buffer('W_C', W_C)\n",
        "        self.g_skip = nn.Parameter(torch.ones(T))\n",
        "        self.g_conv = nn.Parameter(torch.ones(T))\n",
        "\n",
        "        self.feat = nn.Sequential(\n",
        "            nn.Linear(T, hidden_dim),\n",
        "            nn.Tanh(),\n",
        "            nn.Linear(hidden_dim, H)\n",
        "        )\n",
        "        self.cls = nn.Linear(C * H, num_classes)\n",
        "\n",
        "    def _A_norm(self, device):\n",
        "        At = self.W_C + torch.eye(self.C, device=device)\n",
        "        inv = At.sum(1).clamp(min=EPS).pow(-0.5)\n",
        "        return inv.unsqueeze(1) * At * inv.unsqueeze(0)\n",
        "\n",
        "    def _block_fast(self, x_seq, A_norm):\n",
        "        B, C, T = x_seq.shape\n",
        "        x_conv = torch.einsum('ij,bjt->bit', A_norm, x_seq)\n",
        "        h = torch.zeros(B, C, device=x_seq.device, dtype=x_seq.dtype)\n",
        "        outs = []\n",
        "        A_T = A_norm.T\n",
        "        for t in range(T):\n",
        "            h_conv = h @ A_T\n",
        "            pre = x_conv[:, :, t] + h_conv\n",
        "            z = torch.sigmoid(pre)\n",
        "            r = torch.sigmoid(pre)\n",
        "            h = (1.0 - z) * h + z * torch.tanh(x_conv[:, :, t] + r * h_conv)\n",
        "            outs.append(h.unsqueeze(-1))\n",
        "        return torch.cat(outs, dim=-1)\n",
        "\n",
        "    def forward(self, x):\n",
        "        mu = x.mean(1, keepdim=True); sig = x.std(1, keepdim=True)\n",
        "        x0 = (x - mu) / (sig + EPS)\n",
        "        A_norm = self._A_norm(x.device)\n",
        "        h1 = self._block_fast(x0, A_norm)\n",
        "        h  = self.g_skip.view(1,1,self.T)*x0 + self.g_conv.view(1,1,self.T)*h1\n",
        "        z  = self.feat(h)\n",
        "        logits = self.cls(z.reshape(z.size(0), -1))\n",
        "        return logits\n",
        "\n",
        "\n",
        "class PositionalEncoding(nn.Module):\n",
        "    \"\"\" Standard sinusoidal positional encoding for sequences (batch_first=True). \"\"\"\n",
        "    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):\n",
        "        super().__init__()\n",
        "        self.dropout = nn.Dropout(dropout)\n",
        "        pe = torch.zeros(max_len, d_model, dtype=torch.float32)\n",
        "        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)\n",
        "        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))\n",
        "        pe[:, 0::2] = torch.sin(position * div_term)\n",
        "        pe[:, 1::2] = torch.cos(position * div_term)\n",
        "        self.register_buffer('pe', pe)  # (max_len, d_model)\n",
        "\n",
        "    def forward(self, x):\n",
        "        # x: (B, S, d_model)\n",
        "        S = x.size(1)\n",
        "        x = x + self.pe[:S, :].unsqueeze(0)\n",
        "        return self.dropout(x)\n",
        "\n",
        "class TransformerClassifier(nn.Module):\n",
        "    \"\"\"\n",
        "    Treat time as tokens: input at each timestep is the C-channel vector.\n",
        "    Project C→d_model, add positional encoding, prepend a learnable [CLS] token,\n",
        "    pass through TransformerEncoder, classify from CLS.\n",
        "    \"\"\"\n",
        "    def __init__(self, C, T, num_classes, d_model=128, nhead=1, num_layers=1, dim_feedforward=256, dropout=0.1):\n",
        "        super().__init__()\n",
        "        self.proj = nn.Linear(C, d_model)\n",
        "        self.pos  = PositionalEncoding(d_model=d_model, max_len=T+1, dropout=dropout)\n",
        "        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,\n",
        "                                                   dim_feedforward=dim_feedforward, dropout=dropout,\n",
        "                                                   batch_first=True)\n",
        "        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)\n",
        "        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))\n",
        "        nn.init.trunc_normal_(self.cls_token, std=0.02)\n",
        "        self.norm = nn.LayerNorm(d_model)\n",
        "        self.head = nn.Linear(d_model, num_classes)\n",
        "\n",
        "    def forward(self, x_bct):\n",
        "        # x_bct: (B, C, T) → (B, T, C)\n",
        "        x = x_bct.permute(0, 2, 1)\n",
        "        # per-time LN over channels (optional stabilization)\n",
        "        x = (x - x.mean(dim=1, keepdim=True)) / (x.std(dim=1, keepdim=True) + 1e-6)\n",
        "        x = self.proj(x)  # (B, T, d_model)\n",
        "        B, T, D = x.shape\n",
        "        cls = self.cls_token.expand(B, -1, -1)  # (B,1,D)\n",
        "        x = torch.cat([cls, x], dim=1)          # (B, T+1, D)\n",
        "        x = self.pos(x)\n",
        "        h = self.encoder(x)                     # (B, T+1, D)\n",
        "        cls_out = self.norm(h[:, 0, :])         # (B, D)\n",
        "        logits = self.head(cls_out)             # (B, K)\n",
        "        return logits\n",
        "\n",
        "class LSTMClassifier(nn.Module):\n",
        "    \"\"\"\n",
        "    Bi-LSTM over time with input size C, hidden H, 2 layers, mean pooling over time (or use last).\n",
        "    \"\"\"\n",
        "    def __init__(self, C, T, num_classes, hidden=128, num_layers=2, bidirectional=True, dropout=0.1):\n",
        "        super().__init__()\n",
        "        self.lstm = nn.LSTM(input_size=C, hidden_size=hidden, num_layers=num_layers,\n",
        "                            batch_first=True, dropout=dropout, bidirectional=bidirectional)\n",
        "        out_dim = hidden * (2 if bidirectional else 1)\n",
        "        self.norm = nn.LayerNorm(out_dim)\n",
        "        self.head = nn.Linear(out_dim, num_classes)\n",
        "\n",
        "    def forward(self, x_bct):\n",
        "        # (B, C, T) → (B, T, C)\n",
        "        x = x_bct.permute(0, 2, 1)\n",
        "        # z-score along time dimension per channel vector (optional)\n",
        "        x = (x - x.mean(dim=1, keepdim=True)) / (x.std(dim=1, keepdim=True) + 1e-6)\n",
        "        h, _ = self.lstm(x)               # (B, T, H*)\n",
        "        # temporal mean-pooling (alternatively: take h[:, -1, :])\n",
        "        h_pool = h.mean(dim=1)            # (B, H*)\n",
        "        h_pool = self.norm(h_pool)\n",
        "        logits = self.head(h_pool)        # (B, K)\n",
        "        return logits\n",
        "\n",
        "# -------------------------\n",
        "# Train / evaluate helpers\n",
        "# -------------------------\n",
        "def train_and_eval(model, train_loader, test_loader, epochs=40, lr=1e-3, wd=1e-4, print_every=5):\n",
        "    model.to(device)\n",
        "    opt = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
        "    crit = nn.CrossEntropyLoss()\n",
        "\n",
        "    best = (1e9, None)\n",
        "    for ep in range(1, epochs+1):\n",
        "        t0 = time.time()\n",
        "        # train\n",
        "        model.train()\n",
        "        tr_loss = 0.0; tr_corr = 0; tr_tot = 0\n",
        "        for xb, yb in train_loader:\n",
        "            xb, yb = xb.to(device), yb.to(device)\n",
        "            opt.zero_grad()\n",
        "            logits = model(xb)\n",
        "            loss = crit(logits, yb)\n",
        "            loss.backward()\n",
        "            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)\n",
        "            opt.step()\n",
        "            tr_loss += loss.item() * xb.size(0)\n",
        "            tr_corr += (logits.argmax(1) == yb).sum().item()\n",
        "            tr_tot  += xb.size(0)\n",
        "        tr_loss /= max(1, tr_tot)\n",
        "        tr_acc = 100.0 * tr_corr / max(1, tr_tot)\n",
        "\n",
        "        # eval\n",
        "        model.eval()\n",
        "        te_loss = 0.0; te_corr = 0; te_tot = 0\n",
        "        preds_all, targs_all = [], []\n",
        "        with torch.no_grad():\n",
        "            for xb, yb in test_loader:\n",
        "                xb, yb = xb.to(device), yb.to(device)\n",
        "                logits = model(xb)\n",
        "                te_loss += crit(logits, yb).item() * xb.size(0)\n",
        "                p = logits.argmax(1)\n",
        "                te_corr += (p == yb).sum().item()\n",
        "                te_tot  += yb.size(0)\n",
        "                preds_all.extend(p.cpu().tolist())\n",
        "                targs_all.extend(yb.cpu().tolist())\n",
        "        te_loss /= max(1, te_tot)\n",
        "        te_acc = 100.0 * te_corr / max(1, te_tot)\n",
        "        kappa  = cohen_kappa_score(targs_all, preds_all)\n",
        "\n",
        "        if te_loss < best[0]:\n",
        "            best = (te_loss, {k: v.cpu() for k, v in model.state_dict().items()})\n",
        "\n",
        "        if ep % print_every == 0 or ep == 1:\n",
        "            print(f\"Ep {ep:03d} | train {tr_loss:.4f}/{tr_acc:5.2f}% \"\n",
        "                  f\"| test {te_loss:.4f}/{te_acc:5.2f}% κ={kappa:.4f} | {time.time()-t0:.1f}s\")\n",
        "\n",
        "    if best[1] is not None:\n",
        "        model.load_state_dict({k: v.to(device) for k, v in best[1].items()})\n",
        "\n",
        "    # final test metrics\n",
        "    model.eval()\n",
        "    te_corr = 0; te_tot = 0; preds_all=[]; targs_all=[]\n",
        "    with torch.no_grad():\n",
        "        for xb, yb in test_loader:\n",
        "            xb, yb = xb.to(device), yb.to(device)\n",
        "            logits = model(xb)\n",
        "            p = logits.argmax(1)\n",
        "            te_corr += (p == yb).sum().item()\n",
        "            te_tot  += yb.size(0)\n",
        "            preds_all.extend(p.cpu().tolist()); targs_all.extend(yb.cpu().tolist())\n",
        "    te_acc = 100.0 * te_corr / max(1, te_tot)\n",
        "    kappa  = cohen_kappa_score(targs_all, preds_all)\n",
        "    return te_acc, kappa\n",
        "\n",
        "# --------------\n",
        "# MAIN PIPELINE (K-fold CV)\n",
        "# --------------\n",
        "def main():\n",
        "    # 1) Load ALL subjects and concatenate ALL windows\n",
        "    X_list, y_list = [], []\n",
        "    for sid in range(1, 10):  # BNCI2014_001 subjects 1..9\n",
        "        ds = MOABBDataset('BNCI2014_001', [sid])\n",
        "        preprocess(ds, preprocessors, n_jobs=-1)\n",
        "        sf = ds.datasets[0].raw.info['sfreq']\n",
        "        off = int(trial_start_offset_seconds * sf)\n",
        "        win = create_windows_from_events(\n",
        "            ds,\n",
        "            trial_start_offset_samples=off,\n",
        "            trial_stop_offset_samples=0,\n",
        "            preload=True\n",
        "        )\n",
        "        Xs, Ys = convert_to_tensors(win)  # (n_i, C, T); (n_i,)\n",
        "        X_list.append(Xs); y_list.append(Ys)\n",
        "\n",
        "    X_all = torch.cat(X_list, dim=0)\n",
        "    y_all = torch.cat(y_list, dim=0)\n",
        "    N, C, T = X_all.shape\n",
        "    print(\"Data:\", X_all.shape, y_all.shape)  # (N_total, C, T), (N_total,)\n",
        "\n",
        "    # 2) Stratified K-fold CV\n",
        "    skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)\n",
        "    fold_results = {name: {'acc': [], 'kappa': []} for name in\n",
        "                    ['GraphVar','CPGraphST','GVARMA','GGRNN','Transformer','LSTM']}\n",
        "\n",
        "    for fold_id, (tr_idx, te_idx) in enumerate(skf.split(np.zeros(N), y_all.numpy()), 1):\n",
        "        print(f\"\\n================ Fold {fold_id}/{N_SPLITS} ================\")\n",
        "        X_tr, y_tr = X_all[tr_idx], y_all[tr_idx]\n",
        "        X_te, y_te = X_all[te_idx], y_all[te_idx]\n",
        "\n",
        "        # 3) Build W_C from TRAINING ONLY (avoid leakage)\n",
        "        with torch.no_grad():\n",
        "            W_dense = compute_global_adjacency(X_tr).to(device)  # (C,C)\n",
        "\n",
        "        # 4) Dataloaders\n",
        "        train_loader = DataLoader(TensorDataset(X_tr, y_tr),\n",
        "                                  batch_size=batch_size, shuffle=True, drop_last=True)\n",
        "        test_loader  = DataLoader(TensorDataset(X_te, y_te),\n",
        "                                  batch_size=batch_size, shuffle=False)\n",
        "\n",
        "        # 5) Define models (fresh per fold)\n",
        "        models = {\n",
        "            'GraphVar':   GraphVarClassifier(C=C, T=T, W_C=W_dense, num_classes=num_classes, H=H_embed, hidden_dim=128, fun='sqd', ZAVE=True),\n",
        "\n",
        "            'Transformer':TransformerClassifier(C=C, T=T, num_classes=num_classes, d_model=128, nhead=1, num_layers=1, dim_feedforward=256, dropout=0.1),\n",
        "            'LSTM':       LSTMClassifier(C=C, T=T, num_classes=num_classes, hidden=128, num_layers=2, bidirectional=False, dropout=0.1),\n",
        "        }\n",
        "\n",
        "        # 6) Train & evaluate\n",
        "        for name, model in models.items():\n",
        "            print(f\"\\n--- Fold {fold_id} | {name} ---\")\n",
        "            acc, kappa = train_and_eval(model, train_loader, test_loader,\n",
        "                                        epochs=num_epochs, lr=1e-3, wd=1e-4, print_every=5)\n",
        "            fold_results[name]['acc'].append(acc)\n",
        "            fold_results[name]['kappa'].append(kappa)\n",
        "            print(f\"{name} | Fold {fold_id}: Test Acc {acc:.2f}% | κ={kappa:.4f}\")\n",
        "\n",
        "    # 7) Summary across folds\n",
        "    print(\"\\n================ OVERALL SUMMARY (K-fold CV) ================\")\n",
        "    for name in ['GraphVar','CPGraphST','GVARMA','GGRNN','Transformer','LSTM']:\n",
        "        accs = np.array(fold_results[name]['acc'], dtype=float)\n",
        "        kaps = np.array(fold_results[name]['kappa'], dtype=float)\n",
        "        print(f\"{name:12s}  Acc: {accs.mean():6.2f}% ± {accs.std(ddof=1):.2f}   \"\n",
        "              f\"Kappa: {kaps.mean(): .4f} ± {kaps.std(ddof=1):.4f}\")\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    main()\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "UGfFNy3KQW7c",
        "outputId": "b1373e37-625a-4ca9-a61c-1312385b965b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Device: cuda\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/braindecode/preprocessing/preprocess.py:71: UserWarning: Preprocessing choices with lambda functions cannot be saved.\n",
            "  warn(\"Preprocessing choices with lambda functions cannot be saved.\")\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']\n",
            "Data: torch.Size([5184, 22, 1125]) torch.Size([5184])\n",
            "\n",
            "================ Fold 1/5 ================\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "--- Fold 1 | GraphVar ---\n",
            "Ep 001 | train 1.1614/48.97% | test 1.0000/58.44% κ=0.4458 | 0.6s\n",
            "Ep 005 | train 0.4585/83.23% | test 1.1781/58.44% κ=0.4458 | 0.5s\n",
            "Ep 010 | train 0.0647/98.73% | test 1.9201/56.70% κ=0.4227 | 0.5s\n",
            "Ep 015 | train 0.0149/99.85% | test 2.4123/56.32% κ=0.4175 | 0.5s\n",
            "Ep 020 | train 0.0029/100.00% | test 2.5508/56.12% κ=0.4150 | 0.5s\n",
            "Ep 025 | train 0.0020/100.00% | test 2.6557/56.61% κ=0.4214 | 0.5s\n",
            "Ep 030 | train 0.0010/100.00% | test 2.6655/56.89% κ=0.4253 | 0.5s\n",
            "Ep 035 | train 0.0012/100.00% | test 2.7092/57.09% κ=0.4278 | 0.5s\n",
            "Ep 040 | train 0.1646/94.12% | test 2.9692/54.19% κ=0.3893 | 0.5s\n",
            "Ep 045 | train 0.0371/98.63% | test 3.1912/56.32% κ=0.4176 | 0.5s\n",
            "Ep 050 | train 0.0178/99.34% | test 3.5718/55.74% κ=0.4098 | 0.5s\n",
            "Ep 055 | train 0.0062/99.83% | test 3.6008/54.97% κ=0.3996 | 0.5s\n",
            "Ep 060 | train 0.0080/99.63% | test 3.7577/56.32% κ=0.4175 | 0.5s\n",
            "Ep 065 | train 0.0127/99.46% | test 3.8007/56.12% κ=0.4150 | 0.5s\n",
            "Ep 070 | train 0.0202/99.22% | test 4.0172/55.83% κ=0.4111 | 0.5s\n",
            "Ep 075 | train 0.0191/99.29% | test 3.9604/55.83% κ=0.4111 | 0.5s\n",
            "Ep 080 | train 0.0113/99.61% | test 4.1936/55.35% κ=0.4047 | 0.5s\n",
            "Ep 085 | train 0.0058/99.85% | test 4.2895/55.54% κ=0.4073 | 0.5s\n",
            "Ep 090 | train 0.0085/99.76% | test 4.2242/57.09% κ=0.4278 | 0.5s\n",
            "Ep 095 | train 0.0199/99.49% | test 4.5564/55.35% κ=0.4047 | 0.5s\n",
            "Ep 100 | train 0.0125/99.49% | test 4.7620/53.52% κ=0.3803 | 0.5s\n",
            "GraphVar | Fold 1: Test Acc 59.40% | κ=0.4587\n",
            "\n",
            "--- Fold 1 | Transformer ---\n",
            "Ep 001 | train 1.4279/25.12% | test 1.3996/25.27% κ=0.0039 | 1.4s\n",
            "Ep 005 | train 1.2737/41.16% | test 1.2808/39.44% κ=0.1925 | 1.4s\n",
            "Ep 010 | train 1.2050/45.19% | test 1.2286/42.62% κ=0.2350 | 1.4s\n",
            "Ep 015 | train 1.1882/46.34% | test 1.2364/43.11% κ=0.2415 | 1.4s\n",
            "Ep 020 | train 1.1738/46.34% | test 1.2138/44.26% κ=0.2568 | 1.4s\n",
            "Ep 025 | train 1.1674/47.63% | test 1.2802/43.39% κ=0.2453 | 1.4s\n",
            "Ep 030 | train 1.1622/46.90% | test 1.2128/44.17% κ=0.2555 | 1.4s\n",
            "Ep 035 | train 1.1418/48.02% | test 1.2532/44.65% κ=0.2621 | 1.4s\n",
            "Ep 040 | train 1.1443/48.49% | test 1.1959/46.38% κ=0.2851 | 1.4s\n",
            "Ep 045 | train 1.1335/48.88% | test 1.2483/43.39% κ=0.2450 | 1.4s\n",
            "Ep 050 | train 1.1372/49.15% | test 1.2044/45.90% κ=0.2788 | 1.4s\n",
            "Ep 055 | train 1.1196/49.32% | test 1.2067/44.36% κ=0.2582 | 1.4s\n",
            "Ep 060 | train 1.1191/49.83% | test 1.1825/46.19% κ=0.2825 | 1.4s\n",
            "Ep 065 | train 1.1063/51.42% | test 1.1988/45.71% κ=0.2761 | 1.4s\n",
            "Ep 070 | train 1.1103/49.85% | test 1.1985/46.58% κ=0.2876 | 1.4s\n",
            "Ep 075 | train 1.1061/51.54% | test 1.1759/46.96% κ=0.2929 | 1.4s\n",
            "Ep 080 | train 1.0914/51.07% | test 1.1866/48.22% κ=0.3096 | 1.4s\n",
            "Ep 085 | train 1.0903/52.12% | test 1.1598/46.96% κ=0.2929 | 1.4s\n",
            "Ep 090 | train 1.0877/52.69% | test 1.2035/46.19% κ=0.2825 | 1.4s\n",
            "Ep 095 | train 1.0742/52.56% | test 1.1705/49.28% κ=0.3237 | 1.4s\n",
            "Ep 100 | train 1.0980/51.56% | test 1.2143/46.48% κ=0.2864 | 1.4s\n",
            "Transformer | Fold 1: Test Acc 46.96% | κ=0.2929\n",
            "\n",
            "--- Fold 1 | LSTM ---\n",
            "Ep 001 | train 1.4131/25.54% | test 1.4061/26.90% κ=0.0256 | 1.5s\n",
            "Ep 005 | train 1.3474/34.13% | test 1.3719/29.80% κ=0.0636 | 1.5s\n",
            "Ep 010 | train 1.2454/41.53% | test 1.2677/39.73% κ=0.1961 | 1.5s\n",
            "Ep 015 | train 1.1583/48.00% | test 1.1693/47.44% κ=0.2992 | 1.5s\n",
            "Ep 020 | train 1.0475/54.37% | test 1.0915/51.11% κ=0.3482 | 1.5s\n",
            "Ep 025 | train 0.9662/59.35% | test 1.0864/52.56% κ=0.3674 | 1.5s\n",
            "Ep 030 | train 0.8583/63.67% | test 1.1724/49.95% κ=0.3326 | 1.5s\n",
            "Ep 035 | train 0.7658/68.26% | test 1.2210/49.76% κ=0.3299 | 1.5s\n",
            "Ep 040 | train 0.6908/71.78% | test 1.1642/52.36% κ=0.3649 | 1.5s\n",
            "Ep 045 | train 0.5667/77.64% | test 1.3962/52.75% κ=0.3699 | 1.5s\n",
            "Ep 050 | train 0.5129/80.54% | test 1.4578/53.71% κ=0.3828 | 1.5s\n",
            "Ep 055 | train 0.4274/83.37% | test 1.5601/53.62% κ=0.3815 | 1.5s\n",
            "Ep 060 | train 0.3630/86.65% | test 1.6716/52.46% κ=0.3662 | 1.5s\n",
            "Ep 065 | train 0.3354/87.38% | test 1.8944/51.98% κ=0.3597 | 1.5s\n",
            "Ep 070 | train 0.2766/89.65% | test 2.0877/53.52% κ=0.3804 | 1.5s\n",
            "Ep 075 | train 0.2276/91.55% | test 2.1723/53.13% κ=0.3751 | 1.5s\n",
            "Ep 080 | train 0.2508/90.94% | test 2.1544/51.01% κ=0.3468 | 1.5s\n",
            "Ep 085 | train 0.2055/92.41% | test 2.2441/52.75% κ=0.3700 | 1.5s\n",
            "Ep 090 | train 0.1984/92.77% | test 2.3087/52.36% κ=0.3647 | 1.5s\n",
            "Ep 095 | train 0.1631/94.46% | test 2.3206/53.52% κ=0.3802 | 1.5s\n",
            "Ep 100 | train 0.1544/94.78% | test 2.4903/53.33% κ=0.3777 | 1.5s\n",
            "LSTM | Fold 1: Test Acc 52.56% | κ=0.3674\n",
            "\n",
            "================ Fold 2/5 ================\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "--- Fold 2 | GraphVar ---\n",
            "Ep 001 | train 1.1670/48.05% | test 0.9876/60.17% κ=0.4690 | 0.5s\n",
            "Ep 005 | train 0.4789/82.84% | test 1.1479/58.24% κ=0.4433 | 0.6s\n",
            "Ep 010 | train 0.0724/98.41% | test 1.9020/57.76% κ=0.4368 | 0.5s\n",
            "Ep 015 | train 0.0142/99.71% | test 2.3970/58.05% κ=0.4407 | 0.5s\n",
            "Ep 020 | train 0.0039/99.98% | test 2.5706/56.80% κ=0.4240 | 0.5s\n",
            "Ep 025 | train 0.0016/100.00% | test 2.6640/57.09% κ=0.4278 | 0.5s\n",
            "Ep 030 | train 0.0010/100.00% | test 2.6470/56.89% κ=0.4253 | 0.5s\n",
            "Ep 035 | train 0.0058/99.93% | test 2.8588/56.89% κ=0.4252 | 0.5s\n",
            "Ep 040 | train 0.0601/97.88% | test 3.0012/56.03% κ=0.4137 | 0.5s\n",
            "Ep 045 | train 0.0217/99.34% | test 3.3536/55.35% κ=0.4047 | 0.6s\n",
            "Ep 050 | train 0.0137/99.44% | test 3.4959/55.74% κ=0.4098 | 0.5s\n",
            "Ep 055 | train 0.0090/99.78% | test 3.6894/53.91% κ=0.3854 | 0.5s\n",
            "Ep 060 | train 0.0078/99.73% | test 3.7714/56.32% κ=0.4175 | 0.5s\n",
            "Ep 065 | train 0.0104/99.68% | test 3.9998/53.42% κ=0.3790 | 0.6s\n",
            "Ep 070 | train 0.0176/99.54% | test 3.9408/54.29% κ=0.3905 | 0.5s\n",
            "Ep 075 | train 0.0194/99.27% | test 4.2132/53.62% κ=0.3815 | 0.5s\n",
            "Ep 080 | train 0.0118/99.56% | test 4.2410/52.56% κ=0.3674 | 0.5s\n",
            "Ep 085 | train 0.0083/99.71% | test 4.3320/54.87% κ=0.3983 | 0.5s\n",
            "Ep 090 | train 0.0078/99.68% | test 4.3911/54.19% κ=0.3893 | 0.5s\n",
            "Ep 095 | train 0.0061/99.71% | test 4.4527/53.23% κ=0.3764 | 0.5s\n",
            "Ep 100 | train 0.0206/99.37% | test 4.6231/52.75% κ=0.3700 | 0.5s\n",
            "GraphVar | Fold 2: Test Acc 60.27% | κ=0.4702\n",
            "\n",
            "--- Fold 2 | Transformer ---\n",
            "Ep 001 | train 1.4139/26.25% | test 1.3880/24.98% κ=0.0000 | 1.4s\n",
            "Ep 005 | train 1.3470/33.72% | test 1.3150/38.28% κ=0.1771 | 1.4s\n",
            "Ep 010 | train 1.2350/45.19% | test 1.2761/40.89% κ=0.2121 | 1.4s\n",
            "Ep 015 | train 1.1989/46.34% | test 1.1581/48.51% κ=0.3135 | 1.4s\n",
            "Ep 020 | train 1.1556/49.44% | test 1.1654/48.51% κ=0.3134 | 1.4s\n",
            "Ep 025 | train 1.1509/49.29% | test 1.1316/48.31% κ=0.3109 | 1.4s\n",
            "Ep 030 | train 1.1477/49.83% | test 1.1442/49.66% κ=0.3289 | 1.4s\n",
            "Ep 035 | train 1.1376/51.03% | test 1.1616/49.47% κ=0.3264 | 1.4s\n",
            "Ep 040 | train 1.1214/51.90% | test 1.1610/48.31% κ=0.3109 | 1.4s\n",
            "Ep 045 | train 1.1257/51.54% | test 1.1327/51.49% κ=0.3533 | 1.4s\n",
            "Ep 050 | train 1.1235/51.76% | test 1.1082/53.62% κ=0.3816 | 1.4s\n",
            "Ep 055 | train 1.1232/52.03% | test 1.1180/51.98% κ=0.3596 | 1.4s\n",
            "Ep 060 | train 1.1131/51.37% | test 1.1610/50.43% κ=0.3391 | 1.4s\n",
            "Ep 065 | train 1.1150/52.32% | test 1.1307/51.40% κ=0.3519 | 1.4s\n",
            "Ep 070 | train 1.1198/51.61% | test 1.1143/52.84% κ=0.3713 | 1.5s\n",
            "Ep 075 | train 1.1100/52.20% | test 1.1160/51.40% κ=0.3519 | 1.4s\n",
            "Ep 080 | train 1.1024/52.37% | test 1.0999/52.46% κ=0.3661 | 1.4s\n",
            "Ep 085 | train 1.0960/52.98% | test 1.1052/52.36% κ=0.3648 | 1.4s\n",
            "Ep 090 | train 1.0886/52.81% | test 1.1143/50.82% κ=0.3442 | 1.4s\n",
            "Ep 095 | train 1.0907/53.71% | test 1.1021/51.88% κ=0.3584 | 1.4s\n",
            "Ep 100 | train 1.0796/54.20% | test 1.0923/53.71% κ=0.3828 | 1.4s\n",
            "Transformer | Fold 2: Test Acc 55.16% | κ=0.4021\n",
            "\n",
            "--- Fold 2 | LSTM ---\n",
            "Ep 001 | train 1.4175/24.71% | test 1.3917/24.98% κ=0.0000 | 1.5s\n",
            "Ep 005 | train 1.3395/34.23% | test 1.3900/31.44% κ=0.0861 | 1.4s\n",
            "Ep 010 | train 1.2106/43.16% | test 1.2100/45.71% κ=0.2763 | 1.5s\n",
            "Ep 015 | train 1.0944/51.68% | test 1.1469/48.41% κ=0.3120 | 1.5s\n",
            "Ep 020 | train 1.0848/51.10% | test 1.1123/50.63% κ=0.3418 | 1.5s\n",
            "Ep 025 | train 0.9453/59.47% | test 1.1144/53.42% κ=0.3790 | 1.5s\n",
            "Ep 030 | train 0.8775/63.48% | test 1.2096/50.24% κ=0.3365 | 1.5s\n",
            "Ep 035 | train 0.7831/67.92% | test 1.0952/55.16% κ=0.4021 | 1.5s\n",
            "Ep 040 | train 0.6978/70.83% | test 1.1828/53.71% κ=0.3827 | 1.5s\n",
            "Ep 045 | train 0.6433/74.19% | test 1.2359/53.71% κ=0.3829 | 1.5s\n",
            "Ep 050 | train 0.5511/78.32% | test 1.4357/54.10% κ=0.3881 | 1.5s\n",
            "Ep 055 | train 0.4665/81.91% | test 1.4661/52.84% κ=0.3713 | 1.5s\n",
            "Ep 060 | train 0.3867/85.18% | test 1.6765/52.75% κ=0.3700 | 1.4s\n",
            "Ep 065 | train 0.3382/86.55% | test 1.8223/52.07% κ=0.3609 | 1.5s\n",
            "Ep 070 | train 0.2876/88.89% | test 1.9615/52.84% κ=0.3714 | 1.5s\n",
            "Ep 075 | train 0.2825/89.40% | test 1.9730/52.17% κ=0.3623 | 1.5s\n",
            "Ep 080 | train 0.2123/92.02% | test 2.0652/53.13% κ=0.3752 | 1.4s\n",
            "Ep 085 | train 0.2142/91.72% | test 2.2261/53.33% κ=0.3777 | 1.5s\n",
            "Ep 090 | train 0.1644/94.31% | test 2.3236/54.10% κ=0.3880 | 1.4s\n",
            "Ep 095 | train 0.1869/93.07% | test 2.3400/50.34% κ=0.3378 | 1.5s\n",
            "Ep 100 | train 0.1819/93.63% | test 2.5987/50.72% κ=0.3430 | 1.5s\n",
            "LSTM | Fold 2: Test Acc 53.81% | κ=0.3842\n",
            "\n",
            "================ Fold 3/5 ================\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "--- Fold 3 | GraphVar ---\n",
            "Ep 001 | train 1.1612/48.63% | test 1.0275/56.22% κ=0.4163 | 0.5s\n",
            "Ep 005 | train 0.4572/83.20% | test 1.2029/57.86% κ=0.4381 | 0.5s\n",
            "Ep 010 | train 0.0660/98.36% | test 2.0313/55.54% κ=0.4073 | 0.5s\n",
            "Ep 015 | train 0.0125/99.90% | test 2.4776/56.22% κ=0.4163 | 0.5s\n",
            "Ep 020 | train 0.0036/100.00% | test 2.6175/55.54% κ=0.4073 | 0.5s\n",
            "Ep 025 | train 0.0012/100.00% | test 2.6755/56.99% κ=0.4266 | 0.5s\n",
            "Ep 030 | train 0.0022/100.00% | test 2.7851/55.45% κ=0.4060 | 0.5s\n",
            "Ep 035 | train 0.0009/100.00% | test 2.7456/56.99% κ=0.4266 | 0.5s\n",
            "Ep 040 | train 0.0012/100.00% | test 2.7581/55.74% κ=0.4098 | 0.5s\n",
            "Ep 045 | train 0.1680/94.19% | test 3.0187/51.88% κ=0.3584 | 0.5s\n",
            "Ep 050 | train 0.0422/98.56% | test 3.5698/53.52% κ=0.3802 | 0.5s\n",
            "Ep 055 | train 0.0155/99.58% | test 3.6204/55.16% κ=0.4021 | 0.5s\n",
            "Ep 060 | train 0.0106/99.61% | test 3.9154/53.04% κ=0.3738 | 0.5s\n",
            "Ep 065 | train 0.0104/99.66% | test 3.7144/54.19% κ=0.3893 | 0.5s\n",
            "Ep 070 | train 0.0061/99.83% | test 3.8700/54.97% κ=0.3996 | 0.5s\n",
            "Ep 075 | train 0.0110/99.56% | test 4.1059/53.91% κ=0.3854 | 0.5s\n",
            "Ep 080 | train 0.0164/99.49% | test 4.2592/53.52% κ=0.3803 | 0.5s\n",
            "Ep 085 | train 0.0251/99.10% | test 4.7026/52.36% κ=0.3648 | 0.5s\n",
            "Ep 090 | train 0.0116/99.58% | test 4.5095/53.42% κ=0.3790 | 0.5s\n",
            "Ep 095 | train 0.0074/99.71% | test 4.5962/53.13% κ=0.3751 | 0.5s\n",
            "Ep 100 | train 0.0219/99.12% | test 4.8634/52.46% κ=0.3661 | 0.5s\n",
            "GraphVar | Fold 3: Test Acc 60.27% | κ=0.4703\n",
            "\n",
            "--- Fold 3 | Transformer ---\n",
            "Ep 001 | train 1.4230/26.17% | test 1.4108/25.07% κ=0.0000 | 1.4s\n",
            "Ep 005 | train 1.3589/32.10% | test 1.3234/35.68% κ=0.1427 | 1.4s\n",
            "Ep 010 | train 1.3151/36.60% | test 1.3304/34.81% κ=0.1310 | 1.4s\n",
            "Ep 015 | train 1.2771/39.28% | test 1.3128/39.15% κ=0.1885 | 1.4s\n",
            "Ep 020 | train 1.2526/41.36% | test 1.2655/40.98% κ=0.2131 | 1.4s\n",
            "Ep 025 | train 1.1926/46.46% | test 1.2732/42.72% κ=0.2365 | 1.4s\n",
            "Ep 030 | train 1.1743/47.02% | test 1.1855/46.87% κ=0.2915 | 1.4s\n",
            "Ep 035 | train 1.1690/48.46% | test 1.1992/47.06% κ=0.2942 | 1.4s\n",
            "Ep 040 | train 1.1524/48.90% | test 1.2312/46.29% κ=0.2839 | 1.4s\n",
            "Ep 045 | train 1.1334/50.73% | test 1.1543/49.86% κ=0.3314 | 1.4s\n",
            "Ep 050 | train 1.1533/48.71% | test 1.1604/48.12% κ=0.3082 | 1.4s\n",
            "Ep 055 | train 1.1271/51.25% | test 1.1488/50.24% κ=0.3366 | 1.4s\n",
            "Ep 060 | train 1.1371/50.93% | test 1.1277/50.63% κ=0.3417 | 1.4s\n",
            "Ep 065 | train 1.1129/51.12% | test 1.1308/51.78% κ=0.3571 | 1.4s\n",
            "Ep 070 | train 1.1367/50.46% | test 1.1432/50.43% κ=0.3390 | 1.4s\n",
            "Ep 075 | train 1.1326/50.49% | test 1.1831/48.12% κ=0.3082 | 1.4s\n",
            "Ep 080 | train 1.1239/51.39% | test 1.1273/51.11% κ=0.3481 | 1.4s\n",
            "Ep 085 | train 1.1315/51.90% | test 1.1470/49.86% κ=0.3313 | 1.4s\n",
            "Ep 090 | train 1.1055/51.73% | test 1.1662/50.05% κ=0.3340 | 1.4s\n",
            "Ep 095 | train 1.0916/53.22% | test 1.1279/51.98% κ=0.3596 | 1.4s\n",
            "Ep 100 | train 1.0948/52.59% | test 1.2569/48.99% κ=0.3199 | 1.4s\n",
            "Transformer | Fold 3: Test Acc 52.84% | κ=0.3713\n",
            "\n",
            "--- Fold 3 | LSTM ---\n",
            "Ep 001 | train 1.4182/25.17% | test 1.4240/25.07% κ=0.0000 | 1.5s\n",
            "Ep 005 | train 1.3124/35.79% | test 1.3291/35.68% κ=0.1426 | 1.5s\n",
            "Ep 010 | train 1.2080/43.65% | test 1.2607/39.54% κ=0.1938 | 1.5s\n",
            "Ep 015 | train 1.1421/49.12% | test 1.1293/50.05% κ=0.3340 | 1.5s\n",
            "Ep 020 | train 1.0489/53.88% | test 1.1034/53.42% κ=0.3789 | 1.5s\n",
            "Ep 025 | train 0.9942/58.45% | test 1.0698/55.16% κ=0.4021 | 1.5s\n",
            "Ep 030 | train 0.8851/64.14% | test 1.0833/54.48% κ=0.3931 | 1.5s\n",
            "Ep 035 | train 0.8185/66.33% | test 1.2167/51.59% κ=0.3546 | 1.5s\n",
            "Ep 040 | train 0.7495/69.43% | test 1.1132/52.17% κ=0.3623 | 1.5s\n",
            "Ep 045 | train 0.6715/73.75% | test 1.1781/57.67% κ=0.4355 | 1.5s\n",
            "Ep 050 | train 0.5462/78.52% | test 1.3538/54.87% κ=0.3983 | 1.5s\n",
            "Ep 055 | train 0.4628/81.98% | test 1.5152/54.19% κ=0.3893 | 1.5s\n",
            "Ep 060 | train 0.3596/86.50% | test 1.7070/54.77% κ=0.3970 | 1.5s\n",
            "Ep 065 | train 0.3869/85.08% | test 1.5919/54.19% κ=0.3893 | 1.5s\n",
            "Ep 070 | train 0.3273/86.87% | test 1.7498/54.00% κ=0.3868 | 1.5s\n",
            "Ep 075 | train 0.2676/89.99% | test 1.9057/55.64% κ=0.4085 | 1.5s\n",
            "Ep 080 | train 0.2340/91.04% | test 2.0955/52.84% κ=0.3713 | 1.5s\n",
            "Ep 085 | train 0.2070/92.68% | test 2.2406/54.58% κ=0.3944 | 1.5s\n",
            "Ep 090 | train 0.2068/92.16% | test 2.4069/53.04% κ=0.3738 | 1.5s\n",
            "Ep 095 | train 0.2051/92.14% | test 2.3853/52.94% κ=0.3725 | 1.5s\n",
            "Ep 100 | train 0.1836/93.99% | test 2.4495/54.77% κ=0.3969 | 1.5s\n",
            "LSTM | Fold 3: Test Acc 55.16% | κ=0.4021\n",
            "\n",
            "================ Fold 4/5 ================\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "--- Fold 4 | GraphVar ---\n",
            "Ep 001 | train 1.1843/47.56% | test 0.9844/60.46% κ=0.4728 | 0.5s\n",
            "Ep 005 | train 0.4832/82.89% | test 1.1133/60.17% κ=0.4690 | 0.5s\n",
            "Ep 010 | train 0.0763/98.12% | test 1.8119/55.93% κ=0.4124 | 0.5s\n",
            "Ep 015 | train 0.0107/99.95% | test 2.1919/58.44% κ=0.4459 | 0.5s\n",
            "Ep 020 | train 0.0027/100.00% | test 2.3380/58.15% κ=0.4420 | 0.5s\n",
            "Ep 025 | train 0.0014/100.00% | test 2.4151/58.82% κ=0.4510 | 0.5s\n",
            "Ep 030 | train 0.0011/100.00% | test 2.4576/59.02% κ=0.4536 | 0.5s\n",
            "Ep 035 | train 0.0663/97.73% | test 2.7296/56.99% κ=0.4265 | 0.5s\n",
            "Ep 040 | train 0.0495/98.07% | test 2.7936/57.57% κ=0.4343 | 0.5s\n",
            "Ep 045 | train 0.0183/99.39% | test 3.0690/56.32% κ=0.4176 | 0.5s\n",
            "Ep 050 | train 0.0114/99.66% | test 3.1328/56.70% κ=0.4227 | 0.5s\n",
            "Ep 055 | train 0.0023/99.98% | test 3.1440/58.92% κ=0.4523 | 0.5s\n",
            "Ep 060 | train 0.0023/99.93% | test 3.1747/58.82% κ=0.4510 | 0.5s\n",
            "Ep 065 | train 0.0009/100.00% | test 3.1973/58.34% κ=0.4446 | 0.5s\n",
            "Ep 070 | train 0.0022/99.98% | test 3.1897/58.24% κ=0.4433 | 0.5s\n",
            "Ep 075 | train 0.0526/98.05% | test 3.5400/56.51% κ=0.4201 | 0.5s\n",
            "Ep 080 | train 0.0236/99.07% | test 3.6208/56.61% κ=0.4214 | 0.5s\n",
            "Ep 085 | train 0.0178/99.37% | test 3.7101/56.22% κ=0.4163 | 0.5s\n",
            "Ep 090 | train 0.0077/99.68% | test 3.8872/56.70% κ=0.4227 | 0.5s\n",
            "Ep 095 | train 0.0148/99.51% | test 3.9518/55.64% κ=0.4086 | 0.5s\n",
            "Ep 100 | train 0.0138/99.54% | test 3.9520/56.99% κ=0.4266 | 0.5s\n",
            "GraphVar | Fold 4: Test Acc 62.01% | κ=0.4934\n",
            "\n",
            "--- Fold 4 | Transformer ---\n",
            "Ep 001 | train 1.4164/25.78% | test 1.4038/25.07% κ=0.0000 | 1.4s\n",
            "Ep 005 | train 1.2502/42.09% | test 1.2464/44.55% κ=0.2607 | 1.4s\n",
            "Ep 010 | train 1.2020/47.07% | test 1.1956/48.89% κ=0.3186 | 1.4s\n",
            "Ep 015 | train 1.1735/48.61% | test 1.1621/52.46% κ=0.3662 | 1.4s\n",
            "Ep 020 | train 1.1450/49.90% | test 1.1639/51.49% κ=0.3534 | 1.4s\n",
            "Ep 025 | train 1.1550/49.88% | test 1.1802/51.78% κ=0.3572 | 1.4s\n",
            "Ep 030 | train 1.1216/51.37% | test 1.2798/43.49% κ=0.2467 | 1.4s\n",
            "Ep 035 | train 1.1264/50.61% | test 1.1517/53.42% κ=0.3791 | 1.4s\n",
            "Ep 040 | train 1.1334/51.05% | test 1.2168/47.35% κ=0.2980 | 1.4s\n",
            "Ep 045 | train 1.1208/51.15% | test 1.1395/52.17% κ=0.3623 | 1.4s\n",
            "Ep 050 | train 1.1173/52.00% | test 1.1328/51.88% κ=0.3585 | 1.4s\n",
            "Ep 055 | train 1.1203/51.12% | test 1.1385/50.63% κ=0.3416 | 1.4s\n",
            "Ep 060 | train 1.1110/52.32% | test 1.1262/52.75% κ=0.3700 | 1.4s\n",
            "Ep 065 | train 1.1200/52.12% | test 1.1452/50.53% κ=0.3404 | 1.4s\n",
            "Ep 070 | train 1.1058/52.05% | test 1.1346/52.75% κ=0.3700 | 1.4s\n",
            "Ep 075 | train 1.1019/51.93% | test 1.1958/48.02% κ=0.3070 | 1.4s\n",
            "Ep 080 | train 1.0913/52.98% | test 1.1565/51.49% κ=0.3533 | 1.4s\n",
            "Ep 085 | train 1.0922/52.69% | test 1.1330/52.17% κ=0.3623 | 1.4s\n",
            "Ep 090 | train 1.0914/53.42% | test 1.1267/52.84% κ=0.3712 | 1.4s\n",
            "Ep 095 | train 1.0891/52.78% | test 1.1276/51.49% κ=0.3532 | 1.4s\n",
            "Ep 100 | train 1.0807/53.00% | test 1.1284/51.98% κ=0.3597 | 1.4s\n",
            "Transformer | Fold 4: Test Acc 52.65% | κ=0.3687\n",
            "\n",
            "--- Fold 4 | LSTM ---\n",
            "Ep 001 | train 1.4114/25.63% | test 1.3894/24.98% κ=0.0000 | 1.5s\n",
            "Ep 005 | train 1.3285/35.23% | test 1.3607/34.43% κ=0.1254 | 1.5s\n",
            "Ep 010 | train 1.2248/43.14% | test 1.2455/41.37% κ=0.2184 | 1.5s\n",
            "Ep 015 | train 1.1296/49.90% | test 1.1531/46.48% κ=0.2865 | 1.5s\n",
            "Ep 020 | train 1.0286/55.52% | test 1.1804/47.64% κ=0.3016 | 1.5s\n",
            "Ep 025 | train 0.9475/59.28% | test 1.1560/48.70% κ=0.3160 | 1.5s\n",
            "Ep 030 | train 0.9016/61.50% | test 1.3430/45.52% κ=0.2736 | 1.5s\n",
            "Ep 035 | train 0.7955/66.50% | test 1.2028/50.53% κ=0.3404 | 1.5s\n",
            "Ep 040 | train 0.7036/70.80% | test 1.3332/50.82% κ=0.3442 | 1.5s\n",
            "Ep 045 | train 0.6201/75.78% | test 1.3969/50.05% κ=0.3341 | 1.5s\n",
            "Ep 050 | train 0.5430/78.74% | test 1.4064/51.49% κ=0.3533 | 1.5s\n",
            "Ep 055 | train 0.4800/81.49% | test 1.7219/50.53% κ=0.3404 | 1.5s\n",
            "Ep 060 | train 0.4108/83.67% | test 1.8252/49.57% κ=0.3276 | 1.5s\n",
            "Ep 065 | train 0.3518/86.50% | test 1.9377/51.78% κ=0.3571 | 1.5s\n",
            "Ep 070 | train 0.3163/88.60% | test 2.3201/46.87% κ=0.2915 | 1.5s\n",
            "Ep 075 | train 0.2288/91.87% | test 2.2233/49.76% κ=0.3301 | 1.5s\n",
            "Ep 080 | train 0.2521/90.67% | test 2.3105/49.37% κ=0.3250 | 1.5s\n",
            "Ep 085 | train 0.1999/92.94% | test 2.3532/48.99% κ=0.3199 | 1.5s\n",
            "Ep 090 | train 0.1722/93.58% | test 2.5702/49.86% κ=0.3314 | 1.5s\n",
            "Ep 095 | train 0.1685/93.68% | test 2.7625/48.99% κ=0.3198 | 1.5s\n",
            "Ep 100 | train 0.1646/93.92% | test 2.5816/50.14% κ=0.3353 | 1.5s\n",
            "LSTM | Fold 4: Test Acc 49.08% | κ=0.3211\n",
            "\n",
            "================ Fold 5/5 ================\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "--- Fold 5 | GraphVar ---\n",
            "Ep 001 | train 1.1666/48.88% | test 1.0229/58.30% κ=0.4440 | 0.5s\n",
            "Ep 005 | train 0.4694/82.98% | test 1.1720/57.82% κ=0.4376 | 0.5s\n",
            "Ep 010 | train 0.0671/98.71% | test 1.8815/57.43% κ=0.4324 | 0.5s\n",
            "Ep 015 | train 0.0182/99.73% | test 2.2812/55.98% κ=0.4131 | 0.5s\n",
            "Ep 020 | train 0.0034/100.00% | test 2.4760/55.21% κ=0.4028 | 0.5s\n",
            "Ep 025 | train 0.0025/100.00% | test 2.5621/55.31% κ=0.4041 | 0.5s\n",
            "Ep 030 | train 0.0396/98.58% | test 2.9367/55.89% κ=0.4118 | 0.5s\n",
            "Ep 035 | train 0.0482/98.32% | test 3.1585/53.76% κ=0.3835 | 0.5s\n",
            "Ep 040 | train 0.0129/99.63% | test 3.3325/53.86% κ=0.3848 | 0.5s\n",
            "Ep 045 | train 0.0071/99.85% | test 3.3585/56.08% κ=0.4144 | 0.5s\n",
            "Ep 050 | train 0.0079/99.76% | test 3.5520/55.60% κ=0.4080 | 0.5s\n",
            "Ep 055 | train 0.0173/99.39% | test 3.6851/53.96% κ=0.3861 | 0.5s\n",
            "Ep 060 | train 0.0189/99.39% | test 3.8498/53.28% κ=0.3771 | 0.5s\n",
            "Ep 065 | train 0.0122/99.56% | test 3.9600/53.86% κ=0.3848 | 0.5s\n",
            "Ep 070 | train 0.0087/99.66% | test 4.2342/52.99% κ=0.3732 | 0.5s\n",
            "Ep 075 | train 0.0088/99.73% | test 4.2151/54.15% κ=0.3887 | 0.5s\n",
            "Ep 080 | train 0.0076/99.68% | test 4.3156/55.02% κ=0.4003 | 0.5s\n",
            "Ep 085 | train 0.0074/99.73% | test 4.2352/55.98% κ=0.4131 | 0.5s\n",
            "Ep 090 | train 0.0184/99.27% | test 4.4785/53.28% κ=0.3771 | 0.5s\n",
            "Ep 095 | train 0.0139/99.54% | test 4.5589/52.80% κ=0.3707 | 0.5s\n",
            "Ep 100 | train 0.0121/99.58% | test 4.4942/54.15% κ=0.3887 | 0.5s\n",
            "GraphVar | Fold 5: Test Acc 58.78% | κ=0.4505\n",
            "\n",
            "--- Fold 5 | Transformer ---\n",
            "Ep 001 | train 1.4251/25.20% | test 1.4127/25.00% κ=0.0000 | 1.4s\n",
            "Ep 005 | train 1.3423/33.52% | test 1.3442/33.59% κ=0.1145 | 1.4s\n",
            "Ep 010 | train 1.2940/38.79% | test 1.3239/36.97% κ=0.1596 | 1.4s\n",
            "Ep 015 | train 1.2581/41.82% | test 1.2848/41.12% κ=0.2149 | 1.4s\n",
            "Ep 020 | train 1.2289/44.09% | test 1.2841/41.60% κ=0.2214 | 1.4s\n",
            "Ep 025 | train 1.2084/46.61% | test 1.2645/41.12% κ=0.2149 | 1.4s\n",
            "Ep 030 | train 1.1858/46.31% | test 1.2042/47.01% κ=0.2934 | 1.4s\n",
            "Ep 035 | train 1.1525/49.12% | test 1.2002/47.01% κ=0.2934 | 1.4s\n",
            "Ep 040 | train 1.1317/50.85% | test 1.1972/48.07% κ=0.3076 | 1.4s\n",
            "Ep 045 | train 1.1143/51.22% | test 1.1782/49.42% κ=0.3256 | 1.4s\n",
            "Ep 050 | train 1.1018/51.93% | test 1.1517/50.77% κ=0.3436 | 1.4s\n",
            "Ep 055 | train 1.1061/51.73% | test 1.1649/49.81% κ=0.3308 | 1.4s\n",
            "Ep 060 | train 1.0850/52.95% | test 1.1529/52.12% κ=0.3616 | 1.4s\n",
            "Ep 065 | train 1.0887/52.73% | test 1.1469/49.61% κ=0.3282 | 1.4s\n",
            "Ep 070 | train 1.0703/54.35% | test 1.2200/47.59% κ=0.3012 | 1.4s\n",
            "Ep 075 | train 1.0859/53.56% | test 1.1440/50.48% κ=0.3398 | 1.4s\n",
            "Ep 080 | train 1.0709/55.47% | test 1.1789/49.52% κ=0.3269 | 1.4s\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "id": "c4eGU3fa-AsI",
        "outputId": "9099a6fd-b985-4bec-a595-04c4830885ab"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n",
            "Requirement already satisfied: mne in /usr/local/lib/python3.12/dist-packages (1.10.1)\n",
            "Requirement already satisfied: pandas in /usr/local/lib/python3.12/dist-packages (2.3.2)\n",
            "Requirement already satisfied: decorator in /usr/local/lib/python3.12/dist-packages (from mne) (4.4.2)\n",
            "Requirement already satisfied: jinja2 in /usr/local/lib/python3.12/dist-packages (from mne) (3.1.6)\n",
            "Requirement already satisfied: lazy-loader>=0.3 in /usr/local/lib/python3.12/dist-packages (from mne) (0.4)\n",
            "Requirement already satisfied: matplotlib>=3.7 in /usr/local/lib/python3.12/dist-packages (from mne) (3.10.0)\n",
            "Requirement already satisfied: numpy<3,>=1.25 in /usr/local/lib/python3.12/dist-packages (from mne) (1.26.4)\n",
            "Requirement already satisfied: packaging in /usr/local/lib/python3.12/dist-packages (from mne) (25.0)\n",
            "Requirement already satisfied: pooch>=1.5 in /usr/local/lib/python3.12/dist-packages (from mne) (1.8.2)\n",
            "Requirement already satisfied: scipy>=1.11 in /usr/local/lib/python3.12/dist-packages (from mne) (1.16.1)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.12/dist-packages (from mne) (4.67.1)\n",
            "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.12/dist-packages (from pandas) (2.9.0.post0)\n",
            "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.12/dist-packages (from pandas) (2025.2)\n",
            "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.12/dist-packages (from pandas) (2025.2)\n",
            "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne) (1.3.3)\n",
            "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne) (0.12.1)\n",
            "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne) (4.59.2)\n",
            "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne) (1.4.9)\n",
            "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne) (11.3.0)\n",
            "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.12/dist-packages (from matplotlib>=3.7->mne) (3.2.3)\n",
            "Requirement already satisfied: platformdirs>=2.5.0 in /usr/local/lib/python3.12/dist-packages (from pooch>=1.5->mne) (4.4.0)\n",
            "Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.12/dist-packages (from pooch>=1.5->mne) (2.32.4)\n",
            "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.12/dist-packages (from python-dateutil>=2.8.2->pandas) (1.17.0)\n",
            "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.12/dist-packages (from jinja2->mne) (3.0.2)\n",
            "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->pooch>=1.5->mne) (3.4.3)\n",
            "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->pooch>=1.5->mne) (3.10)\n",
            "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->pooch>=1.5->mne) (1.26.20)\n",
            "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.12/dist-packages (from requests>=2.19.0->pooch>=1.5->mne) (2025.8.3)\n",
            "✅ Loading data for 105 valid participants.\n",
            "\n",
            "📥 Processing participant S001 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S001R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S001R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S001R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S001\n",
            "\n",
            "📥 Processing participant S002 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S002R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S002R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S002R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S002\n",
            "\n",
            "📥 Processing participant S003 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S003R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S003R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S003R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S003\n",
            "\n",
            "📥 Processing participant S004 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S004R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S004R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S004R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S004\n",
            "\n",
            "📥 Processing participant S005 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S005R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S005R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S005R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S005\n",
            "\n",
            "📥 Processing participant S006 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S006R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S006R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S006R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S006\n",
            "\n",
            "📥 Processing participant S007 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S007R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S007R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S007R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S007\n",
            "\n",
            "📥 Processing participant S008 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S008R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S008R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S008R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S008\n",
            "\n",
            "📥 Processing participant S009 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S009R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S009R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S009R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S009\n",
            "\n",
            "📥 Processing participant S010 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S010R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S010R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S010R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S010\n",
            "\n",
            "📥 Processing participant S011 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S011R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S011R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S011R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S011\n",
            "\n",
            "📥 Processing participant S012 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S012R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S012R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S012R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S012\n",
            "\n",
            "📥 Processing participant S013 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S013R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S013R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S013R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S013\n",
            "\n",
            "📥 Processing participant S014 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S014R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S014R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S014R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S014\n",
            "\n",
            "📥 Processing participant S015 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S015R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S015R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S015R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S015\n",
            "\n",
            "📥 Processing participant S016 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S016R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S016R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S016R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S016\n",
            "\n",
            "📥 Processing participant S017 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S017R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S017R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S017R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S017\n",
            "\n",
            "📥 Processing participant S018 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S018R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S018R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S018R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S018\n",
            "\n",
            "📥 Processing participant S019 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S019R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S019R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S019R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S019\n",
            "\n",
            "📥 Processing participant S020 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S020R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S020R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S020R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S020\n",
            "\n",
            "📥 Processing participant S021 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S021R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S021R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S021R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S021\n",
            "\n",
            "📥 Processing participant S022 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S022R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S022R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S022R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S022\n",
            "\n",
            "📥 Processing participant S023 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S023R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S023R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S023R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S023\n",
            "\n",
            "📥 Processing participant S024 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S024R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S024R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S024R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S024\n",
            "\n",
            "📥 Processing participant S025 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S025R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S025R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S025R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S025\n",
            "\n",
            "📥 Processing participant S026 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S026R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S026R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S026R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S026\n",
            "\n",
            "📥 Processing participant S027 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S027R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S027R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S027R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S027\n",
            "\n",
            "📥 Processing participant S028 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S028R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S028R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S028R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S028\n",
            "\n",
            "📥 Processing participant S029 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S029R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S029R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S029R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S029\n",
            "\n",
            "📥 Processing participant S030 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S030R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S030R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S030R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S030\n",
            "\n",
            "📥 Processing participant S031 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S031R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S031R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S031R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S031\n",
            "\n",
            "📥 Processing participant S032 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S032R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S032R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S032R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S032\n",
            "\n",
            "📥 Processing participant S033 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S033R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S033R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S033R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S033\n",
            "\n",
            "📥 Processing participant S034 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S034R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S034R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S034R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S034\n",
            "\n",
            "📥 Processing participant S035 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S035R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S035R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S035R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S035\n",
            "\n",
            "📥 Processing participant S036 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S036R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S036R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S036R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S036\n",
            "\n",
            "📥 Processing participant S037 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S037R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S037R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S037R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S037\n",
            "\n",
            "📥 Processing participant S038 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S038R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S038R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S038R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S038\n",
            "\n",
            "📥 Processing participant S039 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S039R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S039R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S039R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S039\n",
            "\n",
            "📥 Processing participant S040 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S040R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S040R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S040R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S040\n",
            "\n",
            "📥 Processing participant S041 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S041R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S041R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S041R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S041\n",
            "\n",
            "📥 Processing participant S042 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S042R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S042R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S042R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S042\n",
            "\n",
            "📥 Processing participant S043 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S043R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S043R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S043R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S043\n",
            "\n",
            "📥 Processing participant S044 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S044R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S044R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S044R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S044\n",
            "\n",
            "📥 Processing participant S045 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S045R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S045R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S045R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S045\n",
            "\n",
            "📥 Processing participant S046 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S046R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S046R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S046R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S046\n",
            "\n",
            "📥 Processing participant S047 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S047R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S047R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S047R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S047\n",
            "\n",
            "📥 Processing participant S048 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S048R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S048R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S048R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S048\n",
            "\n",
            "📥 Processing participant S049 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S049R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S049R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S049R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S049\n",
            "\n",
            "📥 Processing participant S050 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S050R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S050R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S050R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S050\n",
            "\n",
            "📥 Processing participant S051 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S051R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S051R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S051R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S051\n",
            "\n",
            "📥 Processing participant S052 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S052R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S052R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S052R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S052\n",
            "\n",
            "📥 Processing participant S053 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S053R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S053R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S053R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S053\n",
            "\n",
            "📥 Processing participant S054 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S054R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S054R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S054R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S054\n",
            "\n",
            "📥 Processing participant S055 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S055R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S055R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S055R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S055\n",
            "\n",
            "📥 Processing participant S056 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S056R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S056R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S056R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S056\n",
            "\n",
            "📥 Processing participant S057 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S057R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S057R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S057R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S057\n",
            "\n",
            "📥 Processing participant S058 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S058R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S058R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S058R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S058\n",
            "\n",
            "📥 Processing participant S059 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S059R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S059R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S059R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S059\n",
            "\n",
            "📥 Processing participant S060 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S060R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S060R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S060R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S060\n",
            "\n",
            "📥 Processing participant S061 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S061R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S061R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S061R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S061\n",
            "\n",
            "📥 Processing participant S062 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S062R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S062R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S062R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S062\n",
            "\n",
            "📥 Processing participant S063 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S063R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S063R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S063R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S063\n",
            "\n",
            "📥 Processing participant S064 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S064R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S064R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S064R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S064\n",
            "\n",
            "📥 Processing participant S065 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S065R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S065R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S065R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S065\n",
            "\n",
            "📥 Processing participant S066 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S066R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S066R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S066R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S066\n",
            "\n",
            "📥 Processing participant S067 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S067R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S067R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S067R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S067\n",
            "\n",
            "📥 Processing participant S068 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S068R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S068R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S068R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S068\n",
            "\n",
            "📥 Processing participant S069 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S069R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S069R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S069R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S069\n",
            "\n",
            "📥 Processing participant S070 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S070R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S070R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S070R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S070\n",
            "\n",
            "📥 Processing participant S071 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S071R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S071R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S071R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S071\n",
            "\n",
            "📥 Processing participant S072 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S072R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S072R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S072R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S072\n",
            "\n",
            "📥 Processing participant S073 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S073R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S073R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S073R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S073\n",
            "\n",
            "📥 Processing participant S074 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S074R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S074R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S074R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S074\n",
            "\n",
            "📥 Processing participant S075 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S075R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S075R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S075R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S075\n",
            "\n",
            "📥 Processing participant S076 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S076R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S076R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S076R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S076\n",
            "\n",
            "📥 Processing participant S077 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S077R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S077R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S077R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S077\n",
            "\n",
            "📥 Processing participant S078 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S078R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S078R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S078R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S078\n",
            "\n",
            "📥 Processing participant S079 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S079R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S079R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S079R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S079\n",
            "\n",
            "📥 Processing participant S080 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S080R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S080R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S080R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S080\n",
            "\n",
            "📥 Processing participant S081 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S081R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S081R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S081R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S081\n",
            "\n",
            "📥 Processing participant S082 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S082R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S082R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S082R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S082\n",
            "\n",
            "📥 Processing participant S083 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S083R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S083R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S083R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S083\n",
            "\n",
            "📥 Processing participant S084 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S084R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S084R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S084R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S084\n",
            "\n",
            "📥 Processing participant S085 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S085R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S085R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S085R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S085\n",
            "\n",
            "📥 Processing participant S086 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S086R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S086R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S086R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S086\n",
            "\n",
            "📥 Processing participant S087 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S087R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S087R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S087R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S087\n",
            "\n",
            "📥 Processing participant S090 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S090R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S090R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S090R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S090\n",
            "\n",
            "📥 Processing participant S091 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S091R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S091R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S091R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S091\n",
            "\n",
            "📥 Processing participant S093 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S093R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S093R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S093R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S093\n",
            "\n",
            "📥 Processing participant S094 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S094R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S094R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S094R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S094\n",
            "\n",
            "📥 Processing participant S095 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S095R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S095R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S095R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S095\n",
            "\n",
            "📥 Processing participant S096 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S096R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S096R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S096R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S096\n",
            "\n",
            "📥 Processing participant S097 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S097R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S097R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S097R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S097\n",
            "\n",
            "📥 Processing participant S098 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S098R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S098R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S098R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S098\n",
            "\n",
            "📥 Processing participant S099 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S099R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S099R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S099R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S099\n",
            "\n",
            "📥 Processing participant S101 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S101R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S101R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S101R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S101\n",
            "\n",
            "📥 Processing participant S102 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S102R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S102R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S102R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S102\n",
            "\n",
            "📥 Processing participant S103 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S103R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S103R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S103R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S103\n",
            "\n",
            "📥 Processing participant S104 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S104R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 13 T1/T2 events from S104R08.edf\n",
            "Not setting metadata\n",
            "13 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 13 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S104R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 43 trials for S104\n",
            "\n",
            "📥 Processing participant S105 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S105R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S105R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S105R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S105\n",
            "\n",
            "📥 Processing participant S106 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S106R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S106R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S106R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S106\n",
            "\n",
            "📥 Processing participant S107 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S107R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S107R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S107R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S107\n",
            "\n",
            "📥 Processing participant S108 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S108R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S108R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S108R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S108\n",
            "\n",
            "📥 Processing participant S109 (3 files)\n",
            "✔️ Extracting 15 T1/T2 events from S109R04.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S109R08.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✔️ Extracting 15 T1/T2 events from S109R12.edf\n",
            "Not setting metadata\n",
            "15 matching events found\n",
            "No baseline correction applied\n",
            "0 projection items activated\n",
            "Using data from preloaded Raw for 15 events and 497 original time points ...\n",
            "0 bad epochs dropped\n",
            "✅ Stored 45 trials for S109\n"
          ]
        },
        {
          "output_type": "error",
          "ename": "FileNotFoundError",
          "evalue": "[Errno 2] No such file or directory: '/content/drive/My Drive/Physionet MI/all_participants_data.npz'",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mFileNotFoundError\u001b[0m                         Traceback (most recent call last)",
            "\u001b[0;32m/tmp/ipython-input-4233817689.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m    109\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    110\u001b[0m \u001b[0msave_path\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m\"/content/drive/My Drive/Physionet MI/all_participants_data.npz\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 111\u001b[0;31m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msavez\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0msave_path\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdata_dict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    112\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"\\n📁 Saved all data to {save_path}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/numpy/lib/npyio.py\u001b[0m in \u001b[0;36msavez\u001b[0;34m(file, *args, **kwds)\u001b[0m\n\u001b[1;32m    637\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    638\u001b[0m     \"\"\"\n\u001b[0;32m--> 639\u001b[0;31m     \u001b[0m_savez\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwds\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    640\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    641\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/numpy/lib/npyio.py\u001b[0m in \u001b[0;36m_savez\u001b[0;34m(file, args, kwds, compress, allow_pickle, pickle_kwargs)\u001b[0m\n\u001b[1;32m    734\u001b[0m         \u001b[0mcompression\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mzipfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mZIP_STORED\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    735\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 736\u001b[0;31m     \u001b[0mzipf\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mzipfile_factory\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"w\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcompression\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcompression\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    737\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    738\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mnamedict\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitems\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.12/dist-packages/numpy/lib/npyio.py\u001b[0m in \u001b[0;36mzipfile_factory\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m    101\u001b[0m     \u001b[0;32mimport\u001b[0m \u001b[0mzipfile\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    102\u001b[0m     \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'allowZip64'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 103\u001b[0;31m     \u001b[0;32mreturn\u001b[0m \u001b[0mzipfile\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mZipFile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    104\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    105\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/lib/python3.12/zipfile/__init__.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, file, mode, compression, allowZip64, compresslevel, strict_timestamps, metadata_encoding)\u001b[0m\n\u001b[1;32m   1334\u001b[0m             \u001b[0;32mwhile\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1335\u001b[0m                 \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1336\u001b[0;31m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfp\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mio\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfile\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilemode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1337\u001b[0m                 \u001b[0;32mexcept\u001b[0m \u001b[0mOSError\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1338\u001b[0m                     \u001b[0;32mif\u001b[0m \u001b[0mfilemode\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodeDict\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/content/drive/My Drive/Physionet MI/all_participants_data.npz'"
          ]
        }
      ],
      "source": [
        "\n",
        "#PHYSIONET LOAD\n",
        "#DOWNLOAD PHYSIONET DATA AND REPLACE\n",
        "\n",
        "# 1) Mount Google Drive\n",
        "from google.colab import drive\n",
        "drive.mount('/content/drive')\n",
        "\n",
        "# 2) Install necessary packages (if not already installed)\n",
        "!pip install mne pandas\n",
        "\n",
        "import os\n",
        "import mne\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "\n",
        "# 3) Define the main path where all participant folders are stored\n",
        "main_folder_path = '/content/drive/My Drive/PHYSIONET MI/files/'\n",
        "\n",
        "# 4) Generate participant IDs (001 to 105) and exclude faulty ones\n",
        "excluded_participants = {'088', '089', '092', '100'}\n",
        "participants = [f\"{i:03d}\" for i in range(1, 110) if f\"{i:03d}\" not in excluded_participants]\n",
        "\n",
        "print(f\"✅ Loading data for {len(participants)} valid participants.\")\n",
        "\n",
        "# 5) Dictionary to store all data\n",
        "data_dict = {}\n",
        "\n",
        "# 6) Loop through each participant\n",
        "for participant in participants:\n",
        "    participant_path = os.path.join(main_folder_path, f\"S{participant}\")\n",
        "\n",
        "    # Ensure participant folder exists\n",
        "    if not os.path.exists(participant_path):\n",
        "        print(f\"⚠️ Missing folder for S{participant}, skipping...\")\n",
        "        continue\n",
        "\n",
        "    # List EDF files (make sure the order is correct)\n",
        "    edf_files = sorted([f for f in os.listdir(participant_path) if f.lower().endswith('.edf')])\n",
        "\n",
        "    # Only select relevant motor imagery runs (R04, R08, R12)\n",
        "    imagery_runs = [f for f in edf_files if 'R04' in f or 'R08' in f or 'R12' in f]\n",
        "\n",
        "    if not imagery_runs:\n",
        "        print(f\"⚠️ No valid files for S{participant}, skipping...\")\n",
        "        continue\n",
        "\n",
        "    print(f\"\\n📥 Processing participant S{participant} ({len(imagery_runs)} files)\")\n",
        "\n",
        "    # Lists to store participant's trials\n",
        "    X_data = []\n",
        "    y_labels = []\n",
        "\n",
        "    # 7) Process each valid file\n",
        "    for edf_file in imagery_runs:\n",
        "        edf_path = os.path.join(participant_path, edf_file)\n",
        "\n",
        "        # Load EDF data\n",
        "        raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)\n",
        "\n",
        "        # Extract events\n",
        "        events, event_id = mne.events_from_annotations(raw, verbose=False)\n",
        "\n",
        "        # Dynamically find T1 and T2 event IDs\n",
        "        t1_event_id = next((code for label, code in event_id.items() if 'T1' in label), None)\n",
        "        t2_event_id = next((code for label, code in event_id.items() if 'T2' in label), None)\n",
        "\n",
        "        if t1_event_id is None or t2_event_id is None:\n",
        "            print(f\"⚠️ No T1/T2 events in {edf_file}, skipping...\")\n",
        "            continue\n",
        "\n",
        "        # Extract only relevant events\n",
        "        target_events = np.array([e for e in events if e[2] in [t1_event_id, t2_event_id]])\n",
        "\n",
        "        if len(target_events) == 0:\n",
        "            print(f\"⚠️ No T1/T2 trials found in {edf_file}, skipping...\")\n",
        "            continue\n",
        "\n",
        "        print(f\"✔️ Extracting {len(target_events)} T1/T2 events from {edf_file}\")\n",
        "\n",
        "        # Define trial duration (3.1s * 160Hz = 496 samples)\n",
        "        epochs = mne.Epochs(raw, target_events, event_id={'T1': t1_event_id, 'T2': t2_event_id},\n",
        "                             tmin=0, tmax=3.1, baseline=None, preload=True)\n",
        "\n",
        "        # Convert EEG data to NumPy array\n",
        "        data = epochs.get_data()  # Shape: (num_trials, channels, 496)\n",
        "        labels = np.array([0 if e[2] == t1_event_id else 1 for e in target_events])\n",
        "\n",
        "        # Store data\n",
        "        X_data.append(data)\n",
        "        y_labels.append(labels)\n",
        "\n",
        "    # 8) Convert participant data to NumPy arrays\n",
        "    if X_data:\n",
        "        X_data = np.concatenate(X_data, axis=0)  # Shape: (num_trials, channels, 496)\n",
        "        y_labels = np.concatenate(y_labels, axis=0)  # Shape: (num_trials,)\n",
        "\n",
        "        # Store in dictionary\n",
        "        data_dict[f\"S{participant}\"] = {\"X\": X_data, \"y\": y_labels}\n",
        "\n",
        "        print(f\"✅ Stored {X_data.shape[0]} trials for S{participant}\")\n",
        "\n",
        "# 9) Convert to a DataFrame\n",
        "df_list = []\n",
        "for participant, data in data_dict.items():\n",
        "    trials = data[\"X\"].shape[0]\n",
        "    df_list.append({\"Participant\": participant, \"Trials\": trials, \"Data Shape\": data[\"X\"].shape})\n",
        "\n",
        "df = pd.DataFrame(df_list)\n",
        "\n",
        "save_path = \"/content/drive/My Drive/Physionet MI/all_participants_data.npz\"\n",
        "np.savez(save_path, **data_dict)\n",
        "print(f\"\\n📁 Saved all data to {save_path}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#!/usr/bin/env python\n",
        "# -*- coding: utf-8 -*-\n",
        "\n",
        "import time, math\n",
        "import numpy as np\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "from torch.utils.data import TensorDataset, DataLoader\n",
        "from sklearn.model_selection import StratifiedKFold\n",
        "from sklearn.metrics import accuracy_score, cohen_kappa_score, f1_score\n",
        "\n",
        "# -----------------------\n",
        "# Repro & device\n",
        "# -----------------------\n",
        "SEED = 42\n",
        "np.random.seed(SEED)\n",
        "torch.manual_seed(SEED)\n",
        "torch.backends.cudnn.deterministic = True\n",
        "torch.backends.cudnn.benchmark = False\n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
        "print(\"Device:\", device)\n",
        "\n",
        "# -----------------------\n",
        "# Expect data_dict in scope\n",
        "# -----------------------\n",
        "assert 'data_dict' in globals(), \"data_dict must exist: {pid: {'X':(n,C,T), 'y':(n,)}}\"\n",
        "for k,v in data_dict.items():\n",
        "    Xk, yk = v['X'], v['y']\n",
        "    assert Xk.ndim == 3 and yk.ndim == 1 and Xk.shape[0] == yk.shape[0], f\"Bad shapes for {k}\"\n",
        "    u = np.unique(yk)\n",
        "    if not np.array_equal(np.sort(u), np.array([0,1])):\n",
        "        m = {np.sort(u)[0]:0, np.sort(u)[-1]:1}\n",
        "        data_dict[k]['y'] = np.vectorize(m.get)(yk)\n",
        "\n",
        "# Pool all trials\n",
        "X_all = np.concatenate([v['X'] for v in data_dict.values()], axis=0)\n",
        "y_all = np.concatenate([v['y'] for v in data_dict.values()], axis=0)\n",
        "N, C, T = X_all.shape\n",
        "print(f\"Pooled data: N={N}, C={C}, T={T} | pos={y_all.sum()} neg={N - y_all.sum()}\")\n",
        "\n",
        "# ----------------\n",
        "# Hyperparams\n",
        "# ----------------\n",
        "EPS = 1e-6\n",
        "batch_size   = 64\n",
        "num_epochs   = 50\n",
        "LR           = 1e-3\n",
        "WEIGHT_DECAY = 1e-4\n",
        "NUM_CLASSES  = 2\n",
        "N_SPLITS     = 5\n",
        "\n",
        "# ----------------\n",
        "# Graph utils\n",
        "# ----------------\n",
        "def compute_global_adjacency(x):\n",
        "    # x: (N,C,T) torch.float32\n",
        "    mu = x.mean(dim=2, keepdim=True)\n",
        "    Z  = x - mu\n",
        "    Nn, Cc, Tt = Z.shape\n",
        "    cov = torch.einsum('nct,ndt->cd', Z, Z) / (Nn * Tt + 1e-9)\n",
        "    std = torch.sqrt(torch.clamp(torch.diag(cov), min=1e-9))\n",
        "    corr = cov / (std.unsqueeze(1) * std.unsqueeze(0) + 1e-9)\n",
        "    corr = corr.abs()\n",
        "    corr.fill_diagonal_(0.0)\n",
        "    return corr\n",
        "\n",
        "def graph_variate(x, fun='sqd', Zave=True, eps=EPS):\n",
        "    B, Cc, Tt = x.shape\n",
        "    x_ = x\n",
        "    if Zave:\n",
        "        mu = x_.mean(dim=0, keepdim=True)\n",
        "        sig = x_.std(dim=0, keepdim=True, unbiased=True)\n",
        "        x_ = (x_ - mu) / (sig + eps)\n",
        "    if fun == 'sqd':\n",
        "        Om = (x_.unsqueeze(2) - x_.unsqueeze(1)).pow(2)\n",
        "    elif fun == 'corr':\n",
        "        D = x_ - x_.mean(dim=2, keepdim=True)\n",
        "        Om = torch.abs((D.unsqueeze(2) * D.unsqueeze(1)))\n",
        "    elif fun == 'full':\n",
        "        Om = torch.ones(B, Cc, Cc, Tt, device=x.device, dtype=x.dtype)\n",
        "    else:\n",
        "        raise ValueError(fun)\n",
        "    return Om\n",
        "\n",
        "def renormalize_dynamic(A, eps=EPS):\n",
        "    I = torch.eye(A.size(1), device=A.device).view(1, A.size(1), A.size(2), 1)\n",
        "    At = A + I\n",
        "    deg = At.sum(2, keepdim=True)\n",
        "    inv = deg.clamp(min=eps).pow(-0.5)\n",
        "    return inv * At * inv.transpose(1, 2)\n",
        "\n",
        "def graph_conv_bmm(x, A):\n",
        "    B, Cc, Tt = x.shape\n",
        "    A_bt = A.permute(0, 3, 1, 2).contiguous().view(B*Tt, Cc, Cc)\n",
        "    x_bt = x.permute(0, 2, 1).contiguous().view(B*Tt, Cc, 1)\n",
        "    y_bt = torch.bmm(A_bt, x_bt).view(B, Tt, Cc)\n",
        "    return y_bt.permute(0, 2, 1).contiguous()\n",
        "\n",
        "# ----------------------\n",
        "# Backbones (B,C,T)->(B,C,T)\n",
        "# ----------------------\n",
        "class GraphVarBackbone(nn.Module):\n",
        "    def __init__(self, C, T, W_C, fun='sqd', ZAVE=False):\n",
        "        super().__init__()\n",
        "        self.C, self.T = C, T\n",
        "        self.fun, self.ZAVE = fun, ZAVE\n",
        "        self.register_buffer('W_C', W_C)\n",
        "        self.g_conv  = nn.Parameter(torch.ones(T))\n",
        "        self.g_skip  = nn.Parameter(torch.ones(T))\n",
        "        self.W_C1=nn.Parameter(W_C.clone())\n",
        "    def forward(self, x):\n",
        "        mu = x.mean(1, keepdim=True); sig = x.std(1, keepdim=True)\n",
        "        x0 = (x - mu) / (sig + EPS)\n",
        "        Om1 = graph_variate(x0, fun='sqd', Zave=self.ZAVE, eps=EPS)\n",
        "        A1  = renormalize_dynamic(Om1 * self.W_C1.unsqueeze(0).unsqueeze(-1))\n",
        "        h0  = self.g_conv * (graph_conv_bmm(x0, A1)) + self.g_skip * x0\n",
        "        Om2 = graph_variate(h0, fun='sqd', Zave=self.ZAVE, eps=EPS)\n",
        "        A2  = renormalize_dynamic(Om2 * self.W_C.unsqueeze(0).unsqueeze(-1))\n",
        "        h1  = h0 + graph_conv_bmm(h0, A2)\n",
        "        return h0\n",
        "\n",
        "\n",
        "# -----------------------------\n",
        "# MLP Head (subject-level; no time in logits)\n",
        "# -----------------------------\n",
        "\n",
        "class MLPHead(nn.Module):\n",
        "    \"\"\"\n",
        "    Per-time-step MLP (no pooling inside the feature path), but returns\n",
        "    (B, num_classes) by averaging logits over time so your existing\n",
        "    CrossEntropyLoss(y: (B,)) works.\n",
        "\n",
        "    Full temporal logits are stored in `self.last_seq_logits` with shape\n",
        "    (B, num_classes, T) for analysis/visualization.\n",
        "\n",
        "    Input:  h (B, C, T)\n",
        "    Return: logits (B, num_classes)\n",
        "    \"\"\"\n",
        "    def __init__(self, C: int, hidden: int = 256, dropout: float = 0.3, num_classes: int = 2):\n",
        "        super().__init__()\n",
        "        self.norm_t = nn.LayerNorm(C)     # normalize per time step over channels\n",
        "        self.fc1    = nn.Linear(C, hidden)\n",
        "        self.fc2    = nn.Linear(hidden, num_classes)\n",
        "        self.drop   = nn.Dropout(dropout)\n",
        "        self.last_seq_logits = None       # will hold (B, num_classes, T)\n",
        "\n",
        "    def forward(self, h_bct: torch.Tensor) -> torch.Tensor:\n",
        "        # h_bct: (B, C, T)\n",
        "        B, C, T = h_bct.shape\n",
        "        x = h_bct.permute(0, 2, 1)        # (B, T, C)\n",
        "        x = self.norm_t(x)                # LN over channels at each time step\n",
        "        x = x.reshape(B * T, C)           # shared MLP across time steps\n",
        "        x = F.relu(self.fc1(x))\n",
        "        x = self.drop(x)\n",
        "        x = self.fc2(x)                   # (B*T, num_classes)\n",
        "\n",
        "        seq_logits = x.view(B, T, -1).permute(0, 2, 1).contiguous()  # (B, num_classes, T)\n",
        "        self.last_seq_logits = seq_logits\n",
        "\n",
        "        logits = seq_logits.mean(dim=-1)  # (B, num_classes)\n",
        "        return logits\n",
        "\n",
        "class GraphBackboneMLP(nn.Module):\n",
        "    def __init__(self, backbone: nn.Module, C: int, T: int, num_classes: int = 2):\n",
        "        super().__init__()\n",
        "        self.backbone = backbone\n",
        "        self.head = MLPHead(C=C, hidden=256, dropout=0.3, num_classes=num_classes)\n",
        "    def forward(self, x_bct):  # (B, C, T)\n",
        "        h = self.backbone(x_bct)  # (B,C,T)\n",
        "        return self.head(h)       # (B,num_classes)\n",
        "\n",
        "# -----------------------------\n",
        "# EEGNet baseline (separate; no graph)\n",
        "# -----------------------------\n",
        "class EEGNetClassifier(nn.Module):\n",
        "    \"\"\"\n",
        "    Simple EEGNet-like baseline that adapts to (C,T):\n",
        "      - Temporal conv → Depthwise spatial conv across channels → Pointwise\n",
        "      - AdaptiveAvgPool → FC\n",
        "    \"\"\"\n",
        "    def __init__(self, C: int, num_classes: int = 2, temporal_kernel: int = 25):\n",
        "        super().__init__()\n",
        "        pad_t = temporal_kernel // 2\n",
        "        self.temporal = nn.Sequential(\n",
        "            nn.Conv2d(1, 8, kernel_size=(1, temporal_kernel), padding=(0, pad_t), bias=False),\n",
        "            nn.BatchNorm2d(8, affine=False),\n",
        "            nn.ELU()\n",
        "        )\n",
        "        # Depthwise spatial conv over channels (height=C), collapses channel axis to 1\n",
        "        self.spatial = nn.Sequential(\n",
        "            nn.Conv2d(8, 16, kernel_size=(C, 1), groups=8, bias=False),\n",
        "            nn.BatchNorm2d(16, affine=False),\n",
        "            nn.ELU()\n",
        "        )\n",
        "        self.pointwise = nn.Sequential(\n",
        "            nn.Conv2d(16, 32, kernel_size=(1,1), bias=False),\n",
        "            nn.BatchNorm2d(32, affine=False),\n",
        "            nn.ELU()\n",
        "        )\n",
        "        self.pool = nn.AdaptiveAvgPool2d((1,1))\n",
        "        self.fc = nn.Linear(32, num_classes)\n",
        "\n",
        "    def forward(self, x_bct):  # (B,C,T)\n",
        "        x = x_bct.unsqueeze(1)           # (B,1,C,T)\n",
        "        x = self.temporal(x)             # (B,8,C,T)\n",
        "        x = self.spatial(x)              # (B,16,1,T)\n",
        "        x = self.pointwise(x)            # (B,32,1,T)\n",
        "        x = self.pool(x)                 # (B,32,1,1)\n",
        "        x = x.flatten(1)                 # (B,32)\n",
        "        return self.fc(x)                # (B,num_classes)\n",
        "\n",
        "# -----------------------------\n",
        "# Transformer & LSTM baselines (separate; no graph)\n",
        "# -----------------------------\n",
        "class PositionalEncoding(nn.Module):\n",
        "    def __init__(self, d_model: int, max_len: int = 5000, dropout: float = 0.1):\n",
        "        super().__init__()\n",
        "        self.dropout = nn.Dropout(dropout)\n",
        "        pe = torch.zeros(max_len, d_model, dtype=torch.float32)\n",
        "        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)\n",
        "        div_term = torch.exp(torch.arange(0, d_model, 2, dtype=torch.float32) * (-math.log(10000.0) / d_model))\n",
        "        pe[:, 0::2] = torch.sin(position * div_term)\n",
        "        pe[:, 1::2] = torch.cos(position * div_term)\n",
        "        self.register_buffer('pe', pe)\n",
        "\n",
        "    def forward(self, x):  # x: (B, S, D)\n",
        "        S = x.size(1)\n",
        "        x = x + self.pe[:S, :].unsqueeze(0)\n",
        "        return self.dropout(x)\n",
        "\n",
        "class TransformerClassifier(nn.Module):\n",
        "    \"\"\"\n",
        "    Treat time as tokens: each token is a C-dim vector.\n",
        "    \"\"\"\n",
        "    def __init__(self, C, T, num_classes, d_model=128, nhead=4, num_layers=2, dim_feedforward=256, dropout=0.1):\n",
        "        super().__init__()\n",
        "        self.proj = nn.Linear(C, d_model)\n",
        "        self.pos  = PositionalEncoding(d_model=d_model, max_len=T+1, dropout=dropout)\n",
        "        enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead,\n",
        "                                               dim_feedforward=dim_feedforward, dropout=dropout,\n",
        "                                               batch_first=True)\n",
        "        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)\n",
        "        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))\n",
        "        nn.init.trunc_normal_(self.cls_token, std=0.02)\n",
        "        self.norm = nn.LayerNorm(d_model)\n",
        "        self.head = nn.Linear(d_model, num_classes)\n",
        "\n",
        "    def forward(self, x_bct):  # (B,C,T)\n",
        "        x = x_bct.permute(0, 2, 1)           # (B,T,C)\n",
        "        x = (x - x.mean(dim=1, keepdim=True)) / (x.std(dim=1, keepdim=True) + 1e-6)\n",
        "        x = self.proj(x)                      # (B,T,D)\n",
        "        B, T, D = x.shape\n",
        "        cls = self.cls_token.expand(B, 1, D)  # (B,1,D)\n",
        "        x = torch.cat([cls, x], dim=1)        # (B,T+1,D)\n",
        "        x = self.pos(x)\n",
        "        h = self.encoder(x)                   # (B,T+1,D)\n",
        "        cls_out = self.norm(h[:, 0, :])       # (B,D)\n",
        "        return self.head(cls_out)             # (B,num_classes)\n",
        "\n",
        "class LSTMClassifier(nn.Module):\n",
        "    def __init__(self, C, num_classes, hidden=128, num_layers=2, bidirectional=True, dropout=0.1):\n",
        "        super().__init__()\n",
        "        self.lstm = nn.LSTM(input_size=C, hidden_size=hidden, num_layers=num_layers,\n",
        "                            batch_first=True, dropout=dropout, bidirectional=bidirectional)\n",
        "        out_dim = hidden * (2 if bidirectional else 1)\n",
        "        self.norm = nn.LayerNorm(out_dim)\n",
        "        self.head = nn.Linear(out_dim, num_classes)\n",
        "\n",
        "    def forward(self, x_bct):  # (B,C,T)\n",
        "        x = x_bct.permute(0, 2, 1)           # (B,T,C)\n",
        "        x = (x - x.mean(dim=1, keepdim=True)) / (x.std(dim=1, keepdim=True) + 1e-6)\n",
        "        h, _ = self.lstm(x)                   # (B,T,H*)\n",
        "        h_pool = h.mean(dim=1)                # (B,H*)\n",
        "        h_pool = self.norm(h_pool)\n",
        "        return self.head(h_pool)              # (B,num_classes)\n",
        "\n",
        "# -------------------------\n",
        "# Train / evaluate per fold\n",
        "# -------------------------\n",
        "def train_and_eval(model, train_loader, test_loader, epochs=50, lr=1e-3, wd=1e-4):\n",
        "    model.to(device)\n",
        "    opt = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)\n",
        "    crit = nn.CrossEntropyLoss()\n",
        "    best_state = None\n",
        "    best_val = float('inf')\n",
        "\n",
        "    for ep in range(1, epochs+1):\n",
        "        t0 = time.time()\n",
        "        # train\n",
        "        model.train()\n",
        "        tr_loss = 0.0; tr_corr = 0; tr_tot = 0\n",
        "        for xb, yb in train_loader:\n",
        "            xb, yb = xb.to(device), yb.to(device)\n",
        "            opt.zero_grad()\n",
        "            logits = model(xb)\n",
        "            loss = crit(logits, yb)\n",
        "            loss.backward()\n",
        "            nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
        "            opt.step()\n",
        "            tr_loss += loss.item() * xb.size(0)\n",
        "            tr_corr += (logits.argmax(1) == yb).sum().item()\n",
        "            tr_tot  += xb.size(0)\n",
        "        tr_loss /= max(1, tr_tot)\n",
        "        tr_acc = 100.0 * tr_corr / max(1, tr_tot)\n",
        "\n",
        "        # val\n",
        "        model.eval()\n",
        "        te_loss = 0.0; te_corr = 0; te_tot = 0\n",
        "        preds_all, targs_all = [], []\n",
        "        with torch.no_grad():\n",
        "            for xb, yb in test_loader:\n",
        "                xb, yb = xb.to(device), yb.to(device)\n",
        "                logits = model(xb)\n",
        "                loss = crit(logits, yb)\n",
        "                te_loss += loss.item() * xb.size(0)\n",
        "                p = logits.argmax(1)\n",
        "                te_corr += (p == yb).sum().item()\n",
        "                te_tot  += yb.size(0)\n",
        "                preds_all.extend(p.cpu().numpy().tolist())\n",
        "                targs_all.extend(yb.cpu().numpy().tolist())\n",
        "        te_loss /= max(1, te_tot)\n",
        "        te_acc = 100.0 * te_corr / max(1, te_tot)\n",
        "        f1  = f1_score(targs_all, preds_all)\n",
        "        kappa = cohen_kappa_score(targs_all, preds_all)\n",
        "\n",
        "        if te_loss < best_val:\n",
        "            best_val = te_loss\n",
        "            best_state = {k: v.cpu() for k, v in model.state_dict().items()}\n",
        "\n",
        "        if ep % 5 == 0 or ep == 1:\n",
        "            print(f\"Ep {ep:03d} | train {tr_loss:.4f}/{tr_acc:5.2f}% \"\n",
        "                  f\"| val {te_loss:.4f}/{te_acc:5.2f}% F1={f1:.4f} κ={kappa:.4f} \"\n",
        "                  f\"| {time.time()-t0:.1f}s\")\n",
        "\n",
        "    if best_state is not None:\n",
        "        model.load_state_dict({k: v.to(device) for k, v in best_state.items()})\n",
        "\n",
        "    # final on this fold\n",
        "    model.eval()\n",
        "    preds_all, targs_all = [], []\n",
        "    with torch.no_grad():\n",
        "        for xb, yb in test_loader:\n",
        "            xb, yb = xb.to(device), yb.to(device)\n",
        "            logits = model(xb)\n",
        "            p = logits.argmax(1)\n",
        "            preds_all.extend(p.cpu().numpy().tolist())\n",
        "            targs_all.extend(yb.cpu().numpy().tolist())\n",
        "    acc   = accuracy_score(targs_all, preds_all) * 100.0\n",
        "    f1    = f1_score(targs_all, preds_all)\n",
        "    kappa = cohen_kappa_score(targs_all, preds_all)\n",
        "    return acc, f1, kappa\n",
        "\n",
        "# -------------------------\n",
        "# Cross-validation pipeline\n",
        "# -------------------------\n",
        "def run_cv(X_all, y_all, n_splits=5):\n",
        "    X_all_t = torch.tensor(X_all, dtype=torch.float32)\n",
        "    y_all_t = torch.tensor(y_all.astype(int), dtype=torch.long)\n",
        "    N, C, T = X_all_t.shape\n",
        "\n",
        "    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)\n",
        "    results = {\n",
        "        name: {'acc': [], 'f1': [], 'kappa': []}\n",
        "        for name in [\n",
        "            'GraphVar+MLP','CPGraphST+MLP','GVARMA+MLP','GGRNN+MLP',\n",
        "            'EEGNet','Transformer','LSTM'\n",
        "        ]\n",
        "    }\n",
        "\n",
        "    fold = 0\n",
        "    for tr_idx, te_idx in skf.split(X_all, y_all):\n",
        "        fold += 1\n",
        "        print(f\"\\n================ Fold {fold}/{n_splits} ================\")\n",
        "        X_tr, y_tr = X_all_t[tr_idx], y_all_t[tr_idx]\n",
        "        X_te, y_te = X_all_t[te_idx], y_all_t[te_idx]\n",
        "\n",
        "        # Graph from TRAIN only\n",
        "        W_dense = compute_global_adjacency(X_tr).to(device)\n",
        "\n",
        "        # Loaders\n",
        "        train_loader = DataLoader(TensorDataset(X_tr, y_tr), batch_size=batch_size, shuffle=True, drop_last=True)\n",
        "        test_loader  = DataLoader(TensorDataset(X_te, y_te), batch_size=batch_size, shuffle=False)\n",
        "\n",
        "        # ---- Build models ----\n",
        "        def make_graph_mlp(backbone_cls, *bb_args, **bb_kwargs):\n",
        "            backbone = backbone_cls(*bb_args, **bb_kwargs)\n",
        "            return GraphBackboneMLP(backbone=backbone, C=C, T=T, num_classes=NUM_CLASSES)\n",
        "\n",
        "        models = {\n",
        "            'GraphVar+MLP':   make_graph_mlp(GraphVarBackbone,  C=C, T=T, W_C=W_dense, fun='sqd', ZAVE=True),\n",
        "\n",
        "            'EEGNet':         EEGNetClassifier(C=C, num_classes=NUM_CLASSES),\n",
        "            'Transformer':    TransformerClassifier(C=C, T=T, num_classes=NUM_CLASSES, d_model=128, nhead=1, num_layers=1, dim_feedforward=256, dropout=0.1),\n",
        "            'LSTM':           LSTMClassifier(C=C, num_classes=NUM_CLASSES, hidden=128, num_layers=2, bidirectional=True, dropout=0.1),\n",
        "        }\n",
        "\n",
        "        # ---- Train / eval ----\n",
        "        for name, model in models.items():\n",
        "            print(f\"\\n--- Fold {fold} | {name} ---\")\n",
        "            acc, f1, kappa = train_and_eval(model, train_loader, test_loader,\n",
        "                                            epochs=num_epochs, lr=LR, wd=WEIGHT_DECAY)\n",
        "            results[name]['acc'].append(acc)\n",
        "            results[name]['f1'].append(f1)\n",
        "            results[name]['kappa'].append(kappa)\n",
        "            print(f\"{name} | Fold {fold}: Acc {acc:.2f}% | F1={f1:.4f} | κ={kappa:.4f}\")\n",
        "\n",
        "    # Summary\n",
        "    print(\"\\n================ OVERALL SUMMARY ==================\")\n",
        "    for name in [\n",
        "        'GraphVar+MLP','CPGraphST+MLP','GVARMA+MLP','GGRNN+MLP',\n",
        "        'EEGNet','Transformer','LSTM'\n",
        "    ]:\n",
        "        accs = np.array(results[name]['acc'], dtype=float)\n",
        "        f1s  = np.array(results[name]['f1'], dtype=float)\n",
        "        kaps = np.array(results[name]['kappa'], dtype=float)\n",
        "        print(f\"{name:16s} Acc: {accs.mean():6.2f}% ± {accs.std(ddof=1):.2f} | \"\n",
        "              f\"F1: {f1s.mean():.4f} ± {f1s.std(ddof=1):.4f} | \"\n",
        "              f\"κ: {kaps.mean():.4f} ± {kaps.std(ddof=1):.4f}\")\n",
        "    return results\n",
        "\n",
        "# -------------------------\n",
        "# Run CV\n",
        "# -------------------------\n",
        "results = run_cv(X_all, y_all, n_splits=N_SPLITS)\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "2sutZ5esJJbo",
        "outputId": "046e88ed-66e7-48a5-c83a-5a33b771c1ce"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Device: cuda\n",
            "Pooled data: N=4723, C=64, T=497 | pos=2341 neg=2382\n",
            "\n",
            "================ Fold 1/5 ================\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "--- Fold 1 | GraphVar+MLP ---\n",
            "Ep 001 | train 0.6857/54.79% | val 0.6741/56.93% F1=0.6560 κ=0.1419 | 2.1s\n",
            "Ep 005 | train 0.5093/75.74% | val 0.5309/71.01% F1=0.7198 κ=0.4204 | 2.0s\n",
            "Ep 010 | train 0.4268/80.46% | val 0.4911/75.34% F1=0.7397 κ=0.5065 | 2.0s\n",
            "Ep 015 | train 0.3894/82.68% | val 0.4736/76.83% F1=0.7559 κ=0.5362 | 2.0s\n",
            "Ep 020 | train 0.3606/83.69% | val 0.4823/78.73% F1=0.8016 κ=0.5751 | 2.0s\n",
            "Ep 025 | train 0.3486/83.95% | val 0.5081/77.14% F1=0.7943 κ=0.5436 | 2.0s\n",
            "Ep 030 | train 0.3134/86.63% | val 0.4669/78.52% F1=0.7805 κ=0.5703 | 2.0s\n",
            "Ep 035 | train 0.2963/87.10% | val 0.4605/80.42% F1=0.8055 κ=0.6085 | 2.0s\n",
            "Ep 040 | train 0.2826/88.11% | val 0.4892/78.52% F1=0.7988 κ=0.5708 | 2.0s\n",
            "Ep 045 | train 0.2628/89.46% | val 0.5201/78.41% F1=0.7577 κ=0.5676 | 2.0s\n",
            "Ep 050 | train 0.2414/90.47% | val 0.5158/77.99% F1=0.7957 κ=0.5603 | 2.0s\n",
            "GraphVar+MLP | Fold 1: Acc 80.95% | F1=0.8117 | κ=0.6191\n",
            "\n",
            "--- Fold 1 | Transformer ---\n",
            "Ep 001 | train 0.6553/61.36% | val 0.5219/75.45% F1=0.7370 κ=0.5085 | 0.7s\n",
            "Ep 005 | train 0.4438/79.18% | val 0.4468/80.32% F1=0.7877 κ=0.6059 | 0.7s\n",
            "Ep 010 | train 0.4402/79.05% | val 0.4340/80.74% F1=0.8112 κ=0.6150 | 0.7s\n",
            "Ep 015 | train 0.4244/80.40% | val 0.4573/80.11% F1=0.7702 κ=0.6013 | 0.7s\n",
            "Ep 020 | train 0.4024/81.78% | val 0.4734/78.73% F1=0.7446 κ=0.5736 | 0.7s\n",
            "Ep 025 | train 0.4089/81.17% | val 0.4055/82.22% F1=0.8194 κ=0.6444 | 0.7s\n",
            "Ep 030 | train 0.4212/80.53% | val 0.4064/81.48% F1=0.8227 κ=0.6299 | 0.7s\n",
            "Ep 035 | train 0.4100/80.75% | val 0.4219/80.53% F1=0.7991 κ=0.6104 | 0.7s\n",
            "Ep 040 | train 0.3891/82.44% | val 0.4454/80.85% F1=0.8018 κ=0.6168 | 0.7s\n",
            "Ep 045 | train 0.3922/82.26% | val 0.4159/80.95% F1=0.8148 κ=0.6192 | 0.7s\n",
            "Ep 050 | train 0.3908/82.44% | val 0.4265/79.47% F1=0.8098 κ=0.5899 | 0.7s\n",
            "Transformer | Fold 1: Acc 81.69% | F1=0.8142 | κ=0.6338\n",
            "\n",
            "--- Fold 1 | LSTM ---\n",
            "Ep 001 | train 0.7754/53.02% | val 0.6088/64.02% F1=0.7059 κ=0.2828 | 1.3s\n",
            "Ep 005 | train 0.4899/75.03% | val 0.5265/74.60% F1=0.7136 κ=0.4912 | 1.2s\n",
            "Ep 010 | train 0.3735/81.99% | val 0.5570/74.07% F1=0.7380 κ=0.4814 | 1.2s\n",
            "Ep 015 | train 0.2737/88.16% | val 0.6296/74.29% F1=0.7279 κ=0.4853 | 1.2s\n",
            "Ep 020 | train 0.1771/92.64% | val 0.7646/75.56% F1=0.7442 κ=0.5108 | 1.2s\n",
            "Ep 025 | train 0.1255/94.68% | val 0.9419/73.86% F1=0.7265 κ=0.4769 | 1.2s\n",
            "Ep 030 | train 0.1060/95.95% | val 1.0495/74.81% F1=0.7302 κ=0.4958 | 1.2s\n",
            "Ep 035 | train 0.0677/97.09% | val 1.1247/75.66% F1=0.7579 κ=0.5133 | 1.2s\n",
            "Ep 040 | train 0.0622/97.59% | val 1.2738/73.23% F1=0.7265 κ=0.4644 | 1.2s\n",
            "Ep 045 | train 0.0634/97.38% | val 1.4717/74.60% F1=0.7260 κ=0.4915 | 1.2s\n",
            "Ep 050 | train 0.0361/98.70% | val 1.5530/74.07% F1=0.7374 κ=0.4814 | 1.2s\n",
            "LSTM | Fold 1: Acc 75.13% | F1=0.7470 | κ=0.5025\n",
            "\n",
            "================ Fold 2/5 ================\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "--- Fold 2 | GraphVar+MLP ---\n",
            "Ep 001 | train 0.6891/52.81% | val 0.6727/60.42% F1=0.5266 κ=0.2061 | 2.0s\n",
            "Ep 005 | train 0.5206/74.66% | val 0.5169/75.87% F1=0.7635 κ=0.5177 | 2.0s\n",
            "Ep 010 | train 0.4341/79.56% | val 0.4706/76.93% F1=0.7588 κ=0.5383 | 2.0s\n",
            "Ep 015 | train 0.3933/82.49% | val 0.4621/77.78% F1=0.7651 κ=0.5551 | 2.0s\n",
            "Ep 020 | train 0.3683/83.58% | val 0.4412/79.37% F1=0.7919 κ=0.5873 | 2.0s\n",
            "Ep 025 | train 0.3444/85.41% | val 0.4399/79.79% F1=0.7940 κ=0.5957 | 2.0s\n",
            "Ep 030 | train 0.3236/86.36% | val 0.4347/80.53% F1=0.8030 κ=0.6105 | 2.0s\n",
            "Ep 035 | train 0.3189/86.26% | val 0.4410/78.84% F1=0.7717 κ=0.5762 | 2.0s\n",
            "Ep 040 | train 0.2854/88.35% | val 0.4276/81.16% F1=0.8142 κ=0.6234 | 2.0s\n",
            "Ep 045 | train 0.2697/88.82% | val 0.4578/78.94% F1=0.7721 κ=0.5783 | 2.0s\n",
            "Ep 050 | train 0.2557/89.57% | val 0.4498/82.54% F1=0.8287 κ=0.6510 | 2.0s\n",
            "GraphVar+MLP | Fold 2: Acc 80.74% | F1=0.8047 | κ=0.6147\n",
            "\n",
            "--- Fold 2 | Transformer ---\n",
            "Ep 001 | train 0.6399/62.95% | val 0.5581/71.01% F1=0.6327 κ=0.4178 | 0.7s\n",
            "Ep 005 | train 0.4598/78.20% | val 0.4466/80.53% F1=0.7783 κ=0.6097 | 0.7s\n",
            "Ep 010 | train 0.4427/78.68% | val 0.4239/80.11% F1=0.8054 κ=0.6023 | 0.7s\n",
            "Ep 015 | train 0.4179/80.80% | val 0.4252/80.11% F1=0.8171 κ=0.6028 | 0.7s\n",
            "Ep 020 | train 0.4178/81.01% | val 0.4280/80.42% F1=0.8126 κ=0.6088 | 0.7s\n",
            "Ep 025 | train 0.4031/81.20% | val 0.4533/78.41% F1=0.7475 κ=0.5671 | 0.7s\n",
            "Ep 030 | train 0.3996/81.51% | val 0.4228/81.48% F1=0.8168 κ=0.6297 | 0.7s\n",
            "Ep 035 | train 0.3953/81.75% | val 0.4255/80.21% F1=0.7956 κ=0.6040 | 0.7s\n",
            "Ep 040 | train 0.3867/82.18% | val 0.4679/81.06% F1=0.8018 κ=0.6209 | 0.7s\n",
            "Ep 045 | train 0.3879/82.20% | val 0.4352/79.26% F1=0.8000 κ=0.5855 | 0.7s\n",
            "Ep 050 | train 0.3953/82.12% | val 0.4170/81.06% F1=0.7987 κ=0.6208 | 0.7s\n",
            "Transformer | Fold 2: Acc 81.48% | F1=0.8152 | κ=0.6297\n",
            "\n",
            "--- Fold 2 | LSTM ---\n",
            "Ep 001 | train 0.7830/52.12% | val 0.7619/59.47% F1=0.3892 κ=0.1843 | 1.2s\n",
            "Ep 005 | train 0.4873/73.94% | val 0.5418/74.07% F1=0.7161 κ=0.4807 | 1.2s\n",
            "Ep 010 | train 0.4015/81.06% | val 0.5614/74.50% F1=0.7630 κ=0.4907 | 1.2s\n",
            "Ep 015 | train 0.3070/86.84% | val 0.6528/72.06% F1=0.6757 κ=0.4398 | 1.2s\n",
            "Ep 020 | train 0.2202/90.70% | val 0.6455/73.65% F1=0.7053 κ=0.4720 | 1.2s\n",
            "Ep 025 | train 0.1432/94.28% | val 0.8962/74.92% F1=0.7524 κ=0.4986 | 1.3s\n",
            "Ep 030 | train 0.1065/95.55% | val 1.0438/74.60% F1=0.7495 κ=0.4922 | 1.2s\n",
            "Ep 035 | train 0.1203/95.07% | val 1.1068/75.03% F1=0.7516 κ=0.5006 | 1.2s\n",
            "Ep 040 | train 0.0708/97.38% | val 1.2430/74.39% F1=0.7398 κ=0.4877 | 1.2s\n",
            "Ep 045 | train 0.0600/97.80% | val 1.4908/73.97% F1=0.7205 κ=0.4787 | 1.2s\n",
            "Ep 050 | train 0.0625/97.54% | val 1.3300/76.51% F1=0.7780 κ=0.5307 | 1.2s\n",
            "LSTM | Fold 2: Acc 73.02% | F1=0.7072 | κ=0.4596\n",
            "\n",
            "================ Fold 3/5 ================\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "--- Fold 3 | GraphVar+MLP ---\n",
            "Ep 001 | train 0.6821/55.93% | val 0.6788/55.77% F1=0.6034 κ=0.1174 | 2.0s\n",
            "Ep 005 | train 0.5049/74.95% | val 0.5368/74.29% F1=0.7379 κ=0.4856 | 2.0s\n",
            "Ep 010 | train 0.4129/81.20% | val 0.4887/76.93% F1=0.7641 κ=0.5385 | 2.0s\n",
            "Ep 015 | train 0.3816/82.23% | val 0.4762/77.67% F1=0.7694 κ=0.5532 | 2.0s\n",
            "Ep 020 | train 0.3593/83.85% | val 0.4855/76.93% F1=0.7523 κ=0.5381 | 2.0s\n",
            "Ep 025 | train 0.3357/85.06% | val 0.4810/78.73% F1=0.7943 κ=0.5749 | 2.0s\n",
            "Ep 030 | train 0.3076/86.23% | val 0.4750/78.10% F1=0.7723 κ=0.5616 | 2.0s\n",
            "Ep 035 | train 0.2955/86.92% | val 0.5091/77.67% F1=0.7497 κ=0.5526 | 2.0s\n",
            "Ep 040 | train 0.2807/87.82% | val 0.4755/78.94% F1=0.7811 κ=0.5786 | 2.0s\n",
            "Ep 045 | train 0.2652/88.77% | val 0.4913/79.68% F1=0.8076 κ=0.5941 | 2.0s\n",
            "Ep 050 | train 0.2520/89.83% | val 0.4993/78.94% F1=0.7811 κ=0.5786 | 2.0s\n",
            "GraphVar+MLP | Fold 3: Acc 78.94% | F1=0.7853 | κ=0.5787\n",
            "\n",
            "--- Fold 3 | Transformer ---\n",
            "Ep 001 | train 0.5997/67.85% | val 0.4814/76.83% F1=0.7580 κ=0.5362 | 0.7s\n",
            "Ep 005 | train 0.4148/79.87% | val 0.4696/77.57% F1=0.7558 κ=0.5507 | 0.7s\n",
            "Ep 010 | train 0.3880/82.52% | val 0.4603/79.37% F1=0.8123 κ=0.5881 | 0.7s\n",
            "Ep 015 | train 0.3879/81.67% | val 0.4481/80.11% F1=0.8017 κ=0.6022 | 0.7s\n",
            "Ep 020 | train 0.3778/82.52% | val 0.4995/77.14% F1=0.7943 κ=0.5439 | 0.7s\n",
            "Ep 025 | train 0.3686/82.94% | val 0.4468/79.79% F1=0.8033 κ=0.5960 | 0.7s\n",
            "Ep 030 | train 0.3649/83.05% | val 0.4844/79.37% F1=0.8067 κ=0.5879 | 0.7s\n",
            "Ep 035 | train 0.3633/83.40% | val 0.4675/79.89% F1=0.8049 κ=0.5982 | 0.7s\n",
            "Ep 040 | train 0.3617/83.50% | val 0.4792/79.05% F1=0.7838 κ=0.5807 | 0.7s\n",
            "Ep 045 | train 0.3585/83.63% | val 0.4657/79.47% F1=0.7755 κ=0.5888 | 0.7s\n",
            "Ep 050 | train 0.3569/83.74% | val 0.4803/79.05% F1=0.7852 κ=0.5808 | 0.7s\n",
            "Transformer | Fold 3: Acc 79.58% | F1=0.7953 | κ=0.5916\n",
            "\n",
            "--- Fold 3 | LSTM ---\n",
            "Ep 001 | train 0.6900/60.46% | val 0.6229/64.44% F1=0.7259 κ=0.2930 | 1.2s\n",
            "Ep 005 | train 0.4833/75.32% | val 0.5434/70.90% F1=0.7428 κ=0.4195 | 1.2s\n",
            "Ep 010 | train 0.3891/80.77% | val 0.6115/71.96% F1=0.6936 κ=0.4383 | 1.2s\n",
            "Ep 015 | train 0.2784/87.21% | val 0.6906/73.02% F1=0.7059 κ=0.4595 | 1.2s\n",
            "Ep 020 | train 0.1892/92.21% | val 0.8467/73.97% F1=0.7383 κ=0.4794 | 1.2s\n",
            "Ep 025 | train 0.1262/94.60% | val 1.0042/73.76% F1=0.7449 κ=0.4755 | 1.2s\n",
            "Ep 030 | train 0.1055/95.68% | val 1.1808/72.06% F1=0.6865 κ=0.4402 | 1.2s\n",
            "Ep 035 | train 0.1148/95.37% | val 1.0262/74.39% F1=0.7515 κ=0.4882 | 1.2s\n",
            "Ep 040 | train 0.0689/97.33% | val 1.3591/72.91% F1=0.7388 κ=0.4586 | 1.2s\n",
            "Ep 045 | train 0.0391/98.25% | val 1.6186/73.97% F1=0.7338 κ=0.4792 | 1.2s\n",
            "Ep 050 | train 0.0588/98.17% | val 1.5141/72.70% F1=0.7340 κ=0.4543 | 1.2s\n",
            "LSTM | Fold 3: Acc 72.70% | F1=0.7068 | κ=0.4533\n",
            "\n",
            "================ Fold 4/5 ================\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "--- Fold 4 | GraphVar+MLP ---\n",
            "Ep 001 | train 0.6854/55.16% | val 0.6696/60.06% F1=0.5871 κ=0.2009 | 2.0s\n",
            "Ep 005 | train 0.5335/73.65% | val 0.5135/77.22% F1=0.7691 κ=0.5444 | 2.0s\n",
            "Ep 010 | train 0.4383/79.18% | val 0.4754/76.80% F1=0.7456 κ=0.5354 | 2.0s\n",
            "Ep 015 | train 0.4039/81.04% | val 0.4484/78.81% F1=0.7942 κ=0.5765 | 2.0s\n",
            "Ep 020 | train 0.3694/83.42% | val 0.4382/79.87% F1=0.8021 κ=0.5976 | 2.0s\n",
            "Ep 025 | train 0.3486/84.48% | val 0.4396/79.77% F1=0.8065 κ=0.5957 | 2.0s\n",
            "Ep 030 | train 0.3350/85.04% | val 0.4374/80.40% F1=0.8152 κ=0.6085 | 2.0s\n",
            "Ep 035 | train 0.2978/87.63% | val 0.4433/80.93% F1=0.8073 κ=0.6186 | 2.0s\n",
            "Ep 040 | train 0.2818/88.59% | val 0.4512/80.30% F1=0.8125 κ=0.6063 | 2.0s\n",
            "Ep 045 | train 0.2761/89.04% | val 0.4532/80.40% F1=0.8063 κ=0.6082 | 2.0s\n",
            "Ep 050 | train 0.2455/90.20% | val 0.4693/79.87% F1=0.8004 κ=0.5975 | 2.0s\n",
            "GraphVar+MLP | Fold 4: Acc 80.08% | F1=0.8089 | κ=0.6020\n",
            "\n",
            "--- Fold 4 | Transformer ---\n",
            "Ep 001 | train 0.6065/65.52% | val 0.5310/73.94% F1=0.7261 κ=0.4784 | 0.7s\n",
            "Ep 005 | train 0.4363/80.08% | val 0.4198/79.98% F1=0.7805 κ=0.5990 | 0.7s\n",
            "Ep 010 | train 0.4066/81.09% | val 0.3905/82.42% F1=0.8234 κ=0.6483 | 0.7s\n",
            "Ep 015 | train 0.3983/81.33% | val 0.3912/80.83% F1=0.8181 κ=0.6169 | 0.7s\n",
            "Ep 020 | train 0.3903/81.67% | val 0.4013/81.04% F1=0.8081 κ=0.6207 | 0.7s\n",
            "Ep 025 | train 0.3780/82.39% | val 0.4091/81.25% F1=0.8277 κ=0.6256 | 0.7s\n",
            "Ep 030 | train 0.3855/81.54% | val 0.3931/81.67% F1=0.8158 κ=0.6335 | 0.7s\n",
            "Ep 035 | train 0.3736/82.73% | val 0.4166/82.10% F1=0.8238 κ=0.6421 | 0.7s\n",
            "Ep 040 | train 0.3727/82.20% | val 0.3970/81.04% F1=0.8215 κ=0.6212 | 0.7s\n",
            "Ep 045 | train 0.3743/82.65% | val 0.4319/80.30% F1=0.7983 κ=0.6058 | 0.7s\n",
            "Ep 050 | train 0.3689/82.28% | val 0.4090/80.30% F1=0.8067 κ=0.6061 | 0.7s\n",
            "Transformer | Fold 4: Acc 81.36% | F1=0.8174 | κ=0.6273\n",
            "\n",
            "--- Fold 4 | LSTM ---\n",
            "Ep 001 | train 0.6867/59.93% | val 0.5835/66.95% F1=0.7111 κ=0.3406 | 1.2s\n",
            "Ep 005 | train 0.4905/75.03% | val 0.4713/76.80% F1=0.7673 κ=0.5360 | 1.2s\n",
            "Ep 010 | train 0.3892/81.36% | val 0.4955/75.21% F1=0.7741 κ=0.5051 | 1.3s\n",
            "Ep 015 | train 0.3102/86.20% | val 0.5324/75.53% F1=0.7260 κ=0.5097 | 1.2s\n",
            "Ep 020 | train 0.1936/91.95% | val 0.6543/77.44% F1=0.7646 κ=0.5484 | 1.2s\n",
            "Ep 025 | train 0.1178/95.66% | val 0.8136/77.44% F1=0.7914 κ=0.5494 | 1.2s\n",
            "Ep 030 | train 0.0988/95.76% | val 0.8226/77.44% F1=0.7599 κ=0.5483 | 1.2s\n",
            "Ep 035 | train 0.0564/97.59% | val 1.0780/78.18% F1=0.7854 κ=0.5637 | 1.2s\n",
            "Ep 040 | train 0.0551/97.83% | val 1.1324/78.50% F1=0.7843 κ=0.5699 | 1.2s\n",
            "Ep 045 | train 0.0488/98.07% | val 1.1548/77.97% F1=0.7704 κ=0.5591 | 1.2s\n",
            "Ep 050 | train 0.0272/99.05% | val 1.3277/79.77% F1=0.7948 κ=0.5953 | 1.2s\n",
            "LSTM | Fold 4: Acc 76.80% | F1=0.7673 | κ=0.5360\n",
            "\n",
            "================ Fold 5/5 ================\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/transformer.py:392: UserWarning: enable_nested_tensor is True, but self.use_nested_tensor is False because encoder_layer.self_attn.num_heads is odd\n",
            "  warnings.warn(\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\n",
            "--- Fold 5 | GraphVar+MLP ---\n",
            "Ep 001 | train 0.6829/55.80% | val 0.6769/55.72% F1=0.3509 κ=0.1097 | 2.0s\n",
            "Ep 005 | train 0.5077/75.40% | val 0.5436/69.07% F1=0.6422 κ=0.3800 | 2.0s\n",
            "Ep 010 | train 0.4300/79.58% | val 0.4899/73.62% F1=0.7585 κ=0.4733 | 2.0s\n",
            "Ep 015 | train 0.4007/81.73% | val 0.5106/75.00% F1=0.7819 κ=0.5013 | 2.0s\n",
            "Ep 020 | train 0.3681/83.13% | val 0.4509/79.13% F1=0.7852 κ=0.5825 | 2.0s\n",
            "Ep 025 | train 0.3413/84.16% | val 0.4641/76.91% F1=0.7483 κ=0.5375 | 2.0s\n",
            "Ep 030 | train 0.3270/85.75% | val 0.4512/78.60% F1=0.7955 κ=0.5724 | 2.0s\n",
            "Ep 035 | train 0.2946/87.76% | val 0.4433/80.83% F1=0.8120 κ=0.6167 | 2.0s\n",
            "Ep 040 | train 0.2769/88.45% | val 0.4536/78.92% F1=0.7746 κ=0.5780 | 2.0s\n",
            "Ep 045 | train 0.2710/88.35% | val 0.4665/79.87% F1=0.8013 κ=0.5976 | 2.0s\n",
            "Ep 050 | train 0.2526/89.59% | val 0.4723/79.98% F1=0.7930 κ=0.5994 | 2.0s\n",
            "GraphVar+MLP | Fold 5: Acc 80.72% | F1=0.8000 | κ=0.6142\n",
            "\n",
            "--- Fold 5 | Transformer ---\n",
            "Ep 001 | train 0.5768/67.16% | val 0.4859/75.53% F1=0.7674 κ=0.5111 | 0.7s\n",
            "Ep 005 | train 0.4195/80.67% | val 0.5268/75.42% F1=0.7899 κ=0.5099 | 0.7s\n",
            "Ep 010 | train 0.4040/81.54% | val 0.4353/79.77% F1=0.8144 κ=0.5960 | 0.7s\n",
            "Ep 015 | train 0.3948/81.25% | val 0.4783/77.97% F1=0.7541 κ=0.5586 | 0.7s\n",
            "Ep 020 | train 0.3897/82.42% | val 0.4180/79.24% F1=0.7846 κ=0.5845 | 0.7s\n",
            "Ep 025 | train 0.3772/82.63% | val 0.4422/79.34% F1=0.8105 κ=0.5875 | 0.7s\n",
            "Ep 030 | train 0.3748/82.97% | val 0.4143/79.77% F1=0.7953 κ=0.5953 | 0.7s\n",
            "Ep 035 | train 0.3896/81.81% | val 0.4325/80.08% F1=0.8139 κ=0.6022 | 0.7s\n",
            "Ep 040 | train 0.3807/82.81% | val 0.4257/80.40% F1=0.8075 κ=0.6082 | 0.7s\n",
            "Ep 045 | train 0.3683/82.73% | val 0.4532/79.98% F1=0.8021 κ=0.5997 | 0.7s\n",
            "Ep 050 | train 0.3704/82.97% | val 0.4277/80.72% F1=0.8131 κ=0.6146 | 0.7s\n",
            "Transformer | Fold 5: Acc 80.61% | F1=0.8055 | κ=0.6123\n",
            "\n",
            "--- Fold 5 | LSTM ---\n",
            "Ep 001 | train 0.7114/55.46% | val 0.5734/67.80% F1=0.6667 κ=0.3556 | 1.2s\n",
            "Ep 005 | train 0.5019/75.21% | val 0.5615/70.02% F1=0.6320 κ=0.3986 | 1.2s\n",
            "Ep 010 | train 0.3969/81.36% | val 0.5402/73.31% F1=0.6934 κ=0.4650 | 1.2s\n",
            "Ep 015 | train 0.2851/87.29% | val 0.5992/76.17% F1=0.7711 κ=0.5237 | 1.2s\n",
            "Ep 020 | train 0.2081/90.81% | val 0.7314/75.74% F1=0.7680 κ=0.5152 | 1.2s\n",
            "Ep 025 | train 0.1240/94.73% | val 1.0263/73.83% F1=0.7645 κ=0.4777 | 1.2s\n",
            "Ep 030 | train 0.0918/96.37% | val 1.3502/71.50% F1=0.6650 κ=0.4287 | 1.2s\n",
            "Ep 035 | train 0.0614/97.48% | val 1.2819/74.15% F1=0.7359 κ=0.4829 | 1.2s\n",
            "Ep 040 | train 0.0555/98.04% | val 1.4802/73.83% F1=0.7430 κ=0.4769 | 1.2s\n",
            "Ep 045 | train 0.0633/97.56% | val 1.4459/74.36% F1=0.7453 κ=0.4874 | 1.2s\n",
            "Ep 050 | train 0.0522/97.78% | val 1.5961/72.25% F1=0.7337 κ=0.4454 | 1.2s\n",
            "LSTM | Fold 5: Acc 73.31% | F1=0.7110 | κ=0.4654\n",
            "\n",
            "================ OVERALL SUMMARY ==================\n",
            "GraphVar+MLP     Acc:  80.29% ± 0.82 | F1: 0.8021 ± 0.0104 | κ: 0.6058 ± 0.0164\n",
            "CPGraphST+MLP    Acc:    nan% ± nan | F1: nan ± nan | κ: nan ± nan\n",
            "GVARMA+MLP       Acc:    nan% ± nan | F1: nan ± nan | κ: nan ± nan\n",
            "GGRNN+MLP        Acc:    nan% ± nan | F1: nan ± nan | κ: nan ± nan\n",
            "EEGNet           Acc:    nan% ± nan | F1: nan ± nan | κ: nan ± nan\n",
            "Transformer      Acc:  80.94% ± 0.87 | F1: 0.8095 ± 0.0091 | κ: 0.6189 ± 0.0173\n",
            "LSTM             Acc:  74.19% ± 1.74 | F1: 0.7279 ± 0.0277 | κ: 0.4834 ± 0.0351\n"
          ]
        },
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/tmp/ipython-input-3016021800.py:530: RuntimeWarning: Mean of empty slice.\n",
            "  print(f\"{name:16s} Acc: {accs.mean():6.2f}% ± {accs.std(ddof=1):.2f} | \"\n",
            "/usr/local/lib/python3.12/dist-packages/numpy/_core/_methods.py:138: RuntimeWarning: invalid value encountered in scalar divide\n",
            "  ret = ret.dtype.type(ret / rcount)\n",
            "/usr/local/lib/python3.12/dist-packages/numpy/_core/_methods.py:218: RuntimeWarning: Degrees of freedom <= 0 for slice\n",
            "  ret = _var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,\n",
            "/usr/local/lib/python3.12/dist-packages/numpy/_core/_methods.py:175: RuntimeWarning: invalid value encountered in divide\n",
            "  arrmean = um.true_divide(arrmean, div, out=arrmean,\n",
            "/usr/local/lib/python3.12/dist-packages/numpy/_core/_methods.py:210: RuntimeWarning: invalid value encountered in scalar divide\n",
            "  ret = ret.dtype.type(ret / rcount)\n",
            "/tmp/ipython-input-3016021800.py:531: RuntimeWarning: Mean of empty slice.\n",
            "  f\"F1: {f1s.mean():.4f} ± {f1s.std(ddof=1):.4f} | \"\n",
            "/tmp/ipython-input-3016021800.py:532: RuntimeWarning: Mean of empty slice.\n",
            "  f\"κ: {kaps.mean():.4f} ± {kaps.std(ddof=1):.4f}\")\n"
          ]
        }
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "A100",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}