{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba7b6d4c-2bdc-4e8b-8849-051176df9c56",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os, sys, re, warnings\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from collections import defaultdict\n",
    "\n",
    "from MAE_model_downstream import PedSleepMAE\n",
    "from utils.misc import setup_seed\n",
    "from dataloader import HDF5Dataset\n",
    "\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# config\n",
    "search_label = \"apnea_label\"\n",
    "data_dir = \"./hdf5_data\"\n",
    "checkpoint_file = \"./checkpoints/mae_checkpoint.pt\"\n",
    "\n",
    "patch_size = 8\n",
    "mask_ratio = 15\n",
    "emb_dim = 64\n",
    "num_head = 4\n",
    "num_layer = 3\n",
    "patient_ids = [\"pid1\"]\n",
    "study_ids = [\"sid1\"]\n",
    "seed = 42\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "num_patches = int(3840 / patch_size)\n",
    "\n",
    "setup_seed(seed)\n",
    "\n",
    "# helpers\n",
    "def extract_sample_id(fname):\n",
    "    m = re.search(r\"_sample_(\\d+)\\.hdf5$\", fname)\n",
    "    return int(m.group(1)) if m else float(\"inf\")\n",
    "\n",
    "# gather files\n",
    "files = [\n",
    "    os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(\".hdf5\")\n",
    "]\n",
    "files = [\n",
    "    f for f in files\n",
    "    if f.split(\"/\")[-1].split(\"_\")[0] in patient_ids\n",
    "    and f.split(\"/\")[-1].split(\"_\")[1] in study_ids\n",
    "]\n",
    "files = sorted(files, key=extract_sample_id)\n",
    "\n",
    "if len(files) == 0:\n",
    "    print(\"no matching patient/session\"); sys.exit(1)\n",
    "\n",
    "print(f\"total sorted files: {len(files)}\")\n",
    "print(\"first few:\", files[:5])\n",
    "\n",
    "# load model\n",
    "model = PedSleepMAE(\n",
    "    batch_size=len(files),\n",
    "    patch_size=patch_size,\n",
    "    mask_ratio=mask_ratio,\n",
    "    emb_dim=emb_dim,\n",
    "    num_head=num_head,\n",
    "    num_layer=num_layer,\n",
    ").to(device)\n",
    "\n",
    "ckpt = torch.load(checkpoint_file, map_location=device, weights_only=True)\n",
    "model.load_state_dict(ckpt[\"state_dict\"])\n",
    "\n",
    "print(\"device:\", device)\n",
    "\n",
    "# dataloader\n",
    "dataset = HDF5Dataset(files, search_label)\n",
    "loader = DataLoader(dataset, batch_size=100, shuffle=False)\n",
    "\n",
    "# run\n",
    "pool = nn.AdaptiveMaxPool1d(1)\n",
    "emb_list = []\n",
    "\n",
    "for b, (signal, _, _) in enumerate(loader):\n",
    "    print(\"batch\", b)\n",
    "    with torch.no_grad():\n",
    "        sig = signal.squeeze().float().to(device)\n",
    "        n, _, _ = sig.shape\n",
    "        feats, _ = model.encoder(sig)\n",
    "        feats = feats[:, :, 1:, :]\n",
    "        feats = feats.reshape(-1, num_patches, emb_dim)\n",
    "        pooled = pool(feats).reshape(n, -1)\n",
    "        emb_list.extend(pooled.cpu().numpy())\n",
    "\n",
    "emb_array = np.vstack(emb_list)\n",
    "\n",
    "prefix = f\"{'_'.join(patient_ids)}_{'_'.join(study_ids)}\"\n",
    "os.makedirs(\"./output_embeddings_sorted\", exist_ok=True)\n",
    "np.save(f\"./output_embeddings_sorted/{prefix}_embeddings.npy\", emb_array)\n",
    "\n",
    "print(\"saved:\", f\"./output_embeddings_sorted/{prefix}_embeddings.npy\", emb_array.shape)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55933c5f-a811-4ae2-8e92-411b4c7d4d8a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import phate\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "embeddings = np.load(\"./output_embeddings_sorted/example_embeddings.npy\")\n",
    "time_idx = np.arange(len(embeddings))\n",
    "\n",
    "def plot_2D_methods(embeddings, time_idx):\n",
    "    print(\"running PCA, t-SNE, PHATE...\")\n",
    "\n",
    "    pca = PCA(n_components=2).fit_transform(embeddings)\n",
    "    tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, random_state=42).fit_transform(embeddings)\n",
    "    phate_op = phate.PHATE(n_components=2, knn=5, t=12, random_state=42).fit_transform(embeddings)\n",
    "\n",
    "    cmap = plt.get_cmap(\"turbo\")\n",
    "    fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n",
    "\n",
    "    sc1 = axes[0].scatter(pca[:,0], pca[:,1], c=time_idx, cmap=cmap, s=10, alpha=0.8)\n",
    "    axes[0].set_title(\"PCA\"); fig.colorbar(sc1, ax=axes[0], label=\"time\")\n",
    "\n",
    "    sc2 = axes[1].scatter(tsne[:,0], tsne[:,1], c=time_idx, cmap=cmap, s=10, alpha=0.8)\n",
    "    axes[1].set_title(\"t-SNE\"); fig.colorbar(sc2, ax=axes[1], label=\"time\")\n",
    "\n",
    "    sc3 = axes[2].scatter(phate_op[:,0], phate_op[:,1], c=time_idx, cmap=cmap, s=10, alpha=0.8)\n",
    "    axes[2].set_title(\"PHATE\"); fig.colorbar(sc3, ax=axes[2], label=\"time\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "plot_2D_methods(embeddings, time_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "828fbdea-94d9-4692-a4be-28ad5a67051d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "import phate\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "def plot_3D_methods(embeddings, time_idx):\n",
    "    print(\"running PCA, t-SNE, PHATE in 3D...\")\n",
    "\n",
    "    pca = PCA(n_components=3).fit_transform(embeddings)\n",
    "    tsne = TSNE(n_components=3, perplexity=30, learning_rate=200, random_state=42).fit_transform(embeddings)\n",
    "    phate_op = phate.PHATE(n_components=3, knn=5, t=12, random_state=42).fit_transform(embeddings)\n",
    "\n",
    "    cmap = plt.get_cmap(\"turbo\")\n",
    "    fig = plt.figure(figsize=(18, 6))\n",
    "\n",
    "    ax1 = fig.add_subplot(131, projection=\"3d\")\n",
    "    s1 = ax1.scatter(pca[:,0], pca[:,1], pca[:,2], c=time_idx, cmap=cmap, s=5, alpha=0.7)\n",
    "    ax1.set_title(\"PCA 3D\"); fig.colorbar(s1, ax=ax1, label=\"time\")\n",
    "\n",
    "    ax2 = fig.add_subplot(132, projection=\"3d\")\n",
    "    s2 = ax2.scatter(tsne[:,0], tsne[:,1], tsne[:,2], c=time_idx, cmap=cmap, s=5, alpha=0.7)\n",
    "    ax2.set_title(\"t-SNE 3D\"); fig.colorbar(s2, ax=ax2, label=\"time\")\n",
    "\n",
    "    ax3 = fig.add_subplot(133, projection=\"3d\")\n",
    "    s3 = ax3.scatter(phate_op[:,0], phate_op[:,1], phate_op[:,2], c=time_idx, cmap=cmap, s=5, alpha=0.7)\n",
    "    ax3.set_title(\"PHATE 3D\"); fig.colorbar(s3, ax=ax3, label=\"time\")\n",
    "\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "plot_3D_methods(embeddings, time_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f78d9f9-1a30-4d18-a983-438a5ac98562",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import phate\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "\n",
    "embeddings = np.load(\"./output_embeddings_sorted/example_embeddings.npy\")\n",
    "time_idx = np.arange(len(embeddings))\n",
    "\n",
    "def group_time_idx(time_idx):\n",
    "    tmin = time_idx * 0.5\n",
    "    bins = []\n",
    "    for t in tmin:\n",
    "        if t < 60: bins.append(\"0-60\")\n",
    "        elif t < 120: bins.append(\"60-120\")\n",
    "        elif t < 180: bins.append(\"120-180\")\n",
    "        elif t < 240: bins.append(\"180-240\")\n",
    "        elif t < 300: bins.append(\"240-300\")\n",
    "        else: bins.append(\"300+\")\n",
    "    return bins\n",
    "\n",
    "def plot_2D_methods(emb, time_idx):\n",
    "    time_bins = group_time_idx(time_idx)\n",
    "    labels = [\"0-60\",\"60-120\",\"120-180\",\"180-240\",\"240-300\",\"300+\"]\n",
    "    palette = dict(zip(labels, sns.color_palette(\"Spectral\", len(labels))))\n",
    "    colors = [palette[l] for l in time_bins]\n",
    "\n",
    "    pca = PCA(n_components=2).fit_transform(emb)\n",
    "    tsne = TSNE(n_components=2, perplexity=30, learning_rate=200, random_state=42).fit_transform(emb)\n",
    "    phate_op = phate.PHATE(n_components=2, knn=5, t=20, n_pca=100, random_state=42).fit_transform(emb)\n",
    "\n",
    "    fig, axes = plt.subplots(1, 3, figsize=(18,6))\n",
    "    axes[0].scatter(pca[:,0], pca[:,1], c=colors, s=10, alpha=0.8); axes[0].set_title(\"PCA\")\n",
    "    axes[1].scatter(tsne[:,0], tsne[:,1], c=colors, s=10, alpha=0.8); axes[1].set_title(\"t-SNE\")\n",
    "    axes[2].scatter(phate_op[:,0], phate_op[:,1], c=colors, s=10, alpha=0.8); axes[2].set_title(\"PHATE\")\n",
    "\n",
    "    handles = [plt.Line2D([0],[0], marker=\"o\", color=\"w\", label=l, markersize=10, markerfacecolor=c)\n",
    "               for l,c in palette.items()]\n",
    "    fig.legend(handles=handles, loc=\"upper right\", title=\"time (min)\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "plot_2D_methods(embeddings, time_idx)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
