{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pathlib\n",
    "import subprocess\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch as th\n",
    "from cleanba.environments import BoxobanConfig\n",
    "\n",
    "from learned_planners.interp.collect_dataset import DatasetStore\n",
    "from learned_planners.interp.utils import load_jax_model_to_torch, play_level, plotly_feature_vis, save_video\n",
    "from learned_planners.policies import download_policy_from_huggingface\n",
    "\n",
    "on_cluster = os.path.exists(\"/training\")\n",
    "\n",
    "MODEL_PATH_IN_REPO = \"drc33/bkynosqi/cp_2002944000/\"  # DRC(3, 3) 2B checkpoint\n",
    "MODEL_PATH = download_policy_from_huggingface(MODEL_PATH_IN_REPO)\n",
    "LP_DIR = pathlib.Path.cwd().parent.parent\n",
    "if on_cluster:\n",
    "    BOXOBAN_CACHE = pathlib.Path(\"/training/.sokoban_cache/\")\n",
    "else:\n",
    "    BOXOBAN_CACHE = LP_DIR / \"training/.sokoban_cache/\"\n",
    "\n",
    "\n",
    "difficulty = \"unfiltered\"\n",
    "split = \"valid\"\n",
    "\n",
    "boxo_cfg = BoxobanConfig(\n",
    "    cache_path=BOXOBAN_CACHE,\n",
    "    num_envs=1,\n",
    "    max_episode_steps=80,\n",
    "    min_episode_steps=80,\n",
    "    asynchronous=False,\n",
    "    tinyworld_obs=True,\n",
    "    difficulty=difficulty,\n",
    "    split=split,\n",
    ")\n",
    "boxo_env = boxo_cfg.make()\n",
    "cfg_th, policy_th = load_jax_model_to_torch(MODEL_PATH, boxo_cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Future Path Probe Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_lens import SAE\n",
    "from safetensors import safe_open\n",
    "\n",
    "\n",
    "def load_sae(path):\n",
    "    sae = SAE.load_from_pretrained(path)\n",
    "    sae.eval()\n",
    "    path = path if isinstance(path, pathlib.Path) else pathlib.Path(path)\n",
    "    with safe_open(path / \"sparsity.safetensors\", \"pt\", device=\"cpu\") as f:\n",
    "        log_sparsity = f.get_tensor(\"sparsity\")\n",
    "\n",
    "    top_activating_features = th.argsort(log_sparsity, descending=True).tolist()\n",
    "    return sae, top_activating_features\n",
    "\n",
    "\n",
    "if on_cluster:\n",
    "    wandb_ids = [\"ho6ob1tk\"]\n",
    "    sae_files = []\n",
    "    for wandb_id, sae_info in wandb_ids:\n",
    "        command = f\"/training/findsae.sh {wandb_id}\"\n",
    "        file_path = subprocess.run(command, shell=True, capture_output=True, text=True).stdout\n",
    "        file_path = file_path.strip()\n",
    "        sae_files.append(file_path)\n",
    "else:\n",
    "    sae_files = [\"k8_layer2_final\"]\n",
    "    sae_files = [LP_DIR / \"sae\" / file for file in sae_files]\n",
    "\n",
    "for file_name in sae_files:\n",
    "    sae, top_activating_features = load_sae(file_name)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import datetime\n",
    "from functools import partial\n",
    "\n",
    "from learned_planners.interp.utils import predict\n",
    "\n",
    "lfi, li = 0, 7\n",
    "thinking_steps = 0\n",
    "reset_opts = {\"level_file_idx\": lfi, \"level_idx\": li}\n",
    "obs = boxo_env.reset(options=reset_opts)[0]\n",
    "img = np.transpose(obs.squeeze(), (1, 2, 0))\n",
    "plt.imshow(img)\n",
    "plt.show()\n",
    "obs = th.tensor(obs)\n",
    "ds_cache = DatasetStore(\n",
    "    None,\n",
    "    obs,\n",
    "    th.tensor([0]),\n",
    "    False,\n",
    "    th.tensor([0]),\n",
    "    th.zeros(1),\n",
    "    {},\n",
    ")\n",
    "\n",
    "out = play_level(\n",
    "    boxo_env,\n",
    "    policy_th=policy_th,\n",
    "    reset_opts={\"level_file_idx\": lfi, \"level_idx\": li},\n",
    "    thinking_steps=thinking_steps,\n",
    "    internal_steps=False,\n",
    "    sae=sae,\n",
    ")\n",
    "all_obs = out.obs.squeeze(1)\n",
    "ds = DatasetStore(None, all_obs[thinking_steps:], out.rewards, out.solved, out.acts, th.zeros(len(all_obs)), {})\n",
    "\n",
    "thinking_steps_dir = f\"ts{thinking_steps}/\" if thinking_steps > 0 else \"\"\n",
    "sae_acts = out.sae_outs.detach()\n",
    "sae_acts = sae_acts.permute(0, 3, 1, 2)\n",
    "print(sae_acts.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "start = 0\n",
    "num_features = 15\n",
    "feature_labels = [f\"F{top_activating_features[i]}\" for i in range(start,start+num_features)]\n",
    "plotly_feature_vis(sae_acts[:, top_activating_features[start:start+num_features]], out.obs.squeeze(1), feature_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# r = th.randn(2,3,10)\n",
    "r[..., [[0, 1], [4, 4], [6, 6], [8, 8]]].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r[:, :, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# denominator = th.max((sae_acts[:, 165] > 0).sum(dim=(1, 2)), th.tensor(1))\n",
    "denominator = 1\n",
    "up_acts = sae_acts[:, 165].sum(dim=(1, 2)) / denominator\n",
    "print(list(zip(up_acts, range(len(up_acts)))))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Causal Intervention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "action_features = [[304], [187], [244], [385]]\n",
    "magnitude = [0.5, 1, 0.5, 0.3]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 471: R, L\n",
    "ft_idx = action_features[1]\n",
    "feature = sae.W_dec[ft_idx].detach().sum(dim=0)\n",
    "def patch_in_feature(cache, hook):\n",
    "    cache += magnitude * feature.unsqueeze(0)[..., None, None]\n",
    "    return cache\n",
    "\n",
    "fwd_hooks = [\n",
    "    (\n",
    "        f\"features_extractor.cell_list.2.hook_h.{pos}.{int_pos}\",\n",
    "        # partial(patch_in_box_direction, layer=layer, h_or_c=h_or_c_idx),\n",
    "        patch_in_feature,\n",
    "    )\n",
    "    for pos in range(1)\n",
    "    for int_pos in range(3)\n",
    "]\n",
    "\n",
    "steered_reset_opts = {\"level_file_idx\": 0, \"level_idx\": 2}\n",
    "steered_out = play_level(\n",
    "    boxo_env,\n",
    "    policy_th=policy_th,\n",
    "    reset_opts=steered_reset_opts,\n",
    "    thinking_steps=0,\n",
    "    fwd_hooks=fwd_hooks,\n",
    "    # hook_steps=range(11),\n",
    "    hook_steps=-1,\n",
    "    max_steps=10,\n",
    "    internal_steps=False,\n",
    "    sae=sae,\n",
    ")\n",
    "steered_obs = steered_out.obs.squeeze(1)\n",
    "steered_sae_acts = steered_out.sae_outs.detach()\n",
    "steered_sae_acts = steered_sae_acts.permute(0, 3, 1, 2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_features = 15\n",
    "feature_labels = [f\"F{top_activating_features[i]}\" for i in range(num_features)]\n",
    "plotly_feature_vis(steered_sae_acts[:, top_activating_features[:num_features]], steered_out.obs.squeeze(1), feature_labels)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Moving around in circles in empty env"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "action_features = [[304], [187], [244], [385]]\n",
    "magnitude = [0.5, 1, 0.5, 0.3]\n",
    "\n",
    "walls = [(i, 0) for i in range(10)] + [(i, 9) for i in range(10)] + [(0, i) for i in range(1, 9)] + [(9, i) for i in range(1, 9)]\n",
    "steered_reset_opts = dict(walls=walls, boxes=[], targets=[(1, 1)], player=(5, 5))\n",
    "\n",
    "timestep = 0\n",
    "\n",
    "def move_in_circles_hook(cache, hook):\n",
    "    global timestep\n",
    "    phase = (timestep // (2 * 3)) % 4\n",
    "    phase_to_action = [0, 3, 1, 2] # U, R, D, L\n",
    "    action = phase_to_action[phase]\n",
    "\n",
    "    # action = 0\n",
    "\n",
    "    # print(f\"Phase: {phase}, Action: {action}\")\n",
    "    feature = sae.W_dec[action_features[action]].detach().sum(dim=0)\n",
    "    cache += magnitude[action] * feature.unsqueeze(0)[..., None, None]\n",
    "    timestep += 1\n",
    "    return cache\n",
    "\n",
    "fwd_hooks = [\n",
    "    (\n",
    "        f\"features_extractor.cell_list.2.hook_h.{pos}.{int_pos}\",\n",
    "        move_in_circles_hook,\n",
    "    )\n",
    "    for pos in range(1)\n",
    "    for int_pos in range(3)\n",
    "]\n",
    "\n",
    "steered_out = play_level(\n",
    "    boxo_env,\n",
    "    policy_th=policy_th,\n",
    "    reset_opts=steered_reset_opts,\n",
    "    thinking_steps=0,\n",
    "    fwd_hooks=fwd_hooks,\n",
    "    hook_steps=-1,\n",
    "    max_steps=24,\n",
    "    internal_steps=False,\n",
    "    sae=sae,\n",
    ")\n",
    "steered_obs = steered_out.obs.squeeze(1)\n",
    "steered_sae_acts = steered_out.sae_outs.detach()\n",
    "steered_sae_acts = steered_sae_acts.permute(0, 3, 1, 2)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_features = 15\n",
    "feature_labels = [f\"F{top_activating_features[i]}\" for i in range(num_features)]\n",
    "plotly_feature_vis(steered_sae_acts[:, top_activating_features[:num_features]], steered_out.obs.squeeze(1), feature_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "save_video(\"circles.mp4\", steered_obs)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
