{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, glob, warnings, math, random\n",
    "warnings.filterwarnings(\"ignore\", category=RuntimeWarning)\n",
    "\n",
    "USE_JAX = True\n",
    "try:\n",
    "    import jax\n",
    "    import jax.numpy as xp\n",
    "    from jax import jit\n",
    "    JAX_BACKEND = jax.devices()[0].platform\n",
    "    print(f\"[Backend] JAX on {JAX_BACKEND} ({jax.devices()[0]})\")\n",
    "except Exception:\n",
    "    USE_JAX = False\n",
    "    import numpy as xp\n",
    "    print(\"[Backend] Using NumPy (CPU). Enable TPU for speed if desired.\")\n",
    "\n",
    "import numpy as np\n",
    "import scipy.io as sio\n",
    "import pandas as pd\n",
    "from scipy import stats\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Echo versions for reproducibility\n",
    "print(\"[Versions] numpy\", np.__version__)\n",
    "if USE_JAX:\n",
    "    import jaxlib\n",
    "    print(\"[Versions] jax\", jax.__version__, \"| jaxlib\", jaxlib.__version__)\n",
    "\n",
    "# Config\n",
    "DATA_ROOTS = [\n",
    "    \"/kaggle/input/ssvep-sandiego\",\n",
    "    \"/kaggle/input\",\n",
    "]\n",
    "FS_FALLBACK = 250            \n",
    "BANDPASS = (5.0, 45.0)\n",
    "NOTCH = None                 \n",
    "\n",
    "FB_SUBBANDS = [(6,14), (14,22), (22,30), (30,38), (38,46)]\n",
    "FB_WEIGHTS_POW = 1.25\n",
    "\n",
    "TRCA_REG = 1e-4\n",
    "EPS = 1e-8\n",
    "\n",
    "RNG_SEED = 2025\n",
    "np.random.seed(RNG_SEED); random.seed(RNG_SEED)\n",
    "\n",
    "OUTDIR = \"/kaggle/working\"\n",
    "os.makedirs(OUTDIR, exist_ok=True)\n",
    "\n",
    "\n",
    "def _to_numpy_numeric(x):\n",
    "    if isinstance(x, np.ndarray) and x.dtype != object and x.ndim >= 2 and x.size > 0:\n",
    "        return x\n",
    "    return None\n",
    "\n",
    "def _maybe_cell_to_list(x):\n",
    "    if isinstance(x, np.ndarray) and x.dtype == object:\n",
    "        out = []\n",
    "        for i in np.ndindex(x.shape):\n",
    "            v = x[i]\n",
    "            if isinstance(v, np.ndarray):\n",
    "                out.append(v)\n",
    "        return out\n",
    "    return None\n",
    "\n",
    "def _find_first_key(mat, candidates):\n",
    "    for k in candidates:\n",
    "        if k in mat:\n",
    "            return k\n",
    "    return None\n",
    "\n",
    "def _extract_fs(mat, default=FS_FALLBACK):\n",
    "    k = _find_first_key(mat, [\"fs\",\"Fs\",\"srate\",\"sfreq\",\"fsamp\",\"SamplingRate\",\"sampling_rate\"])\n",
    "    if k and np.isscalar(mat[k]): return float(mat[k])\n",
    "    return float(default)\n",
    "\n",
    "def _extract_freqs(mat):\n",
    "    k = _find_first_key(mat, [\"freqs\",\"FREQS\",\"stimulus_frequencies\",\"sfreqs\",\"Freq\",\"frequency\",\"Frequencies\"])\n",
    "    if k is not None:\n",
    "        try:\n",
    "            v = np.array(mat[k], dtype=float).reshape(-1)\n",
    "            if 2 <= v.size <= 128: return v\n",
    "        except Exception:\n",
    "            pass\n",
    "    return None\n",
    "\n",
    "def _extract_labels(mat):\n",
    "    k = _find_first_key(mat, [\"trainLabel\",\"testLabel\",\"label\",\"labels\",\"y\",\"Y\",\"stimulus\",\"class\",\"classes\"])\n",
    "    if k is not None:\n",
    "        try:\n",
    "            return np.array(mat[k]).reshape(-1)\n",
    "        except Exception:\n",
    "            return None\n",
    "    return None\n",
    "\n",
    "def _as_class_trial_ch_samp_from_3d(X3, labels=None, freqs=None):\n",
    "    shp = X3.shape\n",
    "    samp_ax = int(np.argmax(shp))\n",
    "    ch_ax = None\n",
    "    for ax, d in enumerate(shp):\n",
    "        if d in (8,9,16,32,64) and ax != samp_ax:\n",
    "            ch_ax = ax; break\n",
    "    if ch_ax is None:\n",
    "        order = np.argsort(shp)\n",
    "        for ax in order[::-1]:\n",
    "            if ax != samp_ax:\n",
    "                ch_ax = int(ax); break\n",
    "    axes = [0,1,2]; axes.remove(samp_ax); axes.remove(ch_ax)\n",
    "    trial_ax = axes[0]\n",
    "    X_tcs = np.moveaxis(X3, (trial_ax, ch_ax, samp_ax), (0,1,2))\n",
    "\n",
    "    if labels is not None and labels.size == X_tcs.shape[0]:\n",
    "        uniq = np.unique(labels)\n",
    "        label_to_idx = {v:i for i,v in enumerate(uniq)}\n",
    "        groups = {label_to_idx[v]: [] for v in uniq}\n",
    "        for i in range(X_tcs.shape[0]):\n",
    "            groups[label_to_idx[labels[i]]].append(X_tcs[i])\n",
    "        per_class = [np.stack(groups[i], axis=0) for i in range(len(uniq))]\n",
    "        minT = min(x.shape[0] for x in per_class)\n",
    "        per_class = [x[:minT] for x in per_class]\n",
    "        C = len(per_class); T = minT; Ch = X_tcs.shape[1]; S = X_tcs.shape[2]\n",
    "        data4 = np.zeros((C, T, Ch, S), dtype=X_tcs.dtype)\n",
    "        for c in range(C): data4[c] = per_class[c]\n",
    "        return data4\n",
    "    else:\n",
    "        T_total = X_tcs.shape[0]\n",
    "        if freqs is not None and freqs.size >= 2 and T_total % freqs.size == 0:\n",
    "            C = freqs.size; T = T_total // C\n",
    "            data4 = np.zeros((C, T, X_tcs.shape[1], X_tcs.shape[2]), dtype=X_tcs.dtype)\n",
    "            idx = 0\n",
    "            for c in range(C):\n",
    "                data4[c] = X_tcs[idx:idx+T]; idx += T\n",
    "            return data4\n",
    "        for C in [4, 12, 40]:\n",
    "            if T_total % C == 0:\n",
    "                T = T_total // C\n",
    "                data4 = np.zeros((C, T, X_tcs.shape[1], X_tcs.shape[2]), dtype=X_tcs.dtype)\n",
    "                idx = 0\n",
    "                for c in range(C):\n",
    "                    data4[c] = X_tcs[idx:idx+T]; idx += T\n",
    "                return data4\n",
    "        raise RuntimeError(\"Cannot infer classes without labels/freqs for a 3-D tensor.\")\n",
    "\n",
    "def _from_cell_per_class(cell_list):\n",
    "    class_tensors = []\n",
    "    for arr in cell_list:\n",
    "        if not isinstance(arr, np.ndarray) or arr.ndim < 3: continue\n",
    "        perms = [(0,1,2),(0,2,1),(1,2,0),(1,0,2),(2,0,1),(2,1,0)]\n",
    "        best = None\n",
    "        for p in perms:\n",
    "            A = np.transpose(arr, p)\n",
    "            if A.shape[0] >= 1 and A.shape[1] in (8,9,16,32,64) and A.shape[2] >= 64:\n",
    "                best = A; break\n",
    "        if best is None: best = np.transpose(arr, (2,0,1))\n",
    "        class_tensors.append(best)\n",
    "    if not class_tensors: raise RuntimeError(\"Empty/invalid cell array.\")\n",
    "    minT = min(c.shape[0] for c in class_tensors)\n",
    "    class_tensors = [c[:minT] for c in class_tensors]\n",
    "    C = len(class_tensors); T = minT; Ch = class_tensors[0].shape[1]; S = class_tensors[0].shape[2]\n",
    "    data4 = np.zeros((C, T, Ch, S), dtype=class_tensors[0].dtype)\n",
    "    for c in range(C): data4[c] = class_tensors[c]\n",
    "    return data4\n",
    "\n",
    "def _parse_one_mat(mat):\n",
    "    fs = _extract_fs(mat, default=FS_FALLBACK)\n",
    "    freqs = _extract_freqs(mat)\n",
    "    labels = _extract_labels(mat)\n",
    "\n",
    "    for _, v in mat.items():\n",
    "        vnp = _to_numpy_numeric(v)\n",
    "        if vnp is not None and vnp.ndim == 4 and vnp.size > 10000:\n",
    "            dims = list(vnp.shape)\n",
    "            samp_ax = int(np.argmax(dims))\n",
    "            ch_ax = None\n",
    "            for ax, d in enumerate(dims):\n",
    "                if d in (8,9,16,32,64) and ax != samp_ax: ch_ax = ax; break\n",
    "            if ch_ax is None:\n",
    "                rem = [0,1,2,3]; rem.remove(samp_ax)\n",
    "                ch_ax = max(rem, key=lambda ax: dims[ax])\n",
    "            rem = [0,1,2,3]; rem.remove(samp_ax); rem.remove(ch_ax)\n",
    "            a1, a2 = rem\n",
    "            d1, d2 = dims[a1], dims[a2]\n",
    "            trial_ax, class_ax = (a1, a2) if d1 < d2 else (a2, a1)\n",
    "            data4 = np.moveaxis(vnp, (class_ax, trial_ax, ch_ax, samp_ax), (0,1,2,3))\n",
    "            return data4.astype(np.float64), fs, freqs\n",
    "\n",
    "    for key in [\"data\",\"EEG\",\"trainEEG\",\"testEEG\",\"X\",\"eeg\",\"signals\",\"Data\"]:\n",
    "        if key in mat and isinstance(mat[key], np.ndarray) and mat[key].dtype == object:\n",
    "            cell_list = _maybe_cell_to_list(mat[key])\n",
    "            if cell_list:\n",
    "                data4 = _from_cell_per_class(cell_list)\n",
    "                return data4.astype(np.float64), fs, freqs\n",
    "\n",
    "    for key in [\"EEG\",\"trainEEG\",\"testEEG\",\"data\",\"X\",\"eeg\",\"signals\",\"Data\"]:\n",
    "        if key in mat:\n",
    "            vnp = _to_numpy_numeric(mat[key])\n",
    "            if vnp is not None and vnp.ndim == 3:\n",
    "                data4 = _as_class_trial_ch_samp_from_3d(vnp, labels=labels, freqs=freqs)\n",
    "                return data4.astype(np.float64), fs, freqs\n",
    "\n",
    "    class_arrays = []\n",
    "    for _, v in mat.items():\n",
    "        vnp = _to_numpy_numeric(v)\n",
    "        if vnp is not None and vnp.ndim == 3:\n",
    "            class_arrays.append(vnp)\n",
    "    if class_arrays:\n",
    "        normed = []\n",
    "        for arr in class_arrays:\n",
    "            if arr.shape[1] in (8,9,16,32,64): A = arr\n",
    "            else: A = np.transpose(arr, (2,0,1))\n",
    "            normed.append(A)\n",
    "        minT = min(a.shape[0] for a in normed)\n",
    "        normed = [a[:minT] for a in normed]\n",
    "        C = len(normed); T = minT; Ch = normed[0].shape[1]; S = normed[0].shape[2]\n",
    "        data4 = np.zeros((C, T, Ch, S), dtype=normed[0].dtype)\n",
    "        for c in range(C): data4[c] = normed[c]\n",
    "        return data4.astype(np.float64), fs, freqs\n",
    "\n",
    "    return None, None, None\n",
    "\n",
    "def load_subject_pair(train_path, test_path=None):\n",
    "    mat_tr = sio.loadmat(train_path, squeeze_me=True, struct_as_record=False)\n",
    "    Xtr, fs, freqs = _parse_one_mat(mat_tr)\n",
    "    if Xtr is None: raise RuntimeError(f\"No usable EEG tensor in {os.path.basename(train_path)}\")\n",
    "\n",
    "    if test_path is not None and os.path.exists(test_path):\n",
    "        mat_te = sio.loadmat(test_path, squeeze_me=True, struct_as_record=False)\n",
    "        Xte, fs2, freqs2 = _parse_one_mat(mat_te)\n",
    "        if Xte is not None and Xte.shape[0:3] == Xtr.shape[0:3]:\n",
    "            T_min = min(Xtr.shape[1], Xte.shape[1])\n",
    "            X = np.concatenate([Xtr[:, :T_min], Xte[:, :T_min]], axis=1)\n",
    "            fs = fs2 if fs2 is not None else fs\n",
    "            if freqs is None: freqs = freqs2\n",
    "            return X, fs, freqs\n",
    "    return Xtr, fs, freqs\n",
    "\n",
    "def load_all_subjects():\n",
    "    root = None\n",
    "    for r in DATA_ROOTS:\n",
    "        if os.path.exists(r): root = r; break\n",
    "    if root is None: raise FileNotFoundError(\"Attach the dataset under /kaggle/input first.\")\n",
    "    mats = sorted(glob.glob(os.path.join(root, \"**\", \"*.mat\"), recursive=True))\n",
    "    if not mats: raise FileNotFoundError(\"No .mat files found.\")\n",
    "\n",
    "    train_files = [p for p in mats if \"train\" in os.path.basename(p).lower() and \"eeg\" in os.path.basename(p).lower()]\n",
    "    test_files  = [p for p in mats if \"test\"  in os.path.basename(p).lower() and \"eeg\" in os.path.basename(p).lower()]\n",
    "    def twin_for(p, pool):\n",
    "        bn = os.path.basename(p)\n",
    "        cand = bn.lower().replace(\"train\", \"test\")\n",
    "        for q in pool:\n",
    "            if os.path.basename(q).lower() == cand: return q\n",
    "        return None\n",
    "\n",
    "    subjects, fs_list = [], []\n",
    "    global_freqs = None\n",
    "\n",
    "    if train_files:\n",
    "        for tr in sorted(train_files):\n",
    "            te = twin_for(tr, test_files)\n",
    "            try:\n",
    "                Xsub, fs, freqs = load_subject_pair(tr, te)\n",
    "                subjects.append(Xsub); fs_list.append(fs)\n",
    "                if global_freqs is None and freqs is not None: global_freqs = freqs\n",
    "            except Exception as e:\n",
    "                print(f\"[Loader] Skipping {os.path.basename(tr)}: {e}\")\n",
    "    else:\n",
    "        for m in mats:\n",
    "            try:\n",
    "                mat = sio.loadmat(m, squeeze_me=True, struct_as_record=False)\n",
    "                Xsub, fs, freqs = _parse_one_mat(mat)\n",
    "                if Xsub is not None:\n",
    "                    subjects.append(Xsub); fs_list.append(fs)\n",
    "                    if global_freqs is None and freqs is not None: global_freqs = freqs\n",
    "                else:\n",
    "                    print(f\"[Loader] Skipping {os.path.basename(m)}: no usable tensor\")\n",
    "            except Exception as e:\n",
    "                print(f\"[Loader] Skipping {os.path.basename(m)}: {e}\")\n",
    "\n",
    "    if len(subjects) == 0:\n",
    "        raise RuntimeError(\"No usable subjects parsed after robust loader.\")\n",
    "\n",
    "    C0, _, Ch0, S0 = subjects[0].shape\n",
    "    minT = min(x.shape[1] for x in subjects)\n",
    "    subjects = [x[:, :minT] for x in subjects]\n",
    "    for x in subjects:\n",
    "        assert x.shape[0] == C0 and x.shape[1] == minT and x.shape[2] == Ch0 and x.shape[3] == S0\n",
    "\n",
    "    fs = int(round(float(np.median(fs_list)))) if fs_list else FS_FALLBACK\n",
    "    X = np.stack(subjects, axis=0)  \n",
    "    return X, fs, global_freqs\n",
    "\n",
    "def _fft_bandpass_np(x, fs, lo, hi, notch=None):\n",
    "    X = np.fft.rfft(x, axis=-1)\n",
    "    freqs = np.fft.rfftfreq(x.shape[-1], d=1.0/fs)\n",
    "    mask = (freqs >= lo) & (freqs <= hi)\n",
    "    Y = np.zeros_like(X); Y[..., mask] = X[..., mask]\n",
    "    if notch is not None:\n",
    "        notch_bw = 0.5\n",
    "        nmask = (freqs >= (notch - notch_bw)) & (freqs <= (notch + notch_bw))\n",
    "        Y[..., nmask] = 0.0\n",
    "    return np.fft.irfft(Y, n=x.shape[-1], axis=-1)\n",
    "\n",
    "if USE_JAX:\n",
    "    @jit\n",
    "    def _fft_bandpass_jax(x, fs, lo, hi, notch):\n",
    "        X = xp.fft.rfft(x, axis=-1)\n",
    "        freqs = xp.fft.rfftfreq(x.shape[-1], d=1.0/fs)\n",
    "        mask = (freqs >= lo) & (freqs <= hi)\n",
    "        Y = xp.where(mask, X, 0.0 + 0.0j)\n",
    "        if notch is not None and notch > 0.0:\n",
    "            notch_bw = 0.5\n",
    "            nmask = (freqs >= (notch - notch_bw)) & (freqs <= (notch + notch_bw))\n",
    "            Y = xp.where(nmask, 0.0 + 0.0j, Y)\n",
    "        return xp.fft.irfft(Y, n=x.shape[-1], axis=-1)\n",
    "\n",
    "def fft_bandpass(x, fs, lo, hi, notch=None):\n",
    "    return _fft_bandpass_jax(xp.asarray(x), fs, lo, hi, notch) if USE_JAX else _fft_bandpass_np(np.asarray(x), fs, lo, hi, notch)\n",
    "\n",
    "#  TRCA (spatial filter & scoring)\n",
    "def trca_spatial_filter(X_trials, reg=TRCA_REG, use_jax=USE_JAX):\n",
    "    T, C, S = X_trials.shape\n",
    "    if use_jax:\n",
    "        XiXiT = xp.einsum(\"tcs,tks->ck\", X_trials, X_trials)\n",
    "        Q = XiXiT + reg * xp.eye(C)\n",
    "        Xsum = xp.sum(X_trials, axis=0)\n",
    "        Snum = xp.zeros((C, C))\n",
    "        for t in range(T):\n",
    "            Xi = X_trials[t]; Xrest = Xsum - Xi\n",
    "            Snum = Snum + Xi @ Xrest.T\n",
    "        Snum = 0.5 * (Snum + Snum.T)\n",
    "        evals, evecs = xp.linalg.eigh(xp.linalg.pinv(Q) @ Snum)\n",
    "        w = xp.real(evecs[:, xp.argmax(xp.real(evals))])\n",
    "        w = w / (xp.linalg.norm(w) + EPS)\n",
    "        return w\n",
    "    else:\n",
    "        XiXiT = np.einsum(\"tcs,tks->ck\", X_trials, X_trials)\n",
    "        Q = XiXiT + reg * np.eye(C)\n",
    "        Xsum = np.sum(X_trials, axis=0)\n",
    "        Snum = np.zeros((C, C))\n",
    "        for t in range(T):\n",
    "            Xi = X_trials[t]; Xrest = Xsum - Xi\n",
    "            Snum += Xi @ Xrest.T\n",
    "        Snum = 0.5 * (Snum + Snum.T)\n",
    "        evals, evecs = np.linalg.eigh(np.linalg.pinv(Q) @ Snum)\n",
    "        w = np.real(evecs[:, np.argmax(np.real(evals))])\n",
    "        w = w / (np.linalg.norm(w) + EPS)\n",
    "        return w\n",
    "\n",
    "def trca_template(X_trials): return X_trials.mean(axis=0)\n",
    "\n",
    "def corr_pearson(a, b):\n",
    "    am = a - a.mean(); bm = b - b.mean()\n",
    "    num = (am * bm).sum()\n",
    "    den = xp.linalg.norm(am) * xp.linalg.norm(bm) + EPS\n",
    "    return num / den\n",
    "\n",
    "def trca_score_epoch(x_epoch, w, tpl):\n",
    "    s  = w @ x_epoch\n",
    "    st = w @ tpl\n",
    "    return float(corr_pearson(s, st))\n",
    "\n",
    "#  CORAL (unsupervised domain alignment)\n",
    "def coral_fit(Xs, Xt, eps=1e-6, use_jax=USE_JAX):\n",
    "    if use_jax:\n",
    "        mu_s = xp.mean(Xs, axis=0); mu_t = xp.mean(Xt, axis=0)\n",
    "        Xs0 = Xs - mu_s; Xt0 = Xt - mu_t\n",
    "        Cs = (Xs0.T @ Xs0) / (Xs0.shape[0] - 1)\n",
    "        Ct = (Xt0.T @ Xt0) / (Xt0.shape[0] - 1)\n",
    "        evals_s, evecs_s = xp.linalg.eigh(Cs + eps*xp.eye(Cs.shape[0]))\n",
    "        evals_t, evecs_t = xp.linalg.eigh(Ct + eps*xp.eye(Ct.shape[0]))\n",
    "        Cs_inv_sqrt = evecs_s @ xp.diag(1.0/xp.sqrt(evals_s)) @ evecs_s.T\n",
    "        Ct_sqrt     = evecs_t @ xp.diag(xp.sqrt(evals_t))     @ evecs_t.T\n",
    "        A = Cs_inv_sqrt @ Ct_sqrt\n",
    "        b = (mu_t - (mu_s @ A))\n",
    "        return A, b\n",
    "    else:\n",
    "        mu_s = np.mean(Xs, axis=0); mu_t = np.mean(Xt, axis=0)\n",
    "        Xs0 = Xs - mu_s; Xt0 = Xt - mu_t\n",
    "        Cs = (Xs0.T @ Xs0) / (Xs0.shape[0] - 1)\n",
    "        Ct = (Xt0.T @ Xt0) / (Xt0.shape[0] - 1)\n",
    "        evals_s, evecs_s = np.linalg.eigh(Cs + eps*np.eye(Cs.shape[0]))\n",
    "        evals_t, evecs_t = np.linalg.eigh(Ct + eps*np.eye(Ct.shape[0]))\n",
    "        Cs_inv_sqrt = evecs_s @ np.diag(1.0/np.sqrt(evals_s)) @ evecs_s.T\n",
    "        Ct_sqrt     = evecs_t @ np.diag(np.sqrt(evals_t))     @ evecs_t.T\n",
    "        A = Cs_inv_sqrt @ Ct_sqrt\n",
    "        b = (mu_t - (mu_s @ A))\n",
    "        return A, b\n",
    "\n",
    "def coral_apply(X, A, b): return X @ A + b\n",
    "\n",
    "#  FBCCA helpers (TPU-safe CCA: symmetric eig via eigh)\n",
    "def _build_ref_bank(class_freqs, fs, nsamp, harmonics=3, use_jax=USE_JAX):\n",
    "    t = (xp.arange(nsamp)/fs) if use_jax else (np.arange(nsamp)/fs)\n",
    "    bank = []\n",
    "    for f in class_freqs:\n",
    "        comps = []\n",
    "        for h in range(1, harmonics+1):\n",
    "            comps.append(xp.sin(2*xp.pi*h*f*t) if use_jax else np.sin(2*np.pi*h*f*t))\n",
    "            comps.append(xp.cos(2*xp.pi*h*f*t) if use_jax else np.cos(2*np.pi*h*f*t))\n",
    "        R = xp.stack(comps, axis=1) if use_jax else np.stack(comps, axis=1)\n",
    "        bank.append(R)\n",
    "    return bank\n",
    "\n",
    "def _max_cca_corr(Xs, Ys, eps=1e-8, use_jax=USE_JAX):\n",
    "    if use_jax:\n",
    "        X = Xs - xp.mean(Xs, axis=0)\n",
    "        Y = Ys - xp.mean(Ys, axis=0)\n",
    "        Sxx = X.T @ X + eps*xp.eye(X.shape[1])\n",
    "        Syy = Y.T @ Y + eps*xp.eye(Y.shape[1])\n",
    "        Sxy = X.T @ Y\n",
    "        invSxx = xp.linalg.pinv(Sxx)\n",
    "        invSyy = xp.linalg.pinv(Syy)\n",
    "        M = invSxx @ Sxy @ invSyy @ Sxy.T\n",
    "        M = 0.5 * (M + M.T)  \n",
    "        ev = xp.linalg.eigh(M)[0]\n",
    "        rho2 = float(xp.max(xp.real(ev)))\n",
    "        return float(np.sqrt(max(rho2, 0.0)))\n",
    "    else:\n",
    "        X = Xs - Xs.mean(axis=0)\n",
    "        Y = Ys - Ys.mean(axis=0)\n",
    "        Sxx = X.T @ X + eps*np.eye(X.shape[1])\n",
    "        Syy = Y.T @ Y + eps*np.eye(Y.shape[1])\n",
    "        Sxy = X.T @ Y\n",
    "        invSxx = np.linalg.pinv(Sxx)\n",
    "        invSyy = np.linalg.pinv(Syy)\n",
    "        M = invSxx @ Sxy @ invSyy @ Sxy.T\n",
    "        M = 0.5 * (M + M.T)\n",
    "        ev = np.linalg.eigh(M)[0]\n",
    "        rho2 = float(np.max(np.real(ev)))\n",
    "        return float(np.sqrt(max(rho2, 0.0)))\n",
    "\n",
    "def _estimate_class_freqs_from_training(train_by_c, fs, fmin=6.0, fmax=20.0):\n",
    "    est = []\n",
    "    for c in range(len(train_by_c)):\n",
    "        Xc = np.concatenate(train_by_c[c], axis=0)\n",
    "        tpl = Xc.mean(axis=0).mean(axis=0)\n",
    "        S = tpl.shape[0]\n",
    "        freqs = np.fft.rfftfreq(S, d=1.0/fs)\n",
    "        P = np.abs(np.fft.rfft(tpl))**2\n",
    "        band = (freqs >= fmin) & (freqs <= fmax)\n",
    "        est.append(float(freqs[band][np.argmax(P[band])]) if np.any(band) else 10.0)\n",
    "    return np.array(est, dtype=float)\n",
    "\n",
    "def _fbcca_score_epoch(x_epoch, cand_c, ref_bank, fs, subbands, use_jax=USE_JAX):\n",
    "    scores = []\n",
    "    for (lo, hi) in subbands:\n",
    "        x_f = fft_bandpass(x_epoch, fs, lo, hi, NOTCH)\n",
    "        Xs = (x_f.T) if use_jax else x_f.T\n",
    "        R  = ref_bank[cand_c]\n",
    "        rho = _max_cca_corr(Xs, R, use_jax=use_jax)\n",
    "        scores.append(rho)\n",
    "    return scores\n",
    "\n",
    "#  Utilities\n",
    "def itr_bits_per_min(P, N, Tsec):\n",
    "    if N <= 1 or P <= 0.0 or P >= 1.0: return 0.0\n",
    "    return float((math.log2(N) + P*math.log2(P) + (1-P)*math.log2((1-P)/(N-1))) * (60.0/Tsec))\n",
    "\n",
    "def plot_confmat(cm, classes, title, path):\n",
    "    fig = plt.figure(figsize=(6,5))\n",
    "    plt.imshow(cm, interpolation='nearest')\n",
    "    plt.title(title)\n",
    "    plt.xlabel(\"Predicted\")\n",
    "    plt.ylabel(\"True\")\n",
    "    plt.xticks(ticks=np.arange(len(classes)), labels=classes, rotation=45, ha='right')\n",
    "    plt.yticks(ticks=np.arange(len(classes)), labels=classes)\n",
    "    for i in range(cm.shape[0]):\n",
    "        for j in range(cm.shape[1]):\n",
    "            plt.text(j, i, f\"{cm[i,j]}\", ha=\"center\", va=\"center\")\n",
    "    plt.tight_layout()\n",
    "    fig.savefig(path, bbox_inches='tight', dpi=150)\n",
    "    plt.close(fig)\n",
    "\n",
    "def bar_with_ci(values, labels, mean_label, title, path):\n",
    "    vals = np.asarray(values, dtype=float)\n",
    "    fig = plt.figure(figsize=(8,4))\n",
    "    x = np.arange(len(vals))\n",
    "    plt.bar(x, vals)\n",
    "    plt.xticks(x, labels, rotation=45, ha='right')\n",
    "    plt.ylabel(\"Accuracy (%)\")\n",
    "    plt.title(title)\n",
    "    mu = vals.mean(); se = vals.std(ddof=1)/np.sqrt(len(vals)); ci95 = 1.96*se\n",
    "    plt.axhline(mu, linestyle='--')\n",
    "    plt.text(len(vals)-0.5, mu+0.5, f\"{mean_label}: {mu:.2f} ± {ci95:.2f}\")\n",
    "    plt.tight_layout()\n",
    "    fig.savefig(path, bbox_inches='tight', dpi=150)\n",
    "    plt.close(fig)\n",
    "\n",
    "#  Core evaluation (LOSO) with extras\n",
    "def evaluate_loso(\n",
    "    X_all, fs, freqs=None, harmonics=3,\n",
    "    use_coral=True,\n",
    "    window_seconds=None,          \n",
    "    adaptive_fusion=True,          \n",
    "    csv_prefix=\"ssvep_loso_pub\"\n",
    "):\n",
    "    \"\"\"\n",
    "    Returns: dict with summary paths and per-subject arrays\n",
    "    Saves:\n",
    "      - summary CSV\n",
    "      - per-subject confusion matrices (png)\n",
    "      - accuracy bar plots + CI (png)\n",
    "    \"\"\"\n",
    "    Subj, C, T, Ch, S_full = X_all.shape\n",
    "    if window_seconds is None:\n",
    "        S_use = S_full\n",
    "    else:\n",
    "        S_use = int(round(window_seconds * fs))\n",
    "        S_use = max(64, min(S_use, S_full))\n",
    "\n",
    "    M = len(FB_SUBBANDS)\n",
    "    fb_w = xp.asarray([1.0/((m+1)**FB_WEIGHTS_POW) for m in range(M)])\n",
    "\n",
    "    subjects = [f\"S{i+1:02d}\" for i in range(Subj)]\n",
    "    classes  = [f\"{i}\" for i in range(C)]\n",
    "\n",
    "    rows_summary = []\n",
    "    acc_tr_list, acc_fb_list, acc_en_list = [], [], []\n",
    "    confmats = {\"TRCA\": [], \"FBCCA\": [], \"ENS\": []}\n",
    "\n",
    "    for s_te in range(Subj):\n",
    "        idx_tr_np = np.array([i for i in range(Subj) if i != s_te], dtype=int)\n",
    "        X_tr = xp.take(X_all, xp.asarray(idx_tr_np), axis=0) if USE_JAX else X_all[idx_tr_np]\n",
    "        X_te = X_all[s_te]\n",
    "\n",
    "        if S_use != S_full:\n",
    "            X_tr_np = np.array(X_tr)[..., :S_use] \n",
    "            X_te_np = np.array(X_te)[..., :S_use]  \n",
    "            X_tr = xp.asarray(X_tr_np) if USE_JAX else X_tr_np\n",
    "            X_te = X_te_np\n",
    "        else:\n",
    "            X_te = np.array(X_te)\n",
    "\n",
    "        S = S_use\n",
    "\n",
    "        train_by_c = {c: [] for c in range(C)}\n",
    "        for s in range(X_tr.shape[0]):\n",
    "            for c in range(C):\n",
    "                train_by_c[c].append(np.array(X_tr[s, c]))\n",
    "\n",
    "        W = [[None for _ in range(M)] for __ in range(C)]\n",
    "        TPL = [[None for _ in range(M)] for __ in range(C)]\n",
    "        for c in range(C):\n",
    "            Xc_all = np.concatenate(train_by_c[c], axis=0)   # (T_all, Ch, S)\n",
    "            for m, (lo, hi) in enumerate(FB_SUBBANDS):\n",
    "                Xc_f = fft_bandpass(Xc_all, fs, lo, hi, NOTCH)\n",
    "                if USE_JAX: Xc_f = xp.asarray(Xc_f)\n",
    "                w = trca_spatial_filter(Xc_f, reg=TRCA_REG, use_jax=USE_JAX)\n",
    "                tpl = trca_template(Xc_f)\n",
    "                W[c][m] = w; TPL[c][m] = tpl\n",
    "\n",
    "        src_X = []\n",
    "        for c_true in range(C):\n",
    "            Xc_src = np.concatenate(train_by_c[c_true], axis=0)\n",
    "            for t_idx in range(Xc_src.shape[0]):\n",
    "                x = Xc_src[t_idx]\n",
    "                vec = []\n",
    "                for m in range(M):\n",
    "                    w, tpl = W[c_true][m], TPL[c_true][m]\n",
    "                    x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)\n",
    "                    if USE_JAX: x_f = xp.asarray(x_f)\n",
    "                    vec.append(trca_score_epoch(x_f, w, tpl))\n",
    "                for c_other in range(C):\n",
    "                    if c_other == c_true: continue\n",
    "                    for m in range(M):\n",
    "                        w, tpl = W[c_other][m], TPL[c_other][m]\n",
    "                        x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)\n",
    "                        if USE_JAX: x_f = xp.asarray(x_f)\n",
    "                        vec.append(trca_score_epoch(x_f, w, tpl))\n",
    "                src_X.append(vec)\n",
    "        src_X = xp.asarray(src_X)\n",
    "\n",
    "        te_pairs = [(c, t) for c in range(C) for t in range(T)]\n",
    "        te_feats = []\n",
    "        for (c_label, t_idx) in te_pairs:\n",
    "            x = np.array(X_te[c_label, t_idx])\n",
    "            vec = []\n",
    "            for m in range(M):\n",
    "                w, tpl = W[c_label][m], TPL[c_label][m]\n",
    "                x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)\n",
    "                if USE_JAX: x_f = xp.asarray(x_f)\n",
    "                vec.append(trca_score_epoch(x_f, w, tpl))\n",
    "            for c_other in range(C):\n",
    "                if c_other == c_label: continue\n",
    "                for m in range(M):\n",
    "                    w, tpl = W[c_other][m], TPL[c_other][m]\n",
    "                    x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)\n",
    "                    if USE_JAX: x_f = xp.asarray(x_f)\n",
    "                    vec.append(trca_score_epoch(x_f, w, tpl))\n",
    "            te_feats.append(vec)\n",
    "        te_feats = xp.asarray(te_feats)\n",
    "\n",
    "        if use_coral:\n",
    "            A, b = coral_fit(src_X, te_feats, eps=1e-6, use_jax=USE_JAX)\n",
    "        else:\n",
    "            d = src_X.shape[1]\n",
    "            A = xp.eye(d); b = xp.zeros((d,))\n",
    "\n",
    "        if freqs is not None and len(freqs) == C:\n",
    "            class_freqs = np.array(freqs, dtype=float)\n",
    "        else:\n",
    "            class_freqs = _estimate_class_freqs_from_training(train_by_c, fs)\n",
    "        ref_bank = _build_ref_bank(class_freqs, fs, S, harmonics=harmonics, use_jax=USE_JAX)\n",
    "\n",
    "        w_trca, w_fb = 1.0, 1.0\n",
    "        if adaptive_fusion:\n",
    "            acc_pairs = []\n",
    "            for s_val_i in range(X_tr.shape[0]):\n",
    "                X_val = np.array(X_tr[s_val_i])\n",
    "                y_true_v, y_tr_v, y_fb_v = [], [], []\n",
    "                for c in range(C):\n",
    "                    for t_idx in range(T):\n",
    "                        x = X_val[c, t_idx]\n",
    "                        tr_scores = []\n",
    "                        for c_hat in range(C):\n",
    "                            vec = []\n",
    "                            for m in range(M):\n",
    "                                w, tpl = W[c_hat][m], TPL[c_hat][m]\n",
    "                                x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)\n",
    "                                if USE_JAX: x_f = xp.asarray(x_f)\n",
    "                                vec.append(trca_score_epoch(x_f, w, tpl))\n",
    "                            for c_other in range(C):\n",
    "                                if c_other == c_hat: continue\n",
    "                                for m in range(M):\n",
    "                                    w, tpl = W[c_other][m], TPL[c_other][m]\n",
    "                                    x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)\n",
    "                                    if USE_JAX: x_f = xp.asarray(x_f)\n",
    "                                    vec.append(trca_score_epoch(x_f, w, tpl))\n",
    "                            v = xp.asarray(vec)[None,:]\n",
    "                            v_al = coral_apply(v, A, b)[0]\n",
    "                            v_block = v_al[:len(FB_SUBBANDS)]\n",
    "                            tr_scores.append(float((v_block * fb_w).sum()))\n",
    "                        fb_scores = []\n",
    "                        weights_np = 1.0/((np.arange(M)+1)**FB_WEIGHTS_POW)\n",
    "                        for c_hat in range(C):\n",
    "                            sb_scores = _fbcca_score_epoch(x, c_hat, ref_bank, fs, FB_SUBBANDS, use_jax=USE_JAX)\n",
    "                            fb_scores.append(float(np.sum(np.asarray(sb_scores)*weights_np)))\n",
    "                        y_true_v.append(c)\n",
    "                        y_tr_v.append(int(np.argmax(tr_scores)))\n",
    "                        y_fb_v.append(int(np.argmax(fb_scores)))\n",
    "                acc_pairs.append( (np.mean(np.array(y_tr_v)==np.array(y_true_v)),\n",
    "                                   np.mean(np.array(y_fb_v)==np.array(y_true_v))) )\n",
    "            mean_tr, mean_fb = np.mean([a for a,b in acc_pairs]), np.mean([b for a,b in acc_pairs])\n",
    "            w_trca = float(mean_tr + 1e-3); w_fb = float(mean_fb + 1e-3)\n",
    "\n",
    "        y_true, y_trca, y_fbcca, y_ens = [], [], [], []\n",
    "        cm_tr = np.zeros((C,C), dtype=int)\n",
    "        cm_fb = np.zeros((C,C), dtype=int)\n",
    "        cm_en = np.zeros((C,C), dtype=int)\n",
    "\n",
    "        for c in range(C):\n",
    "            for t_idx in range(T):\n",
    "                x = np.array(X_te[c, t_idx])\n",
    "                trca_scores = []\n",
    "                for c_hat in range(C):\n",
    "                    vec = []\n",
    "                    for m in range(M):\n",
    "                        w, tpl = W[c_hat][m], TPL[c_hat][m]\n",
    "                        x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)\n",
    "                        if USE_JAX: x_f = xp.asarray(x_f)\n",
    "                        vec.append(trca_score_epoch(x_f, w, tpl))\n",
    "                    for c_other in range(C):\n",
    "                        if c_other == c_hat: continue\n",
    "                        for m in range(M):\n",
    "                            w, tpl = W[c_other][m], TPL[c_other][m]\n",
    "                            x_f = fft_bandpass(x, fs, *FB_SUBBANDS[m], NOTCH)\n",
    "                            if USE_JAX: x_f = xp.asarray(x_f)\n",
    "                            vec.append(trca_score_epoch(x_f, w, tpl))\n",
    "                    v = xp.asarray(vec)[None,:]\n",
    "                    v_al = coral_apply(v, A, b)[0]\n",
    "                    v_block = v_al[:len(FB_SUBBANDS)]\n",
    "                    trca_scores.append(float((v_block * fb_w).sum()))\n",
    "                # FBCCA\n",
    "                fbcca_scores = []\n",
    "                weights_np = 1.0/((np.arange(M)+1)**FB_WEIGHTS_POW)\n",
    "                for c_hat in range(C):\n",
    "                    sb_scores = _fbcca_score_epoch(x, c_hat, ref_bank, fs, FB_SUBBANDS, use_jax=USE_JAX)\n",
    "                    fbcca_scores.append(float(np.sum(np.asarray(sb_scores)*weights_np)))\n",
    "                # Ensemble (z-norm per head across candidates; adaptive weight)\n",
    "                tr = np.asarray(trca_scores); fb = np.asarray(fbcca_scores)\n",
    "                tr_z = (tr - tr.mean()) / (tr.std() + 1e-6)\n",
    "                fb_z = (fb - fb.mean()) / (fb.std() + 1e-6)\n",
    "                ens = w_trca*tr_z + w_fb*fb_z\n",
    "\n",
    "                yt = c; pr_tr = int(np.argmax(tr)); pr_fb = int(np.argmax(fb)); pr_en = int(np.argmax(ens))\n",
    "                y_true.append(yt); y_trca.append(pr_tr); y_fbcca.append(pr_fb); y_ens.append(pr_en)\n",
    "                cm_tr[yt, pr_tr] += 1; cm_fb[yt, pr_fb] += 1; cm_en[yt, pr_en] += 1\n",
    "\n",
    "        acc_tr = 100.0 * float((np.asarray(y_trca) == np.asarray(y_true)).mean())\n",
    "        acc_fb = 100.0 * float((np.asarray(y_fbcca) == np.asarray(y_true)).mean())\n",
    "        acc_en = 100.0 * float((np.asarray(y_ens)  == np.asarray(y_true)).mean())\n",
    "        itr_en = itr_bits_per_min(acc_en/100.0, C, S/fs)\n",
    "\n",
    "        print(f\"[LOSO] {subjects[s_te]}: TRCA(CORAL={use_coral})={acc_tr:.2f}% | FBCCA={acc_fb:.2f}% | ENS(adapt)={acc_en:.2f}% | ITR={itr_en:.2f}\")\n",
    "\n",
    "        rows_summary.append({\n",
    "            \"subject\": subjects[s_te],\n",
    "            \"use_coral\": use_coral,\n",
    "            \"window_sec\": S/fs,\n",
    "            \"acc_trca\": round(acc_tr,2),\n",
    "            \"acc_fbcca\": round(acc_fb,2),\n",
    "            \"acc_ensemble\": round(acc_en,2),\n",
    "            \"itr_ensemble\": round(itr_en,2),\n",
    "            \"w_trca\": round(w_trca,4),\n",
    "            \"w_fbcca\": round(w_fb,4),\n",
    "        })\n",
    "\n",
    "        confmats[\"TRCA\"].append(cm_tr)\n",
    "        confmats[\"FBCCA\"].append(cm_fb)\n",
    "        confmats[\"ENS\"].append(cm_en)\n",
    "\n",
    "        plot_confmat(cm_tr, classes, f\"{subjects[s_te]} TRCA (CORAL={use_coral})\", os.path.join(OUTDIR, f\"{csv_prefix}_{subjects[s_te]}_cm_trca.png\"))\n",
    "        plot_confmat(cm_fb, classes, f\"{subjects[s_te]} FBCCA\", os.path.join(OUTDIR, f\"{csv_prefix}_{subjects[s_te]}_cm_fbcca.png\"))\n",
    "        plot_confmat(cm_en, classes, f\"{subjects[s_te]} Ensemble\", os.path.join(OUTDIR, f\"{csv_prefix}_{subjects[s_te]}_cm_ens.png\"))\n",
    "\n",
    "        acc_tr_list.append(acc_tr); acc_fb_list.append(acc_fb); acc_en_list.append(acc_en)\n",
    "\n",
    "    df_sum = pd.DataFrame(rows_summary)\n",
    "    sum_path = os.path.join(OUTDIR, f\"{csv_prefix}_summary.csv\")\n",
    "    df_sum.to_csv(sum_path, index=False)\n",
    "    print(\"[Saved]\", sum_path)\n",
    "\n",
    "    bar_with_ci(acc_tr_list, subjects, \"TRCA mean±95%CI\",\n",
    "                f\"TRCA Accuracies (CORAL={use_coral}, win={S/fs:.2f}s)\",\n",
    "                os.path.join(OUTDIR, f\"{csv_prefix}_bars_trca.png\"))\n",
    "    bar_with_ci(acc_fb_list, subjects, \"FBCCA mean±95%CI\",\n",
    "                f\"FBCCA Accuracies (win={S/fs:.2f}s)\",\n",
    "                os.path.join(OUTDIR, f\"{csv_prefix}_bars_fbcca.png\"))\n",
    "    bar_with_ci(acc_en_list, subjects, \"Ensemble mean±95%CI\",\n",
    "                f\"Ensemble Accuracies (adapt={adaptive_fusion}, win={S/fs:.2f}s)\",\n",
    "                os.path.join(OUTDIR, f\"{csv_prefix}_bars_ensemble.png\"))\n",
    "\n",
    "    return {\n",
    "        \"summary_path\": sum_path,\n",
    "        \"acc_tr\": np.array(acc_tr_list),\n",
    "        \"acc_fb\": np.array(acc_fb_list),\n",
    "        \"acc_en\": np.array(acc_en_list),\n",
    "        \"confmats\": confmats,\n",
    "        \"subjects\": subjects,\n",
    "        \"C\": C,\n",
    "        \"T\": T,\n",
    "        \"fs\": fs,\n",
    "        \"S\": S,\n",
    "    }\n",
    "\n",
    "#  Ablations and Window-Length Study + Paired Stats\n",
    "if __name__ == \"__main__\":\n",
    "    X_np, fs, freqs = load_all_subjects()\n",
    "    print(f\"[Load] X shape = {X_np.shape} | fs={fs} | freqs={None if freqs is None else freqs.tolist()}\")\n",
    "\n",
    "    for s in range(X_np.shape[0]):\n",
    "        for c in range(X_np.shape[1]):\n",
    "            X_np[s, c] = fft_bandpass(X_np[s, c], fs, *BANDPASS, NOTCH)\n",
    "\n",
    "    X_dev = xp.asarray(X_np) if USE_JAX else X_np\n",
    "\n",
    "    res_full = evaluate_loso(\n",
    "        X_dev, fs, freqs, harmonics=3,\n",
    "        use_coral=True, window_seconds=None, adaptive_fusion=True,\n",
    "        csv_prefix=\"ssvep_pub_full_coral_adapt\"\n",
    "    )\n",
    "\n",
    "    res_nocoral = evaluate_loso(\n",
    "        X_dev, fs, freqs, harmonics=3,\n",
    "        use_coral=False, window_seconds=None, adaptive_fusion=True,\n",
    "        csv_prefix=\"ssvep_pub_full_nocoral_adapt\"\n",
    "    )\n",
    "\n",
    "    window_list = [1.0, 2.0, 3.0, X_np.shape[-1]/fs]\n",
    "    sweep_rows = []\n",
    "    for wsec in window_list:\n",
    "        res_w = evaluate_loso(\n",
    "            X_dev, fs, freqs, harmonics=3,\n",
    "            use_coral=True, window_seconds=wsec, adaptive_fusion=True,\n",
    "            csv_prefix=f\"ssvep_pub_win{wsec:.1f}s\"\n",
    "        )\n",
    "        sweep_rows.append({\n",
    "            \"window_sec\": wsec,\n",
    "            \"acc_tr_mean\": float(np.mean(res_w[\"acc_tr\"])),\n",
    "            \"acc_fb_mean\": float(np.mean(res_w[\"acc_fb\"])),\n",
    "            \"acc_en_mean\": float(np.mean(res_w[\"acc_en\"])),\n",
    "            \"itr_en_mean\": float(np.mean([itr_bits_per_min(a/100.0, res_w[\"C\"], wsec) for a in res_w[\"acc_en\"]]))\n",
    "        })\n",
    "    df_sweep = pd.DataFrame(sweep_rows)\n",
    "    sweep_path = os.path.join(OUTDIR, \"ssvep_pub_window_sweep.csv\")\n",
    "    df_sweep.to_csv(sweep_path, index=False)\n",
    "    print(\"[Saved]\", sweep_path)\n",
    "\n",
    "    fig1 = plt.figure(figsize=(6,4))\n",
    "    plt.plot(df_sweep[\"window_sec\"], df_sweep[\"acc_en_mean\"], marker='o')\n",
    "    plt.xlabel(\"Window length (s)\"); plt.ylabel(\"Ensemble Accuracy (%)\")\n",
    "    plt.title(\"Ensemble Accuracy vs Window\")\n",
    "    fig1.savefig(os.path.join(OUTDIR, \"ssvep_pub_acc_vs_window.png\"), bbox_inches='tight', dpi=150)\n",
    "    plt.close(fig1)\n",
    "\n",
    "    fig2 = plt.figure(figsize=(6,4))\n",
    "    plt.plot(df_sweep[\"window_sec\"], df_sweep[\"itr_en_mean\"], marker='o')\n",
    "    plt.xlabel(\"Window length (s)\"); plt.ylabel(\"ITR (bits/min)\")\n",
    "    plt.title(\"ITR vs Window\")\n",
    "    fig2.savefig(os.path.join(OUTDIR, \"ssvep_pub_itr_vs_window.png\"), bbox_inches='tight', dpi=150)\n",
    "    plt.close(fig2)\n",
    "\n",
    "    def paired_t(a, b):\n",
    "        t, p = stats.ttest_rel(a, b)\n",
    "        return float(t), float(p)\n",
    "\n",
    "    t_tr_fb, p_tr_fb = paired_t(res_full[\"acc_tr\"], res_full[\"acc_fb\"])\n",
    "    t_fb_en, p_fb_en = paired_t(res_full[\"acc_fb\"], res_full[\"acc_en\"])\n",
    "    t_tr_en, p_tr_en = paired_t(res_full[\"acc_tr\"], res_full[\"acc_en\"])\n",
    "\n",
    "    t_en_coral, p_en_coral = paired_t(res_full[\"acc_en\"], res_nocoral[\"acc_en\"])\n",
    "\n",
    "    stats_rows = [\n",
    "        {\"comparison\": \"TRCA vs FBCCA (full, CORAL)\", \"t\": t_tr_fb, \"p\": p_tr_fb},\n",
    "        {\"comparison\": \"FBCCA vs ENS (full, CORAL)\", \"t\": t_fb_en, \"p\": p_fb_en},\n",
    "        {\"comparison\": \"TRCA vs ENS (full, CORAL)\", \"t\": t_tr_en, \"p\": p_tr_en},\n",
    "        {\"comparison\": \"ENS (CORAL) vs ENS (noCORAL)\", \"t\": t_en_coral, \"p\": p_en_coral},\n",
    "    ]\n",
    "    df_stats = pd.DataFrame(stats_rows)\n",
    "    stats_path = os.path.join(OUTDIR, \"ssvep_pub_paired_stats.csv\")\n",
    "    df_stats.to_csv(stats_path, index=False)\n",
    "    print(\"[Saved]\", stats_path)\n",
    "\n",
    "    print(\"\\n=== FINAL SUMMARY (for manuscript) ===\")\n",
    "    print(\"Full window (baseline):\")\n",
    "    print(f\"  Mean TRCA  = {np.mean(res_full['acc_tr']):.2f}%\")\n",
    "    print(f\"  Mean FBCCA = {np.mean(res_full['acc_fb']):.2f}%\")\n",
    "    print(f\"  Mean ENS   = {np.mean(res_full['acc_en']):.2f}%\")\n",
    "    print(\"Paired t-tests (p-values):\")\n",
    "    for r in stats_rows:\n",
    "        print(f\"  {r['comparison']}: t={r['t']:.3f}, p={r['p']:.3g}\")\n",
    "    print(\"\\nWindow sweep (CSV & PNGs saved).\")\n",
    "    print(\"Figures/CSVs in:\", OUTDIR)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
