{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import subprocess\n",
    "from pathlib import Path\n",
    "\n",
    "\n",
    "def get_project_root():\n",
    "    # get the absolute path to the root of the git repo\n",
    "    root = subprocess.check_output([\"git\", \"rev-parse\", \"--show-toplevel\"]).strip().decode(\"utf-8\")\n",
    "    return Path(root)\n",
    "\n",
    "# get project root and append it to path\n",
    "project_root = get_project_root()\n",
    "sys.path.append(str(project_root))\n",
    "\n",
    "# embeddings path\n",
    "dataset = \"waymo\"\n",
    "data_dir = f\"{dataset}_data\"\n",
    "base_path = os.path.normpath(os.path.join(project_root, \"..\"))\n",
    "\n",
    "# output dir\n",
    "out_reldir = f\"out/control-vectors/{dataset}/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 204/204 [00:00<00:00, 540.22it/s]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from glob import glob\n",
    "from utils.embs_all import load_embeddings\n",
    "from utils.embs_contrastive import load_contrastive_embed_pairs\n",
    "\n",
    "\n",
    "# load data\n",
    "data_path = os.path.join(base_path, \"data\", data_dir)\n",
    "paths_inputs = sorted(glob(f\"{data_path}/input*\"))\n",
    "paths_embeds = sorted(glob(f\"{data_path}/target_embs*\"))\n",
    "\n",
    "# stack embeddings wrt types\n",
    "embs = load_embeddings(paths_inputs, paths_embeds)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "\n",
    "# Step 1: Labels for the clusters\n",
    "labels = ['low', 'moderate', 'high', 'decelerate', 'constant', 'accelerate', 'stationary', 'straight', 'right', 'left', 'vehicle', 'pedestrian', 'cyclist' ]\n",
    "\n",
    "# Step 2: Pairwise distances matrix (normalized values as before)\n",
    "mean_tensor = torch.empty((len(labels), 128))\n",
    "mean_tensor[0] = embs[\"speed\"][\"low\"].mean\n",
    "mean_tensor[1] = embs[\"speed\"][\"moderate\"].mean\n",
    "mean_tensor[2] = embs[\"speed\"][\"high\"].mean\n",
    "mean_tensor[3] = embs[\"acceleration\"][\"decelerate\"].mean\n",
    "mean_tensor[4] = embs[\"acceleration\"][\"constant\"].mean\n",
    "mean_tensor[5] = embs[\"acceleration\"][\"accelerate\"].mean\n",
    "mean_tensor[6] = embs[\"direction\"][\"stationary\"].mean\n",
    "mean_tensor[7] = embs[\"direction\"][\"straight\"].mean\n",
    "mean_tensor[8] = embs[\"direction\"][\"right\"].mean\n",
    "mean_tensor[9] = embs[\"direction\"][\"left\"].mean\n",
    "mean_tensor[10] = embs[\"agent\"][\"vehicle\"].mean\n",
    "mean_tensor[11] = embs[\"agent\"][\"pedestrian\"].mean\n",
    "mean_tensor[12] = embs[\"agent\"][\"cyclist\"].mean\n",
    "\n",
    "vars_tensor = torch.empty((len(labels), 128))\n",
    "vars_tensor[0] = embs[\"speed\"][\"low\"].var\n",
    "vars_tensor[1] = embs[\"speed\"][\"moderate\"].var\n",
    "vars_tensor[2] = embs[\"speed\"][\"high\"].var\n",
    "vars_tensor[3] = embs[\"acceleration\"][\"decelerate\"].var\n",
    "vars_tensor[4] = embs[\"acceleration\"][\"constant\"].var\n",
    "vars_tensor[5] = embs[\"acceleration\"][\"accelerate\"].var\n",
    "vars_tensor[6] = embs[\"direction\"][\"stationary\"].var\n",
    "vars_tensor[7] = embs[\"direction\"][\"straight\"].var\n",
    "vars_tensor[8] = embs[\"direction\"][\"right\"].var\n",
    "vars_tensor[9] = embs[\"direction\"][\"left\"].var\n",
    "vars_tensor[10] = embs[\"agent\"][\"vehicle\"].var\n",
    "vars_tensor[11] = embs[\"agent\"][\"pedestrian\"].var\n",
    "vars_tensor[12] = embs[\"agent\"][\"cyclist\"].var\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Within-Class Variance**\n",
    "\n",
    "$\n",
    "\\text{Var}_{\\text{within}}(c) = \\frac{1}{N_c} \\sum_{i \\in c} \\| x_i - \\mu_c \\|^2\n",
    "$\n",
    "\n",
    "$\n",
    "S_w = \\frac{1}{C} \\sum_{c=1}^{C} \\text{Var}_{\\text{within}}(c)\n",
    "$\n",
    "\n",
    "$S_w$: Average within-class Variance\n",
    "$C$: total number of classes\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "**Between-Class Distances**\n",
    "\n",
    "Squared Euclidean distance between each pair of class means ($d_b(c_1, c_2)$):\n",
    "\n",
    "$\n",
    "d_b(c_1, c_2) = \\| \\mu_{c_1} - \\mu_{c_2} \\|^2\n",
    "$\n",
    "\n",
    "Average between-class squared distance ($S_b$):\n",
    "\n",
    "$\n",
    "S_b = \\frac{2}{C(C - 1)} \\sum_{c_1 < c_2} d_b(c_1, c_2)\n",
    "$\n",
    "\n",
    "The factor $\\frac{2}{C(C - 1)}$ ensures averaging over all unique class pairs.\n",
    "\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "\n",
    "**Class-Distance Normalized Variance (CDNV)**\n",
    "\n",
    "$\n",
    "\\text{CDNV} = \\frac{S_w}{S_b}\n",
    "$\n",
    "\n",
    " A lower CDNV value indicates that the within-class variance is small relative to the between-class variance, implying better class separability in your embedding space.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.9498)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from future_motion.utils.interpretability.neural_collapse import CDNV\n",
    "\n",
    "\n",
    "CDNV(mean_tensor, vars_tensor)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "words",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
