{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba7b6d4c-2bdc-4e8b-8849-051176df9c56",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import os, sys, re\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from MAE_model_downstream import PedSleepMAE\n",
    "from utils.misc import setup_seed\n",
    "from dataloader import HDF5Dataset\n",
    "\n",
    "search_label = \"sleep_label\"\n",
    "directory_path = os.path.join(os.path.dirname(os.getcwd()), \"PYTORCH\", \"hdf5\")\n",
    "patch_size, mask_ratio, emb_dim, num_head, num_layer = 8, 15, 64, 4, 3\n",
    "needed_patient_IDs, needed_study_IDs = [\"xxxx\"], [\"xxxx\"]\n",
    "seed = 42\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "num_patches = int(3840 / patch_size)\n",
    "setup_seed(seed)\n",
    "\n",
    "def extract_sample_id(f):\n",
    "    m = re.search(r\"_sample_(\\d+)\\.hdf5$\", f)\n",
    "    return int(m.group(1)) if m else float(\"inf\")\n",
    "\n",
    "files = [os.path.join(directory_path, x) for x in os.listdir(directory_path) if x.endswith(\".hdf5\")]\n",
    "files = [f for f in files if f.split(\"/\")[-1].split(\"_\")[0] in needed_patient_IDs and f.split(\"/\")[-1].split(\"_\")[1] in needed_study_IDs]\n",
    "sorted_files = sorted(files, key=extract_sample_id)\n",
    "\n",
    "if len(sorted_files) == 0:\n",
    "    print(\"Invalid ID\"); sys.exit(1)\n",
    "\n",
    "print(f\"Total sorted files: {len(sorted_files)}\")\n",
    "\n",
    "model = PedSleepMAE(\n",
    "    batch_size=len(sorted_files),\n",
    "    patch_size=patch_size,\n",
    "    mask_ratio=mask_ratio,\n",
    "    emb_dim=emb_dim,\n",
    "    num_head=num_head,\n",
    "    num_layer=num_layer,\n",
    ").to(device)\n",
    "\n",
    "ckpt_file = f\"../savedmodels{mask_ratio}/signalmask{mask_ratio}_patch_size{patch_size}.pt\"\n",
    "ckpt = torch.load(ckpt_file, weights_only=True)\n",
    "model.load_state_dict(ckpt[\"state_dict\"])\n",
    "\n",
    "pool = nn.AdaptiveMaxPool1d(1)\n",
    "\n",
    "dataset = HDF5Dataset(sorted_files, search_label)\n",
    "loader = DataLoader(dataset, batch_size=100, shuffle=False)\n",
    "\n",
    "save_dir = \"output_embeddings_sorted\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "embeddings, labels = [], []\n",
    "\n",
    "for i, (signal, label, _) in enumerate(loader):\n",
    "    print(f\"Batch {i}\")\n",
    "    with torch.no_grad():\n",
    "        signal = signal.squeeze().float().to(device)\n",
    "        label = label.squeeze().cpu().numpy()\n",
    "        n, _, _ = signal.shape\n",
    "\n",
    "        enc, _ = model.encoder(signal)\n",
    "        enc = enc[:, :, 1:, :].reshape(-1, num_patches, emb_dim)\n",
    "        pooled = pool(enc).reshape(n, -1)\n",
    "\n",
    "        embeddings.extend(pooled.cpu().numpy())\n",
    "        labels.extend(label)\n",
    "\n",
    "emb = np.vstack(embeddings)\n",
    "lab = np.array(labels)\n",
    "\n",
    "prefix = f\"{'_'.join(needed_patient_IDs)}_{'_'.join(needed_study_IDs)}\"\n",
    "np.save(os.path.join(save_dir, f\"{prefix}_embeddings.npy\"), emb)\n",
    "np.save(os.path.join(save_dir, f\"{prefix}_labels.npy\"), lab)\n",
    "\n",
    "print(f\"Saved {prefix}_embeddings.npy {emb.shape}\")\n",
    "print(f\"Saved {prefix}_labels.npy {lab.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8ae56529-b691-4edf-847f-3bd77cf51562",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import phate\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "embeddings = np.load(\"embeddings.npy\")\n",
    "sleep_stages = np.load(\"labels.npy\")\n",
    "time_indices = np.arange(len(embeddings))\n",
    "\n",
    "sleep_stage_names = {\n",
    "    0: \"Wake\",\n",
    "    1: \"N1\",\n",
    "    2: \"N2\",\n",
    "    3: \"N3\",\n",
    "    4: \"REM\"\n",
    "}\n",
    "\n",
    "time_cmap = plt.get_cmap(\"turbo\")\n",
    "stage_palette = sns.color_palette(\"Set1\", 5)\n",
    "stage_colors = [stage_palette[int(stage)] for stage in sleep_stages]\n",
    "\n",
    "pca = PCA(n_components=2)\n",
    "pca_features = pca.fit_transform(embeddings)\n",
    "\n",
    "tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, random_state=42)\n",
    "tsne_features = tsne.fit_transform(embeddings)\n",
    "\n",
    "phate_operator = phate.PHATE(n_components=2, knn=5, t=12, n_pca=100, random_state=42)\n",
    "phate_features = phate_operator.fit_transform(embeddings)\n",
    "\n",
    "fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
    "\n",
    "scatter1 = axes[0, 0].scatter(pca_features[:, 0], pca_features[:, 1], c=time_indices, cmap=time_cmap, alpha=0.8, s=10)\n",
    "axes[0, 0].set_title(\"PCA (Progression Coloring)\")\n",
    "fig.colorbar(scatter1, ax=axes[0, 0], label=\"Time Index\")\n",
    "\n",
    "scatter2 = axes[0, 1].scatter(tsne_features[:, 0], tsne_features[:, 1], c=time_indices, cmap=time_cmap, alpha=0.8, s=10)\n",
    "axes[0, 1].set_title(\"t-SNE (Progression Coloring)\")\n",
    "fig.colorbar(scatter2, ax=axes[0, 1], label=\"Time Index\")\n",
    "\n",
    "scatter3 = axes[0, 2].scatter(phate_features[:, 0], phate_features[:, 1], c=time_indices, cmap=time_cmap, alpha=0.8, s=10)\n",
    "axes[0, 2].set_title(\"PHATE (Progression Coloring)\")\n",
    "fig.colorbar(scatter3, ax=axes[0, 2], label=\"Time Index\")\n",
    "\n",
    "scatter4 = axes[1, 0].scatter(pca_features[:, 0], pca_features[:, 1], c=stage_colors, alpha=0.8, s=10)\n",
    "axes[1, 0].set_title(\"PCA (Categorical Labels)\")\n",
    "\n",
    "scatter5 = axes[1, 1].scatter(tsne_features[:, 0], tsne_features[:, 1], c=stage_colors, alpha=0.8, s=10)\n",
    "axes[1, 1].set_title(\"t-SNE (Categorical Labels)\")\n",
    "\n",
    "scatter6 = axes[1, 2].scatter(phate_features[:, 0], phate_features[:, 1], c=stage_colors, alpha=0.8, s=10)\n",
    "axes[1, 2].set_title(\"PHATE (Categorical Labels)\")\n",
    "\n",
    "legend_labels = [\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label=sleep_stage_names[i],\n",
    "               markersize=10, markerfacecolor=stage_palette[i]) \n",
    "    for i in range(5)\n",
    "]\n",
    "fig.legend(handles=legend_labels, loc='upper right', title=\"Labels\", fontsize=12)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "35381907-7139-4ca8-9ad3-8ed1f2aaac90",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import phate\n",
    "import umap\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "embeddings = np.load(\"embeddings.npy\")\n",
    "sleep_stages = np.load(\"labels.npy\")\n",
    "time_indices = np.arange(len(embeddings))\n",
    "\n",
    "sleep_stage_names = {\n",
    "    0: \"Wake\",\n",
    "    4: \"REM\",\n",
    "    1: \"N1\",\n",
    "    2: \"N2\",\n",
    "    3: \"N3\"\n",
    "}\n",
    "ordered_stages = [0, 4, 1, 2, 3]\n",
    "\n",
    "time_cmap = sns.color_palette(\"turbo\", as_cmap=True)\n",
    "stage_palette = sns.color_palette(\"coolwarm\", len(ordered_stages))\n",
    "stage_color_map = {stage: stage_palette[i] for i, stage in enumerate(ordered_stages)}\n",
    "stage_colors = [stage_color_map[int(stage)] for stage in sleep_stages]\n",
    "\n",
    "pca = PCA(n_components=2).fit_transform(embeddings)\n",
    "tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, random_state=42).fit_transform(embeddings)\n",
    "phate_op = phate.PHATE(n_components=2, knn=5, t=8, n_pca=100, random_state=42)\n",
    "phate_features = phate_op.fit_transform(embeddings)\n",
    "umap_features = umap.UMAP(n_components=2, n_neighbors=15, min_dist=0.1, random_state=42).fit_transform(embeddings)\n",
    "\n",
    "fig, axes = plt.subplots(2, 4, figsize=(22, 10))\n",
    "\n",
    "sc0 = axes[0, 0].scatter(pca[:, 0], pca[:, 1], c=time_indices, cmap=time_cmap, s=10, alpha=0.8)\n",
    "axes[0, 0].set_title(\"PCA (Progression Coloring)\")\n",
    "axes[0, 1].scatter(tsne[:, 0], tsne[:, 1], c=time_indices, cmap=time_cmap, s=10, alpha=0.8)\n",
    "axes[0, 1].set_title(\"t-SNE (Progression Coloring)\")\n",
    "axes[0, 2].scatter(phate_features[:, 0], phate_features[:, 1], c=time_indices, cmap=time_cmap, s=10, alpha=0.8)\n",
    "axes[0, 2].set_title(\"PHATE (Progression Coloring)\")\n",
    "axes[0, 3].scatter(umap_features[:, 0], umap_features[:, 1], c=time_indices, cmap=time_cmap, s=10, alpha=0.8)\n",
    "axes[0, 3].set_title(\"UMAP (Progression Coloring)\")\n",
    "\n",
    "axes[1, 0].scatter(pca[:, 0], pca[:, 1], c=stage_colors, s=10, alpha=0.8)\n",
    "axes[1, 0].set_title(\"PCA (Categorical Labels)\")\n",
    "axes[1, 1].scatter(tsne[:, 0], tsne[:, 1], c=stage_colors, s=10, alpha=0.8)\n",
    "axes[1, 1].set_title(\"t-SNE (Categorical Labels)\")\n",
    "axes[1, 2].scatter(phate_features[:, 0], phate_features[:, 1], c=stage_colors, s=10, alpha=0.8)\n",
    "axes[1, 2].set_title(\"PHATE (Categorical Labels)\")\n",
    "axes[1, 3].scatter(umap_features[:, 0], umap_features[:, 1], c=stage_colors, s=10, alpha=0.8)\n",
    "axes[1, 3].set_title(\"UMAP (Categorical Labels)\")\n",
    "\n",
    "cbar_ax = fig.add_axes([0.92, 0.3, 0.015, 0.4])\n",
    "cb = fig.colorbar(sc0, cax=cbar_ax)\n",
    "cb.set_label('Time Index', fontsize=12)\n",
    "\n",
    "legend_labels = [\n",
    "    plt.Line2D([0], [0], marker='o', color='w', label=sleep_stage_names[i],\n",
    "               markersize=10, markerfacecolor=stage_palette[idx]) \n",
    "    for idx, i in enumerate(ordered_stages)\n",
    "]\n",
    "fig.legend(handles=legend_labels, loc='lower right', title=\"Labels\", fontsize=12)\n",
    "\n",
    "plt.tight_layout(rect=[0, 0, 0.9, 1])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dfb772f-658d-441b-a0ab-54fa5348e091",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "sleep_stages = np.load(\"labels.npy\")\n",
    "time_indices = np.arange(len(sleep_stages)) * 0.5  # each index = 30 seconds → minutes\n",
    "\n",
    "stage_labels = ['Wake', 'N1', 'N2', 'N3', 'REM']\n",
    "stage_ticks = [0, 1, 2, 3, 4]\n",
    "stage_display_order = [0, 4, 1, 2, 3]\n",
    "\n",
    "display_mapping = {0: 0, 1: 2, 2: 3, 3: 4, 4: 1}\n",
    "stage_display_values = np.array([display_mapping[int(s)] for s in sleep_stages])\n",
    "\n",
    "plt.figure(figsize=(15, 4))\n",
    "plt.step(time_indices, stage_display_values, where='mid', linewidth=1.8, color='darkslateblue')\n",
    "plt.yticks(\n",
    "    [display_mapping[i] for i in stage_display_order],\n",
    "    [stage_labels[i] for i in stage_display_order]\n",
    ")\n",
    "plt.xlabel(\"Time (minutes)\", fontsize=12)\n",
    "plt.ylabel(\"Stage\", fontsize=12)\n",
    "plt.title(\"Stages Over Time\", fontsize=14)\n",
    "plt.grid(axis='x', linestyle='--', alpha=0.4)\n",
    "plt.ylim(-0.5, 4.5)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "667fad1a-a286-41c9-8624-ac5e74b7289e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "sleep_stages = np.load(\"labels.npy\")\n",
    "time_indices = np.arange(len(sleep_stages)) * 0.5  # 30s intervals → minutes\n",
    "\n",
    "stage_names_ordered = ['Wake', 'REM', 'N1', 'N2', 'N3']\n",
    "stage_value_map = {0: 0, 4: 1, 1: 2, 2: 3, 3: 4}\n",
    "stage_colors = {\n",
    "    0: '#ff6f69',\n",
    "    1: '#ffeead',\n",
    "    2: '#96ceb4',\n",
    "    3: '#379683',\n",
    "    4: '#88d8b0',\n",
    "}\n",
    "\n",
    "mapped_y = [stage_value_map[int(s)] for s in sleep_stages]\n",
    "\n",
    "plt.figure(figsize=(15, 3))\n",
    "for i in range(len(sleep_stages)):\n",
    "    orig_stage = int(sleep_stages[i])\n",
    "    y_val = stage_value_map[orig_stage]\n",
    "    plt.hlines(y=y_val, xmin=time_indices[i], xmax=time_indices[i] + 0.5,\n",
    "               colors=stage_colors[orig_stage], linewidth=8)\n",
    "\n",
    "plt.yticks(ticks=[0, 1, 2, 3, 4], labels=stage_names_ordered)\n",
    "plt.ylim(-0.5, 4.5)\n",
    "plt.xlabel(\"Time (minutes)\", fontsize=12)\n",
    "plt.title(\"Stages Over Time\", fontsize=14)\n",
    "plt.grid(axis='x', linestyle='--', alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f76a71f-322d-418a-9155-23cdaa77ca8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.metrics import silhouette_score, davies_bouldin_score, calinski_harabasz_score, homogeneity_score\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "from sklearn.feature_selection import mutual_info_classif\n",
    "from scipy.stats import spearmanr\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "embeddings = np.load(\"embeddings.npy\")\n",
    "sleep_stages = np.load(\"labels.npy\")\n",
    "time_indices = np.arange(len(embeddings))\n",
    "\n",
    "scaler = StandardScaler()\n",
    "embeddings_scaled = scaler.fit_transform(embeddings)\n",
    "\n",
    "tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, random_state=42)\n",
    "tsne_features = tsne.fit_transform(embeddings_scaled)\n",
    "\n",
    "silhouette = silhouette_score(tsne_features, sleep_stages)\n",
    "davies_bouldin = davies_bouldin_score(tsne_features, sleep_stages)\n",
    "calinski_harabasz = calinski_harabasz_score(tsne_features, sleep_stages)\n",
    "\n",
    "print(\"t-SNE Metrics:\")\n",
    "print(f\"   - Silhouette Score: {silhouette:.3f} (Higher is better)\")\n",
    "print(f\"   - Davies-Bouldin Index: {davies_bouldin:.3f} (Lower is better)\")\n",
    "print(f\"   - Calinski-Harabasz Index: {calinski_harabasz:.3f} (Higher is better)\\n\")\n",
    "\n",
    "sleep_time_corr, _ = spearmanr(sleep_stages, time_indices)\n",
    "print(f\"Spearman Correlation (Stages vs. Time): {sleep_time_corr:.3f}\")\n",
    "\n",
    "mi_tsne = mutual_info_classif(tsne_features, sleep_stages, discrete_features=False).mean()\n",
    "print(f\"Mutual Information (Higher is better): {mi_tsne:.3f}\\n\")\n",
    "\n",
    "def temporal_coherence_score(embedding_features, time_labels, k_neighbors=10):\n",
    "    neigh = NearestNeighbors(n_neighbors=k_neighbors)\n",
    "    neigh.fit(embedding_features)\n",
    "    _, indices = neigh.kneighbors(embedding_features)\n",
    "    avg_time_variance = np.mean([np.var(time_labels[neighbors]) for neighbors in indices])\n",
    "    return avg_time_variance\n",
    "\n",
    "tsne_temporal_score = temporal_coherence_score(tsne_features, time_indices)\n",
    "print(f\"Temporal Coherence Score (Lower is better): {tsne_temporal_score:.3f}\")\n",
    "\n",
    "tsne_purity = homogeneity_score(sleep_stages, KMeans(n_clusters=5, random_state=42, n_init=10).fit_predict(tsne_features))\n",
    "print(f\"Cluster Purity Score (Higher is better): {tsne_purity:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "394db0cd-a4c2-459d-a6c9-4c369a5a9304",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
