{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import subprocess\n",
    "from pathlib import Path\n",
    "\n",
    "\n",
    "def get_project_root():\n",
    "    # get the absolute path to the root of the git repo\n",
    "    root = subprocess.check_output([\"git\", \"rev-parse\", \"--show-toplevel\"]).strip().decode(\"utf-8\")\n",
    "    return Path(root)\n",
    "\n",
    "# get project root and append it to path\n",
    "project_root = get_project_root()\n",
    "sys.path.append(str(project_root))\n",
    "\n",
    "# embeddings path\n",
    "dataset = \"waymo\"\n",
    "data_dir = f\"{dataset}_data\"\n",
    "base_path = os.path.normpath(os.path.join(project_root, \"..\"))\n",
    "\n",
    "# output dir\n",
    "out_reldir = f\"out/control-vectors/{dataset}/\"\n",
    "out_path = os.path.join(base_path, out_reldir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from glob import glob\n",
    "from utils.embs_all import load_embeddings\n",
    "from utils.embs_contrastive import load_contrastive_embed_pairs\n",
    "\n",
    "\n",
    "# load data\n",
    "data_path = os.path.join(base_path, \"data\", data_dir)\n",
    "paths_inputs = sorted(glob(f\"{data_path}/input*\"))\n",
    "paths_embeds = sorted(glob(f\"{data_path}/target_embs*\"))\n",
    "\n",
    "# stack embeddings wrt types\n",
    "embs = load_embeddings(paths_inputs, paths_embeds)\n",
    "\n",
    "# trim and stack contrastive pairs of embeddings\n",
    "contrastive_embs = load_contrastive_embed_pairs(embs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from utils.embs_all import EmbeddingData\n",
    "\n",
    "\n",
    "# slice data for nc metrics\n",
    "stack = lambda data: torch.stack([torch.mean(entry, dim=0) for entry in data])\n",
    "stack = lambda data: data\n",
    "\n",
    "control_features = {}\n",
    "\n",
    "control_features[\"high\"] = EmbeddingData()\n",
    "control_features[\"high\"].data = stack(contrastive_embs[\"speed\"][:len(contrastive_embs[\"speed\"]) // 2])\n",
    "\n",
    "control_features[\"low\"] = EmbeddingData()\n",
    "control_features[\"low\"].data = stack(contrastive_embs[\"speed\"][len(contrastive_embs[\"speed\"]) // 2:])\n",
    "\n",
    "control_features[\"accelerate\"] = EmbeddingData()\n",
    "control_features[\"accelerate\"].data = stack(contrastive_embs[\"acceleration\"][:len(contrastive_embs[\"acceleration\"]) // 2])\n",
    "\n",
    "control_features[\"decelerate\"] = EmbeddingData()\n",
    "control_features[\"decelerate\"].data  = stack(contrastive_embs[\"acceleration\"][len(contrastive_embs[\"acceleration\"]) // 2:])\n",
    "\n",
    "control_features[\"right\"] = EmbeddingData()\n",
    "control_features[\"right\"].data = stack(contrastive_embs[\"direction\"][:len(contrastive_embs[\"direction\"]) // 2])\n",
    "\n",
    "control_features[\"left\"] = EmbeddingData()\n",
    "control_features[\"left\"].data = stack(contrastive_embs[\"direction\"][len(contrastive_embs[\"direction\"]) // 2:])\n",
    "\n",
    "control_features[\"vehicle\"] = EmbeddingData()\n",
    "control_features[\"vehicle\"].data = stack(contrastive_embs[\"agent\"][:len(contrastive_embs[\"agent\"]) // 2])\n",
    "\n",
    "control_features[\"pedestrian\"] = EmbeddingData()\n",
    "control_features[\"pedestrian\"].data = stack(contrastive_embs[\"agent\"][len(contrastive_embs[\"agent\"]) // 2:])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(control_features[\"high\"].data.shape)\n",
    "control_features[\"high\"].mean"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "\n",
    "# calculate true global mean\n",
    "idx_layer = 2\n",
    "tmp = []\n",
    "# loop over data files\n",
    "for (path_input, path_embs) in tqdm(zip(paths_inputs, paths_embeds)):\n",
    "    e = torch.load(path_embs, map_location=torch.device('cpu'))\n",
    "    # loop over batches (48 default)\n",
    "    for idx_batch in range(len(e)):\n",
    "        # loop over timesteps (11 default)\n",
    "        for idx_ts in range(len(e[idx_layer][idx_batch])):\n",
    "            tmp.append(e[idx_layer][idx_batch][idx_ts])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(tmp)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tmp = torch.stack(tmp)\n",
    "control_features[\"global\"] = EmbeddingData()\n",
    "control_features[\"global\"].data = tmp\n",
    "control_features[\"global\"].mean = torch.mean(tmp, dim=0)\n",
    "control_features[\"global\"].var = torch.var(tmp, dim=0, unbiased=False)\n",
    "tmp.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import itertools\n",
    "import torch\n",
    "import numpy as np\n",
    "from pprint import pprint\n",
    "\n",
    "\n",
    "def pairwise_l2_distances(class_means):\n",
    "\n",
    "    C = len(class_means)\n",
    "    class_indices = list(range(C))\n",
    "    class_pairs = list(itertools.combinations(class_indices, 2))\n",
    "\n",
    "    out = []\n",
    "    for (c1, c2) in class_pairs:\n",
    "        mean_c1 = class_means[c1]\n",
    "        mean_c2 = class_means[c2]\n",
    "        distance = torch.norm(mean_c1 - mean_c2, p=2)\n",
    "        out.append(distance)\n",
    "\n",
    "    return torch.stack(out)\n",
    "\n",
    "\n",
    "def fill_distances_in_matrix(mean_tensor, labels, verbose=True):\n",
    "\n",
    "    n_classes = len(mean_tensor)\n",
    "    class_indices = list(range(n_classes))\n",
    "    class_pairs = list(itertools.combinations(class_indices, 2))\n",
    "\n",
    "    bw_class_distances = np.zeros((n_classes, n_classes))\n",
    "    tmp = pairwise_l2_distances(mean_tensor)\n",
    "    for i, (r, c) in enumerate(class_pairs):\n",
    "        bw_class_distances[r, c] = tmp[i]\n",
    "\n",
    "    bw_class_distances = np.round(bw_class_distances, 2)\n",
    "    if verbose:\n",
    "        print(bw_class_distances)\n",
    "\n",
    "    return bw_class_distances\n",
    "\n",
    "def make_symm(m):\n",
    "    return m + m.T - np.diag(m.diagonal())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "\n",
    "labels = []\n",
    "mean_tensor = torch.empty((len(control_features.keys()), 128))\n",
    "vars_tensor = torch.empty((len(control_features.keys()), 128))\n",
    "\n",
    "# Loop through each key and calculate the mean\n",
    "for i, key in enumerate(control_features.keys()):\n",
    "    print(key)\n",
    "    print(control_features[key].mean)\n",
    "    print(\"---------\")\n",
    "    mean_tensor[i] = control_features[key].mean\n",
    "    #vars_tensor[i] = control_features[key].var\n",
    "    labels.append(key)\n",
    "\n",
    "distances = fill_distances_in_matrix(mean_tensor, labels, verbose=False)\n",
    "normalized_distances = distances / np.max(distances)\n",
    "normalized_distances = np.round(normalized_distances, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from future_motion.utils.similarity.vector import VectorComparison\n",
    "\n",
    "\n",
    "v_speed_high = control_features[\"high\"].mean - control_features[\"global\"].mean\n",
    "v_speed_low = control_features[\"low\"].mean - control_features[\"global\"].mean\n",
    "v_acceleration = control_features[\"accelerate\"].mean - control_features[\"global\"].mean\n",
    "v_deceleration = control_features[\"decelerate\"].mean - control_features[\"global\"].mean\n",
    "v_right= control_features[\"right\"].mean - control_features[\"global\"].mean\n",
    "v_left= control_features[\"left\"].mean - control_features[\"global\"].mean\n",
    "v_vehicle = control_features[\"vehicle\"].mean - control_features[\"global\"].mean\n",
    "v_pedestrian = control_features[\"pedestrian\"].mean - control_features[\"global\"].mean\n",
    "feature_center_v = [v_speed_high, v_speed_low, v_acceleration, v_deceleration, v_right, v_left, v_vehicle, v_pedestrian]\n",
    "\n",
    "# calculate vector similarity\n",
    "class_indices = list(range(len(feature_center_v)))\n",
    "class_pairs = list(itertools.combinations(class_indices, 2))\n",
    "\n",
    "cluster_similarity = np.zeros([len(feature_center_v), len(feature_center_v)])\n",
    "for r, c in class_pairs:\n",
    "    v1 = feature_center_v[r]\n",
    "    v2 = feature_center_v[c]\n",
    "    vc = VectorComparison(v1, v2)\n",
    "    cluster_similarity[r, c] = vc.cos_sim_deg()\n",
    "\n",
    "cluster_similarity = np.round(cluster_similarity, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "# import tikzplotlib\n",
    "\n",
    "# Pairwise distances matrix (9x9)\n",
    "distance_matrix = np.asarray(normalized_distances)\n",
    "distance_matrix = make_symm(distance_matrix)\n",
    "\n",
    "# Cosine similarity matrix (excluding 'global', 8x8)\n",
    "cosine_similarity_matrix = np.asarray(cluster_similarity)\n",
    "cosine_similarity_matrix = make_symm(cosine_similarity_matrix)\n",
    "\n",
    "# Labels for the clusters\n",
    "labels = ['low', 'high', 'decelerate', 'accelerate', 'right', 'left', 'vehicle', 'pedestrian']\n",
    "\n",
    "# Exclude the last row and column ('global') from the distance matrix\n",
    "adjusted_distance_matrix = distance_matrix[:-1, :-1]\n",
    "\n",
    "# Convert matrices to DataFrames\n",
    "dist_df = pd.DataFrame(adjusted_distance_matrix, index=labels, columns=labels)\n",
    "cos_sim_df = pd.DataFrame(cosine_similarity_matrix, index=labels, columns=labels)\n",
    "\n",
    "# Create a mask for the lower triangle (including the diagonal)\n",
    "mask = np.tril(np.ones_like(dist_df, dtype=bool), k=-1)\n",
    "\n",
    "# Plot the heatmap\n",
    "plt.figure(figsize=(10, 8))\n",
    "sns.heatmap(\n",
    "    dist_df,\n",
    "    annot=cos_sim_df,\n",
    "    fmt='.1f',\n",
    "    cmap='coolwarm',\n",
    "    mask=mask,\n",
    "    square=True,\n",
    "    cbar_kws={\"shrink\": 0.8},\n",
    "    annot_kws={\"size\": 8, \"color\": \"white\"},\n",
    "    linewidths=.5\n",
    ")\n",
    "\n",
    "plt.title('Combined Heatmap of Pairwise Distances (Color) and Embedding Cluster Cosine Similarities (Annotations)')\n",
    "plt.xticks(rotation=45, ha='right')\n",
    "plt.yticks(rotation=0)\n",
    "plt.tight_layout()\n",
    "# tikzplotlib.save(\"heatmap.tex\")\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "# import tikzplotlib  # Uncomment if you want to use tikzplotlib\n",
    "\n",
    "# Pairwise distances matrix (9x9)\n",
    "distance_matrix = np.asarray(normalized_distances)\n",
    "distance_matrix = make_symm(distance_matrix)\n",
    "\n",
    "# Cosine similarity matrix (excluding 'global', 8x8)\n",
    "cosine_similarity_matrix = np.asarray(cluster_similarity)\n",
    "cosine_similarity_matrix = make_symm(cosine_similarity_matrix)\n",
    "\n",
    "# Labels for the clusters\n",
    "labels = ['low', 'high', 'decelerate', 'accelerate', 'right', 'left', 'vehicle', 'pedestrian']\n",
    "\n",
    "# Exclude the last row and column ('global') from the distance matrix\n",
    "adjusted_distance_matrix = distance_matrix[:-1, :-1]\n",
    "\n",
    "# Convert matrices to DataFrames\n",
    "dist_df = pd.DataFrame(adjusted_distance_matrix, index=labels, columns=labels)\n",
    "cos_sim_df = pd.DataFrame(cosine_similarity_matrix, index=labels, columns=labels)\n",
    "\n",
    "# Create a mask for the lower triangle (including the diagonal)\n",
    "mask = np.tril(np.ones_like(cos_sim_df, dtype=bool), k=-1)\n",
    "\n",
    "# Plot the heatmap\n",
    "plt.figure(figsize=(10, 8))\n",
    "\n",
    "# Create a custom colormap that goes from similar (small angles) to dissimilar (large angles)\n",
    "cmap = sns.color_palette(\"coolwarm\", as_cmap=True)\n",
    "\n",
    "# Plot the heatmap\n",
    "ax = sns.heatmap(\n",
    "    cos_sim_df,\n",
    "    annot=dist_df,\n",
    "    fmt='.2f',  # Display distances with two decimal places\n",
    "    cmap=cmap,\n",
    "    mask=mask,\n",
    "    square=True,\n",
    "    vmin=0,  # Set the minimum of the heatmap to 0 degrees\n",
    "    vmax=180,  # Set the maximum of the heatmap to 180 degrees\n",
    "    cbar_kws={\"shrink\": 0.8, \"label\": \"Angle (Degrees)\"},\n",
    "    annot_kws={\"size\": 8, \"color\": \"black\"},\n",
    "    linewidths=.5\n",
    ")\n",
    "\n",
    "# Move x-axis labels to the top\n",
    "ax.xaxis.tick_top()\n",
    "ax.xaxis.set_label_position('top')\n",
    "\n",
    "plt.title('Combined Heatmap of Embedding Cluster Cosine Similarities (Color) and Distances (Annotations)')\n",
    "plt.xticks(rotation=45, ha='right')\n",
    "plt.yticks(rotation=0)\n",
    "plt.tight_layout()\n",
    "# tikzplotlib.save(\"heatmap.tex\")  # Uncomment if you want to use tikzplotlib\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy.stats import spearmanr\n",
    "import torch\n",
    "#import tikzplotlib\n",
    "\n",
    "\n",
    "embedding_matrix = np.array([\n",
    "    embs[\"speed\"][\"low\"].mean,\n",
    "    embs[\"speed\"][\"moderate\"].mean,\n",
    "    embs[\"speed\"][\"high\"].mean,\n",
    "    embs[\"speed\"][\"backwards\"].mean,\n",
    "    embs[\"acceleration\"][\"decelerate\"].mean,\n",
    "    embs[\"acceleration\"][\"constant\"].mean,\n",
    "    embs[\"acceleration\"][\"accelerate\"].mean,\n",
    "    embs[\"direction\"][\"stationary\"].mean,\n",
    "    embs[\"direction\"][\"straight\"].mean,\n",
    "    embs[\"direction\"][\"right\"].mean,\n",
    "    embs[\"direction\"][\"left\"].mean,\n",
    "    embs[\"agent\"][\"vehicle\"].mean,\n",
    "    embs[\"agent\"][\"pedestrian\"].mean,\n",
    "    embs[\"agent\"][\"cyclist\"].mean,\n",
    "])\n",
    "\n",
    "\n",
    "# Labels for the clusters\n",
    "labels = ['low', 'moderate', 'high', 'backwards', 'decelerate', 'constant', 'accelerate', 'stationary', 'straight', 'right', 'left', 'vehicle', 'pedestrian', 'cyclist' ]\n",
    "\n",
    "# Step 1: Compute the Spearman correlation matrix\n",
    "spearman_corr = np.zeros((embedding_matrix.shape[0], embedding_matrix.shape[0]))\n",
    "\n",
    "for i in range(embedding_matrix.shape[0]):\n",
    "    for j in range(embedding_matrix.shape[0]):\n",
    "        corr, _ = spearmanr(embedding_matrix[i], embedding_matrix[j])\n",
    "        spearman_corr[i, j] = corr\n",
    "\n",
    "\n",
    "# Step 2: Pairwise distances matrix (normalized values as before)\n",
    "mean_tensor = torch.empty((len(labels), 128))\n",
    "mean_tensor[0] = embs[\"speed\"][\"low\"].mean\n",
    "mean_tensor[1] = embs[\"speed\"][\"moderate\"].mean\n",
    "mean_tensor[2] = embs[\"speed\"][\"high\"].mean\n",
    "mean_tensor[3] = embs[\"speed\"][\"backwards\"].mean\n",
    "mean_tensor[4] = embs[\"acceleration\"][\"decelerate\"].mean\n",
    "mean_tensor[5] = embs[\"acceleration\"][\"constant\"].mean\n",
    "mean_tensor[6] = embs[\"acceleration\"][\"accelerate\"].mean\n",
    "mean_tensor[7] = embs[\"direction\"][\"stationary\"].mean\n",
    "mean_tensor[8] = embs[\"direction\"][\"straight\"].mean\n",
    "mean_tensor[9] = embs[\"direction\"][\"right\"].mean\n",
    "mean_tensor[10] = embs[\"direction\"][\"left\"].mean\n",
    "mean_tensor[11] = embs[\"agent\"][\"vehicle\"].mean\n",
    "mean_tensor[12] = embs[\"agent\"][\"pedestrian\"].mean\n",
    "mean_tensor[13] = embs[\"agent\"][\"cyclist\"].mean\n",
    "\n",
    "\n",
    "distance_matrix = fill_distances_in_matrix(mean_tensor, labels, verbose=False)\n",
    "distance_matrix = make_symm(distance_matrix)\n",
    "distance_matrix /= np.max(distance_matrix)\n",
    "\n",
    "\n",
    "# Step 3: Convert matrices to DataFrames for easier plotting\n",
    "spearman_df = pd.DataFrame(spearman_corr, index=labels, columns=labels)\n",
    "dist_df = pd.DataFrame(distance_matrix, index=labels, columns=labels)\n",
    "\n",
    "# Step 4: Create a mask for the lower triangle (including the diagonal)\n",
    "mask = np.tril(np.ones_like(spearman_df, dtype=bool), k=-1)\n",
    "\n",
    "# Step 5: Plot the heatmap\n",
    "plt.figure(figsize=(10, 8))\n",
    "\n",
    "# Colormap for Spearman correlation (ranging from -1 to 1)\n",
    "cmap = sns.diverging_palette(220, 20, as_cmap=True)\n",
    "cmap = sns.color_palette(\"coolwarm\", as_cmap=True)\n",
    "\n",
    "sns.heatmap(\n",
    "    spearman_df,\n",
    "    annot=dist_df,\n",
    "    fmt='.3f',\n",
    "    cmap=cmap,\n",
    "    mask=mask,\n",
    "    square=True,\n",
    "    cbar_kws={\"shrink\": 0.8, \"label\": \"Spearman Correlation\"},\n",
    "    annot_kws={\"size\": 8, \"color\": \"black\"},\n",
    "    linewidths=.5\n",
    ")\n",
    "\n",
    "plt.title('Combined Heatmap of Spearman Correlation (Color) and Distances (Annotations)')\n",
    "plt.xticks(rotation=45, ha='right')\n",
    "plt.yticks(rotation=0)\n",
    "plt.tight_layout()\n",
    "# tikzplotlib.save(f\"{out_dir}/spearman.tex\")\n",
    "plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "words",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
