{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b06ae58-10a2-4537-ab4c-8f3abc1bf933",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Exhaustive set of imports\n",
    "\n",
    "from collections import Counter\n",
    "from collections import defaultdict\n",
    "from glob import glob\n",
    "from matplotlib import rcParams\n",
    "from matplotlib import rcParamsDefault\n",
    "from matplotlib.colors import ListedColormap\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from pathlib import Path\n",
    "from scipy.interpolate import CubicSpline\n",
    "from scipy.ndimage import affine_transform\n",
    "from scipy.ndimage import binary_erosion\n",
    "from scipy.ndimage import zoom\n",
    "from scipy.optimize import minimize\n",
    "from scipy.spatial import cKDTree\n",
    "from skimage import measure\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.ensemble import RandomForestClassifier, StackingClassifier\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.manifold import MDS\n",
    "from sklearn.metrics import accuracy_score, confusion_matrix\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.neighbors import KNeighborsClassifier\n",
    "from sklearn.preprocessing import MinMaxScaler\n",
    "from sklearn.preprocessing import StandardScaler, LabelEncoder\n",
    "from sklearn.svm import SVC\n",
    "from torch.optim.lr_scheduler import ReduceLROnPlateau\n",
    "from torch.utils.data import DataLoader, random_split, Dataset\n",
    "from tqdm import tqdm\n",
    "from xgboost import XGBClassifier\n",
    "import json\n",
    "import matplotlib.cm as cm\n",
    "import matplotlib.gridspec as gridspec\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import os\n",
    "import pandas as pd\n",
    "import plotly.graph_objects as go\n",
    "import random\n",
    "import scipy as sp\n",
    "import seaborn as sns\n",
    "import shutil\n",
    "import sys\n",
    "import tifffile\n",
    "import time\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import umap\n",
    "import umap.umap_ as umap\n",
    "import umap.umap_ as umap_module"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c03898e1-044e-4c5a-93ae-355a5778861f",
   "metadata": {},
   "source": [
    "<p style=\"height: 100px;\"></p>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "74982941-10e8-4aba-897d-3b0daa403fbd",
   "metadata": {},
   "source": [
    "## Pre-Processing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "011ff744-7f7b-4b42-8821-350899407625",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Must run once to create and save required data / metadata\n",
    "\n",
    "\n",
    "TARGET_SHAPE = (64, 64, 64)\n",
    "\n",
    "base_folder = 'source_data/crop_seg'\n",
    "metadata_path = 'source_data/metadata.csv'\n",
    "mask_save_folder = 'result_data/masks_64'\n",
    "metadata_save_path = 'result_data/cell_data_64.json'\n",
    "\n",
    "os.makedirs(mask_save_folder, exist_ok=True)\n",
    "df = pd.read_csv(metadata_path)\n",
    "\n",
    "cell_data = []\n",
    "i = 0\n",
    "\n",
    "for root, dirs, files in os.walk(base_folder):\n",
    "    for file in files:\n",
    "        if not file.endswith('.ome.tif'):\n",
    "            continue\n",
    "\n",
    "        image_path = os.path.join(root, file)\n",
    "\n",
    "        try:\n",
    "            im = tifffile.imread(image_path)\n",
    "\n",
    "            # Handle possible channel axis arrangement\n",
    "            if im.shape[0] > 2:\n",
    "                im = im.swapaxes(1, 0)\n",
    "\n",
    "            if im.shape[0] != 2:\n",
    "                print(f\"Skipping {file}: wrong number of channels ({im.shape[0]})\")\n",
    "                continue\n",
    "\n",
    "            nucleus_shape = im[0].shape  # (z, y, x)\n",
    "            cell_shape = im[1].shape     # (z, y, x)\n",
    "\n",
    "            resized = np.zeros((2, *TARGET_SHAPE), dtype=np.float32)\n",
    "            for ch in range(2):\n",
    "                z_factor = TARGET_SHAPE[0] / im[ch].shape[0]\n",
    "                y_factor = TARGET_SHAPE[1] / im[ch].shape[1]\n",
    "                x_factor = TARGET_SHAPE[2] / im[ch].shape[2]\n",
    "                resized[ch] = zoom(im[ch], (z_factor, y_factor, x_factor), order=1)\n",
    "\n",
    "            binary_mask = (resized > 0).astype(np.float32)\n",
    "\n",
    "            relative_path = os.path.relpath(image_path, base_folder)\n",
    "            seg_file_name = 'crop_seg/' + relative_path\n",
    "\n",
    "            matches = df[df['crop_seg'] == seg_file_name]\n",
    "            if matches.empty:\n",
    "                print(f\"Warning: No metadata entry found for {seg_file_name}\")\n",
    "                continue\n",
    "\n",
    "            row = matches.iloc[0]\n",
    "            cell_id = row['CellId']\n",
    "            label = row['label']\n",
    "\n",
    "            mask_path = os.path.join(mask_save_folder, f'{cell_id}.npy')\n",
    "            if not os.path.exists(mask_path):\n",
    "                np.save(mask_path, binary_mask)\n",
    "\n",
    "            cell_entry = {\n",
    "                'id': cell_id,\n",
    "                'label': label,\n",
    "                'seg_file': seg_file_name,\n",
    "                'mask_path': mask_path,\n",
    "                'size_nucleus_1': int(nucleus_shape[0]),\n",
    "                'size_nucleus_2': int(nucleus_shape[1]),\n",
    "                'size_nucleus_3': int(nucleus_shape[2]),\n",
    "                'size_cell_1': int(cell_shape[0]),\n",
    "                'size_cell_2': int(cell_shape[1]),\n",
    "                'size_cell_3': int(cell_shape[2]),\n",
    "            }\n",
    "\n",
    "            cell_data.append(cell_entry)\n",
    "            i += 1\n",
    "            if i % 100 == 0:\n",
    "                print(f\"Processed {i} images\")\n",
    "\n",
    "        except Exception as e:\n",
    "            print(f\"Error processing {file}: {e}\")\n",
    "\n",
    "with open(metadata_save_path, 'w') as f:\n",
    "    json.dump(cell_data, f, indent=2)\n",
    "\n",
    "print(f\"Finished processing {i} images. Metadata saved to: {metadata_save_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08eba8f3-c3ef-4366-a1c7-24c4e9c1d36c",
   "metadata": {},
   "outputs": [],
   "source": [
    "# ICP-based rigid alignment of 3D binary masks (align on cell body, apply same transform to nucleus)\n",
    "# - Uses scipy.spatial.cKDTree and scipy.ndimage.affine_transform\n",
    "# - Input npy files should be (2, Z, Y, X) with binary-like masks\n",
    "# - Choose reference_file or it will use the first file in the folder as template\n",
    "\n",
    "\n",
    "input_dir = Path(\"result_data/masks_64\")\n",
    "output_dir = Path(\"result_data/masks_64_rotated\")\n",
    "output_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "reference_file = None   # eg \"example.npy\"\n",
    "max_icp_points = 5000  \n",
    "icp_max_iter = 50\n",
    "icp_tol = 1e-4\n",
    "binarize_output = True  # if True, threshold at 0.5 before saving to keep binary masks\n",
    "threshold_val = 0.5\n",
    "\n",
    "\n",
    "def kabsch(src: np.ndarray, tgt: np.ndarray):\n",
    "    \"\"\"Compute rotation R and translation t using the Kabsch algorithm.\"\"\"\n",
    "    assert src.shape == tgt.shape and src.shape[1] == 3\n",
    "    centroid_src = np.mean(src, axis=0)\n",
    "    centroid_tgt = np.mean(tgt, axis=0)\n",
    "    src_centered = src - centroid_src\n",
    "    tgt_centered = tgt - centroid_tgt\n",
    "    H = src_centered.T @ tgt_centered\n",
    "    U, S, Vt = np.linalg.svd(H)\n",
    "    R = Vt.T @ U.T\n",
    "    \n",
    "    if np.linalg.det(R) < 0:\n",
    "        Vt[-1, :] *= -1\n",
    "        R = Vt.T @ U.T\n",
    "    t = centroid_tgt - R @ centroid_src\n",
    "    return R, t\n",
    "\n",
    "    \n",
    "def icp_rigid(src_pts, tgt_pts, max_iterations=50, tol=1e-5, max_points=None):\n",
    "    \"\"\"Simple ICP: returns R, t such that approximately tgt ≈ R @ src + t\"\"\"\n",
    "    if src_pts.shape[0] < 3 or tgt_pts.shape[0] < 3:\n",
    "        raise ValueError(\"Not enough points for ICP.\")\n",
    "\n",
    "    if max_points is not None and src_pts.shape[0] > max_points:\n",
    "        idx = np.random.choice(src_pts.shape[0], size=max_points, replace=False)\n",
    "        src_sub = src_pts[idx].astype(np.float64)\n",
    "    else:\n",
    "        src_sub = src_pts.astype(np.float64)\n",
    "\n",
    "    tgt = tgt_pts.astype(np.float64)\n",
    "    tgt_kdtree = cKDTree(tgt)\n",
    "\n",
    "    R_total = np.eye(3, dtype=np.float64)\n",
    "    t_total = np.zeros(3, dtype=np.float64)\n",
    "\n",
    "    src_trans = src_sub.copy()\n",
    "    prev_error = np.inf\n",
    "\n",
    "    for i in range(max_iterations):\n",
    "        dists, idxs = tgt_kdtree.query(src_trans, k=1)\n",
    "        matched_tgt = tgt[idxs]\n",
    "\n",
    "        R_i, t_i = kabsch(src_trans, matched_tgt)\n",
    "        R_total = R_i @ R_total\n",
    "        t_total = (R_i @ t_total) + t_i\n",
    "        src_trans = (R_total @ src_sub.T).T + t_total\n",
    "\n",
    "        mean_error = np.mean(dists)\n",
    "        if np.abs(prev_error - mean_error) < tol:\n",
    "            break\n",
    "        prev_error = mean_error\n",
    "\n",
    "    return R_total, t_total, prev_error\n",
    "\n",
    "    \n",
    "def mask_to_points(mask, min_pts=10):\n",
    "    \"\"\"Convert 3D mask to (N,3) array of coordinates of non-zero voxels\"\"\"\n",
    "    coords = np.argwhere(mask > 0.5)\n",
    "    if coords.shape[0] < min_pts:\n",
    "        return None\n",
    "    return coords.astype(np.float64)\n",
    "\n",
    "\n",
    "def apply_rigid_to_volume(vol, R, t, order=1, cval=0.0):\n",
    "    \"\"\"Apply rigid transform\"\"\"\n",
    "    invR = np.linalg.inv(R)\n",
    "    offset = -invR @ t\n",
    "    transformed = affine_transform(vol, invR, offset=offset, order=order, mode='constant', cval=cval)\n",
    "    return transformed\n",
    "\n",
    "\n",
    "# Choose reference template (target)\n",
    "files = sorted([p for p in input_dir.glob(\"*.npy\")])\n",
    "if len(files) == 0:\n",
    "    raise RuntimeError(f\"No .npy files found in {input_dir}\")\n",
    "\n",
    "if reference_file is None:\n",
    "    ref_path = files[0]\n",
    "else:\n",
    "    ref_path = input_dir / reference_file\n",
    "    if not ref_path.exists():\n",
    "        raise FileNotFoundError(f\"Reference file {ref_path} not found in {input_dir}\")\n",
    "\n",
    "print(f\"Reference (target) file: {ref_path}\")\n",
    "\n",
    "ref_arr = np.load(ref_path)\n",
    "if ref_arr.ndim != 4 or ref_arr.shape[0] < 2:\n",
    "    raise ValueError(f\"Reference array has unexpected shape: {ref_arr.shape}\")\n",
    "\n",
    "ref_cell = ref_arr[1].astype(np.float32)  # Use cell body (channel 1) of reference\n",
    "ref_pts = mask_to_points(ref_cell, min_pts=10)\n",
    "if ref_pts is None:\n",
    "    raise ValueError(\"Reference cell body has too few points for ICP.\")\n",
    "\n",
    "\n",
    "if ref_pts.shape[0] > max_icp_points:\n",
    "    sel = np.random.choice(ref_pts.shape[0], max_icp_points, replace=False)\n",
    "    ref_pts_sub = ref_pts[sel]\n",
    "else:\n",
    "    ref_pts_sub = ref_pts\n",
    "\n",
    "\n",
    "errors = []\n",
    "for file_path in tqdm(files, desc=\"Aligning (ICP)\"):\n",
    "    try:\n",
    "        arr = np.load(file_path)\n",
    "        if arr.ndim != 4 or arr.shape[0] < 2:\n",
    "            raise ValueError(f\"Unexpected array shape {arr.shape} in {file_path.name}\")\n",
    "\n",
    "        nucleus = arr[0].astype(np.float32)\n",
    "        cell_body = arr[1].astype(np.float32)\n",
    "\n",
    "        src_pts = mask_to_points(cell_body, min_pts=10)\n",
    "        if src_pts is None:\n",
    "            raise ValueError(\"Too few cell-body voxels for ICP; skipping.\")\n",
    "\n",
    "        if src_pts.shape[0] > max_icp_points:\n",
    "            idxs = np.random.choice(src_pts.shape[0], max_icp_points, replace=False)\n",
    "            src_sub = src_pts[idxs]\n",
    "        else:\n",
    "            src_sub = src_pts\n",
    "\n",
    "        R, t, final_err = icp_rigid(src_pts, ref_pts_sub,\n",
    "                                    max_iterations=icp_max_iter,\n",
    "                                    tol=icp_tol,\n",
    "                                    max_points=max_icp_points)\n",
    "\n",
    "        transformed_cell = apply_rigid_to_volume(cell_body, R, t, order=1, cval=0.0)\n",
    "        transformed_nucleus = apply_rigid_to_volume(nucleus, R, t, order=1, cval=0.0)\n",
    "\n",
    "        if binarize_output:\n",
    "            transformed_cell = (transformed_cell >= threshold_val).astype(np.uint8)\n",
    "            transformed_nucleus = (transformed_nucleus >= threshold_val).astype(np.uint8)\n",
    "\n",
    "        aligned = np.stack([transformed_nucleus, transformed_cell], axis=0)\n",
    "\n",
    "        out_path = output_dir / file_path.name\n",
    "        np.save(out_path, aligned)\n",
    "\n",
    "    except Exception as e:\n",
    "        errors.append((file_path.name, str(e)))\n",
    "        print(f\"[WARN] {file_path.name}: {e}\")\n",
    "\n",
    "\n",
    "print(f\"Done. Saved aligned files to: {output_dir}\")\n",
    "if errors:\n",
    "    print(f\"Completed with {len(errors)} warnings/errors. Example: {errors[:5]}\")\n",
    "else:\n",
    "    print(\"No errors.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "19f5101b-c635-442c-a7c6-fd6bed47939f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Ensure all voxel values at 0 or 1 after rotation based on threshold\n",
    "\n",
    "\n",
    "# dirr = \"result_data/masks_64_rotated\"\n",
    "dirr = \"result_data/masks_64\"\n",
    "\n",
    "threshold = 0.2\n",
    "\n",
    "npy_files = [f for f in os.listdir(dirr) if f.endswith(\".npy\")]\n",
    "\n",
    "print(f\"Thresholding {len(npy_files)} files...\")\n",
    "\n",
    "for filename in npy_files:\n",
    "    file_path = os.path.join(dirr, filename)\n",
    "    \n",
    "    try:\n",
    "        arr = np.load(file_path)\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading {filename}: {e}\")\n",
    "        continue\n",
    "\n",
    "    thresholded_arr = (arr >= threshold).astype(np.uint8)\n",
    "\n",
    "    try:\n",
    "        np.save(file_path, thresholded_arr)\n",
    "        print(f\"Thresholded and saved: {filename}\")\n",
    "    except Exception as e:\n",
    "        print(f\"Error saving {filename}: {e}\")\n",
    "\n",
    "print(\"Done.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e577a79c-4e73-488d-a861-7001c6b81c70",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display objects before and after rotation for sanity check\n",
    "\n",
    "\n",
    "original_dir = \"result_data/masks_64\"\n",
    "rotated_dir = \"result_data/masks_64_rotated\"\n",
    "\n",
    "all_files = [f for f in os.listdir(original_dir) if f.endswith('.npy')]\n",
    "sample_files = random.sample(all_files, 1)\n",
    "\n",
    "\n",
    "def plot_3d_volume(vol, title):\n",
    "    \n",
    "    z, y, x = vol.shape\n",
    "    X, Y, Z = np.mgrid[0:x, 0:y, 0:z]\n",
    "\n",
    "    vol = np.transpose(vol, (2, 1, 0))\n",
    "\n",
    "    fig = go.Figure(data=go.Volume(\n",
    "        x=X.flatten(),\n",
    "        y=Y.flatten(),\n",
    "        z=Z.flatten(),\n",
    "        value=vol.flatten(),\n",
    "        isomin=0.5,\n",
    "        isomax=1.0,\n",
    "        opacity=0.1,\n",
    "        surface_count=15,\n",
    "        colorscale='Viridis'\n",
    "    ))\n",
    "\n",
    "    fig.update_layout(\n",
    "        scene=dict(\n",
    "            xaxis_title='X', yaxis_title='Y', zaxis_title='Z',\n",
    "            xaxis=dict(showticklabels=False),\n",
    "            yaxis=dict(showticklabels=False),\n",
    "            zaxis=dict(showticklabels=False),\n",
    "        ),\n",
    "        title=title,\n",
    "        width=500,\n",
    "        height=500,\n",
    "    )\n",
    "    fig.show()\n",
    "\n",
    "\n",
    "for file in sample_files:\n",
    "    print(f\"Processing file: {file}\")\n",
    "\n",
    "    original = np.load(os.path.join(original_dir, file))\n",
    "    rotated = np.load(os.path.join(rotated_dir, file))\n",
    "\n",
    "    print(f\"Original shape: {original.shape}\")\n",
    "    print(f\"Rotated shape: {rotated.shape}\")\n",
    "\n",
    "    original_nucleus = original[0]  # 0 = nucleus or 1 = cell body\n",
    "    rotated_nucleus = rotated[0]\n",
    "\n",
    "    plot_3d_volume(original_nucleus, f\"Original - {file}\")\n",
    "    plot_3d_volume(rotated_nucleus, f\"Rotated - {file}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35471208-3994-4fa8-a80d-acc70c303cec",
   "metadata": {},
   "source": [
    "<p style=\"height: 200px;\"></p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f2dbee42-a9a2-4359-ae67-c1bdc88ef0a6",
   "metadata": {},
   "source": [
    "## Quantile-Embedding Creation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "932436cd-23f5-4d4a-a24a-1565245a6687",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Returns coordinates of surface voxels\n",
    "\n",
    "\n",
    "def extract_surface_points(volume):\n",
    "    \"\"\"Extracts surface points from a 3D volume.\"\"\"\n",
    "    eroded = binary_erosion(volume)\n",
    "    surface = volume & ~eroded  # Keep only the boundary points\n",
    "    surface_points = np.argwhere(surface)\n",
    "    return surface_points"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d4003b4-14fd-4d94-9737-86c26d7e8fb4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Populates data and data_cell with extracted surface points\n",
    "\n",
    "\n",
    "df = pd.read_csv('source_data/metadata.csv')\n",
    "\n",
    "np.random.seed(42)\n",
    "cell_types = ['M0',\n",
    " 'M1M2',\n",
    " 'M3',\n",
    " 'M4M5',\n",
    " 'M6M7_early',\n",
    " 'M6M7_half',\n",
    " 'blob',\n",
    " 'dead',\n",
    " 'wrong']\n",
    "\n",
    "np.random.seed(42)\n",
    "import numpy as np\n",
    "\n",
    "N = 5767\n",
    "data = [0] * N\n",
    "\n",
    "data_full = [0] * N\n",
    "data_cell = [0] * N\n",
    "import open3d as o3d\n",
    "idx = 0\n",
    "import os\n",
    "\n",
    "directory = 'source_data/crop_seg'\n",
    "\n",
    "labels = [0] * N\n",
    "counter = 0\n",
    "file_names = [0] * N\n",
    "\n",
    "\n",
    "for root, dirs, files in os.walk(directory):\n",
    "    \n",
    "    print(root)\n",
    "    \n",
    "    for file in files:\n",
    "        if file.endswith('.ome.tif'):\n",
    "            \n",
    "            im = tifffile.imread(os.path.join(root, file))\n",
    "            if im.shape[0] > 2:\n",
    "                \n",
    "                im = im.swapaxes(1,0)\n",
    "            imshape = im.shape\n",
    "                        \n",
    "            name = os.path.join(root, file)\n",
    "            name = os.path.relpath(name, directory)\n",
    "            name = 'crop_seg/' + name\n",
    "\n",
    "            print(\"looking for: \" + name)\n",
    "\n",
    "            index = np.where(df['crop_seg'] == name)[0][0]\n",
    "            \n",
    "            im_cell = im[ np.ix_([0,1], range(0,imshape[1],8), range(0,imshape[2],20), range(0,imshape[3],20) ) ]\n",
    "            # Downsizes by (8, 20, 20), change to speed up Wasserstein distance computation\n",
    "            \n",
    "            im = im[ np.ix_([0,1], range(0,imshape[1],4), range(0,imshape[2],10), range(0,imshape[3],10) ) ]\n",
    "\n",
    "\n",
    "            if im.shape[0] == 2:\n",
    "                \n",
    "                data_full[counter] = im\n",
    "                im_cell = im_cell[1,:,:,:]>0\n",
    "                \n",
    "                im = im[0,:,:,:]>0                \n",
    "                \n",
    "                index = range(im[0].shape[0])\n",
    "                data[counter] = [0]\n",
    "                data[counter] = extract_surface_points(im)\n",
    "\n",
    "                data_cell[counter] = [0]\n",
    "                data_cell[counter] = extract_surface_points(im_cell)\n",
    "\n",
    "                print(os.path.join(root, file))\n",
    "\n",
    "                index = np.where( df['crop_seg'] == name)[0][0]\n",
    "                labels[counter] = df['label'][index]\n",
    "                \n",
    "                file_names[counter] = os.path.join(root, file)\n",
    "                counter += 1\n",
    "                print(counter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bdeb3236-93ab-4358-b1c9-fdb55bd60d3c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Sanity check\n",
    "\n",
    "\n",
    "N_shapes = counter\n",
    "for i in range(N_shapes):\n",
    "    data_cell[i] = np.array(data_cell[i], dtype=float)\n",
    "    print(data_cell[i].shape[0])\n",
    "    data[i] = np.array(data[i], dtype=float)\n",
    "    print(data[i].shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e64c4ae9-6325-4b14-ab28-77ce20be97da",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Computes quantiles of Wasserstein distances\n",
    "\n",
    "\n",
    "def compute_WassKernel(data,n_pt=10000, metric='Euclidean',normalize=False, return_distance=False):\n",
    "    n = len(data)\n",
    "    quantiles_all = [0]*len(data)\n",
    "    print(n_pt)\n",
    "    for i in range(n):\n",
    "        if data[i].shape[0] == 0:\n",
    "            quantiles_all[i] = np.inf + np.zeros( (n_pt))\n",
    "        else:\n",
    "            \n",
    "            C1 = sp.spatial.distance.cdist(data[i], data[i])\n",
    "            \n",
    "            if normalize:\n",
    "                C1 = C1/np.median(C1)\n",
    "                \n",
    "            quantiles_all[i] = np.quantile(C1.ravel(), np.linspace(0,1,n_pt,endpoint=True))\n",
    "            \n",
    "    if return_distance == False:       \n",
    "        return sp.spatial.distance.squareform(sp.spatial.distance.pdist(quantiles_all))/n_pt\n",
    "    else:\n",
    "        return sp.spatial.distance.squareform(sp.spatial.distance.pdist(quantiles_all))/n_pt, quantiles_all"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0401969e-a4aa-49d4-9900-b3987c373c82",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute the quantile embedding on full entire dataset\n",
    "\n",
    "\n",
    "start = time.time()\n",
    "K_W_cell, embed_W_cell = compute_WassKernel(data_cell[0:N_shapes],n_pt=100, metric='Euclidean',normalize=False, return_distance=True)\n",
    "print(time.time()-start)\n",
    "K_W, embed_W = compute_WassKernel(data[0:N_shapes],n_pt=100, metric='Euclidean',normalize=False, return_distance=True)\n",
    "print(time.time()-start)\n",
    "\n",
    "# np.save('result_data/All_3DShape_W.npy',K_W_cell)\n",
    "# np.save('result_data/All_3DShape_W_embed.npy',embed_W_cell)\n",
    "\n",
    "# np.save('result_data/All_3DShape_Nucleus_W.npy',K_W)\n",
    "# np.save('result_data/All_3DShape_Nucleus_W_embed.npy',embed_W)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ead536f8-77d2-40bd-a959-0b69aa3e9306",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize UMAP of embeddings across labels\n",
    "\n",
    "\n",
    "out_path = \"result_data/figures\"\n",
    "os.makedirs(out_path, exist_ok=True)\n",
    "\n",
    "class_names = cell_types\n",
    "labels = labels[0:N_shapes]\n",
    "\n",
    "plt.rcParams[\"axes.prop_cycle\"] = rcParamsDefault[\"axes.prop_cycle\"]\n",
    "\n",
    "outlier_markers = ['x', '^', 's']\n",
    "viridis = plt.get_cmap('viridis', 6)\n",
    "continuous_colors = [viridis(i) for i in range(6)]\n",
    "\n",
    "outlier_colors = ['#e41a1c', '#000000', '#f0027f']  # red, black, magenta\n",
    "color_list = continuous_colors + outlier_colors\n",
    "\n",
    "fig = plt.figure(figsize=(18, 5))\n",
    "\n",
    "fit = umap.UMAP(n_components=2, random_state=42)\n",
    "mapper_w = fit.fit_transform(np.hstack([embed_W, embed_W_cell]))\n",
    "\n",
    "ax0 = fig.add_subplot(131)\n",
    "for i in range(len(class_names)):\n",
    "    idx = [j for j in range(len(labels)) if labels[j] == class_names[i]]\n",
    "\n",
    "    if i < len(class_names)-3:\n",
    "        ax0.scatter(mapper_w[idx, 0], mapper_w[idx, 1], s=30, label=class_names[i], color=color_list[i])\n",
    "    else:\n",
    "        ax0.scatter(mapper_w[idx, 0], mapper_w[idx, 1], s=30, label=class_names[i], color=color_list[i],\n",
    "                    marker=outlier_markers[i-(len(class_names)-3)])\n",
    "\n",
    "ax0.set_title('Combined Wasserstein')\n",
    "ax0.set_xlabel(\"UMAP 1\")\n",
    "ax0.set_ylabel(\"UMAP 2\")\n",
    "\n",
    "# Nucleus Only Chart\n",
    "fit = umap.UMAP(n_components=2, random_state=52)\n",
    "mapper_ws = fit.fit_transform(embed_W)\n",
    "\n",
    "ax1 = fig.add_subplot(132)\n",
    "for i in range(len(class_names)):\n",
    "    idx = [j for j in range(len(labels)) if labels[j] == class_names[i]]\n",
    "\n",
    "    if i < len(class_names)-3:\n",
    "        ax1.scatter(mapper_ws[idx, 0], mapper_ws[idx, 1], s=30, label=class_names[i], color=color_list[i])\n",
    "    else:\n",
    "        ax1.scatter(mapper_ws[idx, 0], mapper_ws[idx, 1], s=30, label=class_names[i], color=color_list[i],\n",
    "                    marker=outlier_markers[i-(len(class_names)-3)])\n",
    "\n",
    "ax1.set_title('Nucleus only')\n",
    "ax1.set_xlabel(\"UMAP 1\")\n",
    "ax1.set_ylabel(\"UMAP 2\")\n",
    "\n",
    "\n",
    "# Cell Only Chart\n",
    "mapper_cell = fit.fit_transform(embed_W_cell)\n",
    "\n",
    "ax2 = fig.add_subplot(133)\n",
    "for i in range(len(class_names)):\n",
    "    idx = [j for j in range(len(labels)) if labels[j] == class_names[i]]\n",
    "\n",
    "    if i < len(class_names)-3:\n",
    "        ax2.scatter(mapper_cell[idx, 0], mapper_cell[idx, 1], s=30, label=class_names[i], color=color_list[i])\n",
    "    else:\n",
    "        ax2.scatter(mapper_cell[idx, 0], mapper_cell[idx, 1], s=30, label=class_names[i], color=color_list[i],\n",
    "                    marker=outlier_markers[i-(len(class_names)-3)])\n",
    "\n",
    "ax2.set_title('Cell only')\n",
    "ax2.set_xlabel(\"UMAP 1\")\n",
    "ax2.set_ylabel(\"UMAP 2\")\n",
    "\n",
    "handles, labels_axis = [], []\n",
    "for ax in [ax0, ax1, ax2]:\n",
    "    h, l = ax.get_legend_handles_labels()\n",
    "    handles += h\n",
    "    labels_axis += l\n",
    "\n",
    "    \n",
    "unique = dict()\n",
    "for h, l in zip(handles, labels_axis):\n",
    "    if l not in unique:\n",
    "        unique[l] = h\n",
    "\n",
    "fig.legend(unique.values(), unique.keys(), loc='center left', bbox_to_anchor=(1.01, 0.5), borderaxespad=0.)\n",
    "plt.subplots_adjust(right=0.97)\n",
    "plt.savefig(\"result_data/figures/WASS_UMAP.png\", format=\"png\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "15113398-ad80-43ad-a491-aeae3e8f7a1d",
   "metadata": {},
   "source": [
    "<p style=\"height: 200px;\"></p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9836fe3a-1933-4229-9c9a-fd9eeecab4fc",
   "metadata": {},
   "source": [
    "## Quantile-Embeddings for Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e3f6ce6b-61ac-4ace-a045-744089bcfe68",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Predict label probabilities based on embedding using basic ML classifier\n",
    "# - predict every cell label using cross-training, label as \"bad_prediction\" if confidently predicts incorrectly\n",
    "# - can be used to filter out potential mislabels, or to assist expansion of labels to new unlabelled data\n",
    "\n",
    "\n",
    "X_diff = np.abs(np.array(embed_W) - np.array(embed_W_cell))\n",
    "X = np.hstack([embed_W, embed_W_cell, X_diff])\n",
    "scaler = StandardScaler()\n",
    "X = scaler.fit_transform(X)\n",
    "y = np.array(labels)\n",
    "\n",
    "le = LabelEncoder()\n",
    "y_int = le.fit_transform(y) \n",
    "class_names = le.classes_.tolist()\n",
    "\n",
    "json_path = \"result_data/cell_data_64.json\"\n",
    "with open(json_path, \"r\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "id_to_entry = {entry[\"id\"]: entry for entry in data}\n",
    "\n",
    "# Define classifier\n",
    "def build_classifier(seed):\n",
    "    base_learners = [\n",
    "        ('rf', RandomForestClassifier(n_estimators=50, random_state=seed)),\n",
    "        ('knn', KNeighborsClassifier(n_neighbors=15, weights='distance')),\n",
    "        ('svm', SVC(kernel='rbf', gamma='scale', C=50, probability=True)),\n",
    "        ('xgb', XGBClassifier(random_state=seed, eval_metric='mlogloss', n_estimators=50, subsample=0.8))\n",
    "    ]\n",
    "    meta_learner = RandomForestClassifier(n_estimators=50, random_state=seed)\n",
    "    return StackingClassifier(estimators=base_learners, final_estimator=meta_learner, cv=3)\n",
    "\n",
    "# 50/50 split (train/test)\n",
    "np.random.seed(42)\n",
    "indices = np.arange(len(X))\n",
    "np.random.shuffle(indices)\n",
    "mid = len(indices) // 2\n",
    "split1, split2 = indices[:mid], indices[mid:]\n",
    "\n",
    "predictions = {}\n",
    "\n",
    "for train_idx, test_idx in [(split1, split2), (split2, split1)]:\n",
    "    clf = build_classifier(seed=0)\n",
    "    clf.fit(X[train_idx], y_int[train_idx])\n",
    "    probs = clf.predict_proba(X[test_idx])\n",
    "    preds = np.argmax(probs, axis=1)\n",
    "\n",
    "    for i, idx in enumerate(test_idx):\n",
    "        prob_vector = probs[i]\n",
    "        pred_class_idx = preds[i]\n",
    "        true_class_idx = y_int[idx]\n",
    "\n",
    "        highest_prob = float(np.max(prob_vector))\n",
    "        highest_label = class_names[pred_class_idx]\n",
    "\n",
    "        true_prob = float(prob_vector[true_class_idx])\n",
    "        true_label = class_names[true_class_idx]\n",
    "\n",
    "        bad = bool((pred_class_idx != true_class_idx) and (highest_prob >= 5 * true_prob))\n",
    "\n",
    "        predictions[idx] = {\n",
    "            \"highest_prediction\": (round(highest_prob, 4), highest_label),\n",
    "            \"true_prediction\": (round(true_prob, 4), true_label),\n",
    "            \"bad_prediction\": bad\n",
    "        }\n",
    "\n",
    "\n",
    "for idx, pred in predictions.items():\n",
    "    cell_id = data[idx][\"id\"]\n",
    "    id_to_entry[cell_id].update(pred)\n",
    "\n",
    "with open(json_path, \"w\") as f:\n",
    "    json.dump(list(id_to_entry.values()), f, indent=2)\n",
    "\n",
    "bad_counts = Counter()\n",
    "for idx, pred in predictions.items():\n",
    "    if pred[\"bad_prediction\"]:\n",
    "        true_label = pred[\"true_prediction\"][1]\n",
    "        bad_counts[true_label] += 1\n",
    "\n",
    "print(\"Bad prediction counts by true class:\")\n",
    "total = 0\n",
    "for cls in class_names:\n",
    "    print(f\"{cls}: {bad_counts[cls]}\")\n",
    "    total += bad_counts[cls]\n",
    "\n",
    "print(f\"Total: {total} / 5764\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8807b461-272d-4e0a-9162-9ca25b624527",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# A further classifier to expand set of labelled data, trained only on data with previous \"bad_prediction\" = false\n",
    "# - classifier accuracy is gauged on test data after training on data with \"bad_prediction\" = false\n",
    "# - note that the classifier includes dead, blob, and wrong labels since expanded data may also contain these, however this decreases \n",
    "#   performance quality of the classifier due to the inclusion of these additional labels with minimal training instances.\n",
    "# The nature of this discrete classification along a continuous process is also inherently difficult at borders\n",
    "\n",
    "\n",
    "json_path = \"result_data/cell_data_64.json\"\n",
    "n_repeats = 5\n",
    "random_seed_base = 42\n",
    "\n",
    "with open(json_path, \"r\") as f:\n",
    "    data = json.load(f)\n",
    "\n",
    "if 'N_shapes' in globals():\n",
    "    n_samples = int(globals()['N_shapes'])\n",
    "else:\n",
    "    n_samples = min(len(X), len(y), len(data))\n",
    "\n",
    "if len(data) < n_samples:\n",
    "    raise RuntimeError(f\"JSON contains {len(data)} entries but expected at least {n_samples}.\")\n",
    "\n",
    "print(f\"Using first n_samples = {n_samples} entries (must correspond to rows of X and labels y).\")\n",
    "\n",
    "X = np.asarray(X)[:n_samples]\n",
    "y = np.asarray(y)[:n_samples]\n",
    "\n",
    "keep_mask = np.array([ not data[i].get(\"bad_prediction\", True) for i in range(n_samples) ], dtype=bool)\n",
    "\n",
    "num_total = len(keep_mask)\n",
    "num_kept = keep_mask.sum()\n",
    "print(f\"Total considered samples: {num_total}; samples with bad_prediction==False (kept): {num_kept}\")\n",
    "\n",
    "if num_kept < 10:\n",
    "    raise RuntimeError(\"Too few samples passed the bad_prediction==False filter. Aborting.\")\n",
    "\n",
    "X_filtered = X[keep_mask]\n",
    "y_filtered = y[keep_mask]\n",
    "\n",
    "class_counts = Counter(y_filtered)\n",
    "print(\"Class distribution (kept):\")\n",
    "for cls, cnt in class_counts.items():\n",
    "    print(f\"  {cls}: {cnt}\")\n",
    "\n",
    "scaler = StandardScaler()\n",
    "X_scaled = scaler.fit_transform(X_filtered)\n",
    "\n",
    "\n",
    "def accuracy_stack_and_plot2(X_in, y_in, seed, do_plot):\n",
    "    \n",
    "    X_train, X_test, y_train, y_test = train_test_split(X_in, y_in, test_size=0.2, random_state=seed, stratify=y_in)\n",
    "\n",
    "    base_learners = [\n",
    "        ('rf', RandomForestClassifier(n_estimators=50, random_state=seed)),\n",
    "        ('knn', KNeighborsClassifier(n_neighbors=15, weights='distance')),\n",
    "        ('svm', SVC(kernel='rbf', gamma='scale', C=50, probability=True, random_state=seed)),\n",
    "        ('xgb', XGBClassifier(random_state=seed, eval_metric='mlogloss', n_estimators=50, subsample=0.8, use_label_encoder=False))\n",
    "    ]\n",
    "\n",
    "    meta_learner = RandomForestClassifier(n_estimators=50, random_state=seed)\n",
    "\n",
    "    stack = StackingClassifier(estimators=base_learners, final_estimator=meta_learner, cv=3)\n",
    "    stack.fit(X_train, y_train)\n",
    "\n",
    "    y_pred = stack.predict(X_test)\n",
    "    accuracy = accuracy_score(y_test, y_pred)\n",
    "\n",
    "    class_names_local = sorted(np.unique(np.concatenate([y_train, y_test])))\n",
    "\n",
    "    cm = confusion_matrix(y_test, y_pred, labels=class_names_local)\n",
    "    result_matrix = pd.DataFrame(cm, index=class_names_local, columns=class_names_local)\n",
    "\n",
    "    print(\"\")\n",
    "    print(\"Confusion Matrix:\")\n",
    "    print(result_matrix)\n",
    "    print(f\"\\nAccuracy: {accuracy:.4f}\\n\")\n",
    "\n",
    "    if do_plot:\n",
    "        plt.figure(figsize=(10, 8))\n",
    "        sns.heatmap(result_matrix, annot=True, fmt='d', cmap=\"Blues\", cbar=True)\n",
    "        plt.title(\"Confusion Matrix (Stacking Classifier)\")\n",
    "        plt.xlabel(\"Predicted\")\n",
    "        plt.ylabel(\"Actual\")\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "\n",
    "        # Normalized confusion\n",
    "        row_sums = cm.sum(axis=1)\n",
    "        normalized_cm = np.zeros_like(cm, dtype=float)\n",
    "        for i in range(len(class_names_local)):\n",
    "            for j in range(len(class_names_local)):\n",
    "                denom = np.sqrt(row_sums[i] * row_sums[j])\n",
    "                normalized_cm[i, j] = cm[i, j] / denom if denom > 0 else 0\n",
    "\n",
    "        normalized_df = pd.DataFrame(normalized_cm,\n",
    "                                     index=[f\"True_{cls}\" for cls in class_names_local],\n",
    "                                     columns=[f\"Pred_{cls}\" for cls in class_names_local])\n",
    "\n",
    "        plt.figure(figsize=(10, 8))\n",
    "        sns.heatmap(normalized_df, annot=True, fmt=\".5f\", cmap=\"viridis\", cbar=True)\n",
    "        plt.title(\"Normalized Confusion Matrix (Row Frequency Normalization)\")\n",
    "        plt.xlabel(\"Predicted\")\n",
    "        plt.ylabel(\"Actual\")\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "\n",
    "    return accuracy\n",
    "\n",
    "\n",
    "accuracy_stack_witness = []\n",
    "for i in range(n_repeats):\n",
    "    seed = random_seed_base + i\n",
    "    do_plot = (i == n_repeats - 1)\n",
    "    print(f\"Run {i+1}/{n_repeats} (seed={seed})\")\n",
    "    acc = accuracy_stack_and_plot2(X_scaled, y_filtered.copy(), seed=seed, do_plot=do_plot)\n",
    "    accuracy_stack_witness.append(acc)\n",
    "    print(\"Accuracies so far:\", accuracy_stack_witness)\n",
    "\n",
    "print(\"\\nMean Stacking Accuracy:\", np.mean(accuracy_stack_witness))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "58de5a35-fce5-4873-8d1c-168c92cf4147",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a roughly balanced dataset by phase labels, filtering out mislabels/invalid stages.\n",
    "# Copy both original and rotated masks into balanced folders and write balanced JSON containing both paths\n",
    "# Note this rotation is unnecessary for our pre-aligned dataset, but is vital for other unaligned 3D datasets\n",
    "\n",
    "\n",
    "random.seed(42)\n",
    "\n",
    "original_metadata_path = 'result_data/cell_data_64.json'\n",
    "balanced_mask_dir = 'result_data/masks_64_balanced'\n",
    "balanced_metadata_path = 'result_data/cell_data_64_balanced.json'\n",
    "rotated_balanced_dir = 'result_data/masks_64_rotated_balanced'\n",
    "\n",
    "os.makedirs(balanced_mask_dir, exist_ok=True)\n",
    "os.makedirs(rotated_balanced_dir, exist_ok=True)\n",
    "\n",
    "excluded_labels = {\"blob\", \"dead\", \"wrong\"}\n",
    "\n",
    "with open(original_metadata_path, 'r') as f:\n",
    "    all_entries = json.load(f)\n",
    "\n",
    "valid_data_by_label = defaultdict(list)\n",
    "for entry in all_entries:\n",
    "    label = entry.get(\"label\")\n",
    "    if label in excluded_labels:\n",
    "        continue\n",
    "    if entry.get(\"bad_prediction\") is False:\n",
    "        valid_data_by_label[label].append(entry)\n",
    "\n",
    "label_counts = {label: len(lst) for label, lst in valid_data_by_label.items()}\n",
    "if not label_counts:\n",
    "    raise RuntimeError(\"No entries found with bad_prediction == False after excluding labels. Aborting.\")\n",
    "\n",
    "print(\"Good-prediction counts per label (after excluding blob/dead/wrong):\")\n",
    "for lbl, cnt in label_counts.items():\n",
    "    print(f\"  {lbl}: {cnt}\")\n",
    "\n",
    "# Determine the minimum available per label\n",
    "min_count = min(label_counts.values())\n",
    "print(f\"\\nMinimum available per-label (good predictions): {min_count}\")\n",
    "\n",
    "# For non-min labels we will take up to min_count + 100\n",
    "balanced_metadata = []\n",
    "copy_warnings = []\n",
    "for label, entries in valid_data_by_label.items():\n",
    "    available = len(entries)\n",
    "    if available == min_count:\n",
    "        sample_count = min_count\n",
    "    else:\n",
    "        sample_count = min(min_count + 100, available)\n",
    "\n",
    "    if sample_count == 0:\n",
    "        print(f\"Skipping label {label} because sample_count == 0\")\n",
    "        continue\n",
    "\n",
    "    sampled = random.sample(entries, sample_count)\n",
    "    print(f\"Sampling {len(sampled)} / {available} for label {label}\")\n",
    "\n",
    "    for entry in sampled:\n",
    "        src_mask = entry.get(\"mask_path\")\n",
    "        if not src_mask:\n",
    "            copy_warnings.append(f\"Missing mask_path for id {entry.get('id')}\")\n",
    "            continue\n",
    "\n",
    "        dst_mask = os.path.join(balanced_mask_dir, os.path.basename(src_mask))\n",
    "\n",
    "        try:\n",
    "            if os.path.exists(src_mask):\n",
    "                shutil.copy(src_mask, dst_mask)\n",
    "            else:\n",
    "                copy_warnings.append(f\"Mask source not found: {src_mask} (id {entry.get('id')})\")\n",
    "        except Exception as e:\n",
    "            copy_warnings.append(f\"Error copying mask {src_mask} -> {dst_mask}: {e}\")\n",
    "\n",
    "        rotated_src = entry.get(\"rotated_mask_path\")\n",
    "        if not rotated_src:\n",
    "            rotated_src = src_mask.replace('masks_64', 'masks_64_rotated')\n",
    "\n",
    "        dst_rotated = os.path.join(rotated_balanced_dir, os.path.basename(rotated_src))\n",
    "\n",
    "        try:\n",
    "            if os.path.exists(rotated_src):\n",
    "                shutil.copy(rotated_src, dst_rotated)\n",
    "            else:\n",
    "                copy_warnings.append(f\"Rotated mask source not found: {rotated_src} (id {entry.get('id')})\")\n",
    "        except Exception as e:\n",
    "            copy_warnings.append(f\"Error copying rotated mask {rotated_src} -> {dst_rotated}: {e}\")\n",
    "\n",
    "        new_entry = entry.copy()\n",
    "        new_entry[\"mask_path\"] = dst_mask\n",
    "        new_entry[\"rotated_mask_path\"] = dst_rotated\n",
    "\n",
    "        balanced_metadata.append(new_entry)\n",
    "\n",
    "with open(balanced_metadata_path, 'w') as f:\n",
    "    json.dump(balanced_metadata, f, indent=2)\n",
    "\n",
    "\n",
    "print(f\"\\nBalanced dataset created with {len(balanced_metadata)} total images.\")\n",
    "print(f\"Balanced masks copied to: {balanced_mask_dir}\")\n",
    "print(f\"Balanced rotated masks copied to: {rotated_balanced_dir}\")\n",
    "print(f\"Balanced metadata saved to: {balanced_metadata_path}\")\n",
    "\n",
    "if copy_warnings:\n",
    "    print(\"\\n Warnings encountered during copying:\")\n",
    "    for w in copy_warnings:\n",
    "        print(\"  -\", w)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "54035eaf-6423-42d7-aa19-da27a97c6074",
   "metadata": {},
   "source": [
    "<p style=\"height: 200px;\"></p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58e7476a-337b-48e0-b8b6-ac3c8a30b86c",
   "metadata": {},
   "source": [
    "## VAE Architecture"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dd532f7d-de94-43b1-aed8-84cdd5b81c2f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Required for retrieval\n",
    "\n",
    "\n",
    "class CellMaskDataset(Dataset):\n",
    "    def __init__(self, mask_dir, metadata_path=None):\n",
    "        self.paths = sorted(glob(os.path.join(mask_dir, '*.npy')))\n",
    "        \n",
    "        if metadata_path:\n",
    "            with open(metadata_path, 'r') as f:\n",
    "                metadata = json.load(f)\n",
    "            valid_paths = set(entry[\"mask_path\"] for entry in metadata)\n",
    "            # valid_paths = set(entry[\"rotated_mask_path\"] for entry in metadata)\n",
    "            self.paths = [p for p in self.paths if p in valid_paths]\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.paths)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        arr = np.load(self.paths[idx])\n",
    "        return torch.tensor(arr, dtype=torch.float32), self.paths[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "964469aa-67a3-4959-8ea9-18514e90ac78",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Functions specifying encoding and decoding details in VAE\n",
    "\n",
    "\n",
    "class EncoderBranch(nn.Module):\n",
    "    \n",
    "    def __init__(self, input_shape=(1, 64, 64, 64), latent_dim=16, base_channels=8,\n",
    "                 kernel_size=3, stride=2, padding=1):\n",
    "        super().__init__()\n",
    "        self.conv1 = nn.Conv3d(1, base_channels, kernel_size, stride=stride, padding=stride-1)\n",
    "        self.conv2 = nn.Conv3d(base_channels, base_channels * 2, kernel_size, stride=stride, padding=stride-1)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            dummy_input = torch.zeros(1, *input_shape)\n",
    "            x = F.relu(self.conv1(dummy_input))\n",
    "            x = F.relu(self.conv2(x))\n",
    "            self.flattened_size = x.numel()\n",
    "\n",
    "        self.flatten = nn.Flatten()\n",
    "        self.fc_mu = nn.Linear(self.flattened_size, latent_dim)\n",
    "        self.fc_logvar = nn.Linear(self.flattened_size, latent_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.conv1(x))\n",
    "        x = F.relu(self.conv2(x))\n",
    "        x = self.flatten(x)\n",
    "        return self.fc_mu(x), self.fc_logvar(x)\n",
    "\n",
    "\n",
    "\n",
    "class Decoder(nn.Module):\n",
    "    \n",
    "    def __init__(self, output_shape=(2, 64, 64, 64), latent_dim=32, base_channels=8,\n",
    "                 kernel_size=3, stride=2, padding=1):\n",
    "        super().__init__()\n",
    "        self.base_channels = base_channels\n",
    "        self.output_shape = output_shape\n",
    "\n",
    "        with torch.no_grad():\n",
    "            dummy = torch.zeros(1, *output_shape)\n",
    "            conv1 = nn.Conv3d(1, base_channels, kernel_size, stride=stride, padding=stride-1)\n",
    "            conv2 = nn.Conv3d(base_channels, base_channels * 2, kernel_size, stride=stride, padding=stride-1)\n",
    "            x = F.relu(conv1(dummy[:, 0:1]))\n",
    "            x = F.relu(conv2(x))\n",
    "            self.unflatten_shape = x.shape[1:]\n",
    "            self.linear_input_size = x.numel()\n",
    "\n",
    "        self.fc = nn.Linear(latent_dim, self.linear_input_size)\n",
    "        self.unflatten = nn.Unflatten(1, self.unflatten_shape)\n",
    "\n",
    "        self.deconv1 = nn.ConvTranspose3d(base_channels * 2, base_channels, kernel_size,\n",
    "                                          stride=stride, padding=stride-1, output_padding=stride-1)\n",
    "        self.deconv2 = nn.ConvTranspose3d(base_channels, output_shape[0], kernel_size,\n",
    "                                          stride=stride, padding=stride-1, output_padding=stride-1)\n",
    "\n",
    "    def forward(self, z):\n",
    "        x = F.relu(self.fc(z))\n",
    "        x = self.unflatten(x)\n",
    "        x = F.relu(self.deconv1(x))\n",
    "        return torch.sigmoid(self.deconv2(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "239481e9-80b2-46ae-aa1a-1e92dc529ced",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defines how to encode using dual-branch (easily extendible to more independent branches)\n",
    "\n",
    "\n",
    "class DualBranchVAE(nn.Module):\n",
    "    \n",
    "    def __init__(self, latent_dim=32, input_shape=(2, 64, 64, 64), base_channels=8,\n",
    "                 kernel_size=3, stride=2, padding=1):\n",
    "        super().__init__()\n",
    "        self.latent_dim = latent_dim\n",
    "        self.branch_nucleus = EncoderBranch(input_shape=(1, *input_shape[1:]),\n",
    "                                            latent_dim=latent_dim // 2,\n",
    "                                            base_channels=base_channels,\n",
    "                                            kernel_size=kernel_size,\n",
    "                                            stride=stride,\n",
    "                                            padding=stride-1)\n",
    "        self.branch_cell = EncoderBranch(input_shape=(1, *input_shape[1:]),\n",
    "                                         latent_dim=latent_dim // 2,\n",
    "                                         base_channels=base_channels,\n",
    "                                         kernel_size=kernel_size,\n",
    "                                         stride=stride,\n",
    "                                         padding=stride-1)\n",
    "        self.decoder = Decoder(output_shape=input_shape,\n",
    "                               latent_dim=latent_dim,\n",
    "                               base_channels=base_channels,\n",
    "                               kernel_size=kernel_size,\n",
    "                               stride=stride,\n",
    "                               padding=stride-1)\n",
    "\n",
    "    def encode(self, x):\n",
    "        nucleus = x[:, 0:1]  # shape (B, 1, D, H, W)\n",
    "        cell = x[:, 1:2]\n",
    "        mu_n, logvar_n = self.branch_nucleus(nucleus)\n",
    "        mu_c, logvar_c = self.branch_cell(cell)\n",
    "        mu = torch.cat([mu_n, mu_c], dim=1)\n",
    "        logvar = torch.cat([logvar_n, logvar_c], dim=1)\n",
    "        return mu, logvar\n",
    "\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        std = torch.exp(0.5 * logvar)\n",
    "        eps = torch.randn_like(std)\n",
    "        return mu + eps * std\n",
    "\n",
    "    def decode(self, z):\n",
    "        return self.decoder(z)\n",
    "\n",
    "    def forward(self, x):\n",
    "        mu, logvar = self.encode(x)\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "        return self.decode(z), mu, logvar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "29a8cd77-4184-4a68-b18b-b05da00a8411",
   "metadata": {},
   "outputs": [],
   "source": [
    "def loss_function(recon_x, x, mu, logvar, beta=1.0):\n",
    "    \"\"\"\n",
    "    - mu, logvar: Mean and log-variance of latent space (for KL divergence)\n",
    "    - beta: The weight for KL divergence term (controls regularization)\n",
    "    \n",
    "    Returns:\n",
    "    - total_loss: Sum of reconstruction and KL losses.\n",
    "    - bce_loss: Binary cross-entropy loss (reconstruction)\n",
    "    - kl_loss: KL divergence loss (latent regularization)\n",
    "    \"\"\"\n",
    "\n",
    "    recon_x = torch.clamp(recon_x, min=1e-10, max=1 - 1e-10)\n",
    "    x = torch.clamp(x, min=1e-10, max=1 - 1e-10)\n",
    "\n",
    "    if torch.any(recon_x < 0) or torch.any(recon_x > 1):\n",
    "        print(\"recon_x out of bounds\", recon_x.min().item(), recon_x.max().item())\n",
    "        recon_x = recon_x.clamp(min=1e-10, max=1 - 1e-10)\n",
    "\n",
    "    if torch.any(x < 0) or torch.any(x > 1):\n",
    "        print(\"x out of bounds\", x.min().item(), x.max().item())\n",
    "        x = x.clamp(min=1e-10, max=1 - 1e-10)\n",
    "    \n",
    "    bce_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')\n",
    "    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
    "    return bce_loss + beta * kl_loss, bce_loss, kl_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "912cedfa-c4f9-4143-8eb4-cf3d3b389e68",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defines training using the set of best HPs from previous grid-search results. To recreate HP-search, enter full list\n",
    "# of candidate values in sample_hyperparams() below\n",
    "\n",
    "\n",
    "def run_training(config, dataset_dir, metadata_path, input_shape, latent_dim, device):\n",
    "    dataset = CellMaskDataset(dataset_dir, metadata_path)\n",
    "    if len(dataset) == 0:\n",
    "        raise ValueError(f\"No valid samples found in {dataset_dir}\")\n",
    "\n",
    "    train_size = int(0.8 * len(dataset))\n",
    "    val_size = len(dataset) - train_size\n",
    "    train_set, val_set = random_split(dataset, [train_size, val_size])\n",
    "\n",
    "    train_loader = DataLoader(train_set, batch_size=2, shuffle=True)\n",
    "    val_loader = DataLoader(val_set, batch_size=2, shuffle=True)\n",
    "\n",
    "    model = DualBranchVAE(\n",
    "        latent_dim=latent_dim,\n",
    "        base_channels=config['base_channels'],\n",
    "        stride=config['stride']\n",
    "    ).to(device)\n",
    "\n",
    "    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'])\n",
    "    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1)\n",
    "\n",
    "    best_val_loss = float('inf')\n",
    "    epochs_no_improve = 0\n",
    "    early_stop_patience = 2\n",
    "    best_epoch = 0\n",
    "    epoch_num = 0\n",
    "\n",
    "    for epoch in range(20):\n",
    "\n",
    "        print(f\"epoch number: {epoch_num}\")\n",
    "        epoch_num += 1\n",
    "        \n",
    "        model.train()\n",
    "        total_train_loss = 0\n",
    "        for batch, _ in train_loader:\n",
    "            batch = batch.to(device)\n",
    "            optimizer.zero_grad()\n",
    "            recon, mu, logvar = model(batch)\n",
    "            loss, _, _ = loss_function(recon, batch, mu, logvar, beta=config['beta'])\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            total_train_loss += loss.item()\n",
    "\n",
    "        print(\"Epoch trained\")\n",
    "        \n",
    "        model.eval()\n",
    "        total_val_recon_loss = 0\n",
    "        with torch.no_grad():\n",
    "            for batch, _ in val_loader:\n",
    "                batch = batch.to(device)\n",
    "                recon, mu, logvar = model(batch)\n",
    "                val_loss, recon_loss, kl_loss = loss_function(recon, batch, mu, logvar, beta=config['beta'])\n",
    "                total_val_recon_loss += recon_loss.item()\n",
    "\n",
    "        avg_val_loss = total_val_recon_loss / len(val_loader.dataset)\n",
    "        scheduler.step(avg_val_loss)\n",
    "\n",
    "        print(f\"Validation reconstruction loss: {avg_val_loss}\")\n",
    "\n",
    "        if avg_val_loss < best_val_loss:\n",
    "            best_val_loss = avg_val_loss\n",
    "            best_model_state = model.state_dict()\n",
    "            best_mu = mu.cpu().numpy()\n",
    "            best_logvar = logvar.cpu().numpy()\n",
    "            best_epoch = epoch + 1\n",
    "            epochs_no_improve = 0\n",
    "        else:\n",
    "            epochs_no_improve += 1\n",
    "\n",
    "        if epochs_no_improve >= early_stop_patience:\n",
    "            break\n",
    "\n",
    "    model.load_state_dict(best_model_state)\n",
    "\n",
    "    model_save_path = 'result_data/dual_branch_vae_64_balanced.pth'\n",
    "    # model_save_path = 'result_data/dual_branch_vae_64_rotated_balanced.pth'\n",
    "    torch.save(model.state_dict(), model_save_path)\n",
    "    print(f\"Model saved to: {model_save_path}\")\n",
    "\n",
    "    return best_val_loss, best_epoch, config, model, best_mu, best_logvar, train_loader, val_loader, train_set, val_set\n",
    "\n",
    "\n",
    "def sample_hyperparams(latent_dim):\n",
    "    lrs = [5e-4]  # If running grid-search, adjust to all candidate values for the 4 HPs below\n",
    "    betas = [0.5]\n",
    "    strides = [2]\n",
    "    base_channels_options = [latent_dim // 4]\n",
    "\n",
    "    combos = []\n",
    "    for lr in lrs:\n",
    "        for beta in betas:\n",
    "            for stride in strides:\n",
    "                for base in base_channels_options:\n",
    "                    combos.append({\n",
    "                        'lr': lr,\n",
    "                        'beta': beta,\n",
    "                        'stride': stride,\n",
    "                        'base_channels': base\n",
    "                    })\n",
    "\n",
    "    return random.sample(combos, 1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7806d6f8-553f-4680-a1a0-9f6dbb6581c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Run training on a 80/20 split of balanced dataset\n",
    "\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "input_configs = [\n",
    "    {\n",
    "        'input_shape': (2, 64, 64, 64),\n",
    "        'mask_dir': 'result_data/masks_64_balanced',   # result_data/masks_64_rotated_balanced\n",
    "        'metadata_path': 'result_data/cell_data_64_balanced.json'\n",
    "    }\n",
    "]\n",
    "\n",
    "latent_dims = [64]  # If running grid-search, adjust to all candidate values [16, 32, 64]\n",
    "final_results = []\n",
    "\n",
    "for config_dict in input_configs:\n",
    "    input_shape = config_dict['input_shape']\n",
    "    dataset_dir = config_dict['mask_dir']\n",
    "    metadata_path = config_dict['metadata_path']\n",
    "\n",
    "    for latent_dim in latent_dims:\n",
    "        trials = []\n",
    "        best_loss = float('inf')\n",
    "        best_trial = None\n",
    "\n",
    "        print(f\"\\nStarting search for shape {input_shape}, latent dimension {latent_dim}\")\n",
    "\n",
    "        configs = list(sample_hyperparams(latent_dim))\n",
    "        total_configs = len(configs)\n",
    "\n",
    "        for idx, config in enumerate(configs, start=1):\n",
    "\n",
    "            print(f\"\\n[{idx}/{total_configs}] Testing config: {config}\")\n",
    "            \n",
    "            try:\n",
    "                val_loss, num_epochs, used_config, model, mu, logvar, train_loader, val_loader, train_set, val_set = run_training(\n",
    "                    config=config,\n",
    "                    dataset_dir=dataset_dir,\n",
    "                    metadata_path=metadata_path,\n",
    "                    input_shape=input_shape,\n",
    "                    latent_dim=latent_dim,\n",
    "                    device=device\n",
    "                )\n",
    "            except ValueError as e:\n",
    "                print(f\"Skipping config due to data error: {e}\")\n",
    "                continue\n",
    "\n",
    "            print(f\"Completed config {idx}/{total_configs} with val loss: {val_loss:.4f}\")\n",
    "\n",
    "            trials.append({\n",
    "                'config': used_config,\n",
    "                'val_loss': val_loss\n",
    "            })\n",
    "\n",
    "            if val_loss < best_loss:\n",
    "                best_loss = val_loss\n",
    "                best_trial = {\n",
    "                    'combination': {\n",
    "                        'input_shape': input_shape,\n",
    "                        'latent_dim': latent_dim\n",
    "                    },\n",
    "                    'best_config': used_config,\n",
    "                    'val_loss': val_loss,\n",
    "                    'epochs_ran': num_epochs,\n",
    "                    'tried_configs': trials,\n",
    "                    'mu': mu.tolist(),\n",
    "                    'logvar': logvar.tolist()\n",
    "                }\n",
    "\n",
    "        if best_trial:\n",
    "            final_results.append(best_trial)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0e2a2b8b-b72e-49b2-ba31-e39541f154f5",
   "metadata": {},
   "source": [
    "<p style=\"height: 200px;\"></p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98e25a54-9455-4c2e-a6ce-3fe23a11400b",
   "metadata": {},
   "source": [
    "## VAE Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca34301d-f09b-4a33-9306-27b8cb956f04",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load model again if required, test on new data and save metrics\n",
    "\n",
    "\n",
    "# Loading model:\n",
    "# dataset = CellMaskDataset('result_data/masks_64_balanced', 'result_data/cell_data_64_balanced.json')\n",
    "# train_size = int(0.8 * len(dataset))\n",
    "# val_size = len(dataset) - train_size\n",
    "# train_set, val_set = random_split(dataset, [train_size, val_size])\n",
    "# train_loader = DataLoader(train_set, batch_size=2, shuffle=True)\n",
    "# val_loader = DataLoader(val_set, batch_size=2, shuffle=True)\n",
    "\n",
    "# model = DualBranchVAE(latent_dim=16, base_channels=8, stride=2).to(device)\n",
    "# model.load_state_dict(torch.load('result_data/dual_branch_vae_64_balanced.pth'))\n",
    "\n",
    "\n",
    "test_loader = val_loader\n",
    "test_set = val_set\n",
    "\n",
    "model.eval()\n",
    "latent_vectors = {}\n",
    "logvar_vectors = {}\n",
    "recon_losses = {}\n",
    "voxel_accuracies = {}\n",
    "dice_scores = {}\n",
    "\n",
    "\n",
    "def dice_coefficient(pred, target):\n",
    "    \"\"\"Compute DICE = 2|A∩B| / (|A|+|B|)\"\"\"\n",
    "    intersection = (pred & target).sum().item()\n",
    "    pred_sum = pred.sum().item()\n",
    "    target_sum = target.sum().item()\n",
    "    if pred_sum + target_sum == 0:\n",
    "        return 1.0  # both empty = perfect match\n",
    "    return (2.0 * intersection) / (pred_sum + target_sum)\n",
    "\n",
    "\n",
    "with torch.no_grad():\n",
    "    for batch, path in tqdm(test_loader):\n",
    "        batch = batch.to(device)\n",
    "        recon, mu, logvar = model(batch)\n",
    "\n",
    "        recon = torch.clamp(recon, min=1e-10, max=1 - 1e-10)\n",
    "        batch = torch.clamp(batch, min=1e-10, max=1 - 1e-10)\n",
    "\n",
    "        mu = mu.cpu()\n",
    "        logvar = logvar.cpu()\n",
    "        recon = recon.cpu()\n",
    "        batch = batch.cpu()\n",
    "\n",
    "        for i in range(batch.size(0)):\n",
    "            sample_id = os.path.basename(path[i]).replace('.npy', '')\n",
    "\n",
    "            latent_vectors[sample_id] = mu[i].tolist()\n",
    "            logvar_vectors[sample_id] = logvar[i].tolist()\n",
    "\n",
    "            recon_sample = torch.clamp(recon[i], 1e-10, 1 - 1e-10)\n",
    "            batch_sample = torch.clamp(batch[i], 1e-10, 1 - 1e-10)\n",
    "\n",
    "            bce = F.binary_cross_entropy(recon_sample, batch_sample, reduction='sum')\n",
    "\n",
    "            recon_binary = (recon_sample >= 0.5).int()\n",
    "            batch_binary = (batch_sample >= 0.5).int()\n",
    "            \n",
    "            correct_voxels = (recon_binary == batch_binary).sum().item()\n",
    "            total_voxels = batch_sample.numel()\n",
    "            accuracy = correct_voxels / total_voxels\n",
    "\n",
    "            dice = dice_coefficient(recon_binary, batch_binary)\n",
    "\n",
    "            recon_losses[sample_id] = bce.item()\n",
    "            voxel_accuracies[sample_id] = accuracy\n",
    "            dice_scores[sample_id] = dice\n",
    "\n",
    "\n",
    "with open('result_data/latent_space_test_64_balanced.json', 'w') as f:\n",
    "    json.dump(latent_vectors, f)\n",
    "\n",
    "with open('result_data/logvars_test_64_balanced.json', 'w') as f:\n",
    "    json.dump(logvar_vectors, f)\n",
    "\n",
    "with open('result_data/reconstruction_losses_test_64_balanced.json', 'w') as f:\n",
    "    json.dump(recon_losses, f)\n",
    "\n",
    "with open('result_data/voxel_accuracies_test_64_balanced.json', 'w') as f:\n",
    "    json.dump(voxel_accuracies, f)\n",
    "\n",
    "with open('result_data/dice_scores_test_64_balanced.json', 'w') as f:\n",
    "    json.dump(dice_scores, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6752255c-d512-43ec-97c0-39287c6be88f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Observe voxel-level reconstruction accuracies by phase\n",
    "\n",
    "\n",
    "with open('result_data/voxel_accuracies_test_64_balanced.json', 'r') as f:\n",
    "    voxel_accuracies = json.load(f)\n",
    "\n",
    "with open('result_data/dice_scores_test_64_balanced.json', 'r') as f:\n",
    "    dice_scores = json.load(f)\n",
    "\n",
    "with open('result_data/cell_data_64_balanced.json', 'r') as f:\n",
    "    metadata = json.load(f)\n",
    "\n",
    "id_to_label = {entry[\"id\"]: entry[\"label\"] for entry in metadata}\n",
    "valid_labels = ['M0', 'M1M2', 'M3', 'M4M5', 'M6M7_early', 'M6M7_half']\n",
    "\n",
    "acc_values = list(voxel_accuracies.values())\n",
    "dice_values = list(dice_scores.values())\n",
    "\n",
    "print(\"Overall Metrics:\")\n",
    "print(f\"   Binary Voxel Accuracy — Mean: {np.mean(acc_values) * 100:.2f}%, Std: {np.std(acc_values) * 100:.2f}%\")\n",
    "print(f\"   DICE Score            — Mean: {np.mean(dice_values) * 100:.2f}%, Std: {np.std(dice_values) * 100:.2f}%\")\n",
    "\n",
    "label_to_accs = defaultdict(list)\n",
    "label_to_dices = defaultdict(list)\n",
    "\n",
    "for id_ in voxel_accuracies.keys():\n",
    "    label = id_to_label.get(id_)\n",
    "    if label in valid_labels:\n",
    "        label_to_accs[label].append(voxel_accuracies[id_])\n",
    "        label_to_dices[label].append(dice_scores[id_])\n",
    "\n",
    "print(\"Metrics by Label:\")\n",
    "for label in valid_labels:\n",
    "    accs = label_to_accs[label]\n",
    "    dices = label_to_dices[label]\n",
    "\n",
    "    if accs:\n",
    "        print(f\"  {label:12s} — \"\n",
    "              f\"Acc: {np.mean(accs) * 100:.2f}% ± {np.std(accs) * 100:.2f}%, \"\n",
    "              f\"DICE: {np.mean(dices) * 100:.2f}% ± {np.std(dices) * 100:.2f}%, \"\n",
    "        )\n",
    "    else:\n",
    "        print(f\"{label:12s} — No samples found\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa49ae00-9481-489e-9a8a-456ff68fa18e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Generate images from latent space distribution, visualize central 2D cuts of structures\n",
    "\n",
    "\n",
    "with open('result_data/latent_space_test_64_balanced.json', 'r') as f:   # latent_space_test_64_rotated_balanced.json\n",
    "    mu_dict = json.load(f)\n",
    "\n",
    "with open('result_data/logvars_test_64_balanced.json', 'r') as f:   # logvars_test_64_rotated_balanced.json\n",
    "    logvar_dict = json.load(f)\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "ids = list(mu_dict.keys())\n",
    "mu_array = np.array([mu_dict[i] for i in ids])\n",
    "logvar_array = np.array([logvar_dict[i] for i in ids])\n",
    "\n",
    "num_samples = 5\n",
    "sample_indices = random.sample(range(len(mu_array)), num_samples)\n",
    "\n",
    "samples = []\n",
    "for i in sample_indices:\n",
    "    mu = np.array(mu_array[i])\n",
    "    logvar = np.array(logvar_array[i])\n",
    "    std = np.exp(0.5 * logvar)\n",
    "    epsilon = np.random.randn(*mu.shape)\n",
    "    z = mu + std * epsilon\n",
    "    samples.append(z)\n",
    "\n",
    "samples_tensor = torch.tensor(samples, dtype=torch.float32).to(device)\n",
    "\n",
    "model.eval()\n",
    "with torch.no_grad():\n",
    "    generated = model.decoder(samples_tensor)\n",
    "    generated = (generated >= 0.2).int()\n",
    "\n",
    "for i in range(num_samples):\n",
    "    volume = generated[i].cpu().numpy()\n",
    "    fig, axs = plt.subplots(1, 2, figsize=(8, 4))\n",
    "    axs[0].imshow(volume[0, 32, :, :], cmap='gray')\n",
    "    axs[0].set_title(\"Channel 1 (Nucleus)\")\n",
    "    axs[1].imshow(volume[1, 32, :, :], cmap='gray')\n",
    "    axs[1].set_title(\"Channel 2 (Cell)\")\n",
    "    plt.suptitle(f\"Sampled from Posterior {i+1}\")\n",
    "    plt.tight_layout()\n",
    "    # plt.savefig(f\"result_data/figures/POSTERIOR_SAMPLE_{i}.png\", format=\"png\", bbox_inches=\"tight\")\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3e0bb263-e2a7-4992-b70c-04cd12221511",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Visualize reconstructed images from latent space, before and after\n",
    "\n",
    "\n",
    "def plot_vae_reconstructions_per_label(model, latent_file_path, metadata_path, dataset_dir, n_per_label=3, slice_idx=32):\n",
    "\n",
    "    with open(latent_file_path, 'r') as f:\n",
    "        latent_vectors = json.load(f)\n",
    "\n",
    "    with open(metadata_path, 'r') as f:\n",
    "        metadata = json.load(f)\n",
    "    id_to_label = {entry[\"id\"]: entry[\"label\"] for entry in metadata}\n",
    "\n",
    "    label_to_ids = defaultdict(list)\n",
    "    for id_ in latent_vectors:\n",
    "        if id_ in id_to_label:\n",
    "            label_to_ids[id_to_label[id_]].append(id_)\n",
    "\n",
    "    valid_labels = ['M0', 'M1M2', 'M3', 'M4M5', 'M6M7_early', 'M6M7_half']\n",
    "\n",
    "    device = model.device if hasattr(model, 'device') else 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    model.eval()\n",
    "\n",
    "    j=0\n",
    "\n",
    "    for label in valid_labels:\n",
    "        selected_ids = random.sample(label_to_ids[label], k=min(n_per_label, len(label_to_ids[label])))\n",
    "        samples = [latent_vectors[id_] for id_ in selected_ids]\n",
    "        samples_tensor = torch.tensor(samples, dtype=torch.float32).to(device)\n",
    "\n",
    "        j+=1\n",
    "\n",
    "        with torch.no_grad():\n",
    "            reconstructions = model.decoder(samples_tensor)\n",
    "\n",
    "        for i, id_ in enumerate(selected_ids):\n",
    "            recon = reconstructions[i].cpu().numpy()\n",
    "            recon = np.clip(recon, 1e-10, 1 - 1e-10)\n",
    "            recon_binary = (recon >= 0.2).astype(np.uint8)\n",
    "\n",
    "            try:\n",
    "                mask_path = os.path.join(dataset_dir, f\"{id_}.npy\")\n",
    "                original = np.load(mask_path)\n",
    "            except Exception as e:\n",
    "                print(f\"Could not load original for ID {id_}: {e}\")\n",
    "                continue\n",
    "\n",
    "            fig, axs = plt.subplots(2, 2, figsize=(8, 6))\n",
    "\n",
    "            axs[0, 0].imshow(original[0, slice_idx, :, :], cmap='gray')\n",
    "            axs[0, 0].set_title(\"Original Channel 1\")\n",
    "            axs[0, 0].axis('off')\n",
    "\n",
    "            axs[0, 1].imshow(original[1, slice_idx, :, :], cmap='gray')\n",
    "            axs[0, 1].set_title(\"Original Channel 2\")\n",
    "            axs[0, 1].axis('off')\n",
    "\n",
    "            axs[1, 0].imshow(recon_binary[0, slice_idx, :, :], cmap='gray')\n",
    "            axs[1, 0].set_title(\"Reconstructed Channel 1\")\n",
    "            axs[1, 0].axis('off')\n",
    "\n",
    "            axs[1, 1].imshow(recon_binary[1, slice_idx, :, :], cmap='gray')\n",
    "            axs[1, 1].set_title(\"Reconstructed Channel 2\")\n",
    "            axs[1, 1].axis('off')\n",
    "\n",
    "            plt.suptitle(f\"Label: {label} | ID: {id_}\")\n",
    "            plt.tight_layout(rect=[0, 0.03, 1, 0.95])\n",
    "            # plt.savefig(f\"result_data/figures/RECONSTRUCTION_{i}{j}.png\", format=\"png\", bbox_inches=\"tight\")\n",
    "            plt.show()\n",
    "\n",
    "\n",
    "plot_vae_reconstructions_per_label(\n",
    "    model=model,\n",
    "    latent_file_path='result_data/latent_space_test_64_balanced.json',   # latent_space_test_64_rotated_balanced.json\n",
    "    metadata_path='result_data/cell_data_64_balanced.json',\n",
    "    dataset_dir='result_data/masks_64_balanced',  # masks_64_rotated_balanced\n",
    "    n_per_label=3,\n",
    "    slice_idx=32\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6ac2b2c2-1be1-445e-8870-64fc59f95b91",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute latent space of full balanced dataset, save original 64D embedding and 70D augmented (with scale factors).\n",
    "# scaling factors are themselves scaled before augmented, such they contribute 2x variance of normal dimension for curve fitting\n",
    "\n",
    "\n",
    "model_path = 'result_data/dual_branch_vae_64_balanced.pth'\n",
    "metadata_path = 'result_data/cell_data_64_balanced.json'\n",
    "out_latent_path = 'result_data/latent_space_full_64_balanced.json'\n",
    "out_augmented_path = 'result_data/augmented_latent_space_full_64_balanced.json'\n",
    "\n",
    "LATENT_DIM = 64\n",
    "BASE_CHANNELS = 16\n",
    "STRIDE = 2\n",
    "BATCH_SIZE = 8\n",
    "\n",
    "with open(metadata_path, 'r') as f:\n",
    "    metadata = json.load(f)\n",
    "\n",
    "entries = []\n",
    "for entry in metadata:\n",
    "    sample_id = entry.get('id')\n",
    "    mask_path = entry.get('mask_path')\n",
    "    required_sizes = ['size_nucleus_1', 'size_nucleus_2', 'size_nucleus_3','size_cell_1', 'size_cell_2', 'size_cell_3']\n",
    "    \n",
    "    has_sizes = all(k in entry for k in required_sizes)\n",
    "    if not sample_id or not mask_path:\n",
    "        print(f\"Skipping entry missing id or mask_path: {entry}\")\n",
    "        continue\n",
    "    if not os.path.exists(mask_path):\n",
    "        print(f\"Mask missing for {sample_id}: {mask_path} — skipping\")\n",
    "        continue\n",
    "    if not has_sizes:\n",
    "        print(f\"Size fields missing for {sample_id} — skipping\")\n",
    "        continue\n",
    "    entries.append(entry)\n",
    "\n",
    "N = len(entries)\n",
    "print(f\"Found {N} metadata entries with masks+sizes to process.\")\n",
    "if N == 0:\n",
    "    raise RuntimeError(\"No valid entries found. Check metadata/mask paths/size fields.\")\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "try:\n",
    "    model = DualBranchVAE(latent_dim=LATENT_DIM, base_channels=BASE_CHANNELS, stride=STRIDE).to(device)\n",
    "except NameError as e:\n",
    "    raise RuntimeError(\"DualBranchVAE is not defined in the current notebook. \"\n",
    "                       \"Define/import the model class before running this cell.\") from e\n",
    "\n",
    "state = torch.load(model_path, map_location=device)\n",
    "model.load_state_dict(state)\n",
    "model.eval()\n",
    "print(\"Model loaded and set to eval()\")\n",
    "\n",
    "\n",
    "# Helper to load a mask and convert to tensor\n",
    "def load_mask_tensor(path):\n",
    "    arr = np.load(path)\n",
    "    if not isinstance(arr, np.ndarray):\n",
    "        raise RuntimeError(f\"Loaded object not numpy array from {path}\")\n",
    "    arr = arr.astype(np.float32)\n",
    "    if arr.ndim != 4 or arr.shape[0] != 2:\n",
    "        raise RuntimeError(f\"Unexpected mask shape {arr.shape} for {path}. Expected (2, D, H, W).\")\n",
    "    t = torch.from_numpy(arr).unsqueeze(0)\n",
    "    return t\n",
    "\n",
    "\n",
    "id_list = []\n",
    "embeddings = []\n",
    "sizes_list = []\n",
    "\n",
    "batch_ids = []\n",
    "batch_tensors = []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for entry in tqdm(entries, desc=\"Preparing and computing latents\"):\n",
    "        sid = entry['id']\n",
    "        mask_path = entry['mask_path']\n",
    "\n",
    "        try:\n",
    "            t = load_mask_tensor(mask_path)\n",
    "        except Exception as e:\n",
    "            print(f\"Error loading {mask_path}: {e} — skipping\")\n",
    "            continue\n",
    "\n",
    "        batch_ids.append(sid)\n",
    "        batch_tensors.append(t)\n",
    "\n",
    "        sizes = np.array([\n",
    "            entry['size_nucleus_1'],\n",
    "            entry['size_nucleus_2'],\n",
    "            entry['size_nucleus_3'],\n",
    "            entry['size_cell_1'],\n",
    "            entry['size_cell_2'],\n",
    "            entry['size_cell_3']\n",
    "        ], dtype=float)\n",
    "        sizes_list.append(sizes)\n",
    "\n",
    "        if len(batch_tensors) >= BATCH_SIZE or len(batch_tensors) + len(embeddings) >= N:\n",
    "            batch_stack = torch.cat(batch_tensors, dim=0).to(device)\n",
    "            try:\n",
    "                _, mu, _ = model(batch_stack)\n",
    "            except Exception as e:\n",
    "                if hasattr(model, 'encode'):\n",
    "                    mu, _ = model.encode(batch_stack)\n",
    "                else:\n",
    "                    raise\n",
    "\n",
    "            mu_np = mu.cpu().numpy()\n",
    "            for i, sid in enumerate(batch_ids):\n",
    "                embeddings.append(mu_np[i].astype(float))\n",
    "                id_list.append(sid)\n",
    "\n",
    "            batch_ids = []\n",
    "            batch_tensors = []\n",
    "\n",
    "\n",
    "embeddings = np.vstack(embeddings)\n",
    "sizes_arr = np.vstack(sizes_list[:len(embeddings)])\n",
    "ids_processed = id_list\n",
    "M = embeddings.shape[0]\n",
    "print(f\"Computed latents for {M} samples (expected <= {N}).\")\n",
    "\n",
    "latent_map = {ids_processed[i]: embeddings[i].tolist() for i in range(M)}\n",
    "os.makedirs(os.path.dirname(out_latent_path), exist_ok=True)\n",
    "with open(out_latent_path, 'w') as f:\n",
    "    json.dump(latent_map, f, indent=2)\n",
    "print(f\"Saved latent space to: {out_latent_path}\")\n",
    "\n",
    "var_sizes_per_dim = np.var(sizes_arr, axis=0)\n",
    "mean_var_sizes = float(np.mean(var_sizes_per_dim))\n",
    "\n",
    "var_latent_per_dim = np.var(embeddings, axis=0)\n",
    "mean_var_latent = float(np.mean(var_latent_per_dim))\n",
    "\n",
    "if mean_var_sizes == 0:\n",
    "    scale_factor = 1.0\n",
    "    print(\"Mean variance of size features is zero. Using scale_factor = 1.0\")\n",
    "else:\n",
    "    scale_factor = float(np.sqrt((2.0 * mean_var_latent) / mean_var_sizes))\n",
    "\n",
    "print(f\"Scale factor computed: {scale_factor:.6g}\")\n",
    "print(f\"Mean variance (latent dim): {mean_var_latent:.6g}\")\n",
    "print(f\"Mean variance (size dims, pre-scale): {mean_var_sizes:.6g}\")\n",
    "print(f\"Mean variance (size dims, post-scale target): {2.0 * mean_var_latent:.6g}\")\n",
    "\n",
    "\n",
    "augmented_map = {}\n",
    "for i, sid in enumerate(ids_processed):\n",
    "    emb = embeddings[i]  \n",
    "    sizes = sizes_arr[i]  \n",
    "    scaled_sizes = (sizes * scale_factor).astype(float)\n",
    "    augmented_vec = np.concatenate([emb, scaled_sizes])\n",
    "    augmented_map[sid] = augmented_vec.tolist()\n",
    "\n",
    "out_dict = {\"scale_factor\": scale_factor}\n",
    "out_dict.update(augmented_map)\n",
    "\n",
    "with open(out_augmented_path, 'w') as f:\n",
    "    json.dump(out_dict, f, indent=2)\n",
    "\n",
    "print(f\"Saved augmented latent space (with scale_factor at top) to: {out_augmented_path}\")\n",
    "print(\"Done.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f1fb25c9-aa44-42f8-8bcd-01ab767725ff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Produce 10 UMAP plots (5 nucleus, 5 cell) with colorings to visualize various properties of the learned latent space\n",
    "\n",
    "\n",
    "base_folder = \"source_data/crop_seg\"\n",
    "augmented_latent_path = \"result_data/augmented_latent_space_full_64_balanced.json\"\n",
    "metadata_path = \"result_data/cell_data_64_balanced.json\"\n",
    "\n",
    "valid_labels = ['M0', 'M1M2', 'M3', 'M4M5', 'M6M7_early', 'M6M7_half']\n",
    "label_to_index = {lab: i for i, lab in enumerate(valid_labels)}\n",
    "cmap_labels = plt.get_cmap(\"viridis\")\n",
    "\n",
    "\n",
    "def extract_two_channels(arr):\n",
    "    \"\"\"\n",
    "    detect which axis is the channel axis (size 2) and return array shape (2, z, y, x)\n",
    "    \"\"\"\n",
    "    a = np.asarray(arr)\n",
    "    if a.ndim == 4:\n",
    "        axes_eq_2 = [ax for ax, s in enumerate(a.shape) if s == 2]\n",
    "        if len(axes_eq_2) == 1:\n",
    "            cax = axes_eq_2[0]\n",
    "            return np.moveaxis(a, cax, 0)\n",
    "        else:\n",
    "            if a.shape[0] == 2:\n",
    "                return a\n",
    "            if a.shape[-1] == 2:\n",
    "                return np.moveaxis(a, -1, 0)\n",
    "            if a.shape[1] == 2:\n",
    "                return np.moveaxis(a, 1, 0)\n",
    "            cax_guess = int(np.argmin(a.shape))\n",
    "            if a.shape[cax_guess] == 2:\n",
    "                return np.moveaxis(a, cax_guess, 0)\n",
    "    elif a.ndim == 3:\n",
    "        z, y, x = a.shape\n",
    "        return np.stack([a, a], axis=0)\n",
    "    raise ValueError(f\"can't determine channel axis (shape: {a.shape})\")\n",
    "\n",
    "\n",
    "def compute_height(mask):\n",
    "    z_nonzero = np.any(mask, axis=(1,2))\n",
    "    return float(z_nonzero.sum())\n",
    "\n",
    "def compute_volume(mask):\n",
    "    return float(np.count_nonzero(mask))\n",
    "\n",
    "def compute_major_tilt_degrees(mask):\n",
    "    coords = np.argwhere(mask)\n",
    "    if coords.shape[0] < 3:\n",
    "        return float(np.nan)\n",
    "    coords_centered = coords - coords.mean(axis=0)\n",
    "    pca = PCA(n_components=3)\n",
    "    try:\n",
    "        pca.fit(coords_centered)\n",
    "    except Exception:\n",
    "        return float(np.nan)\n",
    "    principal = pca.components_[0]\n",
    "    z_axis = np.array([1.0, 0.0, 0.0])\n",
    "    dot = np.abs(np.dot(principal, z_axis))\n",
    "    dot = np.clip(dot / (np.linalg.norm(principal) * np.linalg.norm(z_axis)), -1.0, 1.0)\n",
    "    angle_rad = np.arccos(dot)\n",
    "    angle_deg = np.degrees(angle_rad)\n",
    "    if angle_deg > 90:\n",
    "        angle_deg = 180 - angle_deg\n",
    "    return float(angle_deg)\n",
    "\n",
    "def compute_sphericity(mask):\n",
    "    V = np.count_nonzero(mask)\n",
    "    if V <= 0:\n",
    "        return float(np.nan)\n",
    "        \n",
    "    try:\n",
    "        verts, faces, normals, values = measure.marching_cubes(mask.astype(np.uint8), level=0.5)\n",
    "        tris = verts[faces]\n",
    "        vec0 = tris[:,1] - tris[:,0]\n",
    "        vec1 = tris[:,2] - tris[:,0]\n",
    "        cross_prod = np.cross(vec0, vec1)\n",
    "        tri_areas = 0.5 * np.linalg.norm(cross_prod, axis=1)\n",
    "        A = float(tri_areas.sum())\n",
    "        if A <= 0:\n",
    "            raise RuntimeError(\"zero surface area\")\n",
    "        sphericity = (np.pi ** (1.0/3.0)) * ((6.0 * V) ** (2.0/3.0)) / A\n",
    "        return float(sphericity)\n",
    "    except Exception:\n",
    "        padded = np.pad(mask.astype(np.uint8), pad_width=1, mode='constant', constant_values=0)\n",
    "        neighbors_sum = (\n",
    "            padded[2:,1:-1,1:-1] + padded[:-2,1:-1,1:-1] +\n",
    "            padded[1:-1,2:,1:-1] + padded[1:-1,:-2,1:-1] +\n",
    "            padded[1:-1,1:-1,2:] + padded[1:-1,1:-1,:-2]\n",
    "        )\n",
    "        interior = (neighbors_sum == 6) & (padded[1:-1,1:-1,1:-1]==1)\n",
    "        boundary_voxels = np.count_nonzero((padded[1:-1,1:-1,1:-1]==1) & (~interior))\n",
    "        A_approx = float(boundary_voxels)\n",
    "        if A_approx <= 0:\n",
    "            return float(np.nan)\n",
    "        sphericity = (np.pi ** (1.0/3.0)) * ((6.0 * V) ** (2.0/3.0)) / A_approx\n",
    "        return float(sphericity)\n",
    "\n",
    "\n",
    "with open(augmented_latent_path, 'r') as f:\n",
    "    aug = json.load(f)\n",
    "\n",
    "aug_map = dict(aug)\n",
    "if \"scale_factor\" in aug_map:\n",
    "    del aug_map[\"scale_factor\"]\n",
    "\n",
    "ids = list(aug_map.keys())\n",
    "vectors = np.array([np.asarray(aug_map[i], dtype=float) for i in ids])\n",
    "N = vectors.shape[0]\n",
    "print(f\"Loaded {N} augmented vectors (70-D)\")\n",
    "\n",
    "nucleus_idx = list(range(0,32)) + [64,65,66]   # first 32 latent dims + nucleus sizes\n",
    "cell_idx = list(range(32,64)) + [67,68,69]    # next 32 latent dims + cell sizes\n",
    "\n",
    "nucleus_X = vectors[:, nucleus_idx]\n",
    "cell_X = vectors[:, cell_idx]\n",
    "\n",
    "with open(metadata_path, 'r') as f:\n",
    "    metadata = json.load(f)\n",
    "id_to_meta = {entry['id']: entry for entry in metadata}\n",
    "\n",
    "label_numeric = np.full((N,), np.nan)\n",
    "nuc_height = np.full((N,), np.nan)\n",
    "nuc_volume = np.full((N,), np.nan)\n",
    "nuc_sphericity = np.full((N,), np.nan)\n",
    "nuc_tilt = np.full((N,), np.nan)\n",
    "\n",
    "cell_height = np.full((N,), np.nan)\n",
    "cell_volume = np.full((N,), np.nan)\n",
    "cell_sphericity = np.full((N,), np.nan)\n",
    "cell_tilt = np.full((N,), np.nan)\n",
    "\n",
    "\n",
    "print(\"Computing morphological metrics for nucleus & cell (this can take a while)...\")\n",
    "for i, sid in enumerate(tqdm(ids)):\n",
    "    meta = id_to_meta.get(sid)\n",
    "    if meta is None:\n",
    "        continue\n",
    "    lab = meta.get('label')\n",
    "    if lab in label_to_index:\n",
    "        label_numeric[i] = label_to_index[lab]\n",
    "    else:\n",
    "        label_numeric[i] = np.nan\n",
    "\n",
    "    seg_rel = meta.get('seg_file')\n",
    "    if seg_rel is None:\n",
    "        continue\n",
    "        \n",
    "    if seg_rel.startswith(\"crop_seg/\") or seg_rel.startswith(\"./crop_seg/\"):\n",
    "        seg_rel_path = seg_rel.split(\"crop_seg/\", 1)[1] if \"crop_seg/\" in seg_rel else seg_rel\n",
    "        image_path = os.path.join(base_folder, seg_rel_path)\n",
    "    else:\n",
    "        image_path = os.path.join(base_folder, seg_rel)\n",
    "\n",
    "    if not os.path.exists(image_path):\n",
    "        alt = os.path.join(base_folder, seg_rel)\n",
    "        if os.path.exists(alt):\n",
    "            image_path = alt\n",
    "        else:\n",
    "            print(f\"Image file not found for id {sid}: {image_path}\")\n",
    "            continue\n",
    "\n",
    "    try:\n",
    "        im = tifffile.imread(image_path)\n",
    "        chs = extract_two_channels(im)\n",
    "    except Exception as e:\n",
    "        print(f\"Could not read or parse {image_path}: {e}\")\n",
    "        continue\n",
    "\n",
    "    nuc_mask = (chs[0] > 0.5)\n",
    "    cell_mask = (chs[1] > 0.5)\n",
    "\n",
    "    nuc_height[i] = compute_height(nuc_mask)\n",
    "    nuc_volume[i] = compute_volume(nuc_mask)\n",
    "    nuc_sphericity[i] = compute_sphericity(nuc_mask)\n",
    "    nuc_tilt[i] = compute_major_tilt_degrees(nuc_mask)\n",
    "\n",
    "    cell_height[i] = compute_height(cell_mask)\n",
    "    cell_volume[i] = compute_volume(cell_mask)\n",
    "    cell_sphericity[i] = compute_sphericity(cell_mask)\n",
    "    cell_tilt[i] = compute_major_tilt_degrees(cell_mask)\n",
    "\n",
    "\n",
    "print(\"Running UMAP on nucleus and cell feature sets...\")\n",
    "umap_nuc = umap_module.UMAP(n_components=2, random_state=42)\n",
    "umap_cell = umap_module.UMAP(n_components=2, random_state=42)\n",
    "\n",
    "nuc_emb = umap_nuc.fit_transform(nucleus_X)\n",
    "cell_emb = umap_cell.fit_transform(cell_X)\n",
    "\n",
    "def normalize_for_cmap(vals):\n",
    "    vals = np.array(vals, dtype=float)\n",
    "    mask = ~np.isnan(vals)\n",
    "    v = vals.copy()\n",
    "    if np.any(mask):\n",
    "        mn, mx = np.nanmin(v), np.nanmax(v)\n",
    "        if mn == mx:\n",
    "            return np.zeros_like(v)\n",
    "        return (v - mn) / (mx - mn)\n",
    "    else:\n",
    "        return np.zeros_like(v)\n",
    "\n",
    "\n",
    "# Plotting:\n",
    "fig, axes = plt.subplots(2, 6, figsize=(24, 8))\n",
    "plt.subplots_adjust(wspace=0.6, hspace=0.3)\n",
    "\n",
    "label_norm = (label_numeric - np.nanmin(label_numeric)) / (np.nanmax(label_numeric) - np.nanmin(label_numeric) + 1e-12)\n",
    "label_colors = cmap_labels(np.clip(label_norm, 0, 1))\n",
    "\n",
    "unique_labels = sorted(set(label_to_index.keys()), key=lambda x: label_to_index[x])\n",
    "for row in [0, 1]:\n",
    "    ax = axes[row, 0]\n",
    "    ax.axis(\"off\")\n",
    "    ax.set_title(\"Label key\", fontsize=12)\n",
    "    for i, lab in enumerate(unique_labels):\n",
    "        color = cmap_labels(label_to_index[lab] / (len(unique_labels)-1 + 1e-12))\n",
    "        ax.scatter([], [], c=[color], s=40, label=lab)\n",
    "    ax.legend(loc=\"center\", fontsize=10, frameon=False)\n",
    "\n",
    "nuc_metrics = [\n",
    "    (\"True Phase Label\", label_numeric, label_colors, None),\n",
    "    (\"Height (slices)\", nuc_height, None, \"viridis\"),\n",
    "    (\"Volume (voxels)\", nuc_volume, None, \"plasma\"),\n",
    "    (\"Sphericity\", nuc_sphericity, None, \"magma\"),\n",
    "    (\"Major tilt (deg)\", nuc_tilt, None, \"cividis\"),\n",
    "]\n",
    "for j, (title, vals, color_vals, cmap_name) in enumerate(nuc_metrics):\n",
    "    ax = axes[0, j+1]\n",
    "    if vals is None:\n",
    "        vals = np.zeros(N)\n",
    "    if color_vals is None:\n",
    "        norm_vals = normalize_for_cmap(vals)\n",
    "        colors = plt.get_cmap(cmap_name)(norm_vals)\n",
    "    else:\n",
    "        colors = color_vals\n",
    "    ax.scatter(nuc_emb[:,0], nuc_emb[:,1], c=colors, s=8)\n",
    "    ax.set_title(f\"Nucleus UMAP — {title}\")\n",
    "    ax.set_xlabel(\"UMAP1\"); ax.set_ylabel(\"UMAP2\")\n",
    "    ax.set_xticks([]); ax.set_yticks([])\n",
    "    if color_vals is None:\n",
    "        sm = plt.cm.ScalarMappable(cmap=cmap_name, norm=plt.Normalize(vmin=np.nanmin(vals), vmax=np.nanmax(vals)))\n",
    "        sm.set_array([])\n",
    "        cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)\n",
    "        cbar.ax.tick_params(labelsize=8)\n",
    "\n",
    "cell_metrics = [\n",
    "    (\"True label\", label_numeric, label_colors, None),\n",
    "    (\"Height (slices)\", cell_height, None, \"viridis\"),\n",
    "    (\"Volume (voxels)\", cell_volume, None, \"plasma\"),\n",
    "    (\"Sphericity\", cell_sphericity, None, \"magma\"),\n",
    "    (\"Major tilt (deg)\", cell_tilt, None, \"cividis\"),\n",
    "]\n",
    "for j, (title, vals, color_vals, cmap_name) in enumerate(cell_metrics):\n",
    "    ax = axes[1, j+1]\n",
    "    if vals is None:\n",
    "        vals = np.zeros(N)\n",
    "    if color_vals is None:\n",
    "        norm_vals = normalize_for_cmap(vals)\n",
    "        colors = plt.get_cmap(cmap_name)(norm_vals)\n",
    "    else:\n",
    "        colors = color_vals\n",
    "    ax.scatter(cell_emb[:,0], cell_emb[:,1], c=colors, s=8)\n",
    "    ax.set_title(f\"Cell UMAP — {title}\")\n",
    "    ax.set_xlabel(\"UMAP1\"); ax.set_ylabel(\"UMAP2\")\n",
    "    ax.set_xticks([]); ax.set_yticks([])\n",
    "    if color_vals is None:\n",
    "        sm = plt.cm.ScalarMappable(cmap=cmap_name, norm=plt.Normalize(vmin=np.nanmin(vals), vmax=np.nanmax(vals)))\n",
    "        sm.set_array([])\n",
    "        cbar = fig.colorbar(sm, ax=ax, fraction=0.046, pad=0.04)\n",
    "        cbar.ax.tick_params(labelsize=8)\n",
    "\n",
    "plt.suptitle(\"UMAP (Nucleus / Cell) — Colorings: label, height, volume, sphericity, tilt\", fontsize=18)\n",
    "plt.savefig(\"result_data/figures/UMAP_COLOURINGS.png\", format=\"png\", bbox_inches=\"tight\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "358e8b76-d098-45cd-9002-13cf56b3b778",
   "metadata": {},
   "source": [
    "<p style=\"height: 200px;\"></p>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38acd4ab-c95a-4dfa-9f65-0ad8949242c2",
   "metadata": {},
   "source": [
    "## Principal Curve"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2fe3cbd-a53d-4d0f-aca7-551cee6f6e5f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Compute mean latent space vectors for each phase label based on the full balanced data, visualize 2D slices\n",
    "\n",
    "\n",
    "latent_path = 'result_data/latent_space_full_64_balanced.json'\n",
    "metadata_path = 'result_data/cell_data_64_balanced.json'\n",
    "model_path = 'result_data/dual_branch_vae_64_balanced.pth'\n",
    "\n",
    "with open(latent_path, 'r') as f:\n",
    "    latent_data = json.load(f)\n",
    "\n",
    "with open(metadata_path, 'r') as f:\n",
    "    metadata = json.load(f)\n",
    "\n",
    "id_to_label = {entry[\"id\"]: entry[\"label\"] for entry in metadata}\n",
    "\n",
    "label_to_latents = {}\n",
    "for id_, vec in latent_data.items():\n",
    "    if id_ in id_to_label:\n",
    "        label = id_to_label[id_]\n",
    "        label_to_latents.setdefault(label, []).append(vec)\n",
    "\n",
    "label_means = {\n",
    "    label: np.mean(vectors, axis=0)\n",
    "    for label, vectors in label_to_latents.items()\n",
    "}\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model = DualBranchVAE(latent_dim=64, base_channels=16, stride=2).to(device)\n",
    "model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "model.eval()\n",
    "\n",
    "ordered_labels = ['M0', 'M1M2', 'M3', 'M4M5', 'M6M7_early', 'M6M7_half', 'M0']\n",
    "\n",
    "with torch.no_grad():\n",
    "    for label in ordered_labels:\n",
    "        if label not in label_means:\n",
    "            print(f\"Skipping label {label}: no latent data.\")\n",
    "            continue\n",
    "\n",
    "        mean_vec = torch.tensor(label_means[label], dtype=torch.float32).to(device).unsqueeze(0)\n",
    "        recon = model.decode(mean_vec).cpu().squeeze(0).numpy()\n",
    "        \n",
    "        recon = np.clip(recon, 1e-10, 1 - 1e-10)\n",
    "        recon_binary = (recon >= 0.15).astype(np.uint8)\n",
    "\n",
    "        fig, axs = plt.subplots(1, 2, figsize=(6, 3))\n",
    "        axs[0].imshow(recon_binary[0, :, 32, :], cmap='gray')\n",
    "        axs[0].set_title(f\"{label} — Nucleus\")\n",
    "        axs[0].axis('off')\n",
    "\n",
    "        axs[1].imshow(recon_binary[1, :, 32, :], cmap='gray')\n",
    "        axs[1].set_title(f\"{label} — Cell Body\")\n",
    "        axs[1].axis('off')\n",
    "\n",
    "        plt.suptitle(f\"Mean Shape Reconstruction: {label}\", fontsize=14)\n",
    "        plt.tight_layout()\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0758a81-70f5-427d-bb37-5eddccb7f25c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Visualize same phase means as above, but in 3D interactive graphs\n",
    "\n",
    "\n",
    "latent_path = 'result_data/latent_space_full_64_balanced.json'\n",
    "metadata_path = 'result_data/cell_data_64_balanced.json'\n",
    "model_path = 'result_data/dual_branch_vae_64_balanced.pth'\n",
    "\n",
    "with open(latent_path, 'r') as f:\n",
    "    latent_data = json.load(f)\n",
    "\n",
    "with open(metadata_path, 'r') as f:\n",
    "    metadata = json.load(f)\n",
    "\n",
    "id_to_label = {entry[\"id\"]: entry[\"label\"] for entry in metadata}\n",
    "\n",
    "label_to_latents = {}\n",
    "for id_, vec in latent_data.items():\n",
    "    if id_ in id_to_label:\n",
    "        label = id_to_label[id_]\n",
    "        label_to_latents.setdefault(label, []).append(vec)\n",
    "\n",
    "label_means = {\n",
    "    label: np.mean(vectors, axis=0)\n",
    "    for label, vectors in label_to_latents.items()\n",
    "}\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model = DualBranchVAE(latent_dim=64, base_channels=16, stride=2).to(device)\n",
    "model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "model.eval()\n",
    "\n",
    "ordered_labels = ['M0', 'M1M2', 'M3', 'M4M5', 'M6M7_early', 'M6M7_half', 'M0']\n",
    "\n",
    "\n",
    "def plot_3d_voxel(voxel_data, title=\"3D Reconstruction\"):\n",
    "    \"\"\"Plot a 3D binary voxel volume using Plotly with fixed axis scales.\"\"\"\n",
    "    x, y, z = np.where(voxel_data == 1)\n",
    "    fig = go.Figure(data=go.Scatter3d(\n",
    "        x=x, y=y, z=z,\n",
    "        mode='markers',\n",
    "        marker=dict(size=2, color=z, colorscale='Viridis', opacity=0.8)\n",
    "    ))\n",
    "    fig.update_layout(\n",
    "        scene=dict(\n",
    "            xaxis=dict(title='X', range=[0, 64]),\n",
    "            yaxis=dict(title='Y', range=[0, 64]),\n",
    "            zaxis=dict(title='Z', range=[0, 64]),\n",
    "            aspectmode='cube'\n",
    "        ),\n",
    "        title=title,\n",
    "        margin=dict(l=0, r=0, b=0, t=30)\n",
    "    )\n",
    "    fig.show()\n",
    "\n",
    "\n",
    "with torch.no_grad():\n",
    "    for label in ordered_labels:\n",
    "        if label not in label_means:\n",
    "            print(f\"Skipping label {label}: no latent data.\")\n",
    "            continue\n",
    "\n",
    "        mean_vec = torch.tensor(label_means[label], dtype=torch.float32).to(device).unsqueeze(0)\n",
    "        recon = model.decode(mean_vec).cpu().squeeze(0).numpy()\n",
    "\n",
    "        recon = np.clip(recon, 1e-10, 1 - 1e-10)\n",
    "        recon_binary = (recon >= 0.15).astype(np.uint8)\n",
    "\n",
    "        print(f\"Label: {label} — Nucleus\")\n",
    "        plot_3d_voxel(recon_binary[0], title=f\"{label} — Nucleus\")\n",
    "\n",
    "        print(f\"Label: {label} — Cell Body\")\n",
    "        plot_3d_voxel(recon_binary[1], title=f\"{label} — Cell Body\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bcc8145b-6d6f-43a0-97d9-45b29494da3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Defines implementation for our own principle curve finder; it is a L1 regression regularized by arc length and curvature,\n",
    "#         that must pass through all phase means and complete a full cycle -- 100 points are then sampled from the curve. \n",
    "#         The curve is defined by the parameter t, with cubic splines to represent each dimension of the curve\n",
    "\n",
    "\n",
    "def load_phase_means(latent_path, metadata_path, valid_labels):\n",
    "    with open(latent_path, 'r') as f:\n",
    "        latent_data = json.load(f)\n",
    "    with open(metadata_path, 'r') as f:\n",
    "        metadata = json.load(f)\n",
    "    id_to_label = {entry['id']: entry['label'] for entry in metadata}\n",
    "    vectors_by_phase = {label: [] for label in valid_labels}\n",
    "    for id_, vec in latent_data.items():\n",
    "        label = id_to_label.get(id_)\n",
    "        if label in valid_labels:\n",
    "            vectors_by_phase[label].append(vec)\n",
    "    phase_means = []\n",
    "    for label in valid_labels:\n",
    "        phase_vecs = np.array(vectors_by_phase[label])\n",
    "        phase_means.append(phase_vecs.mean(axis=0))\n",
    "    return np.array(phase_means)\n",
    "\n",
    "\n",
    "def curve_length(points):\n",
    "    diffs = np.diff(points, axis=0)\n",
    "    return np.sum(np.linalg.norm(diffs, axis=1))\n",
    "\n",
    "\n",
    "def curvature_penalty(points):\n",
    "    second_diffs = np.diff(points, n=2, axis=0)\n",
    "    return np.sum(np.linalg.norm(second_diffs, axis=1) ** 2)\n",
    "\n",
    "\n",
    "def initialize_curve_from_means(fixed_points, num_points):\n",
    "    dim = fixed_points.shape[1]\n",
    "    num_phases = len(fixed_points)\n",
    "    points_per_segment = num_points // num_phases\n",
    "    remainder = num_points % num_phases\n",
    "    interp_points = []\n",
    "    for i in range(num_phases):\n",
    "        start = fixed_points[i]\n",
    "        end = fixed_points[(i + 1) % num_phases]\n",
    "        steps = points_per_segment + (1 if i < remainder else 0)\n",
    "        alphas = np.linspace(0, 1, steps + 1, endpoint=False)\n",
    "        if i > 0:\n",
    "            alphas = alphas[1:]\n",
    "        for alpha in alphas:\n",
    "            interp_points.append((1 - alpha) * start + alpha * end)\n",
    "    curve_init = np.array(interp_points[:num_points])\n",
    "    assert curve_init.shape == (num_points, dim)\n",
    "    return curve_init\n",
    "\n",
    "\n",
    "def curve_objective_free(params, fixed_points, num_points, lambda_length, lambda_curvature, lambda_regression, data_points):\n",
    "    dim = fixed_points.shape[1]\n",
    "    num_fixed = len(fixed_points)\n",
    "    num_free = num_points - num_fixed\n",
    "    curve = np.zeros((num_points, dim))\n",
    "    curve[:num_fixed] = fixed_points\n",
    "    curve[num_fixed:] = params.reshape((num_free, dim))\n",
    "    length_term = curve_length(curve)\n",
    "    curvature_term = curvature_penalty(curve)\n",
    "    regression_term = 0                        # Regression error (L1 to nearest point on curve)\n",
    "    for dp in data_points:\n",
    "        dists = np.sum(np.abs(curve - dp), axis=1)\n",
    "        regression_term += np.min(dists)\n",
    "    regression_term /= len(data_points)\n",
    "\n",
    "    total = (\n",
    "        lambda_length * length_term +\n",
    "        lambda_curvature * curvature_term +\n",
    "        lambda_regression * regression_term\n",
    "    )\n",
    "\n",
    "    print(f\"Length: {length_term:.4f}, Curv: {curvature_term:.4f}, Reg: {regression_term:.4f}, Total: {total:.4f}\")\n",
    "    return total\n",
    "\n",
    "\n",
    "def optimize_closed_curve(fixed_points, num_points, max_iter, lambda_length, lambda_curvature, lambda_regression, data_points):\n",
    "    dim = fixed_points.shape[1]\n",
    "    num_fixed = len(fixed_points)\n",
    "    num_free = num_points - num_fixed\n",
    "    curve_init = initialize_curve_from_means(fixed_points, num_points)\n",
    "    free_part_init = curve_init[num_fixed:]\n",
    "    x0 = free_part_init.flatten()\n",
    "\n",
    "    result = minimize(\n",
    "        lambda x: curve_objective_free(x, fixed_points, num_points, lambda_length, lambda_curvature, lambda_regression, data_points),\n",
    "        x0,\n",
    "        method='L-BFGS-B',\n",
    "        options={'maxiter': max_iter, 'disp': True}\n",
    "    )\n",
    "\n",
    "    optimized = np.zeros((num_points, dim))\n",
    "    optimized[:num_fixed] = fixed_points\n",
    "    optimized[num_fixed:] = result.x.reshape((num_free, dim))\n",
    "    return optimized\n",
    "\n",
    "\n",
    "def parametrize_curve(points):\n",
    "    diffs = np.diff(points, axis=0)\n",
    "    segment_lengths = np.linalg.norm(diffs, axis=1)\n",
    "    arc_lengths = np.insert(np.cumsum(segment_lengths), 0, 0)\n",
    "    t_vals = arc_lengths / arc_lengths[-1]\n",
    "    return t_vals\n",
    "\n",
    "\n",
    "def get_curve_interpolator(t_vals, curve_points):\n",
    "    interpolators = []\n",
    "    for dim in range(curve_points.shape[1]):\n",
    "        cs = CubicSpline(t_vals, curve_points[:, dim], bc_type='natural')\n",
    "        interpolators.append(cs)\n",
    "\n",
    "    def gamma(t_query):\n",
    "        t_query = np.asarray(t_query)\n",
    "        return np.stack([cs(t_query) for cs in interpolators], axis=-1)\n",
    "\n",
    "    return gamma\n",
    "\n",
    "\n",
    "def sample_curve(gamma_fn, num_points=100):\n",
    "    ts = np.linspace(0, 0.525, num_points, endpoint=False)\n",
    "    return gamma_fn(ts)\n",
    "\n",
    "\n",
    "def save_curve(points, path):\n",
    "    with open(path, 'w') as f:\n",
    "        json.dump(points.tolist(), f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb0d143d-9efe-453a-b9ac-b99e2cbbf215",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Runs fitting of curve to our full latent space data\n",
    "\n",
    "\n",
    "metadata_path = 'result_data/cell_data_64_balanced.json'\n",
    "valid_labels = ['M0','M1M2','M3','M4M5','M6M7_early','M6M7_half']\n",
    "\n",
    "latent_path = 'result_data/latent_space_full_64_balanced.json'\n",
    "output_path = 'result_data/latent_principal_curve_64D_100pts.json'\n",
    "\n",
    "# latent_path = 'result_data/augmented_latent_space_full_64_balanced.json'       # Run after to produce augmented principal curve\n",
    "# output_path = 'result_data/augmented_latent_principal_curve_64D_100pts.json'\n",
    "\n",
    "phase_means = load_phase_means(latent_path, metadata_path, valid_labels)\n",
    "\n",
    "with open(latent_path, 'r') as f:\n",
    "    latent_data = json.load(f)\n",
    "\n",
    "with open(metadata_path, 'r') as f:\n",
    "    metadata = json.load(f)\n",
    "\n",
    "id_to_label = {entry['id']: entry['label'] for entry in metadata}\n",
    "\n",
    "data_points = []\n",
    "for id_, vec in latent_data.items():\n",
    "    if id_to_label.get(id_) in valid_labels:\n",
    "        data_points.append(vec)\n",
    "\n",
    "data_points = np.array(data_points)\n",
    "\n",
    "optimized_curve = optimize_closed_curve(\n",
    "    fixed_points=phase_means,\n",
    "    num_points=100,\n",
    "    max_iter=10000,\n",
    "    lambda_length=1,\n",
    "    lambda_curvature=1,\n",
    "    lambda_regression=1,\n",
    "    data_points=data_points\n",
    ")\n",
    "\n",
    "t_vals = parametrize_curve(optimized_curve)\n",
    "gamma = get_curve_interpolator(t_vals, optimized_curve)\n",
    "sampled_points = sample_curve(gamma, num_points=100)\n",
    "save_curve(sampled_points, output_path)\n",
    "\n",
    "print(\"done\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fccb41ce-588d-4ee9-8018-d3dc63883f87",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creates and saves a smooth, mean-reverting curve with stochastic deviations off of the principal curve\n",
    "\n",
    "\n",
    "def sample_deviated_curve(base_curve, deviation_scale=0.1, reversion_strength=0.05, seed=72):\n",
    "    \"\"\"\n",
    "    Adds smooth, mean-reverting stochastic deviations to a principal curve\n",
    "    \"\"\"\n",
    "    np.random.seed(seed)\n",
    "    num_points, dim = base_curve.shape\n",
    "    deviations = np.zeros((num_points, dim))\n",
    "    \n",
    "    for d in range(dim):\n",
    "        for i in range(1, num_points):\n",
    "            noise = np.random.normal(0, deviation_scale)\n",
    "            drift = -reversion_strength * deviations[i - 1, d]\n",
    "            deviations[i, d] = deviations[i - 1, d] + drift + noise\n",
    "            \n",
    "    deviated_curve = base_curve + deviations\n",
    "    return deviated_curve\n",
    "\n",
    "\n",
    "deviated_curve = sample_deviated_curve(\n",
    "    base_curve=sampled_points,\n",
    "    deviation_scale=0.1,        # adjusts noise level\n",
    "    reversion_strength=0.05,     # controls strength of tendency to stay near the original curve\n",
    "    seed=1337 \n",
    ")\n",
    "\n",
    "deviated_output_path = 'result_data/latent_principal_curve_64D_100pts_deviated.json'\n",
    "save_curve(deviated_curve, deviated_output_path)\n",
    "print(\"Saved deviated curve to:\", deviated_output_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31c7fa57-b3ab-456c-ba3f-b0c021e2841f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creates 3D animation along principal curve (without consideration of scaling factors)\n",
    "\n",
    "\n",
    "with open('result_data/latent_principal_curve_64D_100pts.json', 'r') as f:\n",
    "    curve_points = json.load(f)\n",
    "\n",
    "model_path = 'result_data/dual_branch_vae_64_balanced.pth'\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "model = DualBranchVAE(latent_dim=64, base_channels=16, stride=2).to(device)\n",
    "model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "model.eval()\n",
    "\n",
    "recon_channel = []\n",
    "with torch.no_grad():\n",
    "    for i, vec in enumerate(curve_points):\n",
    "        latent = torch.tensor(vec, dtype=torch.float32).to(device).unsqueeze(0)\n",
    "        recon = model.decode(latent).cpu().squeeze(0).numpy()\n",
    "        recon = np.clip(recon, 1e-10, 1 - 1e-10)\n",
    "        recon_binary = (recon >= 0.15).astype(np.uint8)\n",
    "        recon_channel.append(recon_binary[0])\n",
    "        # recon_channel.append(recon_binary[1])   # for cell channel opposed to nucleus\n",
    "\n",
    "frames = []\n",
    "for i, volume in enumerate(recon_channel):\n",
    "    x, y, z = np.where(volume == 1)\n",
    "    scatter = go.Scatter3d(\n",
    "        x=x, y=y, z=z,\n",
    "        mode='markers',\n",
    "        marker=dict(size=2, color=z, colorscale='Viridis', opacity=0.8),\n",
    "        name=f\"Frame {i}\"\n",
    "    )\n",
    "    frame = go.Frame(data=[scatter], name=str(i))\n",
    "    frames.append(frame)\n",
    "\n",
    "init_x, init_y, init_z = np.where(recon_channel[0] == 1)\n",
    "# init_x, init_y, init_z = np.where(recon_channel[1] == 1)   # for cell channel opposed to nucleus\n",
    "fig = go.Figure(\n",
    "    data=[go.Scatter3d(\n",
    "        x=init_x, y=init_y, z=init_z,\n",
    "        mode='markers',\n",
    "        marker=dict(size=2, color=init_z, colorscale='Viridis', opacity=0.7)\n",
    "    )],\n",
    "    layout=go.Layout(\n",
    "        title=\"3D Cell Nucleus Morphing Along Principal Curve\",\n",
    "        scene=dict(\n",
    "            xaxis=dict(range=[0, 64], title='X'),\n",
    "            yaxis=dict(range=[0, 64], title='Y'),\n",
    "            zaxis=dict(range=[0, 64], title='Z'),\n",
    "            aspectmode='cube'\n",
    "        ),\n",
    "        updatemenus=[dict(\n",
    "            type='buttons',\n",
    "            buttons=[\n",
    "                dict(label='▶️ Play', method='animate', args=[None, {\"frame\": {\"duration\": 300, \"redraw\": True}, \"fromcurrent\": True, \"loop\": True}]),\n",
    "                dict(label='⏸ Pause', method='animate', args=[[None], {\"frame\": {\"duration\": 0, \"redraw\": False}, \"mode\": \"immediate\", \"transition\": {\"duration\": 0}}])\n",
    "            ]\n",
    "        )]\n",
    "    ),\n",
    "    frames=frames\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "\n",
    "fig.write_html(\"result_data/figures/nucleus_morphing_principal_curve.html\")\n",
    "# fig.write_html(\"result_data/figures/cell_morphing_principal_curve.html\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c50ebd93-55bf-4f99-a636-7f2730db19b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creates same 3D animation along principal curve but with mesh visualization\n",
    "\n",
    "\n",
    "with open('result_data/latent_principal_curve_64D_100pts.json', 'r') as f:\n",
    "    curve_points = json.load(f)\n",
    "\n",
    "model_path = 'result_data/dual_branch_vae_64_balanced.pth'\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "model = DualBranchVAE(latent_dim=64, base_channels=16, stride=2).to(device)\n",
    "model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "model.eval()\n",
    "\n",
    "recon_channel = []\n",
    "with torch.no_grad():\n",
    "    for vec in curve_points:\n",
    "        latent = torch.tensor(vec, dtype=torch.float32).to(device).unsqueeze(0)\n",
    "        recon = model.decode(latent).cpu().squeeze(0).numpy()\n",
    "        recon = np.clip(recon, 1e-10, 1 - 1e-10)\n",
    "        recon_binary = (recon >= 0.15).astype(np.uint8)\n",
    "        recon_channel.append(recon_binary[0])   # nucleus channel\n",
    "        # recon_channel.append(recon_binary[1]) # cell channel\n",
    "\n",
    "def volume_to_mesh(volume, level=0.5):\n",
    "    verts, faces, normals, values = measure.marching_cubes(volume, level=level)\n",
    "    x, y, z = verts.T\n",
    "    i, j, k = faces.T\n",
    "    mesh = go.Mesh3d(\n",
    "        x=x, y=y, z=z,\n",
    "        i=i, j=j, k=k,\n",
    "        color='teal',\n",
    "        opacity=0.6,\n",
    "        name=\"mesh\"\n",
    "    )\n",
    "    return mesh\n",
    "\n",
    "frames = []\n",
    "for i, volume in enumerate(recon_channel):\n",
    "    mesh = volume_to_mesh(volume, level=0.5)\n",
    "    frame = go.Frame(data=[mesh], name=str(i))\n",
    "    frames.append(frame)\n",
    "\n",
    "init_mesh = volume_to_mesh(recon_channel[0], level=0.5)   # nucleus channel\n",
    "# init_mesh = volume_to_mesh(recon_channel[1], level=0.5)  # cell channel\n",
    "\n",
    "fig = go.Figure(\n",
    "    data=[init_mesh],\n",
    "    layout=go.Layout(\n",
    "        title=\"3D Cell Nucleus Morphing Along Principal Curve (Mesh Surface)\",\n",
    "        scene=dict(\n",
    "            xaxis=dict(range=[0, 64], title='X'),\n",
    "            yaxis=dict(range=[0, 64], title='Y'),\n",
    "            zaxis=dict(range=[0, 64], title='Z'),\n",
    "            aspectmode='cube'\n",
    "        ),\n",
    "        updatemenus=[dict(\n",
    "            type='buttons',\n",
    "            buttons=[\n",
    "                dict(label='▶️ Play', method='animate',\n",
    "                     args=[None, {\"frame\": {\"duration\": 300, \"redraw\": True},\n",
    "                                  \"fromcurrent\": True, \"loop\": True}]),\n",
    "                dict(label='⏸ Pause', method='animate',\n",
    "                     args=[[None], {\"frame\": {\"duration\": 0, \"redraw\": False},\n",
    "                                    \"mode\": \"immediate\",\n",
    "                                    \"transition\": {\"duration\": 0}}])\n",
    "            ]\n",
    "        )]\n",
    "    ),\n",
    "    frames=frames\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "\n",
    "fig.write_html(\"result_data/figures/nucleus_morphing_principal_curve_mesh.html\")\n",
    "# fig.write_html(\"result_data/figures/cell_morphing_principal_curve_mesh.html\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "32fcf667-51a0-4f29-b404-fc3100186c6f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Creates 3D animation along principal curve with consideration of scaling factors, changing axis-scale zooms as needed.\n",
    "# Note that the 3D images were provided with higher single-layer pixels than number of layers, so depth axis (z)\n",
    "# is also scaled by average ratio of axes to be compared properly against the other dimensions\n",
    "\n",
    "\n",
    "def get_ratio():\n",
    "    metadata_path = \"result_data/cell_data_64_balanced.json\"\n",
    "    with open(metadata_path, \"r\") as f:\n",
    "        metadata = json.load(f)\n",
    "    size_cell_1_vals = []\n",
    "    size_cell_23_vals = []\n",
    "    for entry in metadata:\n",
    "        if \"size_cell_1\" in entry and \"size_cell_2\" in entry and \"size_cell_3\" in entry:\n",
    "            s1 = entry[\"size_cell_1\"]\n",
    "            s2 = entry[\"size_cell_2\"]\n",
    "            s3 = entry[\"size_cell_3\"]\n",
    "            size_cell_1_vals.append(s1)\n",
    "            size_cell_23_vals.append((s2 + s3) / 2.0)\n",
    "    avg_s1 = np.mean(size_cell_1_vals)\n",
    "    avg_s23 = np.mean(size_cell_23_vals)\n",
    "    ratio = avg_s23 / avg_s1 if avg_s1 != 0 else float(\"inf\")\n",
    "    print(f\"Ratio (avg_s23 / avg_s1): {ratio:.4f}\")\n",
    "    return ratio\n",
    "\n",
    "ratio = get_ratio()\n",
    "\n",
    "with open('result_data/augmented_latent_principal_curve_64D_100pts.json', 'r') as f:\n",
    "    curve_points = json.load(f)\n",
    "    \n",
    "model_path = 'result_data/dual_branch_vae_64_balanced.pth'\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "model = DualBranchVAE(latent_dim=64, base_channels=16, stride=2).to(device)\n",
    "model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "model.eval()\n",
    "\n",
    "recon_channel = []\n",
    "scale_vectors = [] # holds (sz, sy, sx)\n",
    "\n",
    "with torch.no_grad():\n",
    "    for vec in curve_points:\n",
    "        vec = np.array(vec, dtype=np.float32)\n",
    "\n",
    "        latent64 = torch.tensor(vec[:64], dtype=torch.float32).to(device).unsqueeze(0)\n",
    "        scale_vec = vec[64:67]\n",
    "        #scale_vec = vec[67:70]  # for cell channel\n",
    "\n",
    "        recon = model.decode(latent64).cpu().squeeze(0).numpy()\n",
    "        recon = np.clip(recon, 1e-10, 1 - 1e-10)\n",
    "        recon_binary = (recon >= 0.15).astype(np.uint8)\n",
    "        volume = recon_binary[0]   \n",
    "        # volume = recon_binary[1]    # for cell channel\n",
    "\n",
    "        recon_channel.append(volume)\n",
    "        scale_norm = 1 + 0.5 * np.tanh(scale_vec)  \n",
    "        scale_vectors.append(scale_vec*0.07)\n",
    "\n",
    "\n",
    "frames = []\n",
    "for i, (volume, scale) in enumerate(zip(recon_channel, scale_vectors)):\n",
    "    x, y, z = np.where(volume == 1)\n",
    "    scatter = go.Scatter3d(\n",
    "        x=x, y=y, z=z,\n",
    "        mode='markers',\n",
    "        marker=dict(size=2, color=z, colorscale='Viridis', opacity=0.8),\n",
    "        name=f\"Frame {i}\"\n",
    "    )\n",
    "\n",
    "    frame = go.Frame(\n",
    "        data=[scatter],\n",
    "        name=str(i),\n",
    "        layout=go.Layout(\n",
    "            scene=dict(\n",
    "                aspectratio=dict(z = ratio * scale[0], y=scale[1], x=scale[2])\n",
    "            )\n",
    "        )\n",
    "    )\n",
    "    frames.append(frame)\n",
    "\n",
    "init_x, init_y, init_z = np.where(recon_channel[0] == 1)\n",
    "# init_x, init_y, init_z = np.where(recon_channel[1] == 1)   # for cell channel\n",
    "fig = go.Figure(\n",
    "    data=[go.Scatter3d(\n",
    "        x=init_x, y=init_y, z=init_z,\n",
    "        mode='markers',\n",
    "        marker=dict(size=2, color=init_z, colorscale='Viridis', opacity=0.8)\n",
    "    )],\n",
    "    layout=go.Layout(\n",
    "        title=\"3D Cell Nucleus Morphing with Latent-Derived Scaling\",\n",
    "        scene=dict(\n",
    "            xaxis=dict(range=[0, 64], title='X', showgrid=True),\n",
    "            yaxis=dict(range=[0, 64], title='Y', showgrid=True),\n",
    "            zaxis=dict(range=[0, 64], title='Z', showgrid=True),\n",
    "            aspectratio=dict(\n",
    "                z=ratio*scale_vectors[1][0],\n",
    "                y=scale_vectors[1][1],\n",
    "                x=scale_vectors[1][2]\n",
    "            )\n",
    "        ),\n",
    "        updatemenus=[dict(\n",
    "            type='buttons',\n",
    "            buttons=[\n",
    "                dict(label='▶️ Play', method='animate',\n",
    "                     args=[None, {\"frame\": {\"duration\": 300, \"redraw\": True},\n",
    "                                  \"fromcurrent\": True, \"loop\": True}]),\n",
    "                dict(label='⏸ Pause', method='animate',\n",
    "                     args=[[None], {\"frame\": {\"duration\": 0, \"redraw\": False},\n",
    "                                    \"mode\": \"immediate\", \"transition\": {\"duration\": 0}}])\n",
    "            ]\n",
    "        )]\n",
    "    ),\n",
    "    frames=frames\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "fig.write_html(\"result_data/figures/nucleus_morphing_principal_curve_scaled.html\")\n",
    "# fig.write_html(\"result_data/figures/cell_morphing_principal_curve_scaled.html\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc5fb71d-5143-4455-843c-7151b9dd2a87",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Visualize central 2D slices of reconstructed points along entire principle curve\n",
    "\n",
    "\n",
    "curve_path = 'result_data/latent_principal_curve_64D_100pts.json'\n",
    "model_path = 'result_data/dual_branch_vae_64_balanced.pth'\n",
    "\n",
    "with open(curve_path, 'r') as f:\n",
    "    curve_points = json.load(f)\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "model = DualBranchVAE(latent_dim=64, base_channels=16, stride=2).to(device)\n",
    "model.load_state_dict(torch.load(model_path, map_location=device))\n",
    "model.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i, vec in enumerate(curve_points):\n",
    "        latent = torch.tensor(vec, dtype=torch.float32).to(device).unsqueeze(0)\n",
    "        recon = model.decode(latent).cpu().squeeze(0).numpy()\n",
    "\n",
    "        recon = np.clip(recon, 1e-10, 1 - 1e-10)\n",
    "        recon_binary = (recon >= 0.15).astype(np.uint8)\n",
    "\n",
    "        fig, axs = plt.subplots(1, 2, figsize=(6, 3))\n",
    "        axs[0].imshow(recon_binary[0, :, 32, :], cmap='gray')\n",
    "        axs[0].set_title(f\"Nucleus – Point {i+1}\")\n",
    "        axs[0].axis('off')\n",
    "\n",
    "        axs[1].imshow(recon_binary[1, :, 32, :], cmap='gray')\n",
    "        axs[1].set_title(f\"Cell Body – Point {i+1}\")\n",
    "        axs[1].axis('off')\n",
    "\n",
    "        plt.suptitle(f\"Principal Curve Reconstruction – Point {i+1}\", fontsize=14)\n",
    "        plt.tight_layout()\n",
    "        plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "67326670-0154-47fd-8668-506ba9ff52df",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize principal curve UMAP in 3D with randomnly sampled stochastic path\n",
    "\n",
    "\n",
    "latent_path = 'result_data/augmented_latent_space_full_64_balanced.json'\n",
    "metadata_path = 'result_data/cell_data_64_balanced.json'\n",
    "curve_path = 'result_data/augmented_latent_principal_curve_64D_100pts.json'\n",
    "curve_path_deviated = 'result_data/augmented_latent_principal_curve_64D_100pts_deviated.json'\n",
    "\n",
    "with open(latent_path, 'r') as f:\n",
    "    latent_data = json.load(f)\n",
    "\n",
    "with open(metadata_path, 'r') as f:\n",
    "    metadata = json.load(f)\n",
    "\n",
    "with open(curve_path, 'r') as f:\n",
    "    principal_curve = json.load(f)\n",
    "\n",
    "with open(curve_path_deviated, 'r') as f:\n",
    "    deviated_curve = json.load(f)\n",
    "\n",
    "id_to_label = {entry[\"id\"]: entry[\"label\"] for entry in metadata}\n",
    "label_to_latents = defaultdict(list)\n",
    "all_latents = []\n",
    "all_labels = []\n",
    "\n",
    "for id_, vec in latent_data.items():\n",
    "    if id_ in id_to_label:\n",
    "        label = id_to_label[id_]\n",
    "        label_to_latents[label].append(vec)\n",
    "        all_latents.append(vec)\n",
    "        all_labels.append(label)\n",
    "\n",
    "combined = np.array(all_latents + principal_curve + deviated_curve)\n",
    "\n",
    "umap_model = umap.UMAP(n_components=3, random_state=42)\n",
    "embedding = umap_model.fit_transform(combined)\n",
    "\n",
    "n_data = len(all_latents)\n",
    "n_principal = len(principal_curve)\n",
    "n_deviated = len(deviated_curve)\n",
    "\n",
    "data_points = embedding[:n_data]\n",
    "curve_embedded = embedding[n_data:n_data + n_principal]\n",
    "deviated_embedded = embedding[n_data + n_principal:]\n",
    "\n",
    "fig = go.Figure()\n",
    "\n",
    "label_colors = {\n",
    "    'M0': 'blue',\n",
    "    'M1M2': 'green',\n",
    "    'M3': 'orange',\n",
    "    'M4M5': 'purple',\n",
    "    'M6M7_early': 'brown',\n",
    "    'M6M7_half': 'pink'\n",
    "}\n",
    "\n",
    "for label, color in label_colors.items():\n",
    "    indices = [i for i, l in enumerate(all_labels) if l == label]\n",
    "    if not indices:\n",
    "        continue\n",
    "    points = data_points[indices]\n",
    "    fig.add_trace(go.Scatter3d(\n",
    "        x=points[:, 0], y=points[:, 1], z=points[:, 2],\n",
    "        mode='markers',\n",
    "        marker=dict(size=3, color=color, opacity=0.3),\n",
    "        name=label,\n",
    "        text=[label]*len(points),\n",
    "        hoverinfo='text'\n",
    "    ))\n",
    "\n",
    "fig.add_trace(go.Scatter3d(\n",
    "    x=[p[0] for p in curve_embedded],\n",
    "    y=[p[1] for p in curve_embedded],\n",
    "    z=[p[2] for p in curve_embedded],\n",
    "    mode='lines+markers',\n",
    "    line=dict(color='green', width=4),\n",
    "    marker=dict(size=5, color='green'),\n",
    "    name='Principal Curve',\n",
    "    text=[f\"Curve Point {i+1}\" for i in range(len(curve_embedded))],\n",
    "    hoverinfo='text'\n",
    "))\n",
    "\n",
    "fig.add_trace(go.Scatter3d(\n",
    "    x=[p[0] for p in deviated_embedded],\n",
    "    y=[p[1] for p in deviated_embedded],\n",
    "    z=[p[2] for p in deviated_embedded],\n",
    "    mode='lines+markers',\n",
    "    line=dict(color='red', width=4),\n",
    "    marker=dict(size=5, color='red'),\n",
    "    name='Deviated Curve',\n",
    "    text=[f\"Deviated Point {i+1}\" for i in range(len(deviated_embedded))],\n",
    "    hoverinfo='text'\n",
    "))\n",
    "\n",
    "fig.update_layout(\n",
    "    title=\"3D UMAP: Principal Curve Through Cell Cycle Latent Space\",\n",
    "    scene=dict(\n",
    "        xaxis_title='UMAP-1',\n",
    "        yaxis_title='UMAP-2',\n",
    "        zaxis_title='UMAP-3'\n",
    "    ),\n",
    "    legend=dict(x=0.02, y=0.98),\n",
    "    margin=dict(l=0, r=0, b=0, t=40)\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "fig.write_html(\"result_data/figures/augmented_principal_curve_3D_UMAP.html\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9db936d-b1ab-49f6-a799-b0e1fb6a1adf",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize UMAP in 3D with tube of specified variance surrounding principal curve,\n",
    "# Requires quantiles of projection distances to be calculated\n",
    "\n",
    "\n",
    "latent_path = 'result_data/augmented_latent_space_full_64_balanced.json'\n",
    "metadata_path = 'result_data/cell_data_64_balanced.json'\n",
    "curve_path = 'result_data/augmented_latent_principal_curve_64D_100pts.json'\n",
    "\n",
    "tube_percentile = 50          # e.g. 50 => tube covers ~50% of points per section\n",
    "percentiles_to_report = [25, 50, 75, 90, 95]\n",
    "n_chunks = 20                 # number of sections along curve\n",
    "n_theta = 20                  \n",
    "scaling_multiplier = 0.5\n",
    "random_state = 42\n",
    "\n",
    "with open(latent_path, 'r') as f:\n",
    "    latent_data_dict = json.load(f)\n",
    "\n",
    "with open(metadata_path, 'r') as f:\n",
    "    metadata = json.load(f)\n",
    "\n",
    "with open(curve_path, 'r') as f:\n",
    "    principal_curve = np.array(json.load(f))\n",
    "\n",
    "id_to_label = {entry['id']: entry['label'] for entry in metadata}\n",
    "valid_labels = ['M0','M1M2','M3','M4M5','M6M7_early','M6M7_half']\n",
    "\n",
    "all_ids = []\n",
    "all_latents = []\n",
    "all_labels = []\n",
    "for id_, vec in latent_data_dict.items():\n",
    "    lbl = id_to_label.get(id_)\n",
    "    if lbl is None:\n",
    "        continue\n",
    "    all_ids.append(id_)\n",
    "    all_latents.append(np.array(vec, dtype=float))\n",
    "    all_labels.append(lbl)\n",
    "\n",
    "all_latents = np.vstack(all_latents)\n",
    "principal_curve = np.asarray(principal_curve, dtype=float)\n",
    "n_data, D = all_latents.shape\n",
    "n_curve = principal_curve.shape[0]\n",
    "\n",
    "print(f\"Loaded {n_data} data points, latent dim = {D}, curve points = {n_curve}\")\n",
    "\n",
    "combined = np.vstack((all_latents, principal_curve))\n",
    "um = umap.UMAP(n_components=3, random_state=random_state)\n",
    "embedding = um.fit_transform(combined)\n",
    "data_emb = embedding[:n_data]\n",
    "curve_emb = embedding[n_data:]\n",
    "\n",
    "diff = all_latents[:, None, :] - principal_curve[None, :, :]\n",
    "distances_high = np.linalg.norm(diff, axis=2)\n",
    "\n",
    "nearest_idx = np.argmin(distances_high, axis=1)\n",
    "d_high_min = distances_high[np.arange(n_data), nearest_idx]\n",
    "\n",
    "d_low_min = np.linalg.norm(data_emb - curve_emb[nearest_idx], axis=1)\n",
    "\n",
    "eps = 1e-9\n",
    "ratios = d_low_min / (d_high_min + eps)\n",
    "ratios_clean = ratios[np.isfinite(ratios)]\n",
    "if ratios_clean.size == 0:\n",
    "    scale_ratio = 1.0\n",
    "else:\n",
    "    scale_ratio = float(np.median(ratios_clean))\n",
    "print(f\"Scale ratio (median low/high) used to convert radii -> UMAP units: {scale_ratio:.6g}\")\n",
    "\n",
    "radii_high = np.zeros(n_curve, dtype=float)\n",
    "global_percentile_default = np.percentile(d_high_min, tube_percentile)\n",
    "for k in range(n_curve):\n",
    "    mask_k = (nearest_idx == k)\n",
    "    if np.any(mask_k):\n",
    "        radii_high[k] = np.percentile(d_high_min[mask_k], tube_percentile)\n",
    "    else:\n",
    "        radii_high[k] = global_percentile_default\n",
    "\n",
    "radii_umb = radii_high * scale_ratio * scaling_multiplier\n",
    "\n",
    "vec_norms = np.linalg.norm(all_latents, axis=1)\n",
    "typical_size = np.median(vec_norms)\n",
    "print(\"\\nGlobal projection distance percentiles (HIGH-D) and relative to median vector norm:\")\n",
    "for p in percentiles_to_report:\n",
    "    val = np.percentile(d_high_min, p)\n",
    "    print(f\"  {p:2d}th pct: {val:.6g}   (relative = {val/typical_size:.6g} of median vector norm)\")\n",
    "\n",
    "indices_per_chunk = np.array_split(np.arange(n_curve), n_chunks)\n",
    "chunk_mean_high = []\n",
    "chunk_count = []\n",
    "for chunk_idx, inds in enumerate(indices_per_chunk):\n",
    "    assigned_mask = np.isin(nearest_idx, inds)\n",
    "    vals = d_high_min[assigned_mask]\n",
    "    chunk_count.append(vals.size)\n",
    "    if vals.size:\n",
    "        chunk_mean_high.append(float(np.mean(vals)))\n",
    "    else:\n",
    "        chunk_mean_high.append(float(np.nan))\n",
    "\n",
    "print(f\"\\nChunk averages (HIGH-D) over {n_chunks} chunks (NaN => no assigned points):\")\n",
    "for ci, (cnt, meanv) in enumerate(zip(chunk_count, chunk_mean_high)):\n",
    "    print(f\"  chunk {ci+1:2d}: count={cnt:4d}, mean_dist={np.nan if np.isnan(meanv) else round(meanv,6)}\")\n",
    "\n",
    "\n",
    "def make_tube_mesh(curve_pts, radii, n_theta=20):\n",
    "    \"\"\"Return vertices (Vx3) and triangle indices (i,j,k) for a tube around curve_pts\"\"\"\n",
    "    n = len(curve_pts)\n",
    "    thetas = np.linspace(0, 2*np.pi, n_theta, endpoint=False)\n",
    "    verts = []\n",
    "    for i in range(n):\n",
    "        p = curve_pts[i]\n",
    "        \n",
    "        if i == 0:\n",
    "            tangent = curve_pts[1] - curve_pts[0]\n",
    "        elif i == n-1:\n",
    "            tangent = curve_pts[-1] - curve_pts[-2]\n",
    "        else:\n",
    "            tangent = curve_pts[i+1] - curve_pts[i-1]\n",
    "        tangent = tangent.astype(float)\n",
    "        tn = np.linalg.norm(tangent)\n",
    "        if tn < 1e-8:\n",
    "            tangent = np.array([1.0, 0.0, 0.0])\n",
    "            tn = 1.0\n",
    "        tangent = tangent / tn\n",
    "        arb = np.array([0.0, 0.0, 1.0])\n",
    "        if abs(np.dot(arb, tangent)) > 0.9:\n",
    "            arb = np.array([0.0, 1.0, 0.0])\n",
    "        u = np.cross(tangent, arb)\n",
    "        u_norm = np.linalg.norm(u)\n",
    "        if u_norm < 1e-8:\n",
    "            u = np.array([1.0, 0.0, 0.0])\n",
    "            u_norm = 1.0\n",
    "        u = u / u_norm\n",
    "        v = np.cross(tangent, u)\n",
    "        v = v / np.linalg.norm(v)\n",
    "        r = radii[i]\n",
    "        for th in thetas:\n",
    "            verts.append(p + r * (u * np.cos(th) + v * np.sin(th)))\n",
    "    verts = np.array(verts)\n",
    "\n",
    "    tri_i = []\n",
    "    tri_j = []\n",
    "    tri_k = []\n",
    "    for i in range(n - 1):\n",
    "        for j in range(n_theta):\n",
    "            a = i * n_theta + j\n",
    "            b = i * n_theta + (j + 1) % n_theta\n",
    "            c = (i + 1) * n_theta + (j + 1) % n_theta\n",
    "            d = (i + 1) * n_theta + j\n",
    "            tri_i.append(a); tri_j.append(b); tri_k.append(c)\n",
    "            tri_i.append(a); tri_j.append(c); tri_k.append(d)\n",
    "    return verts, (np.array(tri_i), np.array(tri_j), np.array(tri_k))\n",
    "\n",
    "verts, (tri_i, tri_j, tri_k) = make_tube_mesh(curve_emb, radii_umb, n_theta=n_theta)\n",
    "print(f\"\\nTube mesh: vertices={verts.shape[0]}, triangles={tri_i.shape[0]}\")\n",
    "\n",
    "label_colors = {\n",
    "    'M0': 'blue',\n",
    "    'M1M2': 'green',\n",
    "    'M3': 'orange',\n",
    "    'M4M5': 'purple',\n",
    "    'M6M7_early': 'brown',\n",
    "    'M6M7_half': 'pink'\n",
    "}\n",
    "\n",
    "fig = go.Figure()\n",
    "\n",
    "for label, color in label_colors.items():\n",
    "    idxs = [i for i, l in enumerate(all_labels) if l == label]\n",
    "    if not idxs:\n",
    "        continue\n",
    "    pts = data_emb[idxs]\n",
    "    fig.add_trace(go.Scatter3d(\n",
    "        x=pts[:,0], y=pts[:,1], z=pts[:,2],\n",
    "        mode='markers', marker=dict(size=3, color=color, opacity=0.3),\n",
    "        name=label, text=[label]*len(pts), hoverinfo='text'\n",
    "    ))\n",
    "\n",
    "fig.add_trace(go.Scatter3d(\n",
    "    x=curve_emb[:,0], y=curve_emb[:,1], z=curve_emb[:,2],\n",
    "    mode='lines+markers', line=dict(color='red', width=3),\n",
    "    marker=dict(size=4, color='red'), name='Principal Curve'\n",
    "))\n",
    "\n",
    "fig.add_trace(go.Mesh3d(\n",
    "    x=verts[:,0], y=verts[:,1], z=verts[:,2],\n",
    "    i=tri_i, j=tri_j, k=tri_k,\n",
    "    color='rgba(173,216,230,0.6)', opacity=0.40, name=f'Tube ({tube_percentile}th pct)'\n",
    "))\n",
    "\n",
    "fig.update_layout(\n",
    "    title=f\"3D UMAP with Principal Curve + Variable-Radius Tube\",\n",
    "    scene=dict(xaxis_title='UMAP-1', yaxis_title='UMAP-2', zaxis_title='UMAP-3'),\n",
    "    margin=dict(l=0, r=0, b=0, t=40)\n",
    ")\n",
    "\n",
    "fig.show()\n",
    "fig.write_html(\"result_data/figures/augmented_principal_curve_3D_UMAP_Tube.html\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43ef13ea-a342-487b-8eb0-83dd17194a32",
   "metadata": {},
   "source": [
    "## Ciao!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
