{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pathlib\n",
    "import pickle\n",
    "import subprocess\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch as th\n",
    "from cleanba.environments import BoxobanConfig, SokobanConfig\n",
    "\n",
    "from learned_planners.interp.collect_dataset import DatasetStore\n",
    "from learned_planners.interp.train_probes import TrainOn\n",
    "from learned_planners.interp.utils import load_jax_model_to_torch, play_level, save_video\n",
    "from learned_planners.policies import download_policy_from_huggingface\n",
    "\n",
    "on_cluster = os.path.exists(\"/training\")\n",
    "\n",
    "MODEL_PATH_IN_REPO = \"drc33/bkynosqi/cp_2002944000/\"  # DRC(3, 3) 2B checkpoint\n",
    "MODEL_PATH = download_policy_from_huggingface(MODEL_PATH_IN_REPO)\n",
    "LP_DIR = pathlib.Path.cwd().parent.parent\n",
    "if on_cluster:\n",
    "    BOXOBAN_CACHE = pathlib.Path(\"/training/.sokoban_cache/\")\n",
    "else:\n",
    "    BOXOBAN_CACHE = LP_DIR / \"training/.sokoban_cache/\"\n",
    "\n",
    "\n",
    "difficulty = \"unfiltered\"\n",
    "split = \"valid\"\n",
    "\n",
    "boxo_cfg = BoxobanConfig(\n",
    "    cache_path=BOXOBAN_CACHE,\n",
    "    num_envs=1,\n",
    "    max_episode_steps=120,\n",
    "    min_episode_steps=120,\n",
    "    asynchronous=False,\n",
    "    tinyworld_obs=True,\n",
    "    difficulty=difficulty,\n",
    "    split=split,\n",
    ")\n",
    "boxo_env = boxo_cfg.make()\n",
    "cfg_th, policy_th = load_jax_model_to_torch(MODEL_PATH, boxo_cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Future Path Probe Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "if on_cluster:\n",
    "    wandb_ids_and_infos = [\n",
    "        (\"dirnsbf3\", TrainOn(layer=-1, dataset_name=\"agents_future_direction_map\")),\n",
    "        (\"vb6474rg\", TrainOn(layer=-1, dataset_name=\"boxes_future_direction_map\")),\n",
    "        (\"42qs0bh1\", TrainOn(layer=-1, dataset_name=\"next_target\")),\n",
    "        (\"6e1w1bb6\", TrainOn(layer=-1, dataset_name=\"next_box\")),\n",
    "    ]\n",
    "    probe_files, probe_infos = [], []\n",
    "    for wandb_id, probe_info in wandb_ids_and_infos:\n",
    "        command = f\"/training/findprobe.sh {wandb_id}\"\n",
    "        file_name = subprocess.run(command, shell=True, capture_output=True, text=True).stdout\n",
    "        file_name = file_name.strip()\n",
    "        probe_files.append(file_name)\n",
    "        probe_infos.append(probe_info)\n",
    "else:\n",
    "    probe_name_infos = [\n",
    "        (\"agents_future_direction_map_l_all.pkl\", TrainOn(layer=-1, dataset_name=\"agents_future_direction_map\")),\n",
    "        (\"boxes_future_direction_map_l_all.pkl\", TrainOn(layer=-1, dataset_name=\"boxes_future_direction_map\")),\n",
    "        # (\"boxes_future_direction_map_sparse_l_1.pkl\", TrainOn(layer=1, dataset_name=\"boxes_future_direction_map\")),\n",
    "        (\"next_target_l_all.pkl\", TrainOn(layer=-1, dataset_name=\"next_target\")),\n",
    "        (\"next_box_l_all.pkl\", TrainOn(layer=-1, dataset_name=\"next_box\")),\n",
    "    ]\n",
    "    probe_files = [LP_DIR / \"probes\" / file for file, _ in probe_name_infos]\n",
    "    probe_infos = [info for _, info in probe_name_infos]\n",
    "\n",
    "probes = []\n",
    "for file_name in probe_files:\n",
    "    with open(file_name, \"rb\") as f:\n",
    "        probes.append(pickle.load(f))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Causal Intervention"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Next box"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from learned_planners.interp.utils import predict\n",
    "import datetime\n",
    "\n",
    "lfi, li = 0, 7\n",
    "reset_opts = {\"level_file_idx\": lfi, \"level_idx\": li}\n",
    "obs = boxo_env.reset(options=reset_opts)[0]\n",
    "img = np.transpose(obs.squeeze(), (1, 2, 0))\n",
    "plt.imshow(img)\n",
    "plt.show()\n",
    "obs = th.tensor(obs)\n",
    "ds_cache = DatasetStore(\n",
    "    None,\n",
    "    obs,\n",
    "    th.tensor([0]),\n",
    "    False,\n",
    "    th.tensor([0]),\n",
    "    th.zeros(1),\n",
    "    {},\n",
    ")\n",
    "box_positions = ds_cache.get_box_positions().numpy()[0]\n",
    "\n",
    "focus_box = box_positions[1]\n",
    "box_probe = probes[1]\n",
    "box_probe_info = probe_infos[1]\n",
    "print(\"Focus box:\", focus_box, \"box_probe_info:\", box_probe_info)\n",
    "coef = th.tensor(box_probe.coef_.squeeze())\n",
    "per_segment_neurons = coef.shape[0] // 6\n",
    "magnitude = 0.1\n",
    "\n",
    "\n",
    "def patch_in_box(cache, hook, layer, h_or_c):\n",
    "    segment_idx = 2 * layer + h_or_c\n",
    "    cache[:, :, focus_box[0], focus_box[1]] += (\n",
    "        coef[per_segment_neurons * segment_idx : per_segment_neurons * (segment_idx + 1)] * magnitude\n",
    "    )\n",
    "\n",
    "    return cache\n",
    "\n",
    "\n",
    "state = policy_th.recurrent_initial_state(1)\n",
    "eps_start = th.tensor([0.0], dtype=th.bool)\n",
    "hook_h_cs = [\"hook_h\", \"hook_c\"]\n",
    "fwd_hooks = [\n",
    "    (\n",
    "        f\"features_extractor.cell_list.{layer}.{h_or_c_name}.{pos}.{int_pos}\",\n",
    "        partial(patch_in_box, layer=layer, h_or_c=h_or_c_idx),\n",
    "    )\n",
    "    for pos in range(1)\n",
    "    for int_pos in range(3)\n",
    "    for layer in range(3)\n",
    "    for h_or_c_idx, h_or_c_name in enumerate(hook_h_cs)\n",
    "]\n",
    "steered_obs, steered_acts, steered_rewards, steered_solved, steered_probe_outputs = play_level(\n",
    "    boxo_env,\n",
    "    policy_th=policy_th,\n",
    "    reset_opts=reset_opts,\n",
    "    probes=[box_probe],\n",
    "    probe_train_ons=[box_probe_info],\n",
    "    thinking_steps=0,\n",
    "    fwd_hooks=fwd_hooks,\n",
    "    hook_steps=-1,\n",
    "    max_steps=30,\n",
    ")\n",
    "timestamp = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "name = f\"steered_videos/{box_probe_info.dataset_name}_{difficulty}_mag{magnitude}_{lfi}_{li}_{timestamp}.mp4\"\n",
    "save_video(name, steered_obs, steered_probe_outputs, all_probe_infos=[box_probe_info])\n",
    "# (best_act_before, best_val_before, best_log_prob_before, _), cache_before = policy_th.run_with_cache(obs, state, eps_start)\n",
    "# (best_act_rwh, best_val, best_log_prob, _) = policy_th.run_with_hooks(obs, state, eps_start, fwd_hooks=fwd_hooks)\n",
    "\n",
    "\n",
    "# pred_before = predict(cache_before, box_probe, box_probe_info, 0)\n",
    "# pred_after = predict(cache, box_probe, box_probe_info, 0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### Agents future direction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from learned_planners.interp.utils import plt_obs_with_direction_probe\n",
    "\n",
    "boxo_cfg = BoxobanConfig(\n",
    "    cache_path=BOXOBAN_CACHE,\n",
    "    num_envs=1,\n",
    "    max_episode_steps=120,\n",
    "    min_episode_steps=120,\n",
    "    asynchronous=False,\n",
    "    tinyworld_obs=True,\n",
    "    difficulty=\"unfiltered\",\n",
    "    split=\"valid\",\n",
    ")\n",
    "boxo_env = boxo_cfg.make()\n",
    "cfg_th, policy_th = load_jax_model_to_torch(MODEL_PATH, boxo_cfg)\n",
    "lfi, li = 0, 0\n",
    "reset_opts = {\"level_file_idx\": lfi, \"level_idx\": li}\n",
    "obs = boxo_env.reset(options=reset_opts)[0]\n",
    "img = np.transpose(obs.squeeze(), (1, 2, 0))\n",
    "# plt.imshow(img)\n",
    "\n",
    "\n",
    "my_direction_map = -1 * np.ones((10, 10))\n",
    "U, D, L, R = 0, 1, 2, 3\n",
    "\n",
    "# actions that the agent takes\n",
    "directions = [L, U, U, L, U, R, R, R, U, R, D, L, L, L, U, U, R, R, U, R, D, L, L, L, D, D, D, D, D, D, \n",
    "              L, U, U, U, L, U, R, R, R, R, U, R, D, L, L, L, L, L, D, D, L, L, D, R, R, D, R, U, U, U,\n",
    "              L, U, R, R, D, R, U, U, L, U, R, R]\n",
    "\n",
    "def construct_direction_map(directions, start_square):\n",
    "    # print(\"start_square\", start_square)\n",
    "    my_direction_map = -1 * np.ones((10, 10))\n",
    "    next_box = None\n",
    "    for d in directions:\n",
    "        if my_direction_map[start_square] == -1:\n",
    "            my_direction_map[start_square] = d\n",
    "        if d == U:\n",
    "            start_square = (start_square[0] - 1, start_square[1])\n",
    "        elif d == D:\n",
    "            start_square = (start_square[0] + 1, start_square[1])\n",
    "        elif d == L:\n",
    "            start_square = (start_square[0], start_square[1] - 1)\n",
    "        elif d == R:\n",
    "            start_square = (start_square[0], start_square[1] + 1)\n",
    "        if next_box is None:\n",
    "            next_box = start_square\n",
    "    return my_direction_map, next_box\n",
    "obs = th.tensor(obs)\n",
    "ds_cache = DatasetStore(\n",
    "    None,\n",
    "    obs,\n",
    "    th.tensor([0]),\n",
    "    False,\n",
    "    th.tensor([0]),\n",
    "    th.zeros(1),\n",
    "    {},\n",
    ")\n",
    "start_pos = tuple(ds_cache.get_agent_positions()[0].tolist())\n",
    "my_direction_map, _ = construct_direction_map(directions, start_pos)\n",
    "print(start_pos)\n",
    "fig, ax = plt.subplots()\n",
    "ax.imshow(img)\n",
    "plt_obs_with_direction_probe(my_direction_map, my_direction_map, ax=ax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "from learned_planners.interp.utils import predict\n",
    "import datetime\n",
    "\n",
    "show_internal_steps_until = 10\n",
    "\n",
    "agent_directions_idx = [i for i, (name, _) in enumerate(probe_name_infos) if \"agents_future_direction_map\" in name]\n",
    "agent_directions_idx = agent_directions_idx[0]\n",
    "agent_directions_probe = probes[agent_directions_idx]\n",
    "agent_directions_probe_info = probe_infos[agent_directions_idx]\n",
    "\n",
    "# box_probe_idx = [i for i, (name, _) in enumerate(probe_name_infos) if \"next_box\" in name]\n",
    "# box_probe = probes[box_probe_idx[0]]\n",
    "# box_probe_info = probe_infos[box_probe_idx[0]]\n",
    "\n",
    "coef = th.tensor(agent_directions_probe.coef_)\n",
    "num_layers = 3 if agent_directions_probe_info.layer == -1 else 1\n",
    "n_segments = (2 * num_layers)\n",
    "bd_per_segment_neurons = coef.shape[1] // n_segments\n",
    "# bd_magnitude = 0.2\n",
    "alpha = 6.0\n",
    "my_map_as_idx = (th.tensor(my_direction_map) + 1).to(th.int64)\n",
    "coef_from_my_map = th.index_select(coef, 0, my_map_as_idx.view(-1)).view(10, 10, -1).permute(2, 0, 1)\n",
    "\n",
    "def patch_in_box_direction(cache, hook, layer, h_or_c):\n",
    "    global directions, start_pos\n",
    "    segment_idx = 2 * layer + h_or_c if agent_directions_probe_info.layer == -1 else h_or_c\n",
    "    my_direction_map, next_pos = construct_direction_map(directions, start_pos)\n",
    "    if  segment_idx + 1 == n_segments and \"0.2\" in hook.name and len(directions) > 0:\n",
    "        directions.pop(0)\n",
    "        start_pos = next_pos\n",
    "    my_map_as_idx = (th.tensor(my_direction_map) + 1).to(th.int64)\n",
    "    coef_from_my_map = th.index_select(coef, 0, my_map_as_idx.view(-1)).view(10, 10, -1).permute(2, 0, 1)\n",
    "\n",
    "    # cache += coef_from_my_map[bd_per_segment_neurons * segment_idx : bd_per_segment_neurons * (segment_idx + 1)].unsqueeze(0) * bd_magnitude\n",
    "    steering_vector = coef_from_my_map[bd_per_segment_neurons * segment_idx : bd_per_segment_neurons * (segment_idx + 1)].unsqueeze(0)\n",
    "    dot_product = th.sum(steering_vector * cache, dim=1, keepdim=True) # (s, b, 10, 10)\n",
    "    magnitude = th.max(th.tensor(0.0), ((alpha / n_segments) - dot_product) / th.sum(steering_vector ** 2, dim=1, keepdim=True))\n",
    "    cache += magnitude * steering_vector\n",
    "    return cache\n",
    "\n",
    "hook_h_cs = [\"hook_h\", \"hook_c\"]\n",
    "fwd_hooks = [\n",
    "    (\n",
    "        f\"features_extractor.cell_list.{layer}.{h_or_c_name}.{pos}.{int_pos}\",\n",
    "        partial(patch_in_box_direction, layer=layer, h_or_c=h_or_c_idx),\n",
    "    )\n",
    "    for pos in range(1)\n",
    "    for int_pos in range(3)\n",
    "    for layer in (range(3) if agent_directions_probe_info.layer == -1 else [agent_directions_probe_info.layer])\n",
    "    for h_or_c_idx, h_or_c_name in enumerate(hook_h_cs)\n",
    "]\n",
    "\n",
    "\n",
    "# nb_coef = th.tensor(box_probe.coef_.squeeze())\n",
    "# nb_num_layers = 3 if box_probe_info.layer == -1 else 1\n",
    "# nb_n_segments = (2 * nb_num_layers)\n",
    "# nb_per_segment_neurons = nb_coef.shape[0] // nb_n_segments\n",
    "\n",
    "# # nb_magnitude = 0.2\n",
    "# nb_alpha = 60.0\n",
    "# focus_box = (4, 5)\n",
    "\n",
    "# def patch_in_box(cache, hook, layer, h_or_c):\n",
    "#     segment_idx = 2 * layer + h_or_c if box_probe_info.layer == -1 else h_or_c\n",
    "#     steering_vector = nb_coef[nb_per_segment_neurons * segment_idx : nb_per_segment_neurons * (segment_idx + 1)].unsqueeze(0)\n",
    "#     dot_product = th.sum(steering_vector * cache[:, :, focus_box[0], focus_box[1]], dim=1, keepdim=True)  # (s, b)\n",
    "#     magnitude = th.max(th.tensor(0.0), (2 * (nb_alpha / nb_n_segments) - dot_product) / th.sum(nb_coef ** 2, dim=0, keepdim=True))\n",
    "#     cache[:, :, focus_box[0], focus_box[1]] += magnitude * steering_vector\n",
    "#     cache[:, :, :, :] -= magnitude * steering_vector.unsqueeze(2).unsqueeze(3) / 2\n",
    "#     # cache[:, :, focus_box[0], focus_box[1]] += (\n",
    "#     #     nb_coef[nb_per_segment_neurons * segment_idx : nb_per_segment_neurons * (segment_idx + 1)] * nb_magnitude\n",
    "#     # )\n",
    "#     # # every other position apart from the focus box should be subtracted\n",
    "#     # cache[:, :, :, :] -= (\n",
    "#     #     nb_coef[nb_per_segment_neurons * segment_idx : nb_per_segment_neurons * (segment_idx + 1)].unsqueeze(0).unsqueeze(2).unsqueeze(3)\n",
    "#     #     * nb_magnitude / 2  # broadcasting\n",
    "#     # )\n",
    "\n",
    "#     return cache\n",
    "#\n",
    "# fwd_hooks += [\n",
    "#     (\n",
    "#         f\"features_extractor.cell_list.{layer}.{h_or_c_name}.{pos}.{int_pos}\",\n",
    "#         partial(patch_in_box, layer=layer, h_or_c=h_or_c_idx),\n",
    "#     )\n",
    "#     for pos in range(1)\n",
    "#     for int_pos in range(3)\n",
    "#     for layer in (range(3) if box_probe_info.layer == -1 else [box_probe_info.layer])\n",
    "#     for h_or_c_idx, h_or_c_name in enumerate(hook_h_cs)\n",
    "# ]\n",
    "\n",
    "out = play_level(\n",
    "    boxo_env,\n",
    "    policy_th=policy_th,\n",
    "    reset_opts=reset_opts,\n",
    "    probes=probes,\n",
    "    probe_train_ons=probe_infos,\n",
    "    thinking_steps=0,\n",
    "    fwd_hooks=fwd_hooks,\n",
    "    # hook_steps=[0, 1, 2, 3],\n",
    "    max_steps=60,\n",
    "    internal_steps=show_internal_steps_until > 0,\n",
    ")\n",
    "steered_obs = out.obs.squeeze(1)\n",
    "timestamp = datetime.datetime.now().strftime(\"%Y%m%d-%H%M%S\")\n",
    "name = f\"steered_videos/{timestamp}_{agent_directions_probe_info.dataset_name}_{difficulty}_mag{alpha}_{lfi}_{li}.mp4\"\n",
    "save_video(name, steered_obs, out.probe_outs, all_probe_infos=probe_infos, show_internal_steps_until=show_internal_steps_until)\n",
    "# (best_act_before, best_val_before, best_log_prob_before, _), cache_before = policy_th.run_with_cache(obs, state, eps_start)\n",
    "# (best_act_rwh, best_val, best_log_prob, _) = policy_th.run_with_hooks(obs, state, eps_start, fwd_hooks=fwd_hooks)\n",
    "\n",
    "\n",
    "# pred_before = predict(cache_before, box_probe, box_probe_info, 0)\n",
    "# pred_after = predict(cache, box_probe, box_probe_info, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "my_map_as_idx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "coef_from_my_map.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_directions_probe.coef_.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualization on Custom Level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_room = 10\n",
    "# soko_env = SokobanConfig(num_envs=1, asynchronous=False, tinyworld_obs=True, max_episode_steps=120, dim_room=(dim_room, dim_room)).make()\n",
    "walls = [(0, i) for i in range(dim_room)] + [(dim_room - 1, i) for i in range(dim_room)]\n",
    "walls += [(i, 0) for i in range(1, dim_room - 1)] + [(i, dim_room - 1) for i in range(1, dim_room - 1)]\n",
    "boxes = [\n",
    "    (2, 3),\n",
    "    (1, 6),\n",
    "]\n",
    "targets = [(1, 3), (2, 2)]\n",
    "player = (3, 3)\n",
    "obs = boxo_env.reset(options=dict(walls=walls, boxes=boxes, targets=targets, player=player))[0]\n",
    "img = np.transpose(obs[0], (1, 2, 0))\n",
    "\n",
    "plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state0 = policy_th.recurrent_initial_state(1)\n",
    "all_actions = th.arange(4)\n",
    "act_map = [\"U\", \"D\", \"L\", \"R\"]\n",
    "values, log_prob, ent = policy_th.evaluate_actions(th.tensor(obs), all_actions, state0, th.tensor([0.0], dtype=th.bool))\n",
    "print(log_prob, act_map[log_prob.argmax().item()])\n",
    "\n",
    "(best_act, best_val, best_log_prob, state1), cache0 = policy_th.run_with_cache(\n",
    "    th.tensor(obs), state0, th.tensor([0.0], dtype=th.bool)\n",
    ")\n",
    "values, log_prob, ent = policy_th.evaluate_actions(th.tensor(obs), all_actions, state1, th.tensor([0.0], dtype=th.bool))\n",
    "print(log_prob, act_map[log_prob.argmax().item()])\n",
    "\n",
    "(best_act, best_val, best_log_prob, state2), cache1 = policy_th.run_with_cache(\n",
    "    th.tensor(obs), state1, th.tensor([0.0], dtype=th.bool)\n",
    ")\n",
    "values, log_prob, ent = policy_th.evaluate_actions(th.tensor(obs), all_actions, state2, th.tensor([0.0], dtype=th.bool))\n",
    "print(log_prob, act_map[log_prob.argmax().item()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cache0.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dim_room = 10\n",
    "soko_env = SokobanConfig(\n",
    "    num_envs=1, asynchronous=False, tinyworld_obs=True, max_episode_steps=120, dim_room=(dim_room, dim_room)\n",
    ").make()\n",
    "walls = [(0, i) for i in range(dim_room)] + [(dim_room - 1, i) for i in range(dim_room)]\n",
    "walls += [(i, 0) for i in range(1, dim_room - 1)] + [(i, dim_room - 1) for i in range(1, dim_room - 1)]\n",
    "boxes = [\n",
    "    (2, 3),\n",
    "    (1, 6),\n",
    "]\n",
    "targets = [(1, 3), (2, 2)]\n",
    "player = (4, 3)\n",
    "obs = soko_env.reset(options=dict(walls=walls, boxes=boxes, targets=targets, player=player))[0]\n",
    "img = np.transpose(obs[0], (1, 2, 0))\n",
    "\n",
    "plt.imshow(img)\n",
    "plt.show()\n",
    "\n",
    "# state0 = policy_th.recurrent_initial_state(1)\n",
    "# all_actions = th.arange(4)\n",
    "# act_map = [\"U\", \"D\", \"L\", \"R\"]\n",
    "# values, log_prob, ent = policy_th.evaluate_actions(th.tensor(obs), all_actions, state0, th.tensor([0.0], dtype=th.bool))\n",
    "# print(log_prob, act_map[log_prob.argmax().item()])\n",
    "\n",
    "# (best_act, best_val, best_log_prob, state1), cache0 = policy_th.run_with_cache(th.tensor(obs), state0, th.tensor([0.0], dtype=th.bool))\n",
    "# obs, _, _, _, _ = soko_env.step([best_act.item()])\n",
    "# values, log_prob, ent = policy_th.evaluate_actions(th.tensor(obs), all_actions, state1, th.tensor([0.0], dtype=th.bool))\n",
    "# print(log_prob, act_map[log_prob.argmax().item()])\n",
    "\n",
    "# (best_act, best_val, best_log_prob, state2), cache1 = policy_th.run_with_cache(th.tensor(obs), state1, th.tensor([0.0], dtype=th.bool))\n",
    "# obs, _, _, _, _ = soko_env.step([best_act.item()])\n",
    "# values, log_prob, ent = policy_th.evaluate_actions(th.tensor(obs), all_actions, state2, th.tensor([0.0], dtype=th.bool))\n",
    "# print(log_prob, act_map[log_prob.argmax().item()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "o, r, _, _, _ = soko_env.step([best_act.item()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "state0 = policy_th.recurrent_initial_state(1)\n",
    "all_actions = th.arange(4)\n",
    "act_map = [\"U\", \"D\", \"L\", \"R\"]\n",
    "values, log_prob, ent = policy_th.evaluate_actions(th.tensor(obs), all_actions, state0, th.tensor([0.0], dtype=th.bool))\n",
    "print(log_prob, act_map[log_prob.argmax().item()])\n",
    "\n",
    "(best_act, best_val, best_log_prob, state1), cache0 = policy_th.run_with_cache(\n",
    "    th.tensor(obs), state0, th.tensor([0.0], dtype=th.bool)\n",
    ")\n",
    "layer = 0\n",
    "state1[layer] = (state0[layer][0], state0[layer][1])\n",
    "layer = 1\n",
    "state1[layer] = (state0[layer][0], state0[layer][1])\n",
    "layer = 2\n",
    "state1[layer] = (state1[layer][0], state0[layer][1])\n",
    "\n",
    "values, log_prob, ent = policy_th.evaluate_actions(th.tensor(obs), all_actions, state1, th.tensor([0.0], dtype=th.bool))\n",
    "print(log_prob, act_map[log_prob.argmax().item()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from learned_planners.interp.train_probes import TrainOn\n",
    "from learned_planners.interp.utils import plt_obs_with_position_probe, predict\n",
    "\n",
    "dim_room = 10\n",
    "walls = [(0, i) for i in range(dim_room)] + [(dim_room - 1, i) for i in range(dim_room)]\n",
    "walls += [(i, 0) for i in range(1, dim_room - 1)] + [(i, dim_room - 1) for i in range(1, dim_room - 1)]\n",
    "boxes = [(2, 3), (2, 4), (2, 5), (2, 6)]\n",
    "targets = [(1, 3), (1, 4), (1, 5), (1, 6)]\n",
    "player = (8, 3)\n",
    "obs = boxo_env.reset(options=dict(walls=walls, boxes=boxes, targets=targets, player=player))[0]\n",
    "if isinstance(obs, th.Tensor):\n",
    "    obs = obs.numpy()\n",
    "img = np.transpose(obs[0], (1, 2, 0))\n",
    "\n",
    "state0 = policy_th.recurrent_initial_state(1)\n",
    "\n",
    "(best_act, best_val, best_log_prob, state1), cache0 = policy_th.run_with_cache(\n",
    "    th.tensor(obs), state0, th.tensor([0.0], dtype=th.bool)\n",
    ")\n",
    "\n",
    "train_on = TrainOn(layer=2)\n",
    "probe_preds = predict(cache0, probe, train_on, grid_wise=True, step=0, internal_steps=False)\n",
    "probe_preds = probe_preds.squeeze(0).squeeze(0)\n",
    "\n",
    "plt_obs_with_position_probe(img, probe_preds)\n",
    "plt.show()\n",
    "\n",
    "train_on_best = TrainOn(layer=-1)\n",
    "probe_best_preds = predict(cache0, probe_best, train_on_best, grid_wise=True, step=0, internal_steps=False)\n",
    "probe_best_preds = probe_best_preds.squeeze(0).squeeze(0)\n",
    "\n",
    "plt_obs_with_position_probe(img, probe_best_preds)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
