{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "16a5485d-6420-4ed2-9140-a6efcae990ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "from collections import defaultdict\n",
    "from dataloader import HDF5Dataset\n",
    "from utils.misc import setup_seed\n",
    "from MAE_model_downstream import PedSleepMAE\n",
    "import re\n",
    "\n",
    "# ========== CONFIG ==========\n",
    "seed = 42\n",
    "setup_seed(seed)\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "# anonymized directories\n",
    "directory_path = \"/data/hdf5_files\"\n",
    "save_dir = \"./output_embeddings\"\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "search_labels = ['apnea_label', 'sleep_label', 'desat_label', 'eeg_label', 'hypop_label']\n",
    "patch_size = 8\n",
    "mask_ratio = 15\n",
    "emb_dim = 64\n",
    "num_head = 4\n",
    "num_layer = 3\n",
    "batch_size = 50\n",
    "\n",
    "pair_range = slice(0, 3)\n",
    "target_pairs = None  # use pair_range\n",
    "\n",
    "# ========== Load Model ==========\n",
    "model = PedSleepMAE(batch_size=batch_size, patch_size=patch_size, mask_ratio=mask_ratio,\n",
    "                    emb_dim=emb_dim, num_head=num_head, num_layer=num_layer).to(device)\n",
    "checkpoint = torch.load(f\"./checkpoints/signalmask{mask_ratio}_patch{patch_size}.pt\", weights_only=True)\n",
    "model.load_state_dict(checkpoint['state_dict'])\n",
    "model.eval()\n",
    "\n",
    "# ========== Group files by (subject_id, session_id) ==========\n",
    "def extract_sample_id(filename):\n",
    "    match = re.search(r\"_sample_(\\d+)\\.hdf5$\", filename)\n",
    "    return int(match.group(1)) if match else float(\"inf\")\n",
    "\n",
    "grouped_files = defaultdict(list)\n",
    "for fname in os.listdir(directory_path):\n",
    "    if not fname.endswith(\".hdf5\"):\n",
    "        continue\n",
    "    parts = fname.split(\"_\")\n",
    "    if len(parts) < 4:\n",
    "        continue\n",
    "    key = (parts[0], parts[1])\n",
    "    grouped_files[key].append(os.path.join(directory_path, fname))\n",
    "\n",
    "for key in grouped_files:\n",
    "    grouped_files[key] = sorted(grouped_files[key], key=extract_sample_id)\n",
    "\n",
    "all_keys = sorted(grouped_files.keys(), key=lambda x: (int(x[0]), int(x[1])))\n",
    "if target_pairs:\n",
    "    selected_keys = [key for key in all_keys if key in target_pairs]\n",
    "else:\n",
    "    selected_keys = all_keys[pair_range]\n",
    "\n",
    "print(f\"Total valid pairs found: {len(all_keys)}\")\n",
    "print(f\"Selected {len(selected_keys)} pairs to extract:\\n{selected_keys}\")\n",
    "\n",
    "pool = nn.AdaptiveMaxPool1d(1)\n",
    "\n",
    "# ========== Extraction ==========\n",
    "for idx, (subject_id, session_id) in enumerate(selected_keys):\n",
    "    pid = f\"{subject_id}_{session_id}\"\n",
    "    print(f\"\\nProcessing Pair {idx + 1}/{len(selected_keys)}: {pid}\")\n",
    "    session_files = grouped_files[(subject_id, session_id)]\n",
    "    dataset = HDF5Dataset(session_files, search_labels)\n",
    "    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "    all_embeddings = []\n",
    "    all_labels = {label: [] for label in search_labels}\n",
    "\n",
    "    with torch.no_grad():\n",
    "        for signals, label_dict, ids in loader:\n",
    "            signals = signals.squeeze().float().to(device)\n",
    "            encoded, _ = model.encoder(signals)\n",
    "            encoded = encoded[:, :, 1:, :].reshape(signals.size(0), -1, emb_dim)\n",
    "            pooled = pool(encoded).squeeze(dim=2).cpu().numpy()\n",
    "            all_embeddings.append(pooled)\n",
    "            for label_name in search_labels:\n",
    "                all_labels[label_name].append(label_dict[label_name].cpu().numpy())\n",
    "\n",
    "    emb_array = np.vstack(all_embeddings)\n",
    "    np.save(os.path.join(save_dir, f\"{pid}_embeddings.npy\"), emb_array)\n",
    "\n",
    "    for label_name in search_labels:\n",
    "        label_array = np.concatenate(all_labels[label_name], axis=0)\n",
    "        np.save(os.path.join(save_dir, f\"{pid}_{label_name}.npy\"), label_array)\n",
    "\n",
    "    print(f\"  Saved: {pid}_embeddings.npy and label files.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b40ec21-1dda-47ca-b71c-36849bc40de1",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import re\n",
    "from collections import defaultdict\n",
    "\n",
    "directory_path = \"./hdf5_data\"\n",
    "preview_n = 30\n",
    "\n",
    "def extract_sample_id(filename):\n",
    "    match = re.search(r\"_sample_(\\d+)\\.hdf5$\", filename)\n",
    "    return int(match.group(1)) if match else float(\"inf\")\n",
    "\n",
    "grouped_files = defaultdict(list)\n",
    "for fname in os.listdir(directory_path):\n",
    "    if not fname.endswith(\".hdf5\"):\n",
    "        continue\n",
    "    parts = fname.split(\"_\")\n",
    "    if len(parts) < 4:\n",
    "        continue\n",
    "    key = (parts[0], parts[1])\n",
    "    grouped_files[key].append(fname)\n",
    "\n",
    "for key in grouped_files:\n",
    "    grouped_files[key] = sorted(grouped_files[key], key=extract_sample_id)\n",
    "\n",
    "all_keys = sorted(grouped_files.keys(), key=lambda x: (int(x[0]), int(x[1])))\n",
    "\n",
    "print(f\"\\nPreviewing first {preview_n} pairs:\")\n",
    "for i, (sub_id, sess_id) in enumerate(all_keys[:preview_n]):\n",
    "    print(f\"{i+1:2d}. Subject: {sub_id}, Session: {sess_id}\")\n",
    "\n",
    "print(f\"\\nTotal pairs found: {len(all_keys)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "db5b8635-9107-4818-8f24-2a97411b59aa",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Extracting phate features\n",
    "import os\n",
    "import re\n",
    "from collections import defaultdict\n",
    "\n",
    "directory_path = \"./hdf5_data\"\n",
    "preview_n = 30\n",
    "\n",
    "def extract_sample_id(filename):\n",
    "    match = re.search(r\"_sample_(\\d+)\\.hdf5$\", filename)\n",
    "    return int(match.group(1)) if match else float(\"inf\")\n",
    "\n",
    "grouped_files = defaultdict(list)\n",
    "for fname in os.listdir(directory_path):\n",
    "    if fname.endswith(\".hdf5\"):\n",
    "        parts = fname.split(\"_\")\n",
    "        if len(parts) >= 4:\n",
    "            grouped_files[(parts[0], parts[1])].append(fname)\n",
    "\n",
    "for key in grouped_files:\n",
    "    grouped_files[key] = sorted(grouped_files[key], key=extract_sample_id)\n",
    "\n",
    "all_keys = sorted(grouped_files.keys(), key=lambda x: (int(x[0]), int(x[1])))\n",
    "\n",
    "print(f\"\\nPreviewing first {preview_n} pairs:\")\n",
    "for i, (sub_id, sess_id) in enumerate(all_keys[:preview_n]):\n",
    "    print(f\"{i+1:2d}. Subject: {sub_id}, Session: {sess_id}\")\n",
    "\n",
    "print(f\"\\nTotal pairs found: {len(all_keys)}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3bb2e8b-a762-43a4-9ac3-2f1d11190197",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import os\n",
    "from matplotlib.cm import get_cmap\n",
    "\n",
    "pair_id = \"sub01_sess01\"  \n",
    "embedding_dir = \"./output_embeddings\"\n",
    "\n",
    "embeddings = np.load(os.path.join(embedding_dir, f\"{pair_id}_embeddings.npy\"))\n",
    "time_indices = np.arange(len(embeddings))\n",
    "\n",
    "tsne_path = os.path.join(embedding_dir, f\"{pair_id}_tsne_traj.npy\")\n",
    "phate_path = os.path.join(embedding_dir, f\"{pair_id}_phate_traj.npy\")\n",
    "umap_path = os.path.join(embedding_dir, f\"{pair_id}_umap_traj.npy\")\n",
    "\n",
    "tsne = np.load(tsne_path) if os.path.exists(tsne_path) else None\n",
    "phate = np.load(phate_path) if os.path.exists(phate_path) else None\n",
    "umap = np.load(umap_path) if os.path.exists(umap_path) else None\n",
    "\n",
    "cmap = get_cmap(\"turbo\")\n",
    "\n",
    "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
    "\n",
    "if tsne is not None:\n",
    "    sc1 = axes[0].scatter(tsne[:, 0], tsne[:, 1], c=time_indices, cmap=cmap, s=10)\n",
    "    axes[0].set_title(\"t-SNE Trajectory\")\n",
    "    fig.colorbar(sc1, ax=axes[0], label=\"Time Index\")\n",
    "\n",
    "if phate is not None:\n",
    "    sc2 = axes[1].scatter(phate[:, 0], phate[:, 1], c=time_indices, cmap=cmap, s=10)\n",
    "    axes[1].set_title(\"PHATE Trajectory\")\n",
    "    fig.colorbar(sc2, ax=axes[1], label=\"Time Index\")\n",
    "\n",
    "if umap is not None:\n",
    "    sc3 = axes[2].scatter(umap[:, 0], umap[:, 1], c=time_indices, cmap=cmap, s=10)\n",
    "    axes[2].set_title(\"UMAP Trajectory\")\n",
    "    fig.colorbar(sc3, ax=axes[2], label=\"Time Index\")\n",
    "\n",
    "plt.suptitle(\"Trajectory Visualization\", fontsize=16)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "278235fe-6ebb-4196-b212-41b810e22d14",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from glob import glob\n",
    "\n",
    "feature_dir = \"./output_embeddings\"\n",
    "feature_files = sorted(glob(os.path.join(feature_dir, \"*_time_feature.npy\")))\n",
    "\n",
    "features, ids = [], []\n",
    "for file in feature_files:\n",
    "    sid = os.path.basename(file).replace(\"_time_feature.npy\", \"\")\n",
    "    try:\n",
    "        feat = np.load(file)\n",
    "        if feat.shape[0] == 10:\n",
    "            features.append(feat)\n",
    "            ids.append(sid)\n",
    "        else:\n",
    "            print(f\"Skipping {sid}: invalid length {feat.shape}\")\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading {file}: {e}\")\n",
    "\n",
    "columns = [\n",
    "    \"total_distance\", \"avg_distance\", \"var_distance\", \"max_distance\",\n",
    "    \"num_clusters\", \"avg_cluster_duration\",\n",
    "    \"entropy_dir_change\", \"mean_angle\",\n",
    "    \"straightness\", \"tortuosity\"\n",
    "]\n",
    "\n",
    "df = pd.DataFrame(features, columns=columns, index=ids)\n",
    "df.index.name = \"id\"\n",
    "print(f\"Loaded {len(df)} feature files\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "450b0747-334e-4640-83b9-ce0511c28c12",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "df.hist(bins=30, figsize=(15, 10), layout=(3, 4))\n",
    "plt.suptitle(\"Distribution of PHATE Time Features\")\n",
    "plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9cc4f5ab-8d36-4fe5-9130-6fea3ead13b8",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "plt.figure(figsize=(10, 8))\n",
    "sns.heatmap(df.corr(), annot=True, cmap=\"coolwarm\", fmt=\".2f\")\n",
    "plt.title(\"Correlation Between PHATE Trajectory Features\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2dca2057-3484-4b1e-8e3c-e50c4e0e8ad7",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Try Multiple DBSCAN Settings\n",
    "import os\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.cluster import DBSCAN\n",
    "\n",
    "phate_file = \"./output_embeddings/sub01_sess01_phate_traj.npy\"\n",
    "phate_array = np.load(phate_file)\n",
    "\n",
    "def try_dbscan(arr, eps_list, min_samples_list):\n",
    "    print(f\"Trajectory shape: {arr.shape}\")\n",
    "    for eps in eps_list:\n",
    "        for min_samples in min_samples_list:\n",
    "            clustering = DBSCAN(eps=eps, min_samples=min_samples).fit(arr)\n",
    "            labels = clustering.labels_\n",
    "            n_clusters = len(set(labels)) - (1 if -1 in labels else 0)\n",
    "            n_noise = np.sum(labels == -1)\n",
    "            print(f\"[eps={eps:.1f}, min_samples={min_samples}] --> clusters: {n_clusters}, noise points: {n_noise}\")\n",
    "            plot_clusters(arr, labels, eps, min_samples)\n",
    "\n",
    "def plot_clusters(arr, labels, eps, min_samples):\n",
    "    plt.figure(figsize=(6, 5))\n",
    "    plt.scatter(arr[:, 0], arr[:, 1], c=labels, cmap='tab10', s=15)\n",
    "    plt.title(f\"DBSCAN Clusters (eps={eps}, min_samples={min_samples})\")\n",
    "    plt.xlabel(\"PHATE 1\")\n",
    "    plt.ylabel(\"PHATE 2\")\n",
    "    plt.grid(True)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "\n",
    "eps_values = [0.5, 1.0, 1.5, 2.0, 3.0]\n",
    "min_samples_values = [3, 5, 10]\n",
    "\n",
    "try_dbscan(phate_array, eps_values, min_samples_values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1ddd0b6-427b-4491-a79c-dfa01c0f98ee",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "## Run code to get time features\n",
    "import os\n",
    "import numpy as np\n",
    "from glob import glob\n",
    "from scipy.stats import entropy\n",
    "from numpy.linalg import norm\n",
    "import ruptures as rpt\n",
    "import math\n",
    "\n",
    "feature_dir = \"./output_embeddings\"\n",
    "\n",
    "def compute_trajectory_features(arr, max_bkps=5):\n",
    "    distances = norm(np.diff(arr, axis=0), axis=1)\n",
    "    total_distance = np.sum(distances)\n",
    "    avg_distance = np.mean(distances)\n",
    "    var_distance = np.var(distances)\n",
    "    max_distance = np.max(distances)\n",
    "\n",
    "    directions = np.diff(arr, axis=0)\n",
    "    angles = []\n",
    "    for i in range(1, len(directions)):\n",
    "        v1, v2 = directions[i - 1], directions[i]\n",
    "        cosine_angle = np.dot(v1, v2) / (norm(v1) * norm(v2) + 1e-8)\n",
    "        angle = math.acos(np.clip(cosine_angle, -1, 1))\n",
    "        angles.append(angle)\n",
    "    hist, _ = np.histogram(angles, bins=20, range=(0, np.pi), density=True)\n",
    "    entropy_dir_change = entropy(hist + 1e-8)\n",
    "    mean_angle = np.mean(angles)\n",
    "\n",
    "    model = rpt.KernelCPD(kernel=\"linear\").fit(arr)\n",
    "    bkps = model.predict(pen=0.1)\n",
    "    num_segments = len(bkps)\n",
    "    segment_lengths = np.diff([0] + bkps)\n",
    "    avg_segment_duration = np.mean(segment_lengths) if len(segment_lengths) > 0 else 0\n",
    "\n",
    "    displacement = norm(arr[-1] - arr[0])\n",
    "    straightness = displacement / (total_distance + 1e-8)\n",
    "    tortuosity = total_distance / (displacement + 1e-8)\n",
    "\n",
    "    return np.array([\n",
    "        total_distance, avg_distance, var_distance, max_distance,\n",
    "        num_segments, avg_segment_duration,\n",
    "        entropy_dir_change, mean_angle,\n",
    "        straightness, tortuosity\n",
    "    ])\n",
    "\n",
    "traj_files = sorted(glob(os.path.join(feature_dir, \"*_phate_traj.npy\")))\n",
    "\n",
    "for file in traj_files:\n",
    "    sid = os.path.basename(file).replace(\"_phate_traj.npy\", \"\")\n",
    "    print(f\"Processing: {sid}\")\n",
    "    arr = np.load(file)\n",
    "    if len(arr) < 3:\n",
    "        print(f\"  Skipped {sid} (too few points)\")\n",
    "        continue\n",
    "    feats = compute_trajectory_features(arr, max_bkps=5)\n",
    "    outfile = os.path.join(feature_dir, f\"{sid}_time_feature.npy\")\n",
    "    np.save(outfile, feats)\n",
    "    print(f\"  Saved: {outfile}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bd82d751-b0b7-4502-96b0-a874abd84819",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "df.hist(bins=30, figsize=(15, 10), layout=(3, 4))\n",
    "plt.suptitle(\"Distribution of PHATE Time Features\")\n",
    "plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b4380e4-1c71-4324-a15d-f33d5f31f252",
   "metadata": {},
   "outputs": [],
   "source": [
    "import seaborn as sns\n",
    "\n",
    "plt.figure(figsize=(7, 5))\n",
    "sns.heatmap(df.corr(), annot=True, cmap=\"coolwarm\", fmt=\".2f\")\n",
    "plt.title(\"Correlation Between PHATE Trajectory Features\")\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6fe6e42c-cae8-46ea-817e-d858626a5e14",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from glob import glob\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "feature_dir = \"./output_embeddings\"\n",
    "feature_files = sorted(glob(os.path.join(feature_dir, \"*_time_feature.npy\")))\n",
    "\n",
    "features, ids = [], []\n",
    "for file in feature_files:\n",
    "    sid = os.path.basename(file).replace(\"_time_feature.npy\", \"\")\n",
    "    feat = np.load(file)\n",
    "    if feat.shape[0] == 6:\n",
    "        features.append(feat)\n",
    "        ids.append(sid)\n",
    "    else:\n",
    "        print(f\"Skipped {sid} due to invalid shape {feat.shape}\")\n",
    "\n",
    "features = np.array(features)\n",
    "\n",
    "scaler = StandardScaler()\n",
    "normalized = scaler.fit_transform(features)\n",
    "\n",
    "for i, sid in enumerate(ids):\n",
    "    out_path = os.path.join(feature_dir, f\"{sid}_time_feature_normalized.npy\")\n",
    "    np.save(out_path, normalized[i])\n",
    "    print(f\"Saved: {out_path}\")\n",
    "\n",
    "print(\"Done.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d216be79-de39-4641-a3a5-6f85746363f0",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "## add new point based feature\n",
    "import os\n",
    "import numpy as np\n",
    "from glob import glob\n",
    "from scipy.stats import entropy\n",
    "from numpy.linalg import norm\n",
    "import ruptures as rpt\n",
    "import math\n",
    "\n",
    "# CONFIG\n",
    "feature_dir = \"./output_embeddings\"\n",
    "pair_range = slice(0, 1000)\n",
    "\n",
    "# FUNCTIONS\n",
    "def compute_trajectory_features(arr, max_bkps=5):\n",
    "    distances = norm(np.diff(arr, axis=0), axis=1)\n",
    "    avg_distance = np.mean(distances)\n",
    "    max_distance = np.max(distances)\n",
    "\n",
    "    directions = np.diff(arr, axis=0)\n",
    "    angles = []\n",
    "    for i in range(1, len(directions)):\n",
    "        v1, v2 = directions[i - 1], directions[i]\n",
    "        cosine = np.dot(v1, v2) / (norm(v1) * norm(v2) + 1e-8)\n",
    "        angles.append(math.acos(np.clip(cosine, -1, 1)))\n",
    "    hist, _ = np.histogram(angles, bins=20, range=(0, np.pi), density=True)\n",
    "    entropy_dir_change = entropy(hist + 1e-8)\n",
    "    mean_angle = np.mean(angles)\n",
    "\n",
    "    model = rpt.KernelCPD(kernel=\"linear\").fit(arr)\n",
    "    bkps = model.predict(pen=0.1)\n",
    "    num_segments = len(bkps)\n",
    "\n",
    "    tortuosity = np.sum(distances) / (norm(arr[-1] - arr[0]) + 1e-8)\n",
    "\n",
    "    delta_distances = np.insert(distances, 0, 0.0)\n",
    "    cumulative_distances = np.cumsum(delta_distances)\n",
    "    dist_to_start = norm(arr - arr[0], axis=1)\n",
    "\n",
    "    angle_changes = np.zeros(len(arr))\n",
    "    curvatures = np.zeros(len(arr))\n",
    "    for i in range(1, len(arr) - 1):\n",
    "        a = arr[i] - arr[i - 1]\n",
    "        b = arr[i + 1] - arr[i]\n",
    "        cosine = np.dot(a, b) / (norm(a) * norm(b) + 1e-8)\n",
    "        angle_changes[i] = math.acos(np.clip(cosine, -1, 1))\n",
    "        curvatures[i] = (norm(a) + norm(b)) / (norm(arr[i + 1] - arr[i - 1]) + 1e-8)\n",
    "\n",
    "    segment_ids = np.zeros(len(arr), dtype=int)\n",
    "    for i, bkp in enumerate(bkps):\n",
    "        segment_ids[:bkp] = i\n",
    "\n",
    "    session_features = np.array([\n",
    "        avg_distance, max_distance,\n",
    "        num_segments, entropy_dir_change, mean_angle,\n",
    "        tortuosity\n",
    "    ])\n",
    "\n",
    "    point_features = np.stack([\n",
    "        delta_distances,\n",
    "        cumulative_distances,\n",
    "        angle_changes,\n",
    "        curvatures,\n",
    "        dist_to_start,\n",
    "        segment_ids\n",
    "    ], axis=1)\n",
    "\n",
    "    return session_features, point_features\n",
    "\n",
    "# MAIN\n",
    "traj_files = sorted(glob(os.path.join(feature_dir, \"*_phate_traj.npy\")))\n",
    "pair_ids = [os.path.basename(f).replace(\"_phate_traj.npy\", \"\") for f in traj_files]\n",
    "pair_ids = sorted(pair_ids, key=lambda x: (int(x.split(\"_\")[0]), int(x.split(\"_\")[1])))\n",
    "selected_ids = pair_ids[pair_range]\n",
    "\n",
    "summary_features = []\n",
    "for idx, sid in enumerate(selected_ids):\n",
    "    file = os.path.join(feature_dir, f\"{sid}_phate_traj.npy\")\n",
    "    arr = np.load(file)\n",
    "    if len(arr) < 3:\n",
    "        print(f\"[{idx}] Skipped {sid} (too few points)\")\n",
    "        continue\n",
    "\n",
    "    print(f\"[{idx}] Processing {sid}\")\n",
    "    session_feats, point_feats = compute_trajectory_features(arr)\n",
    "    np.save(os.path.join(feature_dir, f\"{sid}_time_feature.npy\"), session_feats)\n",
    "    np.save(os.path.join(feature_dir, f\"{sid}_point_features.npy\"), point_feats)\n",
    "    summary_features.append(session_feats)\n",
    "\n",
    "print(f\"Finished processing {len(summary_features)} sessions.\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2082dcf3-f9c0-4c52-9eb1-1495c6adceb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Correlation and distribution plots\n",
    "df = pd.DataFrame(summary_features, columns=[\n",
    "    \"avg_distance\", \"max_distance\",\n",
    "    \"num_segments\", \"entropy_dir_change\", \"mean_angle\", \"tortuosity\"\n",
    "])\n",
    "\n",
    "plt.figure(figsize=(10, 8))\n",
    "sns.heatmap(df.corr(), annot=True, cmap=\"coolwarm\", center=0)\n",
    "plt.title(\"Correlation Between Updated PHATE Trajectory Features\")\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "df.hist(bins=20, figsize=(12, 10), layout=(3, 3))\n",
    "plt.suptitle(\"Distribution of Updated PHATE Time Features\")\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c96cbf9-e34b-4620-a5df-ad6358620b05",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from glob import glob\n",
    "\n",
    "feature_dir = \"./output_embeddings\"\n",
    "\n",
    "point_files = sorted(glob(os.path.join(feature_dir, \"*_point_features.npy\")))\n",
    "all_point_features = []\n",
    "\n",
    "for file in point_files:\n",
    "    arr = np.load(file)\n",
    "    if arr.ndim == 2 and arr.shape[1] == 6:\n",
    "        all_point_features.append(arr)\n",
    "\n",
    "all_data = np.vstack(all_point_features)\n",
    "columns = [\n",
    "    \"delta_distance\",\n",
    "    \"cumulative_distance\",\n",
    "    \"angle_change\",\n",
    "    \"local_curvature\",\n",
    "    \"dist_to_start\",\n",
    "    \"segment_id\"\n",
    "]\n",
    "df = pd.DataFrame(all_data, columns=columns)\n",
    "\n",
    "fig, axes = plt.subplots(3, 3, figsize=(15, 12))\n",
    "axes = axes.flatten()\n",
    "\n",
    "for i, col in enumerate(columns):\n",
    "    sns.histplot(df[col], bins=50, kde=False, ax=axes[i])\n",
    "    axes[i].set_title(f\"Distribution of {col}\")\n",
    "\n",
    "for j in range(i + 1, len(axes)):\n",
    "    axes[j].axis(\"off\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "corr = df.corr()\n",
    "plt.figure(figsize=(5, 4))\n",
    "sns.heatmap(corr, annot=True, fmt=\".2f\", cmap=\"coolwarm\", square=True)\n",
    "plt.title(\"Correlation Between Point Features\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4e4cada-e7f1-4750-bf54-0aafcd40a575",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from glob import glob\n",
    "\n",
    "feature_dir = \"./output_embeddings\"\n",
    "\n",
    "time_files = sorted(glob(os.path.join(feature_dir, \"*_time_feature.npy\")))\n",
    "all_time_features = []\n",
    "\n",
    "for file in time_files:\n",
    "    arr = np.load(file)\n",
    "    if arr.ndim == 1 and arr.shape[0] == 6:\n",
    "        all_time_features.append(arr)\n",
    "\n",
    "all_time_data = np.vstack(all_time_features)\n",
    "columns = [\n",
    "    \"avg_distance\",\n",
    "    \"max_distance\",\n",
    "    \"num_segments\",\n",
    "    \"entropy_dir_change\",\n",
    "    \"mean_angle\",\n",
    "    \"tortuosity\"\n",
    "]\n",
    "df_time = pd.DataFrame(all_time_data, columns=columns)\n",
    "\n",
    "fig, axes = plt.subplots(3, 3, figsize=(15, 12))\n",
    "axes = axes.flatten()\n",
    "\n",
    "for i, col in enumerate(columns):\n",
    "    sns.histplot(df_time[col], bins=50, kde=False, ax=axes[i], edgecolor=None)\n",
    "    axes[i].set_title(f\"Distribution of {col}\")\n",
    "\n",
    "for j in range(i + 1, len(axes)):\n",
    "    axes[j].axis(\"off\")\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()\n",
    "\n",
    "corr_time = df_time.corr()\n",
    "plt.figure(figsize=(5, 4))\n",
    "sns.heatmap(corr_time, annot=True, fmt=\".2f\", cmap=\"coolwarm\", square=True)\n",
    "plt.title(\"Correlation Between Features\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09341891-a94b-46ac-84a9-161b4826351e",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import numpy as np\n",
    "from glob import glob\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "\n",
    "feature_dir = \"./output_embeddings\"\n",
    "feature_files = sorted(glob(os.path.join(feature_dir, \"*_point_features.npy\")))\n",
    "\n",
    "all_features, ids, lengths = [], [], []\n",
    "\n",
    "for file in feature_files:\n",
    "    sid = os.path.basename(file).replace(\"_point_features.npy\", \"\")\n",
    "    feat = np.load(file)  # shape: (T, 6)\n",
    "    if feat.ndim == 2 and feat.shape[1] == 6:\n",
    "        all_features.append(feat)\n",
    "        ids.append(sid)\n",
    "        lengths.append(feat.shape[0])\n",
    "    else:\n",
    "        print(f\"Skipped {sid} due to invalid shape {feat.shape}\")\n",
    "\n",
    "stacked = np.vstack(all_features)\n",
    "\n",
    "scaler = StandardScaler()\n",
    "normalized_stacked = scaler.fit_transform(stacked)\n",
    "\n",
    "idx = 0\n",
    "for sid, length in zip(ids, lengths):\n",
    "    norm_feat = normalized_stacked[idx:idx + length]\n",
    "    idx += length\n",
    "    out_path = os.path.join(feature_dir, f\"{sid}_point_features_normalized.npy\")\n",
    "    np.save(out_path, norm_feat)\n",
    "    print(f\"Saved: {out_path}\")\n",
    "\n",
    "print(\"Done.\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
