{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Analysis of an SAE trained on the activations of a Sokoban model\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import pathlib\n",
    "import subprocess\n",
    "from dataclasses import dataclass, field\n",
    "from typing import Any, Dict\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "import torch as th\n",
    "from cleanba.environments import BoxobanConfig\n",
    "from gym_sokoban.envs.sokoban_env import CHANGE_COORDINATES\n",
    "from sklearn.metrics import precision_recall_curve, precision_recall_fscore_support\n",
    "from torch.nn import functional as F\n",
    "\n",
    "from learned_planners.interp.collect_dataset import DatasetStore\n",
    "from learned_planners.interp.utils import get_player_pos, load_jax_model_to_torch, play_level\n",
    "from learned_planners.policies import download_policy_from_huggingface\n",
    "\n",
    "on_cluster = os.path.exists(\"/training\")\n",
    "\n",
    "MODEL_PATH_IN_REPO = \"drc33/bkynosqi/cp_2002944000/\"  # DRC(3, 3) 2B checkpoint\n",
    "MODEL_PATH = download_policy_from_huggingface(MODEL_PATH_IN_REPO)\n",
    "LP_DIR = pathlib.Path.cwd().parent.parent\n",
    "if on_cluster:\n",
    "    BOXOBAN_CACHE = pathlib.Path(\"/training/.sokoban_cache/\")\n",
    "else:\n",
    "    BOXOBAN_CACHE = LP_DIR / \"training/.sokoban_cache/\"\n",
    "\n",
    "\n",
    "difficulty = \"medium\"\n",
    "split = \"valid\"\n",
    "\n",
    "boxo_cfg = BoxobanConfig(\n",
    "    cache_path=BOXOBAN_CACHE,\n",
    "    num_envs=1,\n",
    "    max_episode_steps=80,\n",
    "    min_episode_steps=80,\n",
    "    asynchronous=False,\n",
    "    tinyworld_obs=True,\n",
    "    difficulty=difficulty,\n",
    "    split=split,\n",
    ")\n",
    "boxo_env = boxo_cfg.make()\n",
    "cfg_th, policy_th = load_jax_model_to_torch(MODEL_PATH, boxo_cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sae_lens import SAE\n",
    "from safetensors import safe_open\n",
    "\n",
    "\n",
    "def load_sae(path):\n",
    "    sae = SAE.load_from_pretrained(path)\n",
    "    sae.eval()\n",
    "    path = path if isinstance(path, pathlib.Path) else pathlib.Path(path)\n",
    "    with safe_open(path / \"sparsity.safetensors\", \"pt\", device=\"cpu\") as f:\n",
    "        log_sparsity = f.get_tensor(\"sparsity\")\n",
    "\n",
    "\n",
    "    sorted_indices = th.argsort(log_sparsity, descending=True)\n",
    "    features_by_activation = sorted_indices.tolist()\n",
    "\n",
    "    # Return the SAE object and the sorted indices (original positions)\n",
    "    return sae, features_by_activation\n",
    "\n",
    "\n",
    "if on_cluster:\n",
    "    wandb_ids = [\"ho6ob1tk\"]\n",
    "    sae_files = []\n",
    "    for wandb_id, sae_info in wandb_ids:\n",
    "        command = f\"/training/findsae.sh {wandb_id}\"\n",
    "        file_path = subprocess.run(command, shell=True, capture_output=True, text=True).stdout\n",
    "        file_path = file_path.strip()\n",
    "        sae_files.append(file_path)\n",
    "else:\n",
    "    sae_files = [\"k8_layer2_final\"]\n",
    "    sae_files = [LP_DIR / \"sae\" / file for file in sae_files]\n",
    "\n",
    "\n",
    "# We use a single SAE so these are global variables\n",
    "sae, features_by_activation = load_sae(sae_files[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Select Level & thinking steps. Load level observations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_frame_image(img):\n",
    "    plt.subplots(figsize=(2,2)) \n",
    "    plt.imshow(img)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Class to contain the puzzle level and the SAE\n",
    "\n",
    "@dataclass\n",
    "class RunContainer:\n",
    "    lfi: int\n",
    "    li: int\n",
    "    thinking_steps: int = 0\n",
    "    show_image: bool = True\n",
    "    reset_opts: Dict[str, Any] = field(init=False)\n",
    "    model_obs: np.ndarray = field(init=False)\n",
    "    model_ds: Any = field(init=False)\n",
    "    sae_acts: np.ndarray = field(init=False)\n",
    "    num_features: int = field(init=False)\n",
    "    repeated_obs: np.ndarray = field(init=False)\n",
    "    activation_levels: np.ndarray = field(init=False)\n",
    "    enhanced_activation_levels: np.ndarray = field(init=False)\n",
    "    to_plot: np.ndarray = field(init=False)\n",
    "    num_frames: int = field(init=False)\n",
    "\n",
    "    def __post_init__(self):\n",
    "        self.reset_opts = {\"level_file_idx\": self.lfi, \"level_idx\": self.li}\n",
    "        self.model_obs = self._initialize_model_obs()\n",
    "        self._show_frame_image_if_needed()\n",
    "        self._initialize_model_ds()\n",
    "        self._process_sae_activations()\n",
    "\n",
    "    def _initialize_model_obs(self):\n",
    "        return boxo_env.reset(options=self.reset_opts)[0]\n",
    "\n",
    "    def _show_frame_image_if_needed(self):\n",
    "        if self.show_image:\n",
    "            img = np.transpose(self.model_obs.squeeze(), (1, 2, 0))\n",
    "            show_frame_image(img)\n",
    "\n",
    "    def _initialize_model_ds(self):\n",
    "        out = play_level(\n",
    "            boxo_env,\n",
    "            policy_th=policy_th,\n",
    "            reset_opts=self.reset_opts,\n",
    "            thinking_steps=self.thinking_steps,\n",
    "            internal_steps=False,\n",
    "            sae=sae,\n",
    "        )\n",
    "        all_obs = out.obs.squeeze(1)\n",
    "        self.model_ds = DatasetStore(None, all_obs[self.thinking_steps:], out.rewards, out.solved, out.acts, th.zeros(len(all_obs)), {})\n",
    "        self.sae_acts = out.sae_outs.detach().permute(0, 3, 1, 2)\n",
    "\n",
    "    def _process_sae_activations(self):\n",
    "        self.num_features = len(features_by_activation)\n",
    "        projected = self.sae_acts[:, features_by_activation]\n",
    "        normed = (projected - projected.min()) / (projected.max() - projected.min())\n",
    "\n",
    "        obs_transposed = np.transpose(self.model_ds.obs, (0, 2, 3, 1))\n",
    "        self._create_repeated_obs(obs_transposed)\n",
    "\n",
    "        cmap = plt.get_cmap(\"viridis\")\n",
    "        normed_extended = cmap(normed)[..., :3] * 255\n",
    "        temp_plot = np.concatenate([self.repeated_obs, normed_extended], axis=1)\n",
    "        to_plot_filtered = temp_plot[:, 1:, :, :, :].astype(np.uint8)\n",
    "        self.activation_levels = np.mean(to_plot_filtered, axis=-1)\n",
    "\n",
    "        self._enhance_activation_levels()\n",
    "        self._create_final_plot()\n",
    "\n",
    "    def _create_repeated_obs(self, obs_transposed):\n",
    "        if self.thinking_steps == 0:\n",
    "            self.repeated_obs = np.repeat(obs_transposed[:, None, :, :, :], 1, axis=0)\n",
    "        else:\n",
    "            repeated_first = np.repeat(obs_transposed[0:1], self.thinking_steps, axis=0)\n",
    "            self.repeated_obs = np.concatenate([repeated_first, obs_transposed], axis=0)\n",
    "            self.repeated_obs = self.repeated_obs[:, np.newaxis, ...]\n",
    "\n",
    "    def _enhance_activation_levels(self):\n",
    "        self.enhanced_activation_levels = np.power(self.activation_levels, 0.5)\n",
    "        self.enhanced_activation_levels = (self.enhanced_activation_levels - self.enhanced_activation_levels.min()) / (self.enhanced_activation_levels.max() - self.enhanced_activation_levels.min()) * 255\n",
    "\n",
    "    def _create_final_plot(self):\n",
    "        cmap = plt.get_cmap(\"viridis\")\n",
    "        normed_extended = cmap(self.enhanced_activation_levels / 255)[..., :3] * 255\n",
    "        self.to_plot = np.concatenate([self.repeated_obs, normed_extended], axis=1)\n",
    "        self.num_frames = self.to_plot.shape[0]\n",
    "\n",
    "    def remove_frames(self, num_to_remove):\n",
    "        self.repeated_obs = self.repeated_obs[num_to_remove:]\n",
    "        self.to_plot = self.to_plot[num_to_remove:]\n",
    "        self.num_frames = self.to_plot.shape[0]\n",
    "        self.activation_levels = self.activation_levels[num_to_remove:]\n",
    "        self.enhanced_activation_levels = self.enhanced_activation_levels[num_to_remove:]\n",
    "\n",
    "    def grid_size(self) -> int:\n",
    "        return self.model_ds.obs.shape[2]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Level 10 generates the same path with and without thinking steps - making comparison easier\n",
    "run1 = RunContainer(lfi=0, li=10, thinking_steps=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot model and top K SAE features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot the ranks. A rank of <0 is the model. A rank of >=0 is an SAE rank \n",
    "def create_model_sae_plot_core(runs_and_ranks, this_plot_rows, this_plot_cols, num_frames, this_scale=1) -> plt.Figure:\n",
    "    num_subplots = len(runs_and_ranks)\n",
    "    this_plot = np.zeros((num_frames, num_subplots, *runs_and_ranks[0][0].to_plot.shape[2:]))\n",
    "    \n",
    "    for i, (run, rank) in enumerate(runs_and_ranks):\n",
    "        this_plot[:num_frames, i] = run.to_plot[:num_frames, rank+1]\n",
    "    \n",
    "    fig = px.imshow(\n",
    "        this_plot,\n",
    "        facet_col=1,\n",
    "        animation_frame=0,\n",
    "        facet_col_wrap=this_plot_cols,\n",
    "        binary_string=True,\n",
    "        width=50 + 125 * this_plot_cols * this_scale,\n",
    "        height=50 + 150 * this_plot_rows * this_scale,\n",
    "    )\n",
    "\n",
    "    unique_runs = []\n",
    "    for run, _ in runs_and_ranks:\n",
    "        if run not in unique_runs:\n",
    "            unique_runs.append(run)\n",
    "    \n",
    "    fig.layout.annotations = ()\n",
    "    for i, (run, rank) in enumerate(runs_and_ranks):\n",
    "        run_index = unique_runs.index(run) + 1\n",
    "        \n",
    "        if rank < 0:\n",
    "            annotation_text = f\"Model Rn {run_index}\"\n",
    "        else:\n",
    "            orig_feat_idx = features_by_activation[rank]\n",
    "            annotation_text = f\"Ft {orig_feat_idx} Rk {rank} Rn {run_index}\"\n",
    "        \n",
    "        col = i % this_plot_cols\n",
    "        row = i // this_plot_cols\n",
    "        \n",
    "        # Calculate the x-coordinate for the annotation\n",
    "        x_start = col / this_plot_cols\n",
    "        x_end = (col + 1) / this_plot_cols\n",
    "        x_center = (x_start + x_end) / 2\n",
    "        \n",
    "        # Calculate the y-coordinate for the annotation\n",
    "        y_top = 1 - 1.05*row / this_plot_rows\n",
    "        y_bottom = 1 - (1.05*row + 1) / this_plot_rows\n",
    "        y_height = y_top - y_bottom\n",
    "        y_position = y_bottom + y_height\n",
    "        \n",
    "        fig.add_annotation(\n",
    "            x=x_center,\n",
    "            y=y_position,\n",
    "            text=annotation_text,\n",
    "            showarrow=False,\n",
    "            font=dict(size=10),\n",
    "            xref=\"paper\",\n",
    "            yref=\"paper\",\n",
    "            xanchor=\"center\",\n",
    "            yanchor=\"bottom\"\n",
    "        )\n",
    "\n",
    "    fig.update_layout(\n",
    "        margin=dict(l=20, r=20, t=50, b=20),  # Increased top margin\n",
    "        sliders=[{\n",
    "            'currentvalue': {\n",
    "                'visible': True,\n",
    "                'prefix': \"Frame: \",\n",
    "                'xanchor': 'right',\n",
    "                'font': {'size': 20}\n",
    "            },\n",
    "            'steps': [{'args': [[frame], {'frame': {'duration': 300, 'redraw': True}, 'mode': 'immediate'}],\n",
    "                       'label': str(frame), 'method': 'animate'} for frame in range(num_frames)]\n",
    "        }]\n",
    "    )\n",
    "\n",
    "    fig.update_layout(coloraxis_showscale=False)\n",
    "    fig.update_xaxes(showticklabels=False)\n",
    "    fig.update_yaxes(showticklabels=False)\n",
    "    return fig\n",
    "    \n",
    "\n",
    "def create_model_sae_plot(run, num_plot_rows, num_plot_cols) -> plt.Figure:\n",
    "    num_features_shown = num_plot_cols * num_plot_rows - 1\n",
    "    runs_and_ranks = [(run,-1)] + [(run, i) for i in range(num_features_shown)]\n",
    "    \n",
    "    return create_model_sae_plot_core(runs_and_ranks, num_plot_rows, num_plot_cols, num_frames=run.num_frames), num_features_shown\n",
    "\n",
    "\n",
    "def create_model_sae_rank_plot(runs_and_ranks, plot_cols=9, scale=2, num_frames=-1) -> plt.Figure:\n",
    "    num_features_shown = len(runs_and_ranks)\n",
    "    this_plot_cols = min(plot_cols, num_features_shown)\n",
    "    this_plot_rows = (num_features_shown + this_plot_cols - 1) // this_plot_cols\n",
    "\n",
    "    if num_frames == -1:\n",
    "        num_frames = runs_and_ranks[0][0].num_frames\n",
    "    \n",
    "    return create_model_sae_plot_core(runs_and_ranks, this_plot_rows, this_plot_cols, num_frames=num_frames, this_scale=scale)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_plot_rows = 5\n",
    "num_plot_cols= 9\n",
    "\n",
    "fig, num_features_shown = create_model_sae_plot(run1, num_plot_rows, num_plot_cols) # Not all SAE features are shown\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Search for features where specific squares are activated / not activated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# These correspond to gym_sokoban > envs > render_utils.py > room_to_tiny_world_rgb constants \n",
    "unsolved_target_G = 126 # Unsolved targets are Pink. Pink has 126 in Green channel\n",
    "solved_target_G = 95  # Solved targets are Orange. Orange has 95 in Green channel\n",
    "unsolved_block_G = 121 # Unsolved targets are Brown. Brown has 121 in Green channel\n",
    "player_G = 212 # Player and player on target both have 212 in Green channel"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Returns the locations, in each frame, of the model squares that have the specified color(s). Shows a sample frame\n",
    "def calc_locations(run, color1:int, color2:int=-1):\n",
    "    main_model_G = run.to_plot[:, 0, :, :, 1].astype(np.uint8)  # Is (frames, height, width, Green RGB channel)\n",
    "\n",
    "    if color2 == -1:\n",
    "        color2 = color1\n",
    "    mask = np.logical_or(main_model_G == color1, main_model_G == color2)\n",
    "\n",
    "    the_locations = [] # Will be [frames, 4 to 8, 2]\n",
    "    for frame in range(run.num_frames):\n",
    "        frame_targets = np.column_stack(np.where(mask[frame]))\n",
    "        the_locations.append(frame_targets)\n",
    "\n",
    "    if run.show_image:\n",
    "        show_frame = 14\n",
    "        print(f\"Model squares with G={color1} or G={color2} in frame {show_frame} are shown in red:\")\n",
    "        img = run.to_plot[show_frame, 0, :, :, :].astype(np.uint8).copy()\n",
    "        img[mask[show_frame]] = [255, 0, 0] \n",
    "        show_frame_image(img)\n",
    "\n",
    "    return the_locations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def feature_rank(the_rank):\n",
    "    if the_rank < 0:\n",
    "        return \"Neg Rank\"\n",
    "    if the_rank >= len(features_by_activation):\n",
    "        return f\"Rank OOB{the_rank}\"\n",
    "\n",
    "    return f\"Rank: {the_rank} (F{features_by_activation[the_rank]})\"\n",
    "\n",
    "# Find best feature matching a location based on binary activation \n",
    "def calc_feature_scores_by_location_deprecated(run, title_prefix, test_locations, cutoff=64):\n",
    "\n",
    "    gts = np.zeros((run.num_frames, run.grid_size(), run.grid_size()), dtype=bool)\n",
    "    for frame, locations in enumerate(test_locations):\n",
    "        gts[frame, locations[:, 0], locations[:, 1]] = True\n",
    "    \n",
    "    sae_preds = run.activation_levels / 255.0  # Normalize \n",
    "    sae_preds = (sae_preds >= cutoff/255.0).astype(int) # Binarize\n",
    "\n",
    "    best_rank = -1\n",
    "    best_f1 = 0\n",
    "    best_offset = (0, 0)\n",
    "    \n",
    "    for rank in range(sae_preds.shape[1]):\n",
    "        for dy, dx in [(0, 0), (0, 1), (0, -1), (1, 0), (-1, 0)]:\n",
    "            pred_offset = np.roll(sae_preds[:, rank], (dy, dx), axis=(1, 2))\n",
    "            \n",
    "            # Calc all frames together, ensuring support used for precision calculation is consistent\n",
    "            pred_flat = pred_offset.reshape(-1)\n",
    "            gt_flat = gts.reshape(-1)\n",
    "            \n",
    "            precision, recall, f1, _ = precision_recall_fscore_support(gt_flat, pred_flat, average='binary', zero_division=0)\n",
    "            \n",
    "            if f1 > best_f1:\n",
    "                best_f1 = f1\n",
    "                best_rank = rank\n",
    "                best_offset = (dy, dx)\n",
    "                best_precision = precision\n",
    "                best_recall = recall\n",
    "\n",
    "            if f1 > 0.9 and run.show_image:\n",
    "                print(f\"{title_prefix}: Candidate {feature_rank(rank)}, Offset ({dy},{dx}), F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}\")\n",
    "\n",
    "    print(f\"{title_prefix}: Best Feature {feature_rank(best_rank)}, Offset: {best_offset}, F1: {best_f1:.4f}, Precision: {best_precision:.4f}, Recall: {best_recall:.4f}\")\n",
    "    \n",
    "    best_pred_offset = np.roll(sae_preds[:, best_rank], best_offset, axis=(1, 2))\n",
    "    best_pred_flat = best_pred_offset.reshape(-1)\n",
    "    gt_flat = gts.reshape(-1)\n",
    "    \n",
    "    precision, recall, _ = precision_recall_curve(gt_flat, best_pred_flat)\n",
    "    \n",
    "    if run.show_image:\n",
    "        fig, ax = plt.subplots(figsize=(4, 3))\n",
    "        ax.plot(recall, precision)\n",
    "        ax.set_xlabel(\"Recall\")\n",
    "        ax.set_ylabel(\"Precision\")\n",
    "        ax.set_title(f\"PR Curve for {feature_rank(best_rank)}, Offset {best_offset}, F1 {best_f1:.4f}\")\n",
    "        ax.set_xlim([0.0, 1.0])\n",
    "        ax.set_ylim([0.0, 1.05])\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "    \n",
    "    return best_rank, best_offset, best_f1, best_precision, best_recall"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_runs_feature_scores_by_location(run_locations_list, title_prefix, cutoff=64):\n",
    "    gts_list = []\n",
    "    sae_preds_list = []\n",
    "    \n",
    "    for run, test_locations in run_locations_list:\n",
    "        gts_run = np.zeros((run.num_frames, run.grid_size(), run.grid_size()), dtype=bool)\n",
    "        for frame, locations in enumerate(test_locations):\n",
    "            gts_run[frame, locations[:, 0], locations[:, 1]] = True\n",
    "\n",
    "        sae_preds_run = run.activation_levels / 255.0  # Normalize \n",
    "        sae_preds_run = (sae_preds_run >= cutoff/255.0).astype(int)  # Binarize\n",
    "\n",
    "        gts_list.append(gts_run)\n",
    "        sae_preds_list.append(sae_preds_run)\n",
    "    \n",
    "    # Concatenate ground truths and predictions across all runs\n",
    "    gts = np.concatenate(gts_list, axis=0)\n",
    "    sae_preds = np.concatenate(sae_preds_list, axis=0)\n",
    "    \n",
    "    best_rank = -1\n",
    "    best_f1 = 0\n",
    "    best_offset = (0, 0)\n",
    "    \n",
    "    for rank in range(sae_preds.shape[1]):\n",
    "        for dy, dx in [(0, 0), (0, 1), (0, -1), (1, 0), (-1, 0)]:\n",
    "            pred_offset = np.roll(sae_preds[:, rank], (dy, dx), axis=(1, 2))\n",
    "            \n",
    "            # Calc all frames together, ensuring support used for precision calculation is consistent\n",
    "            pred_flat = pred_offset.reshape(-1)\n",
    "            gt_flat = gts.reshape(-1)\n",
    "            \n",
    "            precision, recall, f1, _ = precision_recall_fscore_support(gt_flat, pred_flat, average='binary', zero_division=0)\n",
    "            \n",
    "            if f1 > best_f1:\n",
    "                best_f1 = f1\n",
    "                best_rank = rank\n",
    "                best_offset = (dy, dx)\n",
    "                best_precision = precision\n",
    "                best_recall = recall\n",
    "\n",
    "            if f1 > 0.9 and run_locations_list[0][0].show_image:\n",
    "                print(f\"{title_prefix}: Candidate {feature_rank(rank)}, Offset ({dy},{dx}), F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}\")\n",
    "\n",
    "    print(f\"{title_prefix}: Best Feature {feature_rank(best_rank)}, Offset: {best_offset}, F1: {best_f1:.4f}, Precision: {best_precision:.4f}, Recall: {best_recall:.4f}\")\n",
    "    \n",
    "    best_pred_offset = np.roll(sae_preds[:, best_rank], best_offset, axis=(1, 2))\n",
    "    best_pred_flat = best_pred_offset.reshape(-1)\n",
    "    gt_flat = gts.reshape(-1)\n",
    "    \n",
    "    precision_curve, recall_curve, _ = precision_recall_curve(gt_flat, best_pred_flat)\n",
    "    \n",
    "    if run_locations_list[0][0].show_image:\n",
    "        fig, ax = plt.subplots(figsize=(4, 3))\n",
    "        ax.plot(recall_curve, precision_curve)\n",
    "        ax.set_xlabel(\"Recall\")\n",
    "        ax.set_ylabel(\"Precision\")\n",
    "        ax.set_title(f\"{title_prefix} PR Curve for {feature_rank(best_rank)}, Offset {best_offset}, F1 {best_f1:.4f}\")\n",
    "        ax.set_xlim([0.0, 1.0])\n",
    "        ax.set_ylim([0.0, 1.05])\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "    \n",
    "    return best_rank, best_offset, best_f1, best_precision, best_recall \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find best feature matching a location based on binary activation \n",
    "def calc_feature_scores_by_location(run, title_prefix, test_locations, cutoff=64):\n",
    "    return calc_runs_feature_scores_by_location([(run,target_locations)], title_prefix, cutoff)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Static \"Target squares\" concept "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "target_locations = calc_locations(run1, unsolved_target_G, solved_target_G)\n",
    "target_feat_rank_run1, _, _, _, _ = calc_feature_scores_by_location(run1, \"Target\", target_locations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ### \"Unsolved targets and blocks\" concept"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unsolved_locations = calc_locations(run1, unsolved_target_G, unsolved_block_G)\n",
    "unsolved_feat_rank_run1, _, _, _, _ = calc_feature_scores_by_location(run1, \"Unsolved\", unsolved_locations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ### \"Solved blocks on targets\" concept\n",
    " Feat 179 (Rank 43) has maximum score. While Feat 278 has a high score, it implements the \"Target squares\" concept, so Feat 179 is the only high-scoring candidate    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "solved_locations = calc_locations(run1, solved_target_G)\n",
    "solved_feat_rank_run1, _, _, _, _ = calc_feature_scores_by_location(run1, \"Solved\", solved_locations)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    " ### Show Model with \"Target squares\", \"Unsolved targets and blocks\", \"Solved blocks on targets\" features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "runs_and_ranks = [(run1,-1), (run1,target_feat_rank_run1), (run1,unsolved_feat_rank_run1), (run1,solved_feat_rank_run1)]\n",
    "fig = create_model_sae_rank_plot(runs_and_ranks)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Search for features where many squares are activated / not-activated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# These move directions correspond to gym_sokoban > envs > sokoban_env.py constants\n",
    "up_direction = 0\n",
    "down_direction = 1\n",
    "left_direction = 2\n",
    "right_direction = 3\n",
    "noop_direction = 4\n",
    "num_directions = 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### \"Predict Agent Move Direction\" concept"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_player_pos(run):\n",
    "\n",
    "    # Get the player position in each frame \n",
    "    player_pos=[]\n",
    "    for frame in range(run.num_frames):\n",
    "        obs = np.array(run.repeated_obs[frame,0,:,:])\n",
    "        player_pos.append(get_player_pos(obs))\n",
    "\n",
    "    # Calculate the direction player will move in the next frame\n",
    "    player_direction = []\n",
    "    for frame in range(1, run.num_frames):\n",
    "        prev_pos = player_pos[frame-1]\n",
    "        this_pos = player_pos[frame]\n",
    "        if prev_pos[0] < this_pos[0]:\n",
    "            player_direction += [down_direction]\n",
    "        elif prev_pos[0] > this_pos[0]:\n",
    "            player_direction += [up_direction]\n",
    "        elif prev_pos[1] < this_pos[1]:\n",
    "            player_direction += [right_direction]\n",
    "        elif prev_pos[1] > this_pos[1]:\n",
    "            player_direction += [left_direction]\n",
    "        else:\n",
    "            player_direction += [noop_direction] # No movement\n",
    "    player_direction += [noop_direction] # Last step has no future direction\n",
    "\n",
    "    if run.show_image:\n",
    "        # Show first 4 steps of model with \"next direction\" in red square to demonstrate correctness\n",
    "        fig, axes = plt.subplots(1, 4, figsize=(8, 2.5))  # 1 row, 4 columns\n",
    "        fig.suptitle(\"Player next direction in red\")\n",
    "        for frame in range(4):\n",
    "            img = run.to_plot[run.thinking_steps+frame, 0, :, :, :].astype(np.uint8).copy()\n",
    "            player_delta = CHANGE_COORDINATES[player_direction[run.thinking_steps+frame]]\n",
    "            player_next_pos = tuple(p + d for p, d in zip(player_pos[run.thinking_steps+frame], player_delta))\n",
    "            img[player_next_pos] = [255, 0, 0] \n",
    "            axes[frame].imshow(img)\n",
    "            axes[frame].set_title(f\"Frame {frame}\")\n",
    "            axes[frame].axis('off')  # Turn off axis numbers\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "\n",
    "    return player_pos, player_direction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "def one_hot(actions: th.Tensor):\n",
    "    assert len(actions.shape) == 1\n",
    "    return F.one_hot(actions, num_classes=5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_runs_feature_direction_f1_scores(run_player_directions_list, cutoff):\n",
    "    ft_name = [\"UP\", \"DOWN\", \"LEFT\", \"RIGHT\"]\n",
    "    \n",
    "    gts_list = []\n",
    "    sae_preds_list = []\n",
    "    \n",
    "    for run, player_directions in run_player_directions_list:\n",
    "        gts = one_hot(th.from_numpy(np.array(player_directions)).long())\n",
    "        sae_preds = run.activation_levels / run.activation_levels.max(axis=0, keepdims=True)\n",
    "        sae_preds = np.mean(sae_preds, axis=(2, 3))  \n",
    "        sae_preds[sae_preds < cutoff/255] = 0   \n",
    "        gts_list.append(gts)\n",
    "        sae_preds_list.append(sae_preds)\n",
    "    \n",
    "    # Concatenate ground truths and predictions across all runs\n",
    "    gts = th.cat(gts_list, dim=0)\n",
    "    sae_preds = np.concatenate(sae_preds_list, axis=0)\n",
    "    \n",
    "    fig, axs = plt.subplots(2, 2, figsize=(6, 5))\n",
    "    fig.suptitle(\"Precision-Recall Curves for Actions\")\n",
    "    \n",
    "    best_features = []\n",
    "    best_f1s = []\n",
    "    best_ps = []\n",
    "    best_rs = []\n",
    "    best_ts = []\n",
    "    return_results = (len(run_player_directions_list) == 1)\n",
    "    show_image = run_player_directions_list[0][0].show_image\n",
    "\n",
    "    for direction in range(4):\n",
    "        best_f1 = 0\n",
    "        best_feature_rank = -1\n",
    "        best_p, best_r, best_t = None, None, None\n",
    "        \n",
    "        for rank in range(sae_preds.shape[1]):\n",
    "            p, r, t = precision_recall_curve(gts[:, direction].numpy(), sae_preds[:, rank])\n",
    "            f1 = 2 * p * r / (p + r + 1e-10)\n",
    "            max_f1_for_feature = np.max(f1)\n",
    "            \n",
    "            if max_f1_for_feature > best_f1:\n",
    "                best_f1 = max_f1_for_feature\n",
    "                best_feature_rank = rank\n",
    "                best_p, best_r, best_t = p, r, t\n",
    "        \n",
    "        # Find the index of the best threshold for the best feature\n",
    "        best_threshold_idx = np.argmax(2 * best_p * best_r / (best_p + best_r + 1e-10))\n",
    "        \n",
    "        print(f\"Action {ft_name[direction]}: Best Feature {feature_rank(best_feature_rank)}, F1: {best_f1:.4f}, Precision: {best_p[best_threshold_idx]:.4f}, Recall: {best_r[best_threshold_idx]:.4f}, Threshold {best_t[best_threshold_idx]:.4f}\")\n",
    "        \n",
    "        if return_results:\n",
    "            best_features.append(best_feature_rank)\n",
    "            best_f1s.append(best_f1)\n",
    "            best_ps.append(best_p[best_threshold_idx])\n",
    "            best_rs.append(best_r[best_threshold_idx])\n",
    "            best_ts.append(best_t[best_threshold_idx])\n",
    "    \n",
    "        if show_image:\n",
    "            # Plot the precision-recall curve\n",
    "            ax = axs[direction // 2, direction % 2]\n",
    "            ax.plot(best_r, best_p, label=f'Best F1: {best_f1:.2f}')\n",
    "            ax.set_title(ft_name[direction])\n",
    "            ax.set_xlabel('Recall')\n",
    "            ax.set_ylabel('Precision')\n",
    "            ax.legend()\n",
    "\n",
    "    if show_image:\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "    plt.close()\n",
    "    \n",
    "\n",
    "    if return_results:\n",
    "        return best_features, best_f1s, best_ps, best_rs, best_ts\n",
    "    else:\n",
    "        return None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_feature_direction_f1_scores(run, player_directions, cutoff):\n",
    "   return calc_runs_feature_direction_f1_scores([(run,player_directions)], cutoff)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "player_pos1, player_direction1 = calc_player_pos(run1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "best_ranks, best_f1s, best_ps, best_rs, best_ts = calc_feature_direction_f1_scores(run1, player_direction1, 64)\n",
    "up_feat_rank_run1, down_feat_rank_run1, left_feat_rank_run1, right_feat_rank_run1 = best_ranks"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Shown Model with found \"Up\", \"Down\", \"Left\", \"Right\" features "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = create_model_sae_rank_plot([(run1,-1), (run1,up_feat_rank_run1),(run1,down_feat_rank_run1),(run1,left_feat_rank_run1),(run1,right_feat_rank_run1)])\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Run 2: Same model but with thinking steps "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "run2 = RunContainer(lfi=run1.lfi, li=run1.li, thinking_steps=8, show_image=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "player_pos2, player_direction2 = calc_player_pos(run2)\n",
    "best_features, best_f1s, best_ps, best_rs, best_ts = calc_feature_direction_f1_scores(run2, player_direction2, 64)\n",
    "up_feat_rank_run2, down_feat_rank_run2, left_feat_rank_run2, right_feat_rank_run2 = best_features[0], best_features[1], best_features[2], best_features[3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare Run1 (NoThinking) and Run2 (Thinking) SAE features in early steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "identified_features_run1 = [\n",
    "    [target_feat_rank_run1, \"Targets\"], \n",
    "    [unsolved_feat_rank_run1, \"Unsolved\"],\n",
    "    [solved_feat_rank_run1, \"Solved\"],\n",
    "    [up_feat_rank_run1, \"Agent up\"],\n",
    "    [down_feat_rank_run1, \"Agent Down\"],\n",
    "    [left_feat_rank_run1, \"Agent Left\"],\n",
    "    [right_feat_rank_run1, \"Agent Right\"]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_feature_similarity(activation_levels1, activation_levels2):\n",
    "    num_steps = run2.thinking_steps \n",
    "\n",
    "    act_run1 = activation_levels1[:num_steps] # [frames, features, 10, 10]\n",
    "    act_run2 = activation_levels2[:num_steps]  \n",
    "    \n",
    "    act_run1_flat = act_run1.reshape(num_steps, run1.num_features, -1)  # [8, 512, 100]\n",
    "    act_run2_flat = act_run2.reshape(num_steps, run2.num_features, -1)  \n",
    "\n",
    "    avg_run1 = np.mean(act_run1_flat, axis=-1) # [8, 512]\n",
    "    avg_run2 = np.mean(act_run2_flat, axis=-1) # [8, 512]\n",
    "    \n",
    "    # Calculate similarity metric (e.g., correlation) between runs - are the activations similar?\n",
    "    similarity = np.array([np.corrcoef(act_run1_flat[:, i].flatten(), act_run2_flat[:, i].flatten())[0, 1] for i in range(run1.num_features)])\n",
    "    \n",
    "    return avg_run1, avg_run2, similarity"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_feature_similarity(similarity, thinking:bool):\n",
    "   top_features = range(num_features_shown)\n",
    "   title_suffix = \"Run1-First vs Run2-Post-Thinking Steps\" if thinking else \"First Steps\"\n",
    "\n",
    "   fig, ax = plt.subplots(figsize=(15, 5))   \n",
    "   ax.bar(top_features, similarity[top_features])\n",
    "   ax.set_xlabel('Feature Rank')\n",
    "   ax.set_ylabel('Feature correlation (similarity)')\n",
    "   ax.set_title(f'Feature correlation (similarity) across runs ({title_suffix})')\n",
    "    \n",
    "   x = np.arange(len(top_features))\n",
    "   ax.set_xticks(x)\n",
    "   ax.set_xticklabels([f'Feat {feature_rank(i)}' for i in top_features], rotation=90)\n",
    "    \n",
    "   plt.tight_layout()\n",
    "   plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def top_changing_features(avg_frame_run1, avg_frame_run2, offset=0, top_n=8):\n",
    "    avg_run1 = np.mean(avg_frame_run1, axis=0) # [512]\n",
    "    avg_run2 = np.mean(avg_frame_run2, axis=0) # [512]\n",
    "    \n",
    "    percent_change = (avg_run2 - avg_run1) / avg_run1 * 100\n",
    "    \n",
    "    def get_feature_description(idx):\n",
    "        for rank, desc in identified_features_run1:\n",
    "            if rank == idx:\n",
    "                return f\" ({desc})\"\n",
    "        return \"\"\n",
    "\n",
    "    print(f\"Top {top_n} features where run2 showed a significant % decrease compared to run1:\")\n",
    "    decreased_indices = np.argsort(percent_change)[:top_n]\n",
    "    for i, idx in enumerate(decreased_indices, 1):\n",
    "        orig_feat_idx = features_by_activation[idx]\n",
    "        pct_diff = percent_change[idx]\n",
    "        desc = get_feature_description(idx)\n",
    "        print(f\"{i}. Feature {orig_feat_idx} (Rank {idx}): % Change = {pct_diff:.1f}%, Run1: {avg_run1[idx]:.2f}, Run2: {avg_run2[idx]:.2f} {desc}\")\n",
    "    \n",
    "    print(f\"\\nTop {top_n} features with the greatest % activation increase from run1 to run2:\")\n",
    "    increased_indices = np.argsort(percent_change)[-top_n:][::-1]    \n",
    "    for i, idx in enumerate(increased_indices, 1):\n",
    "        orig_feat_idx = features_by_activation[idx]\n",
    "        pct_diff = percent_change[idx]\n",
    "        desc = get_feature_description(idx)\n",
    "        print(f\"{i}. Feature {orig_feat_idx} (Rank {idx}): % Increase = {pct_diff:.1f}%, Run1: {avg_run1[idx]:.2f}, Run2: {avg_run2[idx]:.2f} {desc}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compare features in first 8 steps of each run"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run1_moves = player_direction1[:run2.thinking_steps]\n",
    "run2_moves = player_direction2[run2.thinking_steps:run2.thinking_steps*2]\n",
    "\n",
    "# 0=Up, 1=Down, 2=Left, 3=Right, 4\n",
    "print(\"Run 1 agent moves:\", player_direction1, run1_moves ) \n",
    "print(\"Run 2 agent moves:\", player_direction2, run2_moves )\n",
    "\n",
    "comparison_useful = run1_moves == run2_moves\n",
    "if not comparison_useful:\n",
    "    print()\n",
    "    print( \"WARNING: Run1 (no-thinking) and Run2 (thinking) generated different paths so comparing feature level activations is less useful\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "if (up_feat_rank_run1 != up_feat_rank_run2) or (down_feat_rank_run1 != down_feat_rank_run2) or (left_feat_rank_run1 != left_feat_rank_run2) or (right_feat_rank_run1 != right_feat_rank_run2):\n",
    "    print(\"WARNING: The features best aligned with concepts Up, Down, Left, Right differ between run1 and run2.\")\n",
    "    print(f\"{up_feat_rank_run1} vs {up_feat_rank_run2}, {down_feat_rank_run1} vs {down_feat_rank_run2}, {left_feat_rank_run1} vs {left_feat_rank_run2}, {right_feat_rank_run1} vs {right_feat_rank_run2}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualizing the SAE feature differences between two (no-thinking and thinking) runs\n",
    "\n",
    "- The 1st row, shows frames 0 to 7 of run1 (model has zero thinking steps). \n",
    "- The 2nd row, shows frames 0 to 7 of run2 (model has 8 thinking steps).\n",
    "- The 3rd row, shows frames 8 to 15 of run2 (that is the first 8 moving steps). \n",
    "- Each column shows one feature.\n",
    "\n",
    "Insights:\n",
    "- Some features are constant in row 2 (while row 1 and row 3) over the 8 slider steps.\n",
    "- For levels where the model generates the same agent movements, the 3rd row allows comparison between the no-thinking and thinking models \"agent movements\" over the 8 slider steps."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_frames = run2.thinking_steps # Gives slider range\n",
    "num_ranks = 8 # Gives num SAE features plotted horizontally \n",
    "    \n",
    "# Create run3, a copy of run2, but with the first thinking_steps frames removed. Used to create third row of plots\n",
    "run3 = RunContainer(lfi=run2.lfi, li=run2.li, thinking_steps=num_frames, show_image=False)\n",
    "run3.remove_frames(num_frames)\n",
    "run3.thinking_steps = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize first 8 frames of no-thinking / thinking runs for features with (draft) known purposes\n",
    "\n",
    "- The 1st row (no thinking steps, frames 0 to 7) the agent move features change.\n",
    "- The 2nd row (with thinking steps, frames 0 to 7) is largely constant.\n",
    "- The 3rd row (with thinking steps, frames 8 to 15) the agent move features change."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Identified features:\", identified_features_run1)\n",
    "num_cols = 8 # Gives num SAE features plotted horizontally \n",
    "num_rows = 3\n",
    "\n",
    "runs_and_ranks = []\n",
    "for row in range(num_rows):\n",
    "    run = run1 if row == 0 else run2 if row == 1 else run3\n",
    "    runs_and_ranks += [(run,-1)]\n",
    "    for identified_feature in identified_features_run1:\n",
    "        runs_and_ranks += [(run,identified_feature[0])]\n",
    "\n",
    "fig = create_model_sae_rank_plot(runs_and_ranks, num_cols, scale=1.2, num_frames=num_frames)\n",
    "fig.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Visualize features in no-thinking / thinking runs that are active in thinking steps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Manually selected 24 of first 71 features.\n",
    "# As per https://docs.google.com/spreadsheets/d/1XCmvMbXh399_t0j8N3R99fhSL9n6KuRB8HTJwN2fSg0/edit?usp=sharing \n",
    "feat_sets = [[11,12,18,20,24,30,31,33],\n",
    "    [38,40,41,44,45,48,55,56],\n",
    "    [57,58,60,61,62,63,71,83]]\n",
    "\n",
    "num_cols = len(feat_sets[0])+1\n",
    "for feat_set in range(3):\n",
    "    interesting_features1 = feat_sets[feat_set]\n",
    "    runs_and_ranks = []\n",
    "    for row in range(num_rows):\n",
    "        run = run1 if row == 0 else run2 if row == 1 else run3\n",
    "        runs_and_ranks += [(run,-1)]\n",
    "        for interesting_feature in interesting_features1:\n",
    "            runs_and_ranks += [(run,interesting_feature)]\n",
    "\n",
    "    fig = create_model_sae_rank_plot(runs_and_ranks, num_cols, scale=1.1, num_frames=num_frames)\n",
    "    fig.show()            "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Plot many features in blocks of 8 for manual inspection\n",
    "\n",
    "Manual inspection, using data below, led to the above interesting feature sets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def show_many_sae_features():\n",
    "    num_cols = 8 # Gives num SAE features plotted horizontally \n",
    "\n",
    "    for j in range(30):\n",
    "        offset = j * run2.thinking_steps\n",
    "\n",
    "        runs_and_ranks = []\n",
    "        for row in range(num_rows):\n",
    "            run = run1 if row == 0 else run2 if row == 1 else run3\n",
    "            runs_and_ranks += [(run,-1)]\n",
    "            for i in range(num_cols):\n",
    "                runs_and_ranks += [(run,offset+i)]\n",
    "\n",
    "        print( f\"{j}th comparison. Row 1: Run 1 steps {offset}-{offset+num_ranks-1}. Row 2: Run 2 steps {offset}-{offset+num_ranks-1}. Row 3: Run 2 steps {offset+num_ranks}-{offset+num_ranks*2-1}\")\n",
    "        fig = create_model_sae_rank_plot(runs_and_ranks, num_ranks+1, scale=1.1, num_frames=run2.thinking_steps)\n",
    "        fig.show()\n",
    "\n",
    "if False:\n",
    "    show_many_sae_features()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Search for the concepts target/solved/unsolved/up/down/right/left across many runs\n",
    "Write up is here: https://docs.google.com/document/d/16tKtPqgNxmtvezZEendkJxVg414Vt0EqxiEKVblBlIo/edit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_num = 0\n",
    "runs = []\n",
    "\n",
    "if False:\n",
    "    for thinkingsteps in [0, 4]:\n",
    "        for lfi in [0, 1, 2]:\n",
    "            for li in [10, 13, 17, 21]:  # Random levels\n",
    "                run_num += 1\n",
    "                run = RunContainer(lfi=lfi, li=li, thinking_steps=thinkingsteps, show_image=False)\n",
    "                runs.append(run)\n",
    "else:\n",
    "    for li in range(250):\n",
    "        run_num += 1\n",
    "        run = RunContainer(lfi=0, li=li, thinking_steps=0, show_image=False)\n",
    "        runs.append(run)    \n",
    "        print(f\"Created run {run_num}\")\n",
    "\n",
    "runs[0].show_image = True\n",
    "print(f\"Created {run_num} runs\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_player_directions_list = []\n",
    "for run in runs:\n",
    "    _, player_direction = calc_player_pos(run)\n",
    "    run_player_directions_list.append((run, player_direction))\n",
    "\n",
    "calc_runs_feature_direction_f1_scores(run_player_directions_list, 64)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_locations_list = []\n",
    "for run in runs:\n",
    "    run_locations_list.append((run, calc_locations(run, unsolved_target_G, solved_target_G)))\n",
    "_, _, _, _, _ = calc_runs_feature_scores_by_location(run_locations_list, \"Target\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_locations_list = []\n",
    "for run in runs:\n",
    "    run_locations_list.append((run, calc_locations(run, unsolved_target_G, unsolved_block_G)))\n",
    "_, _, _, _, _ = calc_runs_feature_scores_by_location(run_locations_list, \"Unsolved\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "run_locations_list = []\n",
    "for run in runs:\n",
    "    run_locations_list.append((run, calc_locations(run, solved_target_G)))\n",
    "_, _, _, _, _ = calc_runs_feature_scores_by_location(run_locations_list, \"Solved\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
