{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import subprocess\n",
    "from pathlib import Path\n",
    "\n",
    "\n",
    "def get_project_root():\n",
    "    # get the absolute path to the root of the git repo\n",
    "    root = subprocess.check_output([\"git\", \"rev-parse\", \"--show-toplevel\"]).strip().decode(\"utf-8\")\n",
    "    return Path(root)\n",
    "\n",
    "# get project root and append it to path\n",
    "# (in case the workspace is launched from somewhere else)\n",
    "project_root = get_project_root()\n",
    "sys.path.append(str(project_root))\n",
    "\n",
    "# output dir\n",
    "dataset = \"waymo\"\n",
    "out_reldir = f\"out/control-vectors/{dataset}/\"\n",
    "base_path = os.path.normpath(os.path.join(project_root, \"..\"))\n",
    "out_path = os.path.join(base_path, out_reldir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "# Load PCA-based control vectors\n",
    "pca_files = [file for file in os.listdir(out_path) if file.startswith(\"pca\") and file.endswith(\".pt\")]\n",
    "pca_control_vectors = {}\n",
    "for fname in pca_files:\n",
    "    file_path = os.path.join(out_path, fname)\n",
    "    pca_control_vectors[fname.split(\"_\")[1]] = torch.load(file_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load SAE-based control vectors\n",
    "sae_files = [file for file in os.listdir(out_path) if file.startswith(\"sae\") and file.endswith(\".pt\")]\n",
    "sae_control_vectors = {}\n",
    "for fname in sae_files:\n",
    "    dim = fname.split(\"_\")[0][3:]\n",
    "    sae_control_vectors[dim] = {}\n",
    "    file_path = os.path.join(out_path, fname)\n",
    "    sae_control_vectors[dim][fname.split(\"_\")[1]] = torch.load(file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "import numpy as np\n",
    "from collections import defaultdict\n",
    "from future_motion.utils.similarity.vector import VectorComparison\n",
    "\n",
    "\n",
    "# Helper function to inspect the similarity between two sets of control vectors\n",
    "def inspect_vector(control_vectors1: Dict[str, np.ndarray], control_vectors2: Dict[str, np.ndarray]):\n",
    "    result = defaultdict(dict)\n",
    "    for key1, cv1 in control_vectors1.items():\n",
    "        for key2, cv2 in control_vectors2.items():\n",
    "            vec_compare =  VectorComparison(cv1, cv2)\n",
    "            result[key1 + \"_\" + key2][\"cos_sim_deg\"] = vec_compare.cos_sim_deg()\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "\n",
    "comparison = {}\n",
    "comparison[\"PCA-PCA\"] = inspect_vector(pca_control_vectors, pca_control_vectors)\n",
    "comparison_df = pd.DataFrame(comparison[\"PCA-PCA\"]).T\n",
    "if False:\n",
    "    print(comparison_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**PCA w/ PCA**\n",
    "\n",
    "| **cosine sim** | speed      | acceleration | direction  | agent      |\n",
    "|----------------|------------|--------------|------------|------------|\n",
    "| speed          | 0.0        | 11.458136    | 122.603544 | 10.865894  |\n",
    "| acceleration   |            | 0.0          | 126.78761  | 6.82372    |\n",
    "| direction      |            |              | 0.0        | 128.655917 |\n",
    "| agent          |            |              |            | 0.0        |\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for hidden_dim, sae_cv in sae_control_vectors.items():\n",
    "\n",
    "    # Compare PCA-based and SAE-based control vector\n",
    "    comparison[\"PCA-SAE\"] = inspect_vector(pca_control_vectors, sae_cv)\n",
    "\n",
    "    # Compare SAE-based control vector with itself\n",
    "    comparison[\"SAE-SAE\"] = inspect_vector(sae_cv, sae_cv)\n",
    "    \n",
    "    if False:\n",
    "        print(hidden_dim)\n",
    "        print(pd.DataFrame(comparison[\"PCA-SAE\"]).T)\n",
    "        print(pd.DataFrame(comparison[\"SAE-SAE\"]).T)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**PCA w/ PCA**\n",
    "\n",
    "| **Angle (°)**  | speed      | acceleration | direction  | agent      |\n",
    "|----------------|------------|--------------|------------|------------|\n",
    "| speed          | 0.0        | 11.458136    | 122.603544 | 10.865894  |\n",
    "| acceleration   |            | 0.0          | 126.78761  | 6.82372    |\n",
    "| direction      |            |              | 0.0        | 128.655917 |\n",
    "| agent          |            |              |            | 0.0        |\n",
    "\n",
    "---\n",
    "\n",
    "**PCA w/ SAE (512 hidden-dim)**\n",
    "\n",
    "\n",
    "| **Angle (°)**  | speed | acceleration | direction | agent |\n",
    "|----------------|-------|--------------|-----------|-------|\n",
    "| speed          | 20.7  | 28.6         | 123.8     | 23.4  |\n",
    "| acceleration   | 19.1  | 23.0         | 128.5     | 18.6  |\n",
    "| direction      | 115.9 | 116.6        | 13.7      | 120.8 |\n",
    "| agent          | 19.4  | 24.4         | 130.2     | 18.3  |\n",
    "\n",
    "---\n",
    "\n",
    "**SAE w/ SAE (512 hidden-dim)**\n",
    "\n",
    "| **Angle (°)**  | speed | acceleration | direction | agent |\n",
    "|----------------|-------|--------------|-----------|-------|\n",
    "| speed          | 0.0   | 10.2         | 121.8     | 7.6   |\n",
    "| acceleration   |       | 0.0          | 123.7     | 7.6   |\n",
    "| direction      |       |              | 0.0       | 126.9 |\n",
    "| agent          |       |              |           | 0.0   |\n",
    "\n",
    "---\n",
    "\n",
    "**PCA w/ SAE (256 hidden-dim)**\n",
    "\n",
    "| **Angle (°)**  | speed | acceleration | direction | agent |\n",
    "|----------------|-------|--------------|-----------|-------|\n",
    "| speed          | 21.5  | 26.8         | 123.8     | 23.3  |\n",
    "| acceleration   | 20.3  | 21.0         | 128.7     | 18.7  |\n",
    "| direction      | 114.7 | 116.9        | 13.7      | 120.1 |\n",
    "| agent          | 20.8  | 23.1         | 130.2     | 18.7  |\n",
    "\n",
    "\n",
    "---\n",
    "\n",
    "**SAE w/ SAE (256 hidden-dim)**\n",
    "\n",
    "| **Angle (°)**  | speed | acceleration | direction | agent |\n",
    "|----------------|-------|--------------|-----------|-------|\n",
    "| speed          | 0.0   | 9.9          | 120.9     | 7.9   |\n",
    "| acceleration   |       | 0.0          | 123.7     | 7.2   |\n",
    "| direction      |       |              | 0.0       | 126.3 |\n",
    "| agent          |       |              |           | 0.0   |\n",
    "\n",
    "---\n",
    "\n",
    "**PCA w/ SAE (128 hidden-dim)**\n",
    "\n",
    "| **Angle (°)**  | speed | acceleration | direction | agent |\n",
    "|----------------|-------|--------------|-----------|-------|\n",
    "| speed          | 19.7  | 25.3         | 124.3     | 21.6  |\n",
    "| acceleration   | 19.2  | 20.0         | 128.8     | 17.5  |\n",
    "| direction      | 115.2 | 117.1        | 12.1      | 120.5 |\n",
    "| agent          | 19.5  | 21.8         | 130.4     | 17.1  |\n",
    "\n",
    "---\n",
    "\n",
    "**SAE w/ SAE (128 hidden-dim)**\n",
    "\n",
    "| **Angle (°)**  | speed | acceleration | direction | agent |\n",
    "|----------------|-------|--------------|-----------|-------|\n",
    "| speed          | 0.0   | 9.5          | 120.6     | 7.8   |\n",
    "| acceleration   |       | 0.0          | 122.9     | 7.0   |\n",
    "| direction      |       |              | 0.0       | 125.8 |\n",
    "| agent          |       |              |           | 0.0   |\n",
    "\n",
    "---\n",
    "\n",
    "**PCA w/ SAE (64 hidden-dim)**\n",
    "\n",
    "\n",
    "| **Angle (°)**  | speed | acceleration | direction | agent |\n",
    "|----------------|-------|--------------|-----------|-------|\n",
    "| speed          | 18.1  | 23.7         | 124.7     | 19.3  |\n",
    "| acceleration   | 19.3  | 19.9         | 128.9     | 16.5  |\n",
    "| direction      | 115.0 | 116.6        | 13.3      | 120.5 |\n",
    "| agent          | 19.8  | 21.9         | 130.5     | 16.4  |\n",
    "\n",
    "---\n",
    "\n",
    "**SAE w/ SAE (64 hidden-dim)**\n",
    "\n",
    "| **Angle (°)**  | speed | acceleration | direction | agent |\n",
    "|----------------|-------|--------------|-----------|-------|\n",
    "| speed          | 0.0   | 9.7          | 121.0     | 8.0   |\n",
    "| acceleration   |       | 0.0          | 123.2     | 7.5   |\n",
    "| direction      |       |              | 0.0       | 126.3 |\n",
    "| agent          |       |              |           | 0.0   |\n",
    "\n",
    "---\n",
    "\n",
    "**PCA w/ SAE (32 hidden-dim)**\n",
    "\n",
    "| **Angle (°)**  | speed | acceleration | direction | agent |\n",
    "|----------------|-------|--------------|-----------|-------|\n",
    "| speed          | 14.7  | 18.8         | 126.4     | 15.5  |\n",
    "| acceleration   | 18.0  | 15.5         | 130.3     | 14.1  |\n",
    "| direction      | 114.4 | 116.9        | 10.9      | 120.2 |\n",
    "| agent          | 18.1  | 17.6         | 132.0     | 13.4  |\n",
    "\n",
    "---\n",
    "\n",
    "**SAE w/ SAE (32 hidden-dim)**\n",
    "\n",
    "| **Angle (°)**  | speed | acceleration | direction | agent |\n",
    "|----------------|-------|--------------|-----------|-------|\n",
    "| speed          | 0.0   | 9.8          | 120.3     | 8.3   |\n",
    "| acceleration   |       | 0.0          | 122.8     | 7.0   |\n",
    "| direction      |       |              | 0.0       | 125.8 |\n",
    "| agent          |       |              |           | 0.0   |\n",
    "\n",
    "---\n",
    "\n",
    "**PCA w/ SAE (16 hidden-dim)**\n",
    "\n",
    "| **Angle (°)**  | speed | acceleration | direction | agent |\n",
    "|----------------|-------|--------------|-----------|-------|\n",
    "| speed          | 23.5  | 25.1         | 126.6     | 21.8  |\n",
    "| acceleration   | 28.4  | 26.0         | 128.9     | 23.5  |\n",
    "| direction      | 110.2 | 111.9        | 24.6      | 116.6 |\n",
    "| agent          | 28.0  | 26.8         | 131.0     | 22.5  |\n",
    "\n",
    "---\n",
    "\n",
    "**SAE w/ SAE (16 hidden-dim)**\n",
    "\n",
    "| **Angle (°)**  | speed | acceleration | direction | agent  |\n",
    "|----------------|-------|--------------|-----------|--------|\n",
    "| speed          | 0.0   | 9.5          | 124.1     | 9.3    |\n",
    "| acceleration   |       | 0.0          | 125.2     | 7.5    |\n",
    "| direction      |       |              | 0.0       | 129.3  |\n",
    "| agent          |       |              |           | 0.0    |"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "words",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
