{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import subprocess\n",
    "from pathlib import Path\n",
    "import gdown\n",
    "import zipfile\n",
    "\n",
    "\n",
    "def get_project_root():\n",
    "    # get the absolute path to the root of the git repo\n",
    "    root = subprocess.check_output([\"git\", \"rev-parse\", \"--show-toplevel\"]).strip().decode(\"utf-8\")\n",
    "    return Path(root)\n",
    "\n",
    "# get project root and append it to path\n",
    "project_root = get_project_root()\n",
    "\n",
    "\n",
    "# create data directory for embeddings\n",
    "base_path = os.path.normpath(os.path.join(project_root, \"..\"))\n",
    "data_path = os.path.join(base_path, \"data\")\n",
    "if not os.path.exists(data_path):\n",
    "    os.makedirs(data_path)\n",
    "\n",
    "# download dataset if it doesn't exist\n",
    "dataset = \"argo\"  # \"waymo\" / \"\"argo\"\n",
    "dataset_path = os.path.join(data_path, f\"{dataset}_data\")\n",
    "if not os.path.exists(dataset_path):\n",
    "    file_id = \"1FbMXOT5Upqhm51ZxPHVz6g64KK2Cgbc6\" if dataset == \"waymo\" else \"1s4pKBaz8bb3ZvRwyDFwl-YUvN-TAPLCp\"\n",
    "\n",
    "    download_url = f\"https://drive.google.com/uc?id={file_id}\"\n",
    "    zip_path = dataset_path + \".zip\"\n",
    "    gdown.download(download_url, zip_path, quiet=False)\n",
    "\n",
    "    # unzip\n",
    "    with zipfile.ZipFile(zip_path, 'r') as zf:\n",
    "        zf.extractall(data_path)\n",
    "    os.remove(zip_path)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load target embeddings; torch.Size([48, 11, 128])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Load embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 211/211 [00:00<00:00, 620.50it/s]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "from glob import glob\n",
    "from utils.embs_all import load_embeddings\n",
    "from utils.embs_contrastive import load_contrastive_embed_pairs\n",
    "\n",
    "\n",
    "# load data\n",
    "paths_inputs = sorted(glob(f\"{dataset_path}/input*\"))\n",
    "paths_embeds = sorted(glob(f\"{dataset_path}/target_embs*\"))\n",
    "\n",
    "# stack embeddings wrt types\n",
    "embs = load_embeddings(paths_inputs, paths_embeds)\n",
    "\n",
    "# trim and stack contrastive pairs of embeddings\n",
    "contrastive_embs = load_contrastive_embed_pairs(embs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Get training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "\n",
    "# get all embeddings\n",
    "embs_all = []\n",
    "for path_embs in tqdm(paths_embeds):\n",
    "    embs = torch.load(path_embs, map_location=torch.device('cpu'))\n",
    "    embs_all.append(embs)\n",
    "\n",
    "# we have 3 modules; hence 3 hidden states in embs_all\n",
    "# 48 batch size, 11 past (Waymo), 128 hidden state size\n",
    "len(embs_all), embs_all[0][0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get the last module's hidden state\n",
    "embs_all_last = []\n",
    "for (path_input, path_embs) in tqdm(zip(paths_inputs, paths_embeds)):\n",
    "    embs = torch.load(path_embs, map_location=torch.device('cpu'))\n",
    "    # take the last module (layer) and the last past time step (the current one) of it\n",
    "    embs_all_last.append(embs[-1][:, -1])\n",
    "embs_all_last = torch.stack(embs_all_last, dim=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# create a TensorDataset and DataLoader\n",
    "train_dataset = TensorDataset(embs_all_last, embs_all_last)  # Autoencoder output = input\n",
    "# set 'drop_last=False' to get the exact the same loss values as in the paper\n",
    "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, drop_last=False, num_workers=15)  # Adjust batch size as needed "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create control vectors using PCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# output dir\n",
    "out_reldir = f\"out/control-vectors/{dataset}/\"\n",
    "out_path = os.path.join(base_path, out_reldir)\n",
    "if not os.path.exists(out_path):\n",
    "    os.makedirs(out_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from future_motion.utils.interpretability.control_vectors import fit_control_vector\n",
    "\n",
    "\n",
    "idx_layer = 2\n",
    "PCA_control_vectors = {}\n",
    "for key in contrastive_embs.keys():\n",
    "    PCA_control_vectors[key] = fit_control_vector(contrastive_embs[key])\n",
    "    torch.save(torch.tensor(PCA_control_vectors[key]), f=f\"{out_path}/pca_{key}_layer{idx_layer}.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### Create control vectors using SAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from future_motion.utils.interpretability.sparse_autoencoder import SparseAutoencoder, SAE\n",
    "from pytorch_lightning.loggers import TensorBoardLogger\n",
    "from pytorch_lightning.callbacks import ModelCheckpoint\n",
    "import pytorch_lightning as pl\n",
    "\n",
    "\n",
    "autoencoders = {}\n",
    "d_mlp = 128\n",
    "for n_hidden in [512, 256, 128, 64, 32, 16]:\n",
    "\n",
    "    # Create specific subdirectories under out_path\n",
    "    log_dir = os.path.join(out_path, \"logs\")\n",
    "    checkpoint_dir = os.path.join(out_path, \"checkpoints\")\n",
    "    \n",
    "    # Create directories if they don't exist\n",
    "    os.makedirs(log_dir, exist_ok=True)\n",
    "    os.makedirs(checkpoint_dir, exist_ok=True)\n",
    "\n",
    "    model = SAE(d_mlp, n_hidden, max_epochs=10000)\n",
    "\n",
    "    logger = TensorBoardLogger(log_dir, name=\"sae_training\")\n",
    "    checkpoint_callback = ModelCheckpoint(\n",
    "        monitor=\"loss\",\n",
    "        dirpath=checkpoint_dir,\n",
    "        filename=f\"sae{n_hidden}-\" + \"{epoch:02d}-{loss:.4f}\",\n",
    "        save_top_k=3,\n",
    "        mode=\"min\"\n",
    "    )\n",
    "\n",
    "    # Create the trainer\n",
    "    trainer = pl.Trainer(\n",
    "        max_epochs=10000,\n",
    "        logger=logger,\n",
    "        callbacks=[checkpoint_callback],\n",
    "        accelerator=\"auto\",\n",
    "        devices=\"auto\"\n",
    "    )\n",
    "\n",
    "    # Train the model\n",
    "    trainer.fit(model=model, train_dataloaders=train_loader)\n",
    "\n",
    "    # Only load and save the checkpoint on the main process (global_rank == 0)\n",
    "    if trainer.global_rank == 0:\n",
    "        # Load the best checkpoint\n",
    "        best_model_path = checkpoint_callback.best_model_path\n",
    "        best_model = SAE.load_from_checkpoint(best_model_path)\n",
    "        autoencoders[n_hidden] = best_model\n",
    "\n",
    "        print(f\"Best model loaded from {best_model_path}\")\n",
    "        print(f\"Best loss: {checkpoint_callback.best_model_score.item():.6f}\")\n",
    "\n",
    "        # Save model to out_path\n",
    "        model_save_path = os.path.join(out_path, f\"sae_waymo_n{n_hidden}.pth\")\n",
    "        torch.save(best_model.state_dict(), model_save_path)\n",
    "        print(f\"Best model saved to {model_save_path}\")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Training for 10.000 epochs - SAE\n",
    "(Google Colab T4-instance; seed not set / cuda=12.5 / torch=2.6.0+cu124 / sklearn=2.0.2 / python=3.11.11 (final))\n",
    "\n",
    "| Hidden Dim | Epoch         | Total Loss   | L1 Loss       | L2 Loss        | Total Reconst Loss |\n",
    "|------------|---------------|--------------|---------------|----------------|--------------------|\n",
    "| 512        | 9805/10000    | 4.005656276  | 1.524447083   | 8270.697265625 | 0.001645120210014  |\n",
    "| 256        | 9845/10000    | 3.724161590  | 1.376968503   | 7823.977050781 | 0.001388887991197  |\n",
    "| 128        | 9820/10000    | 4.139010770  | 1.556326985   | 8608.9453125   | 0.001653907238506  |\n",
    "| 64         | 9348/10000    | 4.561335734  | 1.892843366   | 8894.974609375 | 0.001926084747538  |\n",
    "| 32         | 9864/10000    | 7.141473430  | 3.902811527   | 10795.541015625| 0.004311752039939  |\n",
    "| 16         | 9956/10000    | 17.441959654 | 13.368986130  | 13576.573242188| 0.014228038489819  |\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from future_motion.utils.interpretability.control_vectors import fit_control_vector\n",
    "\n",
    "\n",
    "# Create control vectors with SAE\n",
    "SAE_control_vectors = {}\n",
    "for hidden_dims, autoencoder in autoencoders.items():\n",
    "    SAE_control_vectors[hidden_dims] = dict()\n",
    "    for key in contrastive_embs.keys():\n",
    "        print(key)\n",
    "        cv = fit_control_vector(contrastive_embs[key], autoencoder=autoencoder, verbose_explained_variance=True)\n",
    "        SAE_control_vectors[hidden_dims][key] = cv\n",
    "        torch.save(torch.tensor(cv), f=f\"{out_path}/sae{hidden_dims}_{key}_layer{idx_layer}.pt\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "future-motion",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
