{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6c2a81fa-79ae-46ee-8c8f-8e56b0e68454",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from glob import glob\n",
    "\n",
    "data_dir = \"./output_embeddings\"\n",
    "embedding_files = sorted(glob(os.path.join(data_dir, \"*_embeddings.npy\")))\n",
    "label_files = sorted(glob(os.path.join(data_dir, \"*_labels.npy\")))\n",
    "\n",
    "print(f\"Found {len(embedding_files)} embedding files\")\n",
    "print(f\"Found {len(label_files)} label files\")\n",
    "\n",
    "for i in range(min(5, len(embedding_files))):\n",
    "    print(f\"{i+1}. Embedding: {os.path.basename(embedding_files[i])}\")\n",
    "    print(f\"   Label    : {os.path.basename(label_files[i])}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e6967ae-e880-490f-abce-7ed13ba77fb5",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from glob import glob\n",
    "\n",
    "data_dir = \"./output_embeddings\"\n",
    "embedding_files = sorted(glob(os.path.join(data_dir, \"*_embeddings.npy\")))\n",
    "label_files = sorted(glob(os.path.join(data_dir, \"*_labels.npy\")))\n",
    "\n",
    "def extract_id(filename, suffix):\n",
    "    base = os.path.basename(filename)\n",
    "    return base.replace(suffix, \"\")\n",
    "\n",
    "embedding_ids = {extract_id(f, \"_embeddings.npy\"): f for f in embedding_files}\n",
    "label_ids = {extract_id(f, \"_labels.npy\"): f for f in label_files}\n",
    "\n",
    "embedding_keys = set(embedding_ids.keys())\n",
    "label_keys = set(label_ids.keys())\n",
    "\n",
    "missing_labels = sorted(embedding_keys - label_keys)\n",
    "missing_embeddings = sorted(label_keys - embedding_keys)\n",
    "\n",
    "print(f\"Total embeddings: {len(embedding_files)}\")\n",
    "print(f\"Total labels    : {len(label_files)}\")\n",
    "print(f\"Missing labels  : {len(missing_labels)}\")\n",
    "if missing_labels:\n",
    "    print(\"Embeddings without labels:\")\n",
    "    for k in missing_labels:\n",
    "        print(f\" - {k}_embeddings.npy\")\n",
    "\n",
    "if missing_embeddings:\n",
    "    print(\"Labels without embeddings:\")\n",
    "    for k in missing_embeddings:\n",
    "        print(f\" - {k}_labels.npy\")\n",
    "\n",
    "def parse_pair(identifier):\n",
    "    parts = identifier.split(\"_\")\n",
    "    return (int(parts[0]), int(parts[1]))\n",
    "\n",
    "valid_pairs = [parse_pair(k) for k in sorted(embedding_keys & label_keys)]\n",
    "print(f\"\\nValid pairs: {len(valid_pairs)}\")\n",
    "\n",
    "print(\"Sample:\")\n",
    "for pair in valid_pairs[:5]:\n",
    "    print(pair)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e03bf765-c3b0-4d1f-8fdb-0c3b12b69f74",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "data_dir = \"./output_embeddings\"\n",
    "file_to_delete = os.path.join(data_dir, \"sample_embeddings.npy\")\n",
    "\n",
    "if os.path.exists(file_to_delete):\n",
    "    os.remove(file_to_delete)\n",
    "    result = f\"Deleted file: {file_to_delete}\"\n",
    "else:\n",
    "    result = f\"File not found: {file_to_delete}\"\n",
    "\n",
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "18a2282e-fa3f-4a03-8be8-9cc9465008c6",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "directory_path = \"./hdf5_data\"\n",
    "all_hdf5_files = [f for f in os.listdir(directory_path) if f.endswith(\".hdf5\")]\n",
    "\n",
    "unique_pairs = set()\n",
    "for fname in all_hdf5_files:\n",
    "    parts = fname.split(\"_\")\n",
    "    if len(parts) >= 3:\n",
    "        subj_id = parts[0]\n",
    "        sess_id = parts[1]\n",
    "        unique_pairs.add((subj_id, sess_id))\n",
    "\n",
    "print(\"Max unique pairs:\", len(unique_pairs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6480b014-cb93-4421-90b1-b25399f8fc38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from glob import glob\n",
    "\n",
    "base_dir = \"./output_embeddings\"\n",
    "\n",
    "embedding_files = glob(os.path.join(base_dir, \"*_embeddings.npy\"))\n",
    "ids = [os.path.basename(f).replace(\"_embeddings.npy\", \"\") for f in embedding_files]\n",
    "\n",
    "counts = {\n",
    "    \"total_ids\": len(ids),\n",
    "    \"traj\": 0,\n",
    "    \"time_feature\": 0,\n",
    "    \"ehr_feature\": 0\n",
    "}\n",
    "\n",
    "for sid in ids:\n",
    "    if os.path.exists(os.path.join(base_dir, f\"{sid}_traj.npy\")):\n",
    "        counts[\"traj\"] += 1\n",
    "    if os.path.exists(os.path.join(base_dir, f\"{sid}_time_feature.npy\")):\n",
    "        counts[\"time_feature\"] += 1\n",
    "    if os.path.exists(os.path.join(base_dir, f\"{sid}_ehr_feature.npy\")):\n",
    "        counts[\"ehr_feature\"] += 1\n",
    "\n",
    "print(f\"Total pairs with embeddings: {counts['total_ids']}\")\n",
    "print(f\"With trajectories: {counts['traj']}\")\n",
    "print(f\"With time features: {counts['time_feature']}\")\n",
    "print(f\"With EHR features: {counts['ehr_feature']}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a7b57e57-d909-4d82-95bf-24492f53a7f2",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from glob import glob\n",
    "\n",
    "def get_ids(base_dir, suffix):\n",
    "    files = glob(os.path.join(base_dir, f\"*_{suffix}.npy\"))\n",
    "    return sorted([os.path.basename(f).replace(f\"_{suffix}.npy\", \"\") for f in files])\n",
    "\n",
    "base_dir = \"./output_embeddings\"\n",
    "\n",
    "emb_ids   = set(get_ids(base_dir, \"embeddings\"))\n",
    "traj_ids  = set(get_ids(base_dir, \"point_features_normalized\"))\n",
    "time_ids  = set(get_ids(base_dir, \"time_features_normalized\"))\n",
    "ehr_ids   = set(get_ids(base_dir, \"ehr_features\"))\n",
    "\n",
    "common_ids = emb_ids & traj_ids & time_ids & ehr_ids\n",
    "\n",
    "missing_traj = sorted(emb_ids - traj_ids)\n",
    "missing_time = sorted(emb_ids - time_ids)\n",
    "missing_ehr  = sorted(emb_ids - ehr_ids)\n",
    "\n",
    "print(f\"Complete sets (all 4): {len(common_ids)}\")\n",
    "print(f\"Missing trajectory: {len(missing_traj)}\")\n",
    "print(f\"Missing time features: {len(missing_time)}\")\n",
    "print(f\"Missing EHR features: {len(missing_ehr)}\")\n",
    "\n",
    "print(\"\\nSample missing time features:\", missing_time[:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82b7c78b-dab6-44bd-83be-e04b6c499abb",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from glob import glob\n",
    "\n",
    "def get_ids(base_dir, suffix):\n",
    "    files = glob(os.path.join(base_dir, f\"*_{suffix}.npy\"))\n",
    "    return {os.path.basename(f).replace(f\"_{suffix}.npy\", \"\") for f in files}\n",
    "\n",
    "base_dir = \"./output_embeddings\"\n",
    "\n",
    "emb_ids  = get_ids(base_dir, \"embeddings\")\n",
    "traj_ids = get_ids(base_dir, \"point_features_normalized\")\n",
    "time_ids = get_ids(base_dir, \"time_features_normalized\")\n",
    "ehr_ids  = get_ids(base_dir, \"ehr_features\")\n",
    "\n",
    "common_ids = emb_ids & traj_ids & time_ids & ehr_ids\n",
    "missing_traj = sorted(emb_ids - traj_ids)\n",
    "missing_time = sorted(emb_ids - time_ids)\n",
    "missing_ehr  = sorted(emb_ids - ehr_ids)\n",
    "\n",
    "print(f\"Complete sets (all 4): {len(common_ids)}\")\n",
    "print(f\"Missing trajectory: {len(missing_traj)}\")\n",
    "print(f\"Missing time: {len(missing_time)}\")\n",
    "print(f\"Missing EHR: {len(missing_ehr)}\")\n",
    "\n",
    "if missing_time:\n",
    "    print(\"\\nSample missing time:\", missing_time[:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "126d195c-fffe-4b62-8e26-faaacff4dd7b",
   "metadata": {
    "scrolled": true,
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "output_dir = \"./output_embeddings\"\n",
    "\n",
    "all_files = sorted(os.listdir(output_dir))\n",
    "\n",
    "print(f\"Total files in {output_dir}: {len(all_files)}\")\n",
    "for fname in all_files:\n",
    "    print(fname)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
