{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "deaaee72",
   "metadata": {},
   "source": [
    "# ArtifactGEN Exploration\n",
    "Visualize real vs synthetic EEG windows; update paths as needed."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f1940e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "# TUAR quick exploration utilities\n",
    "import os, csv, json, re\n",
    "from pathlib import Path\n",
    "from collections import Counter, defaultdict\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import tqdm\n",
    "\n",
    "# Configure your TUAR base directory here\n",
    "TUAR_ROOT = Path(\"C:/works/TUAR\").resolve()\n",
    "print(\"TUAR root:\", TUAR_ROOT)\n",
    "\n",
    "# Setup for saving figures\n",
    "figs_dir = Path('paper/figs')\n",
    "figs_dir.mkdir(parents=True, exist_ok=True)\n",
    "print(\"Figures will be saved to:\", figs_dir)\n",
    "\n",
    "# 1) List channel configuration subfolders under edf/\n",
    "edf_dir = TUAR_ROOT / \"edf\"\n",
    "subdirs = sorted([p for p in edf_dir.glob(\"*\") if p.is_dir()])\n",
    "print(\"edf subdirs:\")\n",
    "for s in subdirs:\n",
    "    print(\" -\", s.name)\n",
    "\n",
    "# 2) Enumerate EDF files and paired CSVs\n",
    "records = []\n",
    "for sub in subdirs:\n",
    "    for edf_path in sub.glob(\"*.edf\"):\n",
    "        stem = edf_path.stem  # e.g., aaaaaaju_s005_t000\n",
    "        csv_path = edf_path.with_suffix(\".csv\")\n",
    "        seiz_csv = edf_path.with_name(edf_path.stem + \"_seiz.csv\")\n",
    "        records.append({\n",
    "            \"subdir\": sub.name,\n",
    "            \"edf\": edf_path,\n",
    "            \"csv\": csv_path if csv_path.exists() else None,\n",
    "            \"seiz\": seiz_csv if seiz_csv.exists() else None,\n",
    "            \"subject\": stem.split(\"_\")[0],\n",
    "            \"session\": stem.split(\"_\")[1] if \"_\" in stem else None,\n",
    "            \"token\": stem.split(\"_\")[-1] if \"_\" in stem else None,\n",
    "        })\n",
    "\n",
    "print(f\"Found {len(records)} EDF segments\")\n",
    "\n",
    "# 3) Parse CSV headers for metadata and collect event rows for a sample\n",
    "header_keys = [\"version\", \"bname\", \"duration\", \"montage_file\", \"annotation_label_file\"]\n",
    "\n",
    "\n",
    "def read_annotation_csv(csv_path: Path):\n",
    "    meta = {k: None for k in header_keys}\n",
    "    events = []\n",
    "    if csv_path is None or not csv_path.exists():\n",
    "        return meta, events\n",
    "    with open(csv_path, \"r\", encoding=\"utf-8\") as f:\n",
    "        reader = csv.reader(f)\n",
    "        for row in reader:\n",
    "            if not row:\n",
    "                continue\n",
    "            if row[0].startswith(\"#\"):\n",
    "                # Header lines: format like '# key = value'\n",
    "                line = \",\".join(row)\n",
    "                m = re.match(r\"#\\s*([a-z_]+)\\s*=\\s*(.*)\", line, re.IGNORECASE)\n",
    "                if m:\n",
    "                    k, v = m.group(1).strip(), m.group(2).strip()\n",
    "                    if k in meta:\n",
    "                        meta[k] = v\n",
    "                continue\n",
    "            if row[0].lower() == \"channel\":\n",
    "                # column header\n",
    "                continue\n",
    "            # Data lines: channel,start_time,stop_time,label,confidence\n",
    "            try:\n",
    "                ch, start, stop, label, conf = row[:5]\n",
    "                events.append({\n",
    "                    \"channel\": ch,\n",
    "                    \"start\": float(start),\n",
    "                    \"stop\": float(stop),\n",
    "                    \"label\": label,\n",
    "                    \"confidence\": float(conf),\n",
    "                })\n",
    "            except Exception:\n",
    "                continue\n",
    "    return meta, events\n",
    "\n",
    "# Sample N files for event statistics\n",
    "N = min(50, len(records))\n",
    "sample_records = records[:N]\n",
    "label_counts = Counter()\n",
    "channel_counts = Counter()\n",
    "durations = []\n",
    "montage_files = Counter()\n",
    "have_seiz = 0\n",
    "\n",
    "for r in tqdm.tqdm(sample_records, desc=\"Processing sample records\"):\n",
    "    meta, events = read_annotation_csv(r[\"csv\"])\n",
    "    if meta.get(\"montage_file\"):\n",
    "        montage_files[Path(meta[\"montage_file\"]).name] += 1\n",
    "    if r[\"seiz\"] is not None:\n",
    "        have_seiz += 1\n",
    "    for e in events:\n",
    "        label_counts[e[\"label\"]] += 1\n",
    "        channel_counts[e[\"channel\"]] += 1\n",
    "        durations.append(e[\"stop\"] - e[\"start\"])\n",
    "\n",
    "print(\"Sample summary (first\", N, \"files):\")\n",
    "print(\" - Unique labels:\", len(label_counts))\n",
    "print(\" - Top-10 labels:\", label_counts.most_common(10))\n",
    "print(\" - Unique channels:\", len(channel_counts))\n",
    "print(\" - Top-10 channels:\", channel_counts.most_common(10))\n",
    "print(\" - Events with durations: \", len(durations))\n",
    "if durations:\n",
    "    arr = np.array(durations)\n",
    "    print(f\"   duration mean={arr.mean():.2f}s, median={np.median(arr):.2f}s, p95={np.percentile(arr,95):.2f}s\")\n",
    "print(\" - Montage files (top):\", montage_files.most_common(5))\n",
    "print(\" - Files with seizure CSV:\", have_seiz, \"/\", N)\n",
    "\n",
    "# 4) Plot label histogram for the sample\n",
    "plt.figure(figsize=(10,4))\n",
    "labels, counts = zip(*label_counts.most_common(15)) if label_counts else ([], [])\n",
    "plt.bar(labels, counts)\n",
    "plt.title(\"Top-15 labels (sample)\")\n",
    "plt.xticks(rotation=45, ha='right')\n",
    "plt.tight_layout()\n",
    "plt.savefig(figs_dir / 'top15_labels_sample.png')\n",
    "plt.show()\n",
    "\n",
    "# 5) Estimate per-channel event frequency\n",
    "plt.figure(figsize=(10,4))\n",
    "chs, ccounts = zip(*channel_counts.most_common(20)) if channel_counts else ([], [])\n",
    "plt.bar(chs, ccounts)\n",
    "plt.title(\"Top-20 channels with events (sample)\")\n",
    "plt.xticks(rotation=45, ha='right')\n",
    "plt.tight_layout()\n",
    "plt.savefig(figs_dir / 'top20_channels_sample.png')\n",
    "plt.show()\n",
    "\n",
    "# 6) Sketch subject-wise split feasibility (unique subjects in sample)\n",
    "subjects = {r['subject'] for r in sample_records}\n",
    "print(\"Unique subjects in sample:\", len(subjects))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1b3857e2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Deep dive into one annotation file\n",
    "from pprint import pprint\n",
    "\n",
    "# Pick first record that has annotations\n",
    "one = next((r for r in records if r['csv'] is not None and r['csv'].exists()), None)\n",
    "print(\"Example file:\", one)\n",
    "meta, events = read_annotation_csv(one['csv']) if one else ({}, [])\n",
    "print(\"Metadata:\")\n",
    "pprint(meta)\n",
    "\n",
    "# Per-label duration stats\n",
    "by_label = defaultdict(list)\n",
    "for e in events:\n",
    "    by_label[e['label']].append(e['stop'] - e['start'])\n",
    "\n",
    "label_stats = {\n",
    "    k: {\n",
    "        'n': len(v),\n",
    "        'mean_s': float(np.mean(v)) if v else 0.0,\n",
    "        'p50_s': float(np.median(v)) if v else 0.0,\n",
    "        'p95_s': float(np.percentile(v, 95)) if v else 0.0,\n",
    "        'max_s': float(np.max(v)) if v else 0.0,\n",
    "    }\n",
    "    for k, v in by_label.items()\n",
    "}\n",
    "print(\"Per-label duration stats (example file):\")\n",
    "pprint(label_stats)\n",
    "\n",
    "# Overlap rate estimation per channel\n",
    "# Sort events by channel and start time\n",
    "from itertools import groupby\n",
    "\n",
    "def overlap_rate(evts):\n",
    "    if not evts:\n",
    "        return 0.0\n",
    "    evts = sorted(evts, key=lambda e: e['start'])\n",
    "    overlaps = 0\n",
    "    prev_end = -1.0\n",
    "    for e in evts:\n",
    "        if e['start'] < prev_end:\n",
    "            overlaps += 1\n",
    "        prev_end = max(prev_end, e['stop'])\n",
    "    return overlaps / len(evts)\n",
    "\n",
    "ov_by_channel = {}\n",
    "for ch, group in groupby(sorted(events, key=lambda e: (e['channel'], e['start'])), key=lambda e: e['channel']):\n",
    "    evts = list(group)\n",
    "    ov_by_channel[ch] = overlap_rate(evts)\n",
    "\n",
    "print(\"Overlap rate by channel (example file, first 10):\")\n",
    "print(dict(list(ov_by_channel.items())[:10]))\n",
    "\n",
    "# Recommend window length based on p95 of durations\n",
    "all_durations = [e['stop'] - e['start'] for e in events]\n",
    "if all_durations:\n",
    "    p95 = float(np.percentile(all_durations, 95))\n",
    "    recommended = 2.0 if p95 > 1.0 else 1.0\n",
    "    print(f\"Recommended window length ≈ {recommended}s (p95={p95:.2f}s)\")\n",
    "else:\n",
    "    print(\"No events to analyze for window length.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09c0ce05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Map TUAR short labels to canonical classes for our project\n",
    "# Use centralized merge_map so notebook visuals match preprocessing\n",
    "from src.merge_map import remap_label\n",
    "# label_map maps short keys to nicer display names (post-remap)\n",
    "# Keep exactly the canonical 5 classes used for training\n",
    "label_map = {\n",
    "    'musc': 'Muscle',\n",
    "    'eyem': 'Eye movement',\n",
    "    'elec': 'Electrode',\n",
    "    'chew': 'Chewing',\n",
    "    'shiv': 'Shiver',\n",
    "}\n",
    "\n",
    "remapped = Counter()\n",
    "# apply remap_label first, then map to display names\n",
    "for k, v in label_counts.items():\n",
    "    mk = remap_label(k)\n",
    "    remapped[label_map.get(mk, mk)] += v\n",
    "\n",
    "plt.figure(figsize=(8,4))\n",
    "if remapped:\n",
    "    labs, cnts = zip(*remapped.most_common())\n",
    "    plt.bar(labs, cnts)\n",
    "    plt.xticks(rotation=45, ha='right')\n",
    "    plt.title('Remapped label distribution (sample)')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(figs_dir / 'remapped_labels_sample.png')\n",
    "else:\n",
    "    print('No labels found in sample')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ce9b905",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Duration distribution and suggested 8-channel set\n",
    "if durations:\n",
    "    plt.figure(figsize=(8,4))\n",
    "    plt.hist(durations, bins=50)\n",
    "    plt.title('Event duration distribution (sample)')\n",
    "    plt.xlabel('seconds')\n",
    "    plt.ylabel('count')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(figs_dir / 'duration_distribution_sample.png')\n",
    "    plt.show()\n",
    "\n",
    "# Suggest canonical 8-channel set present in TUAR labels\n",
    "candidate_channels = ['Fp1-F7','Fp2-F8','C3-P3','C4-P4','O1-O2','T3-T5','T4-T6','Fz-Cz']\n",
    "available = [ch for ch in candidate_channels if ch in channel_counts]\n",
    "print('Suggested 8-channel montage subset (present in sample):', available)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "90c7ff4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Feature engineering for visualization (per file sample): label-wise median duration and channel diversity\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "from sklearn.manifold import TSNE\n",
    "import umap\n",
    "\n",
    "# Build per-file summary vectors\n",
    "summ_rows = []\n",
    "for r in tqdm.tqdm(sample_records, desc=\"Building summary vectors\"):\n",
    "    meta, events = read_annotation_csv(r['csv'])\n",
    "    if not events:\n",
    "        continue\n",
    "    # compute: label counts (subset) and median durations\n",
    "    lc = Counter([e['label'] for e in events])\n",
    "    dur_by_label = defaultdict(list)\n",
    "    ch_by_label = defaultdict(set)\n",
    "    for e in events:\n",
    "        dur_by_label[e['label']].append(e['stop'] - e['start'])\n",
    "        ch_by_label[e['label']].add(e['channel'])\n",
    "    labels_ref = [k for k,_ in label_counts.most_common(6)]  # top-6 labels in sample\n",
    "    feat = []\n",
    "    for lab in labels_ref:\n",
    "        feat.append(lc.get(lab, 0))\n",
    "        durs = dur_by_label.get(lab, [])\n",
    "        feat.append(float(np.median(durs)) if durs else 0.0)\n",
    "        feat.append(len(ch_by_label.get(lab, set())))\n",
    "    summ_rows.append({\n",
    "        'subject': r['subject'],\n",
    "        'file': str(r['edf'].name),\n",
    "        'features': np.array(feat, dtype=float),\n",
    "        'dominant_label': max(lc.items(), key=lambda kv: kv[1])[0] if lc else 'none'\n",
    "    })\n",
    "\n",
    "if summ_rows:\n",
    "    X = np.stack([row['features'] for row in summ_rows])\n",
    "    y = [row['dominant_label'] for row in summ_rows]\n",
    "    Xn = StandardScaler().fit_transform(X)\n",
    "\n",
    "    # UMAP\n",
    "    reducer = umap.UMAP(n_neighbors=15, min_dist=0.1, random_state=42)\n",
    "    emb_umap = reducer.fit_transform(Xn)\n",
    "    plt.figure(figsize=(6,5))\n",
    "    for lab in set(y):\n",
    "        idx = [i for i,yy in enumerate(y) if yy==lab]\n",
    "        plt.scatter(emb_umap[idx,0], emb_umap[idx,1], s=14, label=lab, alpha=0.7)\n",
    "    plt.title('UMAP of per-file artifact summaries')\n",
    "    plt.legend(bbox_to_anchor=(1.05,1), loc='upper left')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(figs_dir / 'umap_perfile_summaries.png')\n",
    "    plt.show()\n",
    "\n",
    "    # t-SNE\n",
    "    tsne = TSNE(n_components=2, learning_rate='auto', init='random', perplexity=30, random_state=42)\n",
    "    emb_tsne = tsne.fit_transform(Xn)\n",
    "    plt.figure(figsize=(6,5))\n",
    "    for lab in set(y):\n",
    "        idx = [i for i,yy in enumerate(y) if yy==lab]\n",
    "        plt.scatter(emb_tsne[idx,0], emb_tsne[idx,1], s=14, label=lab, alpha=0.7)\n",
    "    plt.title('t-SNE of per-file artifact summaries')\n",
    "    plt.legend(bbox_to_anchor=(1.05,1), loc='upper left')\n",
    "    plt.tight_layout()\n",
    "    plt.savefig(figs_dir / 'tsne_perfile_summaries.png')\n",
    "    plt.show()\n",
    "else:\n",
    "    print('No summary rows to embed')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f10ef00c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Subject/session analysis and split suggestion\n",
    "# Count files per subject and dominant label distribution per subject\n",
    "subj_files = defaultdict(list)\n",
    "subj_label_counts = defaultdict(Counter)\n",
    "for r in tqdm.tqdm(records, desc=\"Processing all records\"):\n",
    "    meta, events = read_annotation_csv(r['csv'])\n",
    "    if not events:\n",
    "        continue\n",
    "    lc = Counter([e['label'] for e in events])\n",
    "    dom = max(lc.items(), key=lambda kv: kv[1])[0]\n",
    "    subj_files[r['subject']].append(r)\n",
    "    subj_label_counts[r['subject']][dom] += 1\n",
    "\n",
    "n_subjects = len(subj_files)\n",
    "print('Subjects with annotations:', n_subjects)\n",
    "\n",
    "# Heuristic: subject-wise split 70/15/15 stratified by overall dominant label\n",
    "subjects_sorted = sorted(subj_files.keys(), key=lambda s: sum(subj_label_counts[s].values()), reverse=True)\n",
    "\n",
    "def stratified_subject_split(subjects, frac=(0.7, 0.15, 0.15)):\n",
    "    # Bucket subjects by their dominant label\n",
    "    buckets = defaultdict(list)\n",
    "    for s in subjects:\n",
    "        if subj_label_counts[s]:\n",
    "            dom = max(subj_label_counts[s].items(), key=lambda kv: kv[1])[0]\n",
    "        else:\n",
    "            dom = 'none'\n",
    "        buckets[dom].append(s)\n",
    "    train, val, test = [], [], []\n",
    "    for dom, subs in buckets.items():\n",
    "        k = len(subs)\n",
    "        t = int(round(frac[0]*k))\n",
    "        v = int(round(frac[1]*k))\n",
    "        # ensure all accounted\n",
    "        tr, va, te = subs[:t], subs[t:t+v], subs[t+v:]\n",
    "        train += tr\n",
    "        val += va\n",
    "        test += te\n",
    "    return train, val, test\n",
    "\n",
    "train_subj, val_subj, test_subj = stratified_subject_split(subjects_sorted)\n",
    "print('Train subjects:', len(train_subj), 'Val subjects:', len(val_subj), 'Test subjects:', len(test_subj))\n",
    "\n",
    "# Show label distributions per split\n",
    "split_label_counts = {'train': Counter(), 'val': Counter(), 'test': Counter()}\n",
    "for split, subs in [('train', train_subj), ('val', val_subj), ('test', test_subj)]:\n",
    "    for s in subs:\n",
    "        split_label_counts[split] += subj_label_counts[s]\n",
    "\n",
    "print('Split label distributions (by dominant label):')\n",
    "for k, v in split_label_counts.items():\n",
    "    print(k, dict(v.most_common()))\n",
    "\n",
    "# Save suggested CSV files in data/processed for downstream use (optional)\n",
    "import pandas as pd\n",
    "\n",
    "splits_rows = []\n",
    "for s in train_subj:\n",
    "    splits_rows.append({'subject_id': s, 'split': 'train'})\n",
    "for s in val_subj:\n",
    "    splits_rows.append({'subject_id': s, 'split': 'val'})\n",
    "for s in test_subj:\n",
    "    splits_rows.append({'subject_id': s, 'split': 'test'})\n",
    "\n",
    "splits_df = pd.DataFrame(splits_rows).drop_duplicates()\n",
    "splits_out = Path('data/processed/suggested_splits_subjectwise.csv')\n",
    "splits_out.parent.mkdir(parents=True, exist_ok=True)\n",
    "splits_df.to_csv(splits_out, index=False)\n",
    "print('Wrote suggested splits to', splits_out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c29a173c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualizations for simple stratified split data spread\n",
    "\n",
    "# Compute per-split statistics for simple split\n",
    "split_stats_simple = {'train': {}, 'val': {}, 'test': {}}\n",
    "\n",
    "for split_name, subj_list in [('train', train_subj), ('val', val_subj), ('test', test_subj)]:\n",
    "    split_records = [r for r in records if r['subject'] in subj_list]\n",
    "    label_counts_split = Counter()\n",
    "    channel_counts_split = Counter()\n",
    "    durations_split = []\n",
    "    files_count = len(split_records)\n",
    "    subjects_count = len(subj_list)\n",
    "    \n",
    "    for r in split_records:\n",
    "        meta, events = read_annotation_csv(r['csv'])\n",
    "        for e in events:\n",
    "            label_counts_split[e['label']] += 1\n",
    "            channel_counts_split[e['channel']] += 1\n",
    "            durations_split.append(e['stop'] - e['start'])\n",
    "    \n",
    "    split_stats_simple[split_name] = {\n",
    "        'label_counts': label_counts_split,\n",
    "        'channel_counts': channel_counts_split,\n",
    "        'durations': durations_split,\n",
    "        'files': files_count,\n",
    "        'subjects': subjects_count\n",
    "    }\n",
    "\n",
    "# 1) Split sizes: subjects and files\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
    "splits = ['train', 'val', 'test']\n",
    "subjects = [split_stats_simple[s]['subjects'] for s in splits]\n",
    "files = [split_stats_simple[s]['files'] for s in splits]\n",
    "\n",
    "ax1.bar(splits, subjects, color=['blue', 'orange', 'green'])\n",
    "ax1.set_title('Number of Subjects per Split (Simple)')\n",
    "ax1.set_ylabel('Count')\n",
    "\n",
    "ax2.bar(splits, files, color=['blue', 'orange', 'green'])\n",
    "ax2.set_title('Number of Files per Split (Simple)')\n",
    "ax2.set_ylabel('Count')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(figs_dir / 'split_sizes_simple.png')\n",
    "plt.show()\n",
    "\n",
    "# 2) Label distribution per split (top 10)\n",
    "fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)\n",
    "for i, split in enumerate(splits):\n",
    "    lc = split_stats_simple[split]['label_counts']\n",
    "    if lc:\n",
    "        labels, counts = zip(*lc.most_common(10))\n",
    "        axes[i].bar(labels, counts)\n",
    "        axes[i].set_title(f'Top 10 Labels in {split.capitalize()} (Simple)')\n",
    "        axes[i].tick_params(axis='x', rotation=45)\n",
    "    else:\n",
    "        axes[i].text(0.5, 0.5, 'No data', ha='center', va='center', transform=axes[i].transAxes)\n",
    "        axes[i].set_title(f'{split.capitalize()} Labels (Simple)')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(figs_dir / 'label_dist_per_split_simple.png')\n",
    "plt.show()\n",
    "\n",
    "# 3) Pie chart for total events per split\n",
    "total_events = [sum(split_stats_simple[s]['label_counts'].values()) for s in splits]\n",
    "fig, ax = plt.subplots(figsize=(8, 8))\n",
    "ax.pie(total_events, labels=[f'{s.capitalize()}\\n{total_events[i]}' for i, s in enumerate(splits)], autopct='%1.1f%%', startangle=90)\n",
    "ax.set_title('Total Events Distribution Across Splits (Simple)')\n",
    "plt.savefig(figs_dir / 'events_pie_simple.png')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11e7c8b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Window-length sensitivity by label and overlap summary\n",
    "# Compute p50/p95 duration per label over the sampled files\n",
    "label_durations = defaultdict(list)\n",
    "for r in tqdm.tqdm(sample_records, desc=\"Computing label durations\"):\n",
    "    meta, events = read_annotation_csv(r['csv'])\n",
    "    for e in events:\n",
    "        label_durations[e['label']].append(e['stop'] - e['start'])\n",
    "\n",
    "wd_table = []\n",
    "for lab, vals in label_durations.items():\n",
    "    arr = np.array(vals)\n",
    "    wd_table.append({\n",
    "        'label': lab,\n",
    "        'p50': float(np.percentile(arr, 50)),\n",
    "        'p95': float(np.percentile(arr, 95)),\n",
    "        'count': int(len(arr))\n",
    "    })\n",
    "wd_df = pd.DataFrame(wd_table).sort_values('p95', ascending=False)\n",
    "print('Top-10 labels by p95 duration:')\n",
    "print(wd_df.head(10))\n",
    "\n",
    "# Overlap ratios (estimated): mean across channels in the example file\n",
    "if ov_by_channel:\n",
    "    mean_overlap = float(np.mean(list(ov_by_channel.values())))\n",
    "    print(f'Mean overlap rate across channels (example): {mean_overlap:.3f}')\n",
    "else:\n",
    "    print('No overlap info available from example')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4bd39f75",
   "metadata": {},
   "source": [
    "# Multi-label, subject-wise stratified split (70/15/15)\n",
    "We build a subject×label count matrix from TUAR annotations, then assign subjects to train/val/test to match target label proportions and avoid leakage. We:\n",
    "- Prioritize rare-label subjects first (weighted by 1/label-frequency).\n",
    "- Greedily place each subject into the split that best reduces the weighted deviation from desired per-split label totals.\n",
    "- Report per-split coverage and attempt to remedy large deficits in val/test by moving subjects when possible.\n",
    "\n",
    "Outputs: `data/processed/suggested_splits_subjectwise_multilabel.csv`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7e26a491",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build subject × label matrix and perform multi-label stratified split\n",
    "from collections import defaultdict, Counter\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from pathlib import Path\n",
    "\n",
    "# Collect counts per subject per label (from all records)\n",
    "subj_label_full = defaultdict(Counter)\n",
    "all_labels = set()\n",
    "subj_total = Counter()\n",
    "\n",
    "for r in tqdm.tqdm(records, desc=\"Collecting subject-label counts\"):\n",
    "    meta, events = read_annotation_csv(r['csv'])\n",
    "    if not events:\n",
    "        continue\n",
    "    lc = Counter([e['label'] for e in events])\n",
    "    subj_total[r['subject']] += 1\n",
    "    for lab, cnt in lc.items():\n",
    "        subj_label_full[r['subject']][lab] += cnt\n",
    "        all_labels.add(lab)\n",
    "\n",
    "subjects_all = sorted(subj_label_full.keys())\n",
    "labels_all = sorted(all_labels)\n",
    "print(f\"Subjects considered: {len(subjects_all)} | Labels: {len(labels_all)}\")\n",
    "\n",
    "# Build a matrix for convenience\n",
    "S = len(subjects_all)\n",
    "L = len(labels_all)\n",
    "mat = np.zeros((S, L), dtype=float)\n",
    "for i, s in enumerate(subjects_all):\n",
    "    for j, lab in enumerate(labels_all):\n",
    "        mat[i, j] = subj_label_full[s][lab]\n",
    "\n",
    "label_totals = mat.sum(axis=0)  # total counts per label\n",
    "print(\"Nonzero labels:\", {labels_all[j]: int(label_totals[j]) for j in range(L) if label_totals[j] > 0})\n",
    "\n",
    "# Target proportions\n",
    "frac = (0.70, 0.15, 0.15)\n",
    "train_target = label_totals * frac[0]\n",
    "val_target = label_totals * frac[1]\n",
    "test_target = label_totals * frac[2]\n",
    "\n",
    "# Weights to prioritize rare labels\n",
    "with np.errstate(divide='ignore', invalid='ignore'):\n",
    "    inv_freq = 1.0 / np.maximum(label_totals, 1.0)\n",
    "    inv_freq /= inv_freq.max() if inv_freq.max() > 0 else 1.0\n",
    "\n",
    "# Sort subjects by contribution to rare labels: score = sum(counts * inv_freq)\n",
    "subject_scores = [(i, float((mat[i] * inv_freq).sum())) for i in range(S)]\n",
    "subject_scores.sort(key=lambda t: t[1], reverse=True)\n",
    "\n",
    "assign = np.full(S, -1, dtype=int)  # -1 unassigned, 0 train, 1 val, 2 test\n",
    "current = {\n",
    "    0: np.zeros(L, dtype=float),\n",
    "    1: np.zeros(L, dtype=float),\n",
    "    2: np.zeros(L, dtype=float),\n",
    "}\n",
    "\n",
    "# Greedy assignment minimizing weighted deviation from targets\n",
    "def deviation(after, target):\n",
    "    # Weighted L1 deviation\n",
    "    d = np.abs(after - target)\n",
    "    w = 1.0 + 2.0 * inv_freq  # emphasize rare labels\n",
    "    return float((d * w).sum())\n",
    "\n",
    "for i, _score in subject_scores:\n",
    "    contrib = mat[i]\n",
    "    # Try each split and pick the best\n",
    "    best_split, best_obj = None, float('inf')\n",
    "    candidates = [0, 1, 2]\n",
    "    # Light subject-count balancing (aim close to frac by subject count)\n",
    "    n_assigned = (assign != -1).sum()\n",
    "    n_train = (assign == 0).sum()\n",
    "    n_val = (assign == 1).sum()\n",
    "    n_test = (assign == 2).sum()\n",
    "    subj_target = np.array(frac) * max(n_assigned + 1, 1)\n",
    "    subj_dev = {\n",
    "        0: abs((n_train + 1) - subj_target[0]),\n",
    "        1: abs((n_val + 1) - subj_target[1]),\n",
    "        2: abs((n_test + 1) - subj_target[2]),\n",
    "    }\n",
    "    # Evaluate objective = label deviation + small subject-balance penalty\n",
    "    for sidx in candidates:\n",
    "        if sidx == 0:\n",
    "            obj = deviation(current[0] + contrib, train_target) + \\\n",
    "                  deviation(current[1], val_target) + \\\n",
    "                  deviation(current[2], test_target) + 0.05 * subj_dev[0]\n",
    "        elif sidx == 1:\n",
    "            obj = deviation(current[0], train_target) + \\\n",
    "                  deviation(current[1] + contrib, val_target) + \\\n",
    "                  deviation(current[2], test_target) + 0.05 * subj_dev[1]\n",
    "        else:\n",
    "            obj = deviation(current[0], train_target) + \\\n",
    "                  deviation(current[1], val_target) + \\\n",
    "                  deviation(current[2] + contrib, test_target) + 0.05 * subj_dev[2]\n",
    "        if obj < best_obj:\n",
    "            best_obj, best_split = obj, sidx\n",
    "    assign[i] = best_split\n",
    "    current[best_split] += contrib\n",
    "\n",
    "train_idx = [subjects_all[i] for i in range(S) if assign[i] == 0]\n",
    "val_idx = [subjects_all[i] for i in range(S) if assign[i] == 1]\n",
    "test_idx = [subjects_all[i] for i in range(S) if assign[i] == 2]\n",
    "print(\"Initial subject counts:\", len(train_idx), len(val_idx), len(test_idx))\n",
    "\n",
    "# Coverage checks: ensure labels appear in each split when feasible\n",
    "MIN_VAL = 30   # minimum total label-counts in val when possible\n",
    "MIN_TEST = 5  # minimum total label-counts in test when possible\n",
    "\n",
    "def label_counts_for(split_subjects):\n",
    "    if not split_subjects:\n",
    "        return np.zeros(L)\n",
    "    idxs = [subjects_all.index(s) for s in split_subjects]\n",
    "    return mat[idxs].sum(axis=0)\n",
    "\n",
    "val_counts = label_counts_for(val_idx)\n",
    "test_counts = label_counts_for(test_idx)\n",
    "\n",
    "def try_fix_coverage(missing_mask, from_split, to_split):\n",
    "    # Move the best subject from from_split to to_split to improve missing labels\n",
    "    improved = False\n",
    "    src = val_idx if from_split == 'val' else test_idx if from_split == 'test' else train_idx\n",
    "    dst = val_idx if to_split == 'val' else test_idx if to_split == 'test' else train_idx\n",
    "    if not src:\n",
    "        return False\n",
    "    best_gain, best_s = 0.0, None\n",
    "    for s in src:\n",
    "        i = subjects_all.index(s)\n",
    "        gain = float((mat[i] * missing_mask).sum())\n",
    "        if gain > best_gain:\n",
    "            best_gain = gain\n",
    "            best_s = s\n",
    "    if best_s is None or best_gain <= 0:\n",
    "        return False\n",
    "    # Move\n",
    "    src.remove(best_s)\n",
    "    dst.append(best_s)\n",
    "    return True\n",
    "\n",
    "# Identify labels under-covered in val/test\n",
    "val_missing = (val_counts < np.minimum(MIN_VAL, label_totals)).astype(float)\n",
    "test_missing = (test_counts < np.minimum(MIN_TEST, label_totals)).astype(float)\n",
    "\n",
    "# Attempt limited fixes by moving from train to val/test\n",
    "for _ in range(5):\n",
    "    changed = False\n",
    "    if val_missing.any():\n",
    "        if try_fix_coverage(val_missing, 'train', 'val'):\n",
    "            changed = True\n",
    "            val_counts = label_counts_for(val_idx)\n",
    "            val_missing = (val_counts < np.minimum(MIN_VAL, label_totals)).astype(float)\n",
    "    if test_missing.any():\n",
    "        if try_fix_coverage(test_missing, 'train', 'test'):\n",
    "            changed = True\n",
    "            test_counts = label_counts_for(test_idx)\n",
    "            test_missing = (test_counts < np.minimum(MIN_TEST, label_totals)).astype(float)\n",
    "    if not changed:\n",
    "        break\n",
    "\n",
    "print(\"Per-split label totals (top 10 labels):\")\n",
    "for name, subs in [('train', train_idx), ('val', val_idx), ('test', test_idx)]:\n",
    "    counts = label_counts_for(subs)\n",
    "    pairs = [(labels_all[j], int(counts[j])) for j in range(L)]\n",
    "    pairs.sort(key=lambda x: x[1], reverse=True)\n",
    "    print(name, dict(pairs[:10]))\n",
    "\n",
    "# Save CSV\n",
    "rows = []\n",
    "for s in train_idx:\n",
    "    rows.append({'subject_id': s, 'split': 'train'})\n",
    "for s in val_idx:\n",
    "    rows.append({'subject_id': s, 'split': 'val'})\n",
    "for s in test_idx:\n",
    "    rows.append({'subject_id': s, 'split': 'test'})\n",
    "\n",
    "out_path = Path('data/processed/suggested_splits_subjectwise_multilabel.csv')\n",
    "out_path.parent.mkdir(parents=True, exist_ok=True)\n",
    "pd.DataFrame(rows).drop_duplicates().to_csv(out_path, index=False)\n",
    "print('Wrote multi-label stratified splits to', out_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f51cbfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualizations for multi-label stratified split data spread\n",
    "\n",
    "# Compute per-split statistics\n",
    "split_stats = {'train': {}, 'val': {}, 'test': {}}\n",
    "\n",
    "for split_name, subj_list in [('train', train_idx), ('val', val_idx), ('test', test_idx)]:\n",
    "    split_records = [r for r in records if r['subject'] in subj_list]\n",
    "    label_counts_split = Counter()\n",
    "    channel_counts_split = Counter()\n",
    "    durations_split = []\n",
    "    files_count = len(split_records)\n",
    "    subjects_count = len(set(r['subject'] for r in split_records))\n",
    "    \n",
    "    for r in split_records:\n",
    "        meta, events = read_annotation_csv(r['csv'])\n",
    "        for e in events:\n",
    "            label_counts_split[e['label']] += 1\n",
    "            channel_counts_split[e['channel']] += 1\n",
    "            durations_split.append(e['stop'] - e['start'])\n",
    "    \n",
    "    split_stats[split_name] = {\n",
    "        'label_counts': label_counts_split,\n",
    "        'channel_counts': channel_counts_split,\n",
    "        'durations': durations_split,\n",
    "        'files': files_count,\n",
    "        'subjects': subjects_count\n",
    "    }\n",
    "\n",
    "# 1) Split sizes: subjects and files\n",
    "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))\n",
    "splits = ['train', 'val', 'test']\n",
    "subjects = [split_stats[s]['subjects'] for s in splits]\n",
    "files = [split_stats[s]['files'] for s in splits]\n",
    "\n",
    "ax1.bar(splits, subjects, color=['blue', 'orange', 'green'])\n",
    "ax1.set_title('Number of Subjects per Split')\n",
    "ax1.set_ylabel('Count')\n",
    "\n",
    "ax2.bar(splits, files, color=['blue', 'orange', 'green'])\n",
    "ax2.set_title('Number of Files per Split')\n",
    "ax2.set_ylabel('Count')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(figs_dir / 'split_sizes_multilabel.png')\n",
    "plt.show()\n",
    "\n",
    "# 2) Label distribution per split (top 10)\n",
    "fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)\n",
    "for i, split in enumerate(splits):\n",
    "    lc = split_stats[split]['label_counts']\n",
    "    if lc:\n",
    "        labels, counts = zip(*lc.most_common(10))\n",
    "        axes[i].bar(labels, counts)\n",
    "        axes[i].set_title(f'Top 10 Labels in {split.capitalize()}')\n",
    "        axes[i].tick_params(axis='x', rotation=45)\n",
    "    else:\n",
    "        axes[i].text(0.5, 0.5, 'No data', ha='center', va='center', transform=axes[i].transAxes)\n",
    "        axes[i].set_title(f'{split.capitalize()} Labels')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(figs_dir / 'label_dist_per_split_multilabel.png')\n",
    "plt.show()\n",
    "\n",
    "# 3) Channel distribution per split (top 10)\n",
    "fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)\n",
    "for i, split in enumerate(splits):\n",
    "    cc = split_stats[split]['channel_counts']\n",
    "    if cc:\n",
    "        channels, counts = zip(*cc.most_common(10))\n",
    "        axes[i].bar(channels, counts)\n",
    "        axes[i].set_title(f'Top 10 Channels in {split.capitalize()}')\n",
    "        axes[i].tick_params(axis='x', rotation=45)\n",
    "    else:\n",
    "        axes[i].text(0.5, 0.5, 'No data', ha='center', va='center', transform=axes[i].transAxes)\n",
    "        axes[i].set_title(f'{split.capitalize()} Channels')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(figs_dir / 'channel_dist_per_split_multilabel.png')\n",
    "plt.show()\n",
    "\n",
    "# 4) Duration distributions per split\n",
    "fig, axes = plt.subplots(1, 3, figsize=(18, 5), sharey=True)\n",
    "for i, split in enumerate(splits):\n",
    "    durs = split_stats[split]['durations']\n",
    "    if durs:\n",
    "        axes[i].hist(durs, bins=50, alpha=0.7)\n",
    "        axes[i].set_title(f'Duration Distribution in {split.capitalize()}')\n",
    "        axes[i].set_xlabel('Duration (s)')\n",
    "        axes[i].set_ylabel('Count')\n",
    "    else:\n",
    "        axes[i].text(0.5, 0.5, 'No data', ha='center', va='center', transform=axes[i].transAxes)\n",
    "        axes[i].set_title(f'{split.capitalize()} Durations')\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.savefig(figs_dir / 'duration_dist_per_split_multilabel.png')\n",
    "plt.show()\n",
    "\n",
    "# 5) Pie chart for total events per split\n",
    "total_events = [sum(split_stats[s]['label_counts'].values()) for s in splits]\n",
    "fig, ax = plt.subplots(figsize=(8, 8))\n",
    "ax.pie(total_events, labels=[f'{s.capitalize()}\\n{total_events[i]}' for i, s in enumerate(splits)], autopct='%1.1f%%', startangle=90)\n",
    "ax.set_title('Total Events Distribution Across Splits')\n",
    "plt.savefig(figs_dir / 'events_pie_multilabel.png')\n",
    "plt.show()\n",
    "\n",
    "# 6) Boxplot for durations per split\n",
    "dur_lists = [split_stats[s]['durations'] for s in splits if split_stats[s]['durations']]\n",
    "if dur_lists:\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    plt.boxplot(dur_lists, labels=[s for s in splits if split_stats[s]['durations']])\n",
    "    plt.title('Duration Boxplots per Split')\n",
    "    plt.ylabel('Duration (s)')\n",
    "    plt.savefig(figs_dir / 'duration_boxplot_multilabel.png')\n",
    "    plt.show()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
