{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a2bb42d4",
   "metadata": {},
   "source": [
    "# ArtifactGen Class Comparison Notebook\n",
    "\n",
    "Comprehensive per-class comparison of DDPM vs WGAN-GP generative models for EEG artifact window synthesis.\n",
    "This notebook evaluates model performance across all artifact classes (e.g., chewing, muscle, eye movement, etc.).\n",
    "For each class:\n",
    "- Load real samples from that class\n",
    "- Generate synthetic samples for that class\n",
    "- Compute metrics (MMD, PSD, diversity, etc.)\n",
    "- Aggregate results for comparison\n",
    "\n",
    "> NOTE: This focuses on per-class metrics to assess how well models generalize across different artifact types."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ed0a52af",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Imports & environment setup\n",
    "import os\n",
    "import sys\n",
    "import json\n",
    "import math\n",
    "import random\n",
    "import textwrap\n",
    "import importlib\n",
    "import pathlib\n",
    "import itertools\n",
    "from pathlib import Path\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import yaml\n",
    "import scipy.signal as sps\n",
    "from scipy.spatial.distance import cdist\n",
    "from sklearn.manifold import TSNE\n",
    "try:\n",
    "    import umap\n",
    "    HAVE_UMAP = True\n",
    "except Exception:\n",
    "    HAVE_UMAP = False\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "sns.set_context('talk')\n",
    "sns.set_style('whitegrid')\n",
    "\n",
    "# Always save figures to ArtifactGen/paper/figs\n",
    "FIG_DIR = Path('C:/works/ArtifactGen/paper/figs')\n",
    "FIG_DIR.mkdir(parents=True, exist_ok=True)\n",
    "print('Figures ->', FIG_DIR.resolve())\n",
    "\n",
    "# Add project root to path so we can import src.* reliably when running from notebook folder\n",
    "ROOT = Path(__file__).parent.parent if '__file__' in globals() else Path('..')\n",
    "if (ROOT / 'src').exists():\n",
    "    sys.path.append(str(ROOT))\n",
    "\n",
    "from src.eval.generate import generate_samples  # uses models under src.models\n",
    "from src.models import UNet1D, WGANGPGenerator  # (ensures availability)\n",
    "\n",
    "RESULTS_DIR = ROOT / 'results'\n",
    "GEN_DIR = RESULTS_DIR / 'generated'\n",
    "CKPT_DIR = RESULTS_DIR / 'checkpoints'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e342e38d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Force rescale real_samples and R to [-1, 1] per channel for all analysis and plots\n",
    "def scale_to_unit(x):\n",
    "    # x: (N, C, L) or (C, L)\n",
    "    if isinstance(x, np.ndarray) and x.size > 0:\n",
    "        if x.ndim == 3:\n",
    "            min_c = x.min(axis=(0,2), keepdims=True)\n",
    "            max_c = x.max(axis=(0,2), keepdims=True)\n",
    "            denom = (max_c - min_c) + 1e-9\n",
    "            x = 2 * (x - min_c) / denom - 1\n",
    "        elif x.ndim == 2:\n",
    "            min_c = x.min(axis=1, keepdims=True)\n",
    "            max_c = x.max(axis=1, keepdims=True)\n",
    "            denom = (max_c - min_c) + 1e-9\n",
    "            x = 2 * (x - min_c) / denom - 1\n",
    "    return x\n",
    "if 'real_samples' in locals():\n",
    "    real_samples = scale_to_unit(real_samples)\n",
    "if 'R' in locals():\n",
    "    R = scale_to_unit(R)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "790e2c68",
   "metadata": {},
   "source": [
    "## Load Configs\n",
    "We extract key hyperparameters for each model to contextualize metric differences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f64ba1d1",
   "metadata": {},
   "outputs": [],
   "source": [
    "ddpm_cfg_path = Path('../configs/ddpm_raw.yaml') if (Path.cwd().name == 'notebooks') else Path('configs/ddpm_raw.yaml')\n",
    "wgan_cfg_path = Path('../configs/wgan_raw.yaml') if (Path.cwd().name == 'notebooks') else Path('configs/wgan_raw.yaml')\n",
    "with open(ddpm_cfg_path, 'r', encoding='utf-8') as f: ddpm_cfg = yaml.safe_load(f)\n",
    "with open(wgan_cfg_path, 'r', encoding='utf-8') as f: wgan_cfg = yaml.safe_load(f)\n",
    "def cfg_summary(name, cfg):\n",
    "    m = cfg['model']; tr = cfg['training']; data = cfg['data']\n",
    "    return {\n",
    "        'model': name,\n",
    "        'channels': m.get('channels'),\n",
    "        'length': m.get('length'),\n",
    "        'num_classes': m.get('num_classes'),\n",
    "        'epochs': tr.get('epochs'),\n",
    "        'batch_size': tr.get('batch_size'),\n",
    "        'lr': tr.get('lr'),\n",
    "        'window_seconds': data.get('window_seconds'),\n",
    "        'sample_rate': data.get('sample_rate'),\n",
    "    }\n",
    "cfg_table = pd.DataFrame([cfg_summary('DDPM', ddpm_cfg), cfg_summary('WGAN-GP', wgan_cfg)])\n",
    "cfg_table"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca34b5c5",
   "metadata": {},
   "source": [
    "## Per-Class Evaluation\n",
    "Loop over all classes, load/generate samples, and compute metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b7da48f",
   "metadata": {},
   "outputs": [],
   "source": [
    "DO_GENERATE_DDPM = True\n",
    "DO_GENERATE_WGAN = True\n",
    "N_GEN = 256  # smaller for per-class\n",
    "REAL_LIMIT = 500  # per class\n",
    "\n",
    "def find_checkpoint(dir_: Path, pattern='*.pth'):\n",
    "    cks = sorted(dir_.glob(pattern), key=lambda p: p.stat().st_mtime)\n",
    "    return cks[-1] if cks else None\n",
    "\n",
    "ddpm_ckpt = CKPT_DIR / 'ddpm_unet_best.pth'\n",
    "if not ddpm_ckpt.exists():\n",
    "    ddpm_ckpt = find_checkpoint(CKPT_DIR)\n",
    "wgan_ckpt = CKPT_DIR / 'wgan_generator_best.pth'\n",
    "if not wgan_ckpt.exists():\n",
    "    wgan_ckpt = find_checkpoint(CKPT_DIR, pattern='*wgan_generator*.pth')\n",
    "\n",
    "print('DDPM ckpt:', ddpm_ckpt)\n",
    "print('WGAN ckpt:', wgan_ckpt)\n",
    "\n",
    "# Load class map\n",
    "VAL_ROOT = Path('../data/processed/train') if (Path.cwd().name == 'notebooks') else Path('data/processed/train')\n",
    "class_map_df = pd.read_csv(VAL_ROOT.parent / 'class_map.csv')\n",
    "class_map = list(zip(class_map_df['short'], class_map_df['display']))\n",
    "class_to_idx = {short: i for i, (short, display) in enumerate(class_map)}\n",
    "classes = class_map_df['short'].tolist()\n",
    "\n",
    "# Metric utilities (same as model_comparison)\n",
    "def channel_stats(x):\n",
    "    return x.mean(axis=(0,2)), x.std(axis=(0,2))\n",
    "\n",
    "def welch_psd(x, fs, nperseg=128):\n",
    "    if x.size==0: return np.array([]), np.array([])\n",
    "    C = x.shape[1]\n",
    "    psds = []\n",
    "    for c in range(C):\n",
    "        f, p = sps.welch(x[:,c,:], fs=fs, nperseg=min(nperseg, x.shape[-1]))\n",
    "        p = p.mean(axis=0)\n",
    "        psds.append(p)\n",
    "    return f, np.array(psds)\n",
    "\n",
    "def bandpower(f, psd, bands):\n",
    "    out = {}\n",
    "    for name, (lo, hi) in bands.items():\n",
    "        m = (f>=lo) & (f<hi)\n",
    "        if m.sum()>0:\n",
    "            out[name] = psd[:,m].sum(axis=1)\n",
    "        else:\n",
    "            out[name] = np.zeros(psd.shape[0])\n",
    "    return out\n",
    "\n",
    "def rbf_mmd(X, Y, sigmas=(10,20,40,80)):\n",
    "    if X.size==0 or Y.size==0: return np.nan\n",
    "    XX = cdist(X, X, 'euclidean')**2\n",
    "    YY = cdist(Y, Y, 'euclidean')**2\n",
    "    XY = cdist(X, Y, 'euclidean')**2\n",
    "    mmd = 0.0\n",
    "    for s in sigmas:\n",
    "        kXX = np.exp(-XX/(2*s*s))\n",
    "        kYY = np.exp(-YY/(2*s*s))\n",
    "        kXY = np.exp(-XY/(2*s*s))\n",
    "        m = kXX.mean() + kYY.mean() - 2*kXY.mean()\n",
    "        mmd += m\n",
    "    return mmd / len(sigmas)\n",
    "\n",
    "def diversity(x):\n",
    "    if x.shape[0] < 3: return np.nan\n",
    "    idx = np.random.choice(x.shape[0], size=min(64, x.shape[0]), replace=False)\n",
    "    flat = x[idx].reshape(len(idx), -1)\n",
    "    corr = np.corrcoef(flat)\n",
    "    upper = corr[np.triu_indices_from(corr, k=1)]\n",
    "    return 1 - upper.mean()\n",
    "\n",
    "def effect_size(a, b):\n",
    "    if a.size==0 or b.size==0: return np.nan\n",
    "    return (a.mean()-b.mean())/math.sqrt(0.5*(a.var()+b.var())+1e-9)\n",
    "\n",
    "FS = ddpm_cfg['data']['sample_rate'] if 'data' in ddpm_cfg else 250\n",
    "BANDS = {'delta':(0.5,4),'theta':(4,8),'alpha':(8,13),'beta':(13,30),'gamma':(30,45)}\n",
    "\n",
    "# Store results\n",
    "results = {}\n",
    "\n",
    "device = 'cuda' if (importlib.util.find_spec('torch') and __import__('torch').cuda.is_available()) else 'cpu'\n",
    "print('Device:', device)\n",
    "\n",
    "for cls in classes:\n",
    "    print(f'\\n--- Processing class: {cls} ---')\n",
    "    class_id = class_to_idx[cls]\n",
    "    display_name = class_map_df[class_map_df['short'] == cls]['display'].iloc[0]\n",
    "    class_root = VAL_ROOT / display_name\n",
    "    \n",
    "    # Load real samples for this class\n",
    "    real_samples = []\n",
    "    real_labels = []\n",
    "    extensions = {'.npy', '.npz', '.pt'}\n",
    "    if class_root.exists():\n",
    "        for f in class_root.glob('**/*'):\n",
    "            if f.suffix.lower() in extensions:\n",
    "                try:\n",
    "                    if f.suffix.lower() == '.pt':\n",
    "                        arr = __import__('torch').load(f).numpy()\n",
    "                    elif f.suffix.lower() == '.npz':\n",
    "                        npz = np.load(f)\n",
    "                        arr = npz[npz.files[0]]\n",
    "                    else:\n",
    "                        arr = np.load(f)\n",
    "                    if arr.ndim == 2:\n",
    "                        if arr.shape[0] <= 32:\n",
    "                            pass\n",
    "                        elif arr.shape[1] <= 32:\n",
    "                            arr = arr.T\n",
    "                        else:\n",
    "                            continue\n",
    "                        if arr.shape == (8, 250):\n",
    "                            real_samples.append(arr)\n",
    "                            real_labels.append(class_id)\n",
    "                except Exception:\n",
    "                    pass\n",
    "            if len(real_samples) >= REAL_LIMIT:\n",
    "                break\n",
    "    if len(real_samples) == 0:\n",
    "        print(f'No real samples for {cls}')\n",
    "        continue\n",
    "    real_samples = np.stack(real_samples)\n",
    "    real_labels = np.array(real_labels)\n",
    "    # Scale real_samples to [-1, 1] per channel immediately after loading\n",
    "    def scale_to_unit(x):\n",
    "        if isinstance(x, np.ndarray) and x.size > 0:\n",
    "            if x.ndim == 3:\n",
    "                min_c = x.min(axis=(0,2), keepdims=True)\n",
    "                max_c = x.max(axis=(0,2), keepdims=True)\n",
    "                denom = (max_c - min_c) + 1e-9\n",
    "                x = 2 * (x - min_c) / denom - 1\n",
    "            elif x.ndim == 2:\n",
    "                min_c = x.min(axis=1, keepdims=True)\n",
    "                max_c = x.max(axis=1, keepdims=True)\n",
    "                denom = (max_c - min_c) + 1e-9\n",
    "                x = 2 * (x - min_c) / denom - 1\n",
    "        return x\n",
    "    real_samples = scale_to_unit(real_samples)\n",
    "    print(f'Real samples shape: {real_samples.shape}')\n",
    "    \n",
    "    # Generate samples for this class\n",
    "    ddpm_X = np.empty((0,))\n",
    "    wgan_X = np.empty((0,))\n",
    "    \n",
    "    if DO_GENERATE_DDPM and ddpm_ckpt is not None:\n",
    "        print('Generating DDPM...')\n",
    "        ddpm_X, _ = generate_samples(ddpm_cfg, str(ddpm_ckpt), __import__('torch').device(device), 'ddpm', n=N_GEN, class_id=class_id)\n",
    "    if DO_GENERATE_WGAN and wgan_ckpt is not None:\n",
    "        print('Generating WGAN...')\n",
    "        wgan_X, _ = generate_samples(wgan_cfg, str(wgan_ckpt), __import__('torch').device(device), 'wgan_gp', n=N_GEN, class_id=class_id)\n",
    "    \n",
    "    # Compute metrics\n",
    "    R = real_samples[:512] if real_samples.shape[0] > 512 else real_samples\n",
    "    D = ddpm_X[:512] if ddpm_X.size > 0 else np.empty((0,8,250))\n",
    "    G = wgan_X[:512] if wgan_X.size > 0 else np.empty((0,8,250))\n",
    "    \n",
    "    R_flat = R.reshape(R.shape[0], -1) if R.size else np.empty((0,))\n",
    "    D_flat = D.reshape(D.shape[0], -1) if D.size else np.empty((0,))\n",
    "    G_flat = G.reshape(G.shape[0], -1) if G.size else np.empty((0,))\n",
    "    \n",
    "    # Only compute MMD if arrays are 2D and feature dimensions match\n",
    "    mmd_r_d = np.nan\n",
    "    mmd_r_g = np.nan\n",
    "    mmd_d_g = np.nan\n",
    "    if R_flat.size and D_flat.size and R_flat.ndim == 2 and D_flat.ndim == 2 and R_flat.shape[1] == D_flat.shape[1]:\n",
    "        mmd_r_d = rbf_mmd(R_flat, D_flat)\n",
    "    if R_flat.size and G_flat.size and R_flat.ndim == 2 and G_flat.ndim == 2 and R_flat.shape[1] == G_flat.shape[1]:\n",
    "        mmd_r_g = rbf_mmd(R_flat, G_flat)\n",
    "    if D_flat.size and G_flat.size and D_flat.ndim == 2 and G_flat.ndim == 2 and D_flat.shape[1] == G_flat.shape[1]:\n",
    "        mmd_d_g = rbf_mmd(D_flat, G_flat)\n",
    "    \n",
    "    fR, psdR = welch_psd(R, FS)\n",
    "    fD, psdD = welch_psd(D, FS)\n",
    "    fG, psdG = welch_psd(G, FS)\n",
    "    psdD_a = np.interp(fR, fD, psdD.mean(axis=0)) if fD.size and fR.size else np.array([])\n",
    "    psdG_a = np.interp(fR, fG, psdG.mean(axis=0)) if fG.size and fR.size else np.array([])\n",
    "    psd_err_ddpm = np.linalg.norm(psdR.mean(axis=0) - psdD_a) / (np.linalg.norm(psdR.mean(axis=0)) + 1e-9) if psdR.size and psdD_a.size else np.nan\n",
    "    psd_err_wgan = np.linalg.norm(psdR.mean(axis=0) - psdG_a) / (np.linalg.norm(psdR.mean(axis=0)) + 1e-9) if psdR.size and psdG_a.size else np.nan\n",
    "    \n",
    "    div_ddpm = diversity(D) if D.size else np.nan\n",
    "    div_wgan = diversity(G) if G.size else np.nan\n",
    "    div_real = diversity(R) if R.size else np.nan\n",
    "    \n",
    "    results[cls] = {\n",
    "        'MMD(R,DDPM)': mmd_r_d,\n",
    "        'MMD(R,WGAN)': mmd_r_g,\n",
    "        'MMD(DDPM,WGAN)': mmd_d_g,\n",
    "        'PSD L2 Error DDPM': psd_err_ddpm,\n",
    "        'PSD L2 Error WGAN': psd_err_wgan,\n",
    "        'Diversity DDPM': div_ddpm,\n",
    "        'Diversity WGAN': div_wgan,\n",
    "        'Diversity Real': div_real,\n",
    "    }\n",
    "\n",
    "# Display results\n",
    "results_df = pd.DataFrame.from_dict(results, orient='index')\n",
    "results_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b9e6e0f6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Debug: Print DDPM/WGAN sample shapes and check for empty arrays or mismatches\n",
    "print(f'Class: {cls}')\n",
    "print('  Real samples shape:', R.shape)\n",
    "print('  DDPM samples shape:', D.shape, '| Empty:', D.size == 0)\n",
    "print('  WGAN samples shape:', G.shape, '| Empty:', G.size == 0)\n",
    "if D.size and R.size:\n",
    "    print('  DDPM/Real feature dim match:', D.reshape(D.shape[0], -1).shape[1] == R.reshape(R.shape[0], -1).shape[1])\n",
    "else:\n",
    "    print('  DDPM or Real samples missing')\n",
    "if G.size and R.size:\n",
    "    print('  WGAN/Real feature dim match:', G.reshape(G.shape[0], -1).shape[1] == R.reshape(R.shape[0], -1).shape[1])\n",
    "else:\n",
    "    print('  WGAN or Real samples missing')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "df6cfffe",
   "metadata": {},
   "source": [
    "## Assemble Generated Arrays\n",
    "Load DDPM and WGAN generated arrays (downsample / upsample if lengths mismatch)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b740fb7",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define generated_files mapping: keys 'ddpm' and 'wgan' to tuple of (samples_path, labels_path)\n",
    "# Example: search for files in GEN_DIR matching expected patterns\n",
    "generated_files = {}\n",
    "for key in ['ddpm', 'wgan']:\n",
    "    sample_file = None\n",
    "    label_file = None\n",
    "    for f in GEN_DIR.glob(f'{key}*samples*.npy'):\n",
    "        sample_file = f\n",
    "        break\n",
    "    for f in GEN_DIR.glob(f'{key}*labels*.npy'):\n",
    "        label_file = f\n",
    "        break\n",
    "    if sample_file and label_file:\n",
    "        generated_files[key] = (str(sample_file), str(label_file))\n",
    "# If no files found, generated_files will be empty and load_gen will return empty arrays\n",
    "def load_gen(key):\n",
    "    if key not in generated_files: return np.empty((0,)), np.empty((0,))\n",
    "    s_path, l_path = generated_files[key]\n",
    "    X = np.load(s_path)\n",
    "    y = np.load(l_path)\n",
    "    return X, y\n",
    "ddpm_X, ddpm_y = load_gen('ddpm')\n",
    "wgan_X, wgan_y = load_gen('wgan')\n",
    "print('DDPM gen shape:', ddpm_X.shape)\n",
    "print('WGAN gen shape:', wgan_X.shape)\n",
    "# Debug: Check array existence and shape before harmonization\n",
    "print('real_samples defined:', 'real_samples' in locals())\n",
    "if 'real_samples' in locals(): print('real_samples shape:', getattr(real_samples, 'shape', 'Not ndarray'))\n",
    "print('ddpm_X defined:', 'ddpm_X' in locals())\n",
    "if 'ddpm_X' in locals(): print('ddpm_X shape:', getattr(ddpm_X, 'shape', 'Not ndarray'))\n",
    "print('wgan_X defined:', 'wgan_X' in locals())\n",
    "if 'wgan_X' in locals(): print('wgan_X shape:', getattr(wgan_X, 'shape', 'Not ndarray'))\n",
    "# Harmonize all arrays to the minimum length\n",
    "def resample_to(x, target_len):\n",
    "    if not isinstance(x, np.ndarray) or x.size == 0: return x\n",
    "    if x.shape[-1] == target_len: return x\n",
    "    return sps.resample(x, target_len, axis=-1)\n",
    "# Only include arrays that are defined and are numpy ndarrays with size > 0\n",
    "arrays = []\n",
    "for name in ['real_samples', 'ddpm_X', 'wgan_X']:\n",
    "    if name in locals():\n",
    "        arr = locals()[name]\n",
    "        if isinstance(arr, np.ndarray) and arr.size > 0:\n",
    "            arrays.append(arr)\n",
    "lengths = [arr.shape[-1] for arr in arrays]\n",
    "if lengths:\n",
    "    min_len = min(lengths)\n",
    "    if 'real_samples' in locals() and isinstance(real_samples, np.ndarray) and real_samples.size > 0:\n",
    "        real_samples = resample_to(real_samples, min_len)\n",
    "    if 'ddpm_X' in locals() and isinstance(ddpm_X, np.ndarray) and ddpm_X.size > 0:\n",
    "        ddpm_X = resample_to(ddpm_X, min_len)\n",
    "    if 'wgan_X' in locals() and isinstance(wgan_X, np.ndarray) and wgan_X.size > 0:\n",
    "        wgan_X = resample_to(wgan_X, min_len)\n",
    "    print(f\"Resampled all arrays to length {min_len}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1b83c0b8",
   "metadata": {},
   "source": [
    "## Metric Utilities\n",
    "Helper functions computing: channel-wise stats, PSD, bandpowers, RBF-MMD, coverage, diversity, and simple effect sizes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d62bbd2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def channel_stats(x):\n",
    "    # x: (N,C,L)\n",
    "    return x.mean(axis=(0,2)), x.std(axis=(0,2))\n",
    "\n",
    "def welch_psd(x, fs, nperseg=128):\n",
    "    # x: (N,C,L) -> returns freq, PSD averaged over samples\n",
    "    if x.size==0: return np.array([]), np.array([])\n",
    "    C = x.shape[1]\n",
    "    psds = []\n",
    "    for c in range(C):\n",
    "        f, p = sps.welch(x[:,c,:], fs=fs, nperseg=min(nperseg, x.shape[-1]))\n",
    "        p = p.mean(axis=0)  # average over segments if p is 2D\n",
    "        psds.append(p)\n",
    "    return f, np.array(psds)  # shape (C,F)\n",
    "\n",
    "def bandpower(f, psd, bands):\n",
    "    out = {}\n",
    "    for name, (lo, hi) in bands.items():\n",
    "        m = (f>=lo) & (f<hi)\n",
    "        if m.sum()>0:\n",
    "            out[name] = psd[:,m].sum(axis=1)\n",
    "        else:\n",
    "            out[name] = np.zeros(psd.shape[0])\n",
    "    return out\n",
    "\n",
    "def rbf_mmd(X, Y, sigmas=(10,20,40,80)):\n",
    "    # X,Y: (N,D) flattened\n",
    "    if X.size==0 or Y.size==0: return np.nan\n",
    "    XX = cdist(X, X, 'euclidean')**2\n",
    "    YY = cdist(Y, Y, 'euclidean')**2\n",
    "    XY = cdist(X, Y, 'euclidean')**2\n",
    "    mmd = 0.0\n",
    "    for s in sigmas:\n",
    "        kXX = np.exp(-XX/(2*s*s))\n",
    "        kYY = np.exp(-YY/(2*s*s))\n",
    "        kXY = np.exp(-XY/(2*s*s))\n",
    "        m = kXX.mean() + kYY.mean() - 2*kXY.mean()\n",
    "        mmd += m\n",
    "    return mmd / len(sigmas)\n",
    "\n",
    "def coverage(labels):\n",
    "    if labels.size==0: return {}\n",
    "    uniq, cnt = np.unique(labels, return_counts=True)\n",
    "    return {int(u): int(c) for u,c in zip(uniq,cnt)}\n",
    "\n",
    "def diversity(x):\n",
    "    # mean pairwise correlation across channels/time for random subset\n",
    "    if x.shape[0] < 3: return np.nan\n",
    "    idx = np.random.choice(x.shape[0], size=min(64, x.shape[0]), replace=False)\n",
    "    flat = x[idx].reshape(len(idx), -1)\n",
    "    corr = np.corrcoef(flat)\n",
    "    upper = corr[np.triu_indices_from(corr, k=1)]\n",
    "    return 1 - upper.mean()  # higher => more diversity (less correlation)\n",
    "\n",
    "def effect_size(a, b):\n",
    "    # Cohen's d between two vectors\n",
    "    if a.size==0 or b.size==0: return np.nan\n",
    "    return (a.mean()-b.mean())/math.sqrt(0.5*(a.var()+b.var())+1e-9)\n",
    "\n",
    "FS = ddpm_cfg['data']['sample_rate'] if 'data' in ddpm_cfg else 250\n",
    "BANDS = {'delta':(0.5,4),'theta':(4,8),'alpha':(8,13),'beta':(13,30),'gamma':(30,45)}"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6990008",
   "metadata": {},
   "source": [
    "## Compute Metrics\n",
    "We standardize sample counts, flatten for feature-space MMD, compute PSD & bandpower deltas, plus diversity & coverage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b64a3890",
   "metadata": {},
   "outputs": [],
   "source": [
    "# All data loading and generation is done above. Now compute metrics.\n",
    "\n",
    "# Ensure all required arrays are defined\n",
    "if 'real_labels' not in globals(): real_labels = np.array([])\n",
    "if 'ddpm_y' not in globals(): ddpm_y = np.array([])\n",
    "if 'wgan_y' not in globals(): wgan_y = np.array([])\n",
    "\n",
    "# Subsample to balance counts\n",
    "def balanced_subset(x, n=512):\n",
    "    if x.size==0: return x\n",
    "    if x.shape[0] <= n: return x\n",
    "    idx = np.random.choice(x.shape[0], n, replace=False)\n",
    "    return x[idx]\n",
    "R = balanced_subset(real_samples, 512)\n",
    "D = balanced_subset(ddpm_X, 512)\n",
    "G = balanced_subset(wgan_X, 512)\n",
    "\n",
    "# Harmonize shapes for MMD (feature dimension)\n",
    "def harmonize_shape(a, b):\n",
    "    if a.size == 0 or b.size == 0: return a, b\n",
    "    # Ensure channel and length match\n",
    "    min_channels = min(a.shape[1], b.shape[1]) if a.ndim > 2 and b.ndim > 2 else 1\n",
    "    min_length = min(a.shape[2], b.shape[2]) if a.ndim > 2 and b.ndim > 2 else 1\n",
    "    a_h = a[:, :min_channels, :min_length] if a.ndim > 2 else a\n",
    "    b_h = b[:, :min_channels, :min_length] if b.ndim > 2 else b\n",
    "    return a_h, b_h\n",
    "\n",
    "R_h, D_h = harmonize_shape(R, D)\n",
    "R_h, G_h = harmonize_shape(R, G)\n",
    "D_h, G_h2 = harmonize_shape(D, G)  # G_h2 for DDPM vs WGAN\n",
    "\n",
    "# Channel stats\n",
    "r_mu, r_sd = channel_stats(R_h) if R_h.size else (np.array([]), np.array([]))\n",
    "d_mu, d_sd = channel_stats(D_h) if D_h.size else (np.array([]), np.array([]))\n",
    "g_mu, g_sd = channel_stats(G_h) if G_h.size else (np.array([]), np.array([]))\n",
    "\n",
    "# Flatten for MMD\n",
    "R_flat = R_h.reshape(R_h.shape[0], -1) if R_h.size else np.empty((0,))\n",
    "D_flat = D_h.reshape(D_h.shape[0], -1) if D_h.size else np.empty((0,))\n",
    "G_flat = G_h.reshape(G_h.shape[0], -1) if G_h.size else np.empty((0,))\n",
    "G_flat2 = G_h2.reshape(G_h2.shape[0], -1) if G_h2.size else np.empty((0,))\n",
    "\n",
    "mmd_r_d = rbf_mmd(R_flat, D_flat) if R_flat.size and D_flat.size and R_flat.shape[1] == D_flat.shape[1] else np.nan\n",
    "mmd_r_g = rbf_mmd(R_flat, G_flat) if R_flat.size and G_flat.size and R_flat.shape[1] == G_flat.shape[1] else np.nan\n",
    "mmd_d_g = rbf_mmd(D_flat, G_flat2) if D_flat.size and G_flat2.size and D_flat.shape[1] == G_flat2.shape[1] else np.nan\n",
    "\n",
    "# PSD & bandpower\n",
    "fR, psdR = welch_psd(R_h, FS) if R_h.size else (np.array([]), np.array([]))\n",
    "fD, psdD = welch_psd(D_h, FS) if D_h.size else (np.array([]), np.array([]))\n",
    "fG, psdG = welch_psd(G_h, FS) if G_h.size else (np.array([]), np.array([]))\n",
    "def align_psd(f_ref, psd_ref, f_other, psd_other):\n",
    "    if len(f_ref)==0 or len(f_other)==0: return psd_other\n",
    "    if np.array_equal(f_ref, f_other): return psd_other\n",
    "    # simple interpolation\n",
    "    return np.vstack([np.interp(f_ref, f_other, psd_other[i, :]) for i in range(psd_other.shape[0])])\n",
    "psdD_a = align_psd(fR, psdR, fD, psdD) if len(fR)>0 and len(fD)>0 else psdD\n",
    "psdG_a = align_psd(fR, psdR, fG, psdG) if len(fR)>0 and len(fG)>0 else psdG\n",
    "# L2 difference normalized\n",
    "def psd_l2(a,b):\n",
    "    if a.size==0 or b.size==0: return np.nan\n",
    "    # Ensure shapes match before subtraction\n",
    "    if a.shape != b.shape:\n",
    "        min_shape = tuple(min(sa, sb) for sa, sb in zip(a.shape, b.shape))\n",
    "        a = a[tuple(slice(0, ms) for ms in min_shape)]\n",
    "        b = b[tuple(slice(0, ms) for ms in min_shape)]\n",
    "    return np.linalg.norm(a-b)/ (np.linalg.norm(a)+1e-9)\n",
    "psd_err_ddpm = psd_l2(psdR, psdD_a) if psdR.size and psdD_a.size else np.nan\n",
    "psd_err_wgan = psd_l2(psdR, psdG_a) if psdR.size and psdG_a.size else np.nan\n",
    "\n",
    "# Ensure frequency and PSD arrays match for bandpower\n",
    "def safe_bandpower(f, psd, bands):\n",
    "    # If frequency and PSD shapes mismatch, interpolate PSD to match frequency bins\n",
    "    if psd.size == 0 or f.size == 0: return {}\n",
    "    if psd.shape[1] != f.shape[0]:\n",
    "        # Interpolate PSD to match frequency bins\n",
    "        psd_interp = np.vstack([np.interp(f, np.linspace(f.min(), f.max(), psd.shape[1]), psd[i, :]) for i in range(psd.shape[0])])\n",
    "        return bandpower(f, psd_interp, bands)\n",
    "    else:\n",
    "        return bandpower(f, psd, bands)\n",
    "bpR = safe_bandpower(fR, psdR, BANDS) if len(fR)>0 and psdR.size else {}\n",
    "bpD = safe_bandpower(fR, psdD_a, BANDS) if len(fR)>0 and psdD_a.size else {}\n",
    "bpG = safe_bandpower(fR, psdG_a, BANDS) if len(fR)>0 and psdG_a.size else {}\n",
    "band_rows = []\n",
    "for band in BANDS.keys():\n",
    "    if band not in bpR: continue\n",
    "    row = {'band':band}\n",
    "    r = bpR[band] if band in bpR else np.zeros(1)\n",
    "    d = bpD.get(band, np.zeros_like(r))\n",
    "    g = bpG.get(band, np.zeros_like(r))\n",
    "    row['rel_err_ddpm'] = np.mean(np.abs(d-r)/(r+1e-9)) if r.size and d.size else np.nan\n",
    "    row['rel_err_wgan'] = np.mean(np.abs(g-r)/(r+1e-9)) if r.size and g.size else np.nan\n",
    "    band_rows.append(row)\n",
    "band_table = pd.DataFrame(band_rows)\n",
    "\n",
    "# Coverage & Diversity\n",
    "def coverage(labels):\n",
    "    if labels.size==0: return {}\n",
    "    uniq, cnt = np.unique(labels, return_counts=True)\n",
    "    return {int(u): int(c) for u,c in zip(uniq,cnt)}\n",
    "\n",
    "cov_real = coverage(real_labels) if real_labels.size else {}\n",
    "cov_ddpm = coverage(ddpm_y) if ddpm_y.size else {}\n",
    "cov_wgan = coverage(wgan_y) if wgan_y.size else {}\n",
    "div_real = diversity(R_h) if R_h.size else np.nan\n",
    "div_ddpm = diversity(D_h) if D_h.size else np.nan\n",
    "div_wgan = diversity(G_h) if G_h.size else np.nan\n",
    "\n",
    "metric_rows = [\n",
    "    {'metric':'MMD(R,DDPM)','ddpm':mmd_r_d,'wgan':np.nan},\n",
    "    {'metric':'MMD(R,WGAN)','ddpm':np.nan,'wgan':mmd_r_g},\n",
    "    {'metric':'MMD(DDPM,WGAN)','ddpm':mmd_d_g,'wgan':mmd_d_g},\n",
    "    {'metric':'PSD L2 Error','ddpm':psd_err_ddpm,'wgan':psd_err_wgan},\n",
    "    {'metric':'Diversity (1-mean corr)','ddpm':div_ddpm,'wgan':div_wgan, 'real':div_real},\n",
    " ]\n",
    "metrics_table = pd.DataFrame(metric_rows)\n",
    "metrics_table"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c562ebc0",
   "metadata": {},
   "source": [
    "### Bandpower Relative Error\n",
    "Lower is better (closer to real distribution)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c3d3000a",
   "metadata": {},
   "outputs": [],
   "source": [
    "band_table"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8504ebf",
   "metadata": {},
   "source": [
    "## Channel Mean/Std Comparison\n",
    "Effect sizes (Cohen's d) for per-channel aggregated amplitude distribution differences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "028ecc16",
   "metadata": {},
   "outputs": [],
   "source": [
    "chan_rows = []\n",
    "for i in range(len(r_mu)) if r_mu.size else []:\n",
    "    row = {'channel': i}\n",
    "    if d_mu.size: row['d_mu_diff'] = d_mu[i]-r_mu[i]; row['d_mean_effect'] = effect_size(d_mu[i:i+1], r_mu[i:i+1])\n",
    "    if g_mu.size: row['g_mu_diff'] = g_mu[i]-r_mu[i]; row['g_mean_effect'] = effect_size(g_mu[i:i+1], r_mu[i:i+1])\n",
    "    chan_rows.append(row)\n",
    "channel_table = pd.DataFrame(chan_rows)\n",
    "channel_table.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bb0f1500",
   "metadata": {},
   "source": [
    "## Visualization: Example Windows\n",
    "Overlay real vs synthetic windows for a random class + channel."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bc724c2",
   "metadata": {},
   "source": [
    "## Visualization: Example Windows (Channel 1)\n",
    "Overlay real vs synthetic windows for channel 1.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9923a860",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_examples_multi(channels=[0,1,2,3,4], n=5, seed=0):\n",
    "    random.seed(seed); np.random.seed(seed)\n",
    "    fig, axes = plt.subplots(len(channels), n, figsize=(3*n, 2*len(channels)), sharex=True)\n",
    "    if len(channels)==1: axes = [axes]\n",
    "    for ci, channel in enumerate(channels):\n",
    "        for i in range(n):\n",
    "            axi = axes[ci][i] if len(channels)>1 else axes[i]\n",
    "            if R.size: axi.plot(R[i%len(R), channel], label='Real', color='black', lw=1)\n",
    "            if D.size: axi.plot(D[i%len(D), channel], label='DDPM', alpha=0.8)\n",
    "            if G.size: axi.plot(G[i%len(G), channel], label='WGAN', alpha=0.8)\n",
    "            if i==0: axi.legend(frameon=False)\n",
    "            axi.set_ylabel(f'Ch {channel}, Sample {i}')\n",
    "        axes[ci][-1].set_xlabel('Time (samples)')\n",
    "    fig.suptitle('Example Windows (Channels 0-4)')\n",
    "    plt.tight_layout()\n",
    "    out = FIG_DIR / 'examples_windows_multi.png'\n",
    "    fig.savefig(out, dpi=200)\n",
    "    print('Saved', out)\n",
    "    plt.show()\n",
    "plot_examples_multi(channels=[0,1,2,3,4], n=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3ee6b600",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Debug: Plot several individual real samples for a single channel to check for diversity and flatness\n",
    "def plot_individual_real_samples(channel=0, n=10):\n",
    "    if R.size == 0:\n",
    "        print('No real samples loaded.')\n",
    "        return\n",
    "    fig, axes = plt.subplots(n, 1, figsize=(10, 2*n), sharex=True)\n",
    "    if n == 1: axes = [axes]\n",
    "    for i in range(n):\n",
    "        axi = axes[i]\n",
    "        axi.plot(R[i % len(R), channel], color='black', lw=1)\n",
    "        axi.set_ylabel(f'Sample {i}')\n",
    "    axes[-1].set_xlabel('Time (samples)')\n",
    "    fig.suptitle(f'Individual Real Samples (Channel {channel})')\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "plot_individual_real_samples(channel=0, n=10)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b958019",
   "metadata": {},
   "source": [
    "## Visualization: Example Windows (Channels 0-4 Side by Side)\n",
    "Compare real vs synthetic windows for channels 0-4 in a single figure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba4a6ae9",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_examples(cls=None, channel=0, n=5, seed=0):\n",
    "    random.seed(seed); np.random.seed(seed)\n",
    "    fig, ax = plt.subplots(n, 1, figsize=(10, 2*n), sharex=True)\n",
    "    if n==1: ax=[ax]\n",
    "    for i in range(n):\n",
    "        axi = ax[i]\n",
    "        if R.size: axi.plot(R[i%len(R), channel], label='Real', color='black', lw=1)\n",
    "        if D.size: axi.plot(D[i%len(D), channel], label='DDPM', alpha=0.8)\n",
    "        if G.size: axi.plot(G[i%len(G), channel], label='WGAN', alpha=0.8)\n",
    "        if i==0: axi.legend(frameon=False)\n",
    "        axi.set_ylabel(f'Sample {i}')\n",
    "    ax[-1].set_xlabel('Time (samples)')\n",
    "    fig.suptitle('Example Windows (Channel %d)' % channel)\n",
    "    plt.tight_layout()\n",
    "    out = FIG_DIR / 'examples_windows.png'\n",
    "    fig.savefig(out, dpi=200)\n",
    "    print('Saved', out)\n",
    "    plt.show()\n",
    "plot_examples(channel=4, n=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7db7fc46",
   "metadata": {},
   "source": [
    "## Visualization: Example Windows (Channel 4)\n",
    "Overlay real vs synthetic windows for channel 4."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3fa0fe1a",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_examples(channel=3, n=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76b7ad3f",
   "metadata": {},
   "source": [
    "## Visualization: Example Windows (Channel 3)\n",
    "Overlay real vs synthetic windows for channel 3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9b19e274",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_examples(channel=2, n=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4284b3b6",
   "metadata": {},
   "source": [
    "## Visualization: Example Windows (Channel 2)\n",
    "Overlay real vs synthetic windows for channel 2."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29816697",
   "metadata": {},
   "outputs": [],
   "source": [
    "plot_examples(channel=1, n=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2d86ca93",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_examples(cls=None, channel=0, n=5, seed=0):\n",
    "    random.seed(seed); np.random.seed(seed)\n",
    "    fig, ax = plt.subplots(n, 1, figsize=(10, 2*n), sharex=True)\n",
    "    if n==1: ax=[ax]\n",
    "    for i in range(n):\n",
    "        axi = ax[i]\n",
    "        if R.size: axi.plot(R[i%len(R), channel], label='Real', color='black', lw=1)\n",
    "        if D.size: axi.plot(D[i%len(D), channel], label='DDPM', alpha=0.8)\n",
    "        if G.size: axi.plot(G[i%len(G), channel], label='WGAN', alpha=0.8)\n",
    "        if i==0: axi.legend(frameon=False)\n",
    "        axi.set_ylabel(f'Sample {i}')\n",
    "    ax[-1].set_xlabel('Time (samples)')\n",
    "    fig.suptitle('Example Windows (Channel %d)' % channel)\n",
    "    plt.tight_layout()\n",
    "    out = FIG_DIR / 'examples_windows.png'\n",
    "    fig.savefig(out, dpi=200)\n",
    "    print('Saved', out)\n",
    "    plt.show()\n",
    "plot_examples(channel=0, n=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a751e99",
   "metadata": {},
   "source": [
    "## Visualization: PSD Comparison\n",
    "Average channel-aggregated PSD overlay with log-power scale."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f96c3bf4",
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_psd():\n",
    "    if len(fR)==0: return\n",
    "    fig, ax = plt.subplots(1,1, figsize=(8,5))\n",
    "    ax.plot(fR, psdR.mean(axis=0), label='Real', color='black', lw=2)\n",
    "    if psdD_a.size: ax.plot(fR, psdD_a.mean(axis=0), label='DDPM')\n",
    "    if psdG_a.size: ax.plot(fR, psdG_a.mean(axis=0), label='WGAN')\n",
    "    ax.set_xscale('log'); ax.set_yscale('log')\n",
    "    ax.set_xlabel('Frequency (Hz)')\n",
    "    ax.set_ylabel('Power')\n",
    "    ax.legend(frameon=False)\n",
    "    fig.tight_layout()\n",
    "    out = FIG_DIR / 'psd_comparison.png'\n",
    "    fig.savefig(out, dpi=200)\n",
    "    print('Saved', out)\n",
    "    plt.show()\n",
    "plot_psd()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f23ee514",
   "metadata": {},
   "source": [
    "## Embedding Visualization (t-SNE / UMAP)\n",
    "We project flattened windows. (For publication quality, replace with encoder feature embeddings.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faf4f455",
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_embedding(R, D, G, max_per=300):\n",
    "    data = []; labels=[]; origin=[]\n",
    "    # Find minimum feature dimension across all non-empty blocks\n",
    "    shapes = [block.shape for block in [R, D, G] if block.size > 0]\n",
    "    if not shapes: return None, None\n",
    "    min_feat = min([s[1]*s[2] if len(s) == 3 else s[-1] for s in shapes])\n",
    "    def add(block, name):\n",
    "        if block.size==0: return\n",
    "        k = min(block.shape[0], max_per)\n",
    "        idx = np.random.choice(block.shape[0], k, replace=False)\n",
    "        flat = block[idx].reshape(k, -1)\n",
    "        # Truncate to min_feat columns\n",
    "        if flat.shape[1] > min_feat:\n",
    "            flat = flat[:, :min_feat]\n",
    "        elif flat.shape[1] < min_feat:\n",
    "            # Pad with zeros if needed\n",
    "            pad = np.zeros((k, min_feat-flat.shape[1]))\n",
    "            flat = np.hstack([flat, pad])\n",
    "        data.append(flat)\n",
    "        labels.extend([name]*k)\n",
    "    add(R,'Real'); add(D,'DDPM'); add(G,'WGAN')\n",
    "    if not data: return None, None\n",
    "    X = np.vstack(data)\n",
    "    return X, labels\n",
    "Xemb, lab_emb = build_embedding(R, D, G)\n",
    "if Xemb is not None:\n",
    "    tsne = TSNE(n_components=2, init='pca', random_state=42, perplexity=min(30, max(5, Xemb.shape[0]//10)))\n",
    "    X2 = tsne.fit_transform(Xemb)\n",
    "    fig, ax = plt.subplots(figsize=(6,6))\n",
    "    df_emb = pd.DataFrame({'x':X2[:,0],'y':X2[:,1],'origin':lab_emb})\n",
    "    sns.scatterplot(data=df_emb, x='x', y='y', hue='origin', s=20, ax=ax)\n",
    "    ax.set_title('t-SNE (Raw Flattened)')\n",
    "    fig.tight_layout()\n",
    "    out = FIG_DIR / 'embedding_tsne.png'\n",
    "    fig.savefig(out, dpi=200)\n",
    "    print('Saved', out)\n",
    "    plt.show()\n",
    "    if HAVE_UMAP:\n",
    "        reducer = umap.UMAP(random_state=42, n_neighbors=15, min_dist=0.2)\n",
    "        X3 = reducer.fit_transform(Xemb)\n",
    "        fig2, ax2 = plt.subplots(figsize=(6,6))\n",
    "        df_emb2 = pd.DataFrame({'x':X3[:,0],'y':X3[:,1],'origin':lab_emb})\n",
    "        sns.scatterplot(data=df_emb2, x='x', y='y', hue='origin', s=20, ax=ax2)\n",
    "        ax2.set_title('UMAP (Raw Flattened)')\n",
    "        fig2.tight_layout()\n",
    "        out2 = FIG_DIR / 'embedding_umap.png'\n",
    "        fig2.savefig(out2, dpi=200)\n",
    "        print('Saved', out2)\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1e75ad95",
   "metadata": {},
   "source": [
    "## Coverage Comparison\n",
    "Class count distribution (raw counts)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e971b57c",
   "metadata": {},
   "outputs": [],
   "source": [
    "cov_df = pd.DataFrame({'class':sorted(set(list(cov_real.keys())+list(cov_ddpm.keys())+list(cov_wgan.keys())))} )\n",
    "cov_df['real'] = cov_df['class'].map(cov_real).fillna(0).astype(int)\n",
    "cov_df['ddpm'] = cov_df['class'].map(cov_ddpm).fillna(0).astype(int)\n",
    "cov_df['wgan'] = cov_df['class'].map(cov_wgan).fillna(0).astype(int)\n",
    "cov_df"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a6c43f0",
   "metadata": {},
   "source": [
    "### Coverage Plot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11066376",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not cov_df.empty:\n",
    "    fig, ax = plt.subplots(figsize=(20,10))  # Much larger plot for maximum visibility\n",
    "    covm = cov_df.melt(id_vars='class', value_vars=['real','ddpm','wgan'], var_name='source', value_name='count')\n",
    "    sns.barplot(data=covm, x='class', y='count', hue='source', ax=ax)\n",
    "    ax.set_title('Class Coverage', fontsize=24)\n",
    "    ax.tick_params(axis='x', labelsize=16)\n",
    "    ax.tick_params(axis='y', labelsize=16)\n",
    "    ax.legend(fontsize=16)\n",
    "    fig.tight_layout()\n",
    "    out = FIG_DIR / 'coverage_bar.png'\n",
    "    fig.savefig(out, dpi=200)\n",
    "    print('Saved', out)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "760103f9",
   "metadata": {},
   "source": [
    "## Publication Tables (LaTeX Export Helpers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ac0ba5da",
   "metadata": {},
   "outputs": [],
   "source": [
    "def to_latex(df, fname):\n",
    "    out = FIG_DIR / fname\n",
    "    with open(out, 'w', encoding='utf-8') as f: f.write(df.to_latex(index=False, float_format=lambda x: f'{x:.3g}'))\n",
    "    print('Wrote', out)\n",
    "to_latex(metrics_table, 'table_metrics.tex')\n",
    "to_latex(band_table, 'table_bandpower.tex')\n",
    "to_latex(channel_table.head(12), 'table_channel_effects.tex')\n",
    "to_latex(cov_df, 'table_coverage.tex')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e86b64cb",
   "metadata": {},
   "source": [
    "## Next Steps / Extensions\n",
    "- Integrate encoder-based feature extraction (CNN or self-supervised) to compute FID/KID & Precision/Recall.\\n\n",
    "- Add functional metrics: Train classifier on real -> test synthetic (TRTS) & train synthetic -> test real (TSTR).\\n\n",
    "- Add conditional guidance strength sweeps for DDPM to analyze class fidelity vs diversity.\\n\n",
    "- Statistical tests (e.g., Wilcoxon) on per-channel or band distributions.\\n\n",
    "- Panel assembly script to combine exported PNGs into multi-part figure."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87344167",
   "metadata": {},
   "source": [
    "## Summary\n",
    "This table shows per-class metrics. Lower MMD and PSD error indicate better performance for that class."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf644c1c",
   "metadata": {},
   "source": [
    "## Per-Class Results Table\n",
    "\n",
    "Below is the summary table showing per-class metrics for DDPM and WGAN-GP models. Lower MMD and PSD error indicate better performance for that class. Use this table to compare how well each model captures the distribution of real EEG artifact windows across different classes."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
