import os
import pathlib
import pickle
import subprocess
import time
from dataclasses import dataclass
from typing import Callable, Optional, Tuple, Union

import matplotlib
import matplotlib.cm as cm
import numpy as np
import plotly.express as px
import torch as th
from cleanba import cleanba_impala
from cleanba import convlstm as cleanba_convlstm
from cleanba.environments import BoxobanConfig, convert_to_cleanba_config
from gym_sokoban.envs.sokoban_env import CHANGE_COORDINATES
from gymnasium.spaces import MultiDiscrete
from matplotlib import animation
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from sklearn.multioutput import MultiOutputClassifier
from transformer_lens.utils import Slice

from learned_planners.activation_fns import IdentityActConfig, ReLUConfig
from learned_planners.convlstm import CompileConfig, ConvConfig, ConvLSTMCellConfig, ConvLSTMOptions
from learned_planners.policies import ConvLSTMPolicyConfig, NetArchConfig, download_policy_from_huggingface


def jax_to_th(x):
    return th.tensor(np.asarray(x))


def conv_args_process(conv_args):
    d = conv_args.__dict__
    lower_keys = ["padding", "padding_mode"]
    skip_keys = ["initialization"]
    ret = {}
    for k, v in d.items():
        if k in skip_keys:
            continue
        if k in lower_keys and isinstance(v, str):
            v = v.lower()
        ret[k] = v
    return ret


def jax_to_torch_cfg(jax_cfg):
    assert isinstance(jax_cfg, cleanba_convlstm.ConvLSTMConfig)
    mlp_hiddens = jax_cfg.mlp_hiddens
    recurrent_conv = ConvConfig(**conv_args_process(jax_cfg.recurrent.conv))
    recurrent_less_conv = dict(jax_cfg.recurrent.__dict__)
    del recurrent_less_conv["conv"]
    recurrent = ConvLSTMCellConfig(recurrent_conv, **recurrent_less_conv)
    return ConvLSTMPolicyConfig(
        features_extractor=ConvLSTMOptions(
            compile=CompileConfig(),
            embed=[ConvConfig(**conv_args_process(jax_embed)) for jax_embed in jax_cfg.embed],
            recurrent=recurrent,
            n_recurrent=jax_cfg.n_recurrent,
            repeats_per_step=jax_cfg.repeats_per_step,
            pre_model_nonlin=ReLUConfig() if jax_cfg.use_relu else IdentityActConfig(),
            skip_final=jax_cfg.skip_final,
            residual=jax_cfg.residual,
        ),
        net_arch=NetArchConfig(mlp_hiddens, mlp_hiddens),
    )


def jax_get(param_name, params):
    param_name = param_name.split(".")
    ret_params = params
    for p in param_name:
        ret_params = ret_params[p]
    return ret_params


def copy_params_from_jax(torch_policy, jax_params, jax_args):
    h, w = 10, 10
    network_params = jax_params["network_params"]

    num_recurrent_layers = jax_args.net.n_recurrent
    num_embed_layers = len(jax_args.net.embed)
    is_pool_and_inject = jax_args.net.recurrent.pool_and_inject != "no"
    num_mlps = len(jax_args.net.mlp_hiddens)
    hidden_channels = jax_args.net.recurrent.conv.features

    # copy embed params
    for i in range(num_embed_layers):
        conv = torch_policy.features_extractor.pre_model[2 * i]
        conv.weight.data.copy_(th.tensor(np.asarray(jax_get(f"conv_list_{i}.kernel", network_params)).transpose(3, 2, 0, 1)))
        conv.bias.data.copy_(th.tensor(np.asarray(jax_get(f"conv_list_{i}.bias", network_params))))

    # copy recurrent conv params
    for i in range(num_recurrent_layers):
        cell_i = torch_policy.features_extractor.cell_list[i]

        for th_key, jax_key in [("conv_ih", "ih"), ("conv_hh", "hh"), ("fence_conv", "fence")]:
            conv = getattr(cell_i, th_key)
            weight = np.asarray(jax_get(f"cell_list_{i}.{jax_key}.kernel", network_params).transpose(3, 2, 0, 1))
            if jax_key == "fence":
                weight = np.sum(weight, axis=1, keepdims=True)

            conv.weight.data.copy_(th.tensor(weight))
            try:
                bias = np.asarray(jax_get(f"cell_list_{i}.{jax_key}.bias", network_params))
                conv.bias.data.copy_(th.tensor(bias))
            except KeyError:
                pass

        if is_pool_and_inject:
            weight = np.asarray(jax_get(f"cell_list_{i}.project", network_params))
            cell_i.pool_project.data.copy_(th.tensor(weight))

    # copy actor, critic params
    for i in range(num_mlps):
        mlp_weights = jax_get(f"dense_list_{i}.kernel", network_params).transpose()
        if i == 0:
            mlp_weights = th.tensor(np.asarray(mlp_weights.reshape(mlp_weights.shape[0], h, w, hidden_channels)))
            mlp_weights = mlp_weights.permute(0, 3, 1, 2).reshape(mlp_weights.shape[0], -1)
        else:
            mlp_weights = th.tensor(np.asarray(mlp_weights))
        mlp_bias = np.asarray(jax_get(f"dense_list_{i}.bias", network_params))
        getattr(torch_policy.mlp_extractor.policy_net, f"fc{i}").weight.data.copy_(mlp_weights)
        getattr(torch_policy.mlp_extractor.policy_net, f"fc{i}").bias.data.copy_(th.tensor(mlp_bias))
        getattr(torch_policy.mlp_extractor.value_net, f"fc{i}").weight.data.copy_(mlp_weights)
        getattr(torch_policy.mlp_extractor.value_net, f"fc{i}").bias.data.copy_(th.tensor(mlp_bias))

    th_keys = ["action_net", "value_net"]
    jax_keys = ["actor_params.Output", "critic_params.Output"]

    for th_key, jax_key in zip(th_keys, jax_keys):
        mlp_weights = np.asarray(jax_get(f"{jax_key}.kernel", jax_params).transpose())
        mlp_bias = np.asarray(jax_get(f"{jax_key}.bias", jax_params))
        getattr(torch_policy, th_key).weight.data.copy_(th.tensor(mlp_weights))
        getattr(torch_policy, th_key).bias.data.copy_(th.tensor(mlp_bias))


def load_jax_model_to_torch(path, env_cfg):
    env_cfg = convert_to_cleanba_config(env_cfg)
    vec_env = env_cfg.make()
    _, _, args, state, _ = cleanba_impala.load_train_state(path, env_cfg)
    cfg = jax_to_torch_cfg(args.net)
    policy_cls, kwargs = cfg.policy_and_kwargs(vec_env)  # type: ignore
    assert isinstance(policy_cls, Callable)
    action_space = vec_env.action_space
    if isinstance(action_space, MultiDiscrete):
        action_space = action_space[0]
    policy = policy_cls(
        observation_space=vec_env.single_observation_space,
        action_space=action_space,
        activation_fn=th.nn.ReLU,
        lr_schedule=lambda _: 0.0,
        normalize_images=True,
        **kwargs,
    )
    policy.eval()
    copy_params_from_jax(policy, state.params["params"], args)
    return cfg, policy


def load_policy(
    local_or_hgf_repo_path: str = "drc33/bkynosqi/cp_2002944000/",
    difficulty: str = "medium",
    split: str = "valid",
):
    LP_DIR = pathlib.Path(__file__).parent.parent.parent

    on_cluster = os.path.exists("/training")

    if on_cluster:
        BOXOBAN_CACHE = pathlib.Path("/training/.sokoban_cache/")
    else:
        BOXOBAN_CACHE = LP_DIR / "training/.sokoban_cache/"

    model_path = download_policy_from_huggingface(local_or_hgf_repo_path)
    boxo_cfg = BoxobanConfig(
        cache_path=BOXOBAN_CACHE,
        num_envs=1,
        max_episode_steps=120,
        min_episode_steps=120,
        asynchronous=False,
        tinyworld_obs=True,
        difficulty=difficulty,  # type: ignore
        split=split,  # type: ignore
    )
    cfg, policy_th = load_jax_model_to_torch(model_path, boxo_cfg)
    return cfg, policy_th


def load_probe(path: str | pathlib.Path = "", wandb_id: str = ""):
    if path != "" and wandb_id != "":
        raise ValueError("Cannot specify both probe_path and probe_wandb_id")

    if wandb_id != "":
        command = f"/training/findprobe.sh {wandb_id}"
        path = subprocess.run(command, shell=True, capture_output=True, text=True).stdout
        path = path.strip()

    if not (path and pathlib.Path(path).exists()):
        raise FileNotFoundError(f"Probe file not found at {path}")

    with open(path, "rb") as f:
        probe = pickle.load(f)
    grid_wise = probe.n_features_in_ % 100 != 0
    return probe, grid_wise


def is_probe_multioutput(probe):
    return isinstance(probe, MultiOutputClassifier)


def prepare_cache_values(
    cache: dict[str, th.Tensor],
    layer: int,
    hooks: list[str],
    step: int,
    internal_steps: bool = False,
    is_concatenated_cache: bool = False,
) -> list[list[th.Tensor]]:
    key = "features_extractor.cell_list.{layer}.{hook}.{step}.{internal_step}"
    int_steps = [0, 1, 2] if internal_steps else [2]
    if is_concatenated_cache:
        key = key.replace(".{step}.{internal_step}", "")
        cache_values = [
            th.tensor(cache[key.format(layer=layer, hook=hook)])
            for layer in (range(3) if layer == -1 else [layer])
            for hook in hooks
        ]
        cache_values = [cache_values]
    else:
        cache_values = [
            [
                cache[key.format(layer=layer, step=step, internal_step=int_step, hook=hook)]
                for layer in (range(3) if layer == -1 else [layer])
                for hook in hooks
            ]
            for int_step in int_steps
        ]
    return cache_values


def predict(cache, probe, train_on, step: int, internal_steps: bool = False, is_concatenated_cache=False):
    """Predict the probe on the activations of the policy.

    Args:
        cache (dict): Activations of the policy.
        probe (sklearn.linear_model.LogisticRegression): Probe to run on the activations of the policy.
        train_on (ProbeTrainOn): Configuration to train the probe on.
        step (int): Step at which to run the probe. In most cases, this will be 0 as we don't evaluate simultaneously
            on multiple steps. We sequentially evaluate the policy by interacting with the environment.
        internal_steps (bool): Whether to run the probe on all internal steps or just the last one. Assumes 3 internal steps.

    Returns:
        np.ndarray: Probe predictions.
    """
    cache_values = prepare_cache_values(cache, train_on.layer, train_on.hooks, step, internal_steps, is_concatenated_cache)

    assert all(
        [len(cache_value.shape) == 4 for cache_values_at_a_step in cache_values for cache_value in cache_values_at_a_step]
    )
    # assert len(cache_values.shape) == 4
    s, b, _, h, w = len(cache_values), *cache_values[0][0].shape
    if train_on.grid_wise:
        cache_values = [th.cat(cache_values_at_a_step, dim=1) for cache_values_at_a_step in cache_values]
        stack_cache_values = th.stack(cache_values, dim=0)
        stack_cache_values = stack_cache_values.permute(0, 1, 3, 4, 2)
        stack_cache_values = stack_cache_values.reshape(-1, stack_cache_values.shape[-1]).cpu()
        probe_preds = probe.predict(stack_cache_values)
        if isinstance(probe, MultiOutputClassifier):
            probe_preds = probe_preds.reshape(s, b, h, w, -1)
        else:
            probe_preds = probe_preds.reshape(s, b, h, w)
    else:
        cache_values = [
            th.cat(
                [
                    cache_value_at_a_step.reshape(cache_value_at_a_step.shape[0], -1)
                    for cache_value_at_a_step in cache_values_at_a_step
                ],
                dim=1,
            )
            for cache_values_at_a_step in cache_values
        ]
        stack_cache_values = th.stack(cache_values, dim=0)
        # stack_cache_values = stack_cache_values.reshape(*stack_cache_values.shape[:2], -1)
        stack_cache_values = stack_cache_values.reshape(-1, stack_cache_values.shape[-1]).cpu()
        probe_preds = probe.predict(stack_cache_values)
        probe_preds = probe_preds.reshape(s, b)
    if is_concatenated_cache:
        probe_preds = probe_preds.squeeze()
    return probe_preds


def process_cache_for_sae(cache_tensor, grid_wise: bool = False):
    if len(cache_tensor.shape) == 4:
        if grid_wise:
            if isinstance(cache_tensor, np.ndarray):
                return np.transpose(cache_tensor, (0, 2, 3, 1)).reshape(-1, cache_tensor.shape[1])
            else:
                return cache_tensor.permute(0, 2, 3, 1).reshape(-1, cache_tensor.shape[1])
        else:
            # TODO: check if this is correct since channels should be flattened together
            return cache_tensor.reshape(cache_tensor.shape[0], -1)
    elif len(cache_tensor.shape) == 5:
        if grid_wise:
            if isinstance(cache_tensor, np.ndarray):
                return np.transpose(cache_tensor, (0, 1, 3, 4, 2)).reshape(-1, cache_tensor.shape[2])
            else:
                return cache_tensor.permute(0, 1, 3, 4, 2).reshape(-1, cache_tensor.shape[2])
        else:
            # TODO: check if this is correct
            return cache_tensor.reshape(cache_tensor.shape[0] * cache_tensor.shape[1], -1)


def encode_with_sae(
    sae,
    cache,
    internal_steps=False,
    decode=False,
    is_concatenated_cache=False,
) -> Union[th.Tensor, Tuple[th.Tensor, th.Tensor]]:
    if is_concatenated_cache:
        original_act = cache[sae.cfg.hook_name]
        assert len(original_act.shape) == 4
        initial_dims = original_act.shape[:1]
        processed_act = process_cache_for_sae(original_act, grid_wise=sae.cfg.grid_wise)
        if isinstance(processed_act, np.ndarray):
            processed_act = th.tensor(processed_act)
    else:
        int_steps = [0, 1, 2] if internal_steps else [2]
        original_act = th.stack([cache[sae.cfg.hook_name + f".0.{i}"] for i in int_steps])
        initial_dims = original_act.shape[:2]
        processed_act = process_cache_for_sae(original_act, grid_wise=sae.cfg.grid_wise)

    sae_feature_activations = sae.encode(processed_act.to(sae.device))  # type: ignore
    sae_feature_reshaped = sae_feature_activations.reshape(*initial_dims, 10, 10, -1)
    if decode:
        sae_out = sae.decode(sae_feature_activations).to(original_act.device)
        return sae_feature_reshaped, sae_out.reshape(*initial_dims, 10, 10, -1)
    return sae_feature_reshaped


@dataclass
class PlayLevelOutput:
    obs: th.Tensor
    acts: th.Tensor
    logits: th.Tensor
    rewards: th.Tensor
    lengths: th.Tensor
    solved: th.Tensor
    cache: dict[str, th.Tensor]
    probe_outs: Optional[list[np.ndarray]] = None
    sae_outs: Optional[th.Tensor] = None


def play_level(
    env,
    policy_th,
    reset_opts={},
    probes=[],
    probe_train_ons=[],
    sae=None,
    thinking_steps=0,
    max_steps=120,
    internal_steps=False,
    fwd_hooks=None,
    hook_steps=-1,
    names_filter=None,
):
    """Execute the policy on the environment and the probe on the policy's activations.

    Args:
        env (gymnasium.Env): Environment to play the level in.
        policy_th (torch.nn.Module): Policy to play the level with.
        reset_opts (dict): Options to reset the environment with. Useful for custom-built levels
            or providing the `level_file_idx` and `level_idx` of a level in Boxoban.
        probes (list[sklearn.linear_model.LogisticRegression]): Probes to run on the activations of the policy.
        probe_train_ons (list[ProbeTrainOn]): Correponding configuration of the probe.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: Observations, actions, rewards, all probe outputs.
    """
    assert len(probe_train_ons) == len(probes)
    try:
        start_obs = env.reset(options=reset_opts)[0]
    except:  # noqa
        start_obs = env.reset()[0]
    device = policy_th.device
    start_obs = th.tensor(start_obs, device=device)
    all_obs = [start_obs]
    all_acts = []
    all_logits = []
    all_rewards = []
    all_cache = []
    all_sae_outs = []

    all_probe_outs = [[] for _ in probes]
    N = start_obs.shape[0]
    eps_done = th.zeros(N, dtype=th.bool)
    eps_solved = th.zeros(N, dtype=th.bool)
    episode_lengths = th.zeros(N, dtype=th.int32)

    state = policy_th.recurrent_initial_state(N, device=device)
    obs = start_obs
    r, d, t = [0.0], np.array([False]), np.array([False])
    for i in range(max_steps):
        (distribution, state), cache = run_fn_with_cache(
            policy_th,
            "get_distribution",
            obs,
            state,
            th.tensor([0.0] * N, dtype=th.bool, device=device),
            # return_repeats=False,
            fwd_hooks=fwd_hooks if (hook_steps == -1) or (i in hook_steps) else None,
            names_filter=names_filter,
        )
        best_act = distribution.get_actions(deterministic=True)
        all_acts.append(best_act)
        all_logits.append(distribution.distribution.logits.detach())
        if i >= thinking_steps:
            obs, r, d, t, _ = env.step(best_act.cpu().numpy())
            obs = th.tensor(obs, device=device)
            episode_lengths[~d] += 1
            eps_done |= d | t
            eps_solved |= d
            all_rewards.append(r)
        all_obs.append(obs)
        all_cache.append(cache)

        for pidx, (probe, probe_train_on) in enumerate(zip(probes, probe_train_ons)):
            probe_out = predict(cache, probe, probe_train_on, step=0, internal_steps=internal_steps)
            if N == 1:
                probe_out = probe_out.squeeze(1)
            if not internal_steps:
                probe_out = probe_out.squeeze(0)
            all_probe_outs[pidx].append(probe_out)
        if sae:
            sae_acts = encode_with_sae(sae, cache, internal_steps=internal_steps)
            all_sae_outs.append(sae_acts.squeeze(1) if internal_steps else sae_acts.squeeze(0).squeeze(0))  # type: ignore
        if eps_done.all().item():
            break
    return PlayLevelOutput(
        obs=th.stack(all_obs[:-1]),
        acts=th.stack(all_acts),
        logits=th.stack(all_logits),
        rewards=th.tensor(np.array(all_rewards)),
        lengths=episode_lengths,
        solved=eps_solved,
        cache={k: th.stack([cache[k] for cache in all_cache]) for k in all_cache[0].keys()},
        probe_outs=[np.stack(probe_out) for probe_out in all_probe_outs],
        sae_outs=th.stack(all_sae_outs) if sae else None,
    )


def run_fn_with_cache(
    hooked_model,
    fn_name: str,
    *model_args,
    names_filter=None,
    device=None,
    remove_batch_dim=False,
    incl_bwd=False,
    reset_hooks_end=True,
    clear_contexts=False,
    pos_slice=None,
    **model_kwargs,
):
    """Combines the run_with_cache functions from MambaLens and TransformerLens to run arbitrary functions
    with cache."""
    model_kwargs = dict(list(model_kwargs.items()))
    fwd_hooks = None
    if "fwd_hooks" in model_kwargs:
        fwd_hooks = model_kwargs["fwd_hooks"]
        del model_kwargs["fwd_hooks"]
    bwd_hooks = None
    if "bwd_hooks" in model_kwargs:
        bwd_hooks = model_kwargs["bwd_hooks"]
        del model_kwargs["bwd_hooks"]
    # need to wrap run_with_cache to setup input_dependent hooks
    setup_all_input_hooks = False

    # turn names_filter into a fwd_hooks for setup input dependent hooks stuff
    if names_filter is None:
        setup_all_input_hooks = True
    else:
        name_fake_hooks = [(name, None) for name in names_filter]
        if fwd_hooks is None:
            fwd_hooks = name_fake_hooks
        else:
            fwd_hooks = fwd_hooks + name_fake_hooks

    with hooked_model.input_dependent_hooks_context(
        *model_args, fwd_hooks=fwd_hooks, bwd_hooks=bwd_hooks, setup_all_input_hooks=setup_all_input_hooks, **model_kwargs
    ):
        fwd_hooks = [(name, hook) for name, hook in (fwd_hooks if fwd_hooks else []) if hook is not None]
        bwd_hooks = bwd_hooks if bwd_hooks else []
        with hooked_model.hooks(fwd_hooks, bwd_hooks, reset_hooks_end, clear_contexts) as hooked_hooked_model:
            pos_slice = Slice.unwrap(pos_slice)

            cache_dict, fwd, bwd = hooked_hooked_model.get_caching_hooks(
                names_filter,
                incl_bwd,
                device,
                remove_batch_dim=remove_batch_dim,
                pos_slice=pos_slice,
            )

            with hooked_hooked_model.hooks(
                fwd_hooks=fwd,
                bwd_hooks=bwd,
                reset_hooks_end=reset_hooks_end,
                clear_contexts=clear_contexts,
            ):
                if fn_name:
                    model_out = getattr(hooked_hooked_model, fn_name)(*model_args, **model_kwargs)
                else:
                    model_out = hooked_hooked_model(*model_args, **model_kwargs)
                if incl_bwd:
                    model_out.backward()

    return model_out, cache_dict


def get_metrics(preds: np.ndarray, labels: np.ndarray, classification: bool, key_prefix: str = ""):
    if classification:
        try:
            negative_label = labels.min()
        except ValueError:
            return {}
        preds = preds[: len(labels)]
        assert len(preds.shape) == len(labels.shape), f"{preds.shape} != {labels.shape}"
        acc = (preds == labels).mean()
        prec = (preds[preds != negative_label] == labels[preds != negative_label]).mean()
        rec = (preds[labels != negative_label] == labels[labels != negative_label]).mean()
        f1 = 2 * prec * rec / (prec + rec)
        metrics = {"accuracy": acc, "precision": prec, "recall": rec, "f1": f1}
    else:
        loss = th.nn.functional.mse_loss(th.tensor(preds), th.tensor(labels))
        metrics = {"loss": loss.item()}
    if key_prefix:
        metrics = {f"{key_prefix}/{k}": v for k, v in metrics.items()}
    return metrics


def get_player_pos(obs):
    """Get the position of the player in the observation using the pixel values of the player.


    Args:
        obs (np.ndarray): Observation of the level of shape (3, 10, 10).

    Returns:
        Tuple[int, int]: Position of the player in the observation.
    """
    # assert isinstance(obs, np.ndarray) and obs.shape == (3, 10, 10)
    assert isinstance(obs, np.ndarray) and obs.shape == (10, 10, 3)
    # agent_pos = np.where(((obs[0] == 160) | (obs[0] == 219)) & (obs[1] == 212) & (obs[2] == 56))
    agent_pos = np.where(((obs[..., 0] == 160) | (obs[..., 0] == 219)) & (obs[..., 1] == 212) & (obs[..., 2] == 56))
    assert len(agent_pos[0]) == 1
    return agent_pos[0][0], agent_pos[1][0]


def plt_obs_with_cycle_probe(
    obs,
    probe_pred_prev_timestep,
    probe_pred_curr_timestep,
    gt_curr_timestep,
    last_player_pos,
    show_dot: bool,  # this will be true on internal steps and the first external step
    ax,
):
    """Helper function to plot the level image with the cycle probe predictions."""
    if not probe_pred_curr_timestep:
        return "", last_player_pos
    title_prefix = " | In Cycle"
    player_pos = get_player_pos(obs)
    player_pos = (player_pos[1], player_pos[0])
    if gt_curr_timestep is None:
        color = "blue"
    else:
        color = "green" if probe_pred_curr_timestep == gt_curr_timestep else "red"
    if show_dot or (not probe_pred_prev_timestep) or last_player_pos == player_pos:
        ax.plot(*player_pos, color, marker="o")
    else:
        ax.plot(
            [last_player_pos[0], player_pos[0]],  # type: ignore
            [last_player_pos[1], player_pos[1]],  # type: ignore
            color=color,
            linewidth=2,
        )
    return title_prefix, player_pos


def plt_obs_with_position_probe(probe_preds, gt_labels, ax, marker="s", s=200, heatmap_color_range=None):
    """Helper function to plot the level image with the position probe predictions."""
    if heatmap_color_range is not None:
        if isinstance(ax, matplotlib.axes._axes.Axes):  # type: ignore
            return ax.imshow(probe_preds, cmap="viridis", vmin=heatmap_color_range[0], vmax=heatmap_color_range[1])
        else:
            ax.set_data(probe_preds)
            return None
    positives = np.where(probe_preds == 1)
    if gt_labels is None:
        ax.scatter(positives[1], positives[0], color="blue", marker=marker, s=s)
    else:
        gt_positives = gt_labels[positives] == 1
        ax.scatter(positives[1][gt_positives], positives[0][gt_positives], color="green", marker=marker, s=s)
        ax.scatter(positives[1][~gt_positives], positives[0][~gt_positives], color="red", marker=marker, s=s)


def plt_obs_with_direction_probe(probe_preds, gt_labels, ax, color_scheme=["red", "green", "blue"], vector=False):
    """Helper function to plot the level image with the direction probe predictions."""
    if probe_preds.ndim == 2:
        directions_i, directions_j = np.where(probe_preds != -1)
        for i, j in zip(directions_i, directions_j):
            pred_direction_idx = probe_preds[i, j]
            delta_i, delta_j = CHANGE_COORDINATES[pred_direction_idx]
            color_idx = 2 if gt_labels is None else (gt_labels[i, j] == pred_direction_idx).astype(int)
            color = color_scheme[color_idx]
            if vector:
                i += 0.5
                j += 0.5
                ax.arrow(j, 10 - i, delta_j, -delta_i, color=color, head_width=0.2, head_length=0.2)
            else:
                ax.arrow(j, i, delta_j, delta_i, color=color, head_width=0.2, head_length=0.2)
    elif probe_preds.ndim == 3:  # multioutput
        assert probe_preds.shape[2] == 4
        grid = np.arange(10, dtype=float)
        if vector:
            grid += 0.5
            probe_preds = probe_preds[::-1]
            gt_labels = gt_labels[::-1]
        for dir_idx in range(4):
            probe_preds_dir = probe_preds[..., dir_idx]
            gt_labels_dir = gt_labels[..., dir_idx] if gt_labels is not None else None
            delta_i, delta_j = CHANGE_COORDINATES[dir_idx]
            if gt_labels_dir is None:
                color_args, cmap = (), None
            else:
                cmap = ListedColormap(color_scheme[:2])
                color_args = [(gt_labels_dir == probe_preds_dir).astype(int)]
                color_args[0][0, 0] = 0  # to avoid color collapse when preds are correct

            ax.quiver(
                grid,
                grid,
                delta_j * probe_preds_dir,
                -delta_i * probe_preds_dir,
                *color_args,
                cmap=cmap,  # only used when color_args is not empty
                color=color_scheme[2],  # only used when color_args is empty
                scale_units="xy",
                scale=1,
                minshaft=1,
                minlength=0,
            )
    else:
        raise ValueError("probe_preds should be 2D or 3D.")


def plt_obs_with_box_labels(the_labels, ax):
    """Plot the box label as B0 to B3 at the top left of the square."""
    location_i, location_j = np.where(the_labels != -1)

    unique_locations = set(zip(location_i, location_j))
    assert len(unique_locations) == 4, f"Expected 4 unique box label locations, but found {len(unique_locations)}"

    for i, j in unique_locations:
        the_label = the_labels[i, j]
        ax.text(j, i, f"B{the_label}", fontsize=10, color="black", ha="left", va="top")


def plt_obs_with_target_labels(the_labels, ax):
    """Plot the target label as T0 to T3 at the bottom left of the square."""
    location_i, location_j = np.where(the_labels != -1)

    unique_locations = set(zip(location_i, location_j))
    assert len(unique_locations) == 4, f"Expected 4 unique target label locations, but found {len(unique_locations)}"

    for i, j in unique_locations:
        the_label = the_labels[i, j]
        ax.text(j, i, f"T{the_label}", fontsize=10, color="black", ha="left", va="bottom")


last_player_pos = None


def save_video(
    filename,
    all_obs,
    all_probes_preds=[],
    all_gt_labels=[],
    all_probe_infos=[],
    overlapped=False,
    show_internal_steps_until=0,
    sae_feature_offset=0,
    base_dir="videos",
    box_labels=None,
    target_labels=None,
    remove_ticks=True,
    truncate_title=-1,
):
    """Save the video of the level given by all_obs. Video will be saved in the folder videos_{probe_type}.

    Args:
        filename (str): Name of the video file (with extension).
        all_obs (np.ndarray): observations of the level of shape (num_steps, 3, 10, 10).
        all_probes_preds (Optional[list[np.ndarray]]): list of predictions from multiple probes.
            The np arrays can be of the shape (num_steps,) or (num_steps, 10, 10) depending on the `probe_type`.
            Default is None.
        all_gt_labels (list[np.ndarray]): list of ground truth labels for the probes.
        all_probe_infos (list[ProbeTrainOn]): list of ProbeTrainOn.
        overlapped (bool): Whether to plot the probes on the same image or side-by-side.
        show_internal_steps_until (int): Number of internal steps to show. Default is 0.
        box_labels (np.ndarray): labels of the boxes in the level of shape (num_steps, 10, 10).
        target_labels (np.ndarray): labels of the targets of shape (10, 10).
    """
    if all_probe_infos:
        assert len(all_probes_preds) == len(all_probe_infos)
    max_len = len(all_obs)
    if all_gt_labels:
        assert len(all_gt_labels) == len(all_probes_preds)
    repeats_per_step = all_probes_preds[0].shape[1] if show_internal_steps_until else 1
    # if show_internal_steps_until:
    #     repeats_per_step = all_probes_preds[0].shape[1]
    #     assert all(preds.shape[1] == repeats_per_step for preds in all_probes_preds)
    #     assert all(preds.shape in [(max_len, repeats_per_step, 10, 10), (max_len,)] for preds in all_probes_preds)
    # else:
    #     assert all(preds.shape in [(max_len, 10, 10), (max_len,)] for preds in all_probes_preds)
    global last_player_pos
    last_player_pos = None
    os.makedirs(base_dir, exist_ok=True)
    title_prefix = ""

    if all_probes_preds is not None:
        try:
            cycle_probe_idx = [info.probe_type for info in all_probe_infos].index("cycle")
            no_cycle = not np.any(all_probes_preds[cycle_probe_idx])
            if no_cycle:
                title_prefix = " | No Cycle"
                filename = filename.replace(".mp4", "_no_cycle.mp4")
        except ValueError:
            pass
    total_subplots = len(all_probes_preds)
    if overlapped or len(all_probes_preds) <= 1:
        fig, axs = plt.subplots(1, 1, figsize=(4, 4))
        axs = [axs]
    else:
        total_subplots += 0 if all_probe_infos else 1
        rows, cols = np.ceil(total_subplots / 4).astype(int), min(4, total_subplots)
        fig, axs = plt.subplots(rows, cols, figsize=(2 * cols + 1, 2 * rows + 1))
        plt.subplots_adjust(left=0.05, top=0.9, right=0.95, bottom=0.05, hspace=0.5, wspace=0.5)  # manually fine-tuned
        axs = axs.flatten()
    if remove_ticks:
        [ax.axis("off") for ax in axs]

    max_fig_dim = max(fig.get_figwidth(), fig.get_figheight())
    heatmap_color_range = None
    if not all_probe_infos and len(all_probes_preds) != 0:  # sae
        heatmap_color_range = (all_probes_preds.min(), all_probes_preds.max())
        norm = plt.Normalize(vmin=heatmap_color_range[0], vmax=heatmap_color_range[1])
        fig.colorbar(cm.ScalarMappable(cmap="viridis", norm=norm), ax=axs)

    all_obs = np.transpose(all_obs, (0, 2, 3, 1))
    total_internal_steps = repeats_per_step * show_internal_steps_until
    total_frames = total_internal_steps + max_len - show_internal_steps_until + 1

    def update_frame(i, title_prefix=title_prefix):
        global last_player_pos
        if i == total_frames - 1:
            if all_gt_labels:
                all_metrics = {}
                for pidx, probe_preds in enumerate(all_probes_preds):
                    probe_preds_external = probe_preds[:, repeats_per_step - 1] if show_internal_steps_until else probe_preds
                    probe_metrics = get_metrics(probe_preds_external, all_gt_labels[pidx], classification=True)  # type: ignore
                    prefix = all_probe_infos[pidx].dataset_name + "/" if total_subplots > 1 else ""
                    probe_metrics = {f"{prefix}{k}": v for k, v in probe_metrics.items()}
                    all_metrics.update(probe_metrics)
                plt.axes().clear()
                plt.text(0.1, 0.1, "\n".join([f"{k}: {v:.2f}" for k, v in all_metrics.items()]))
            else:
                print("No GT labels provided.")
                plt.text(0.1, 0.1, "No GT labels provided.")
            return
        if i < total_internal_steps:
            obs_idx = i // repeats_per_step
            probe_idx = (obs_idx, i % repeats_per_step)
        else:
            obs_idx = show_internal_steps_until + i - total_internal_steps
            # probe_idx = repeats_per_step * (obs_idx + 1) - 1 if show_internal_steps_until else obs_idx
            probe_idx = (obs_idx, repeats_per_step - 1) if show_internal_steps_until else obs_idx
        obs = all_obs[obs_idx]
        if len(all_probes_preds) == 0:
            axs[0].clear()
            axs[0].imshow(obs)
        for pidx, probe_preds in enumerate(all_probes_preds):
            ax = axs[pidx]
            if not all_probe_infos and len(all_probes_preds) != 0:  # sae
                if pidx == 0:
                    ax.clear()
                    ax.imshow(obs)
                    ax.set_title("Observation")
                ax = axs[pidx + 1]
                ax.clear()
            elif (not overlapped) or (pidx == 0):
                ax.clear()
                ax.imshow(obs)

            if (not overlapped) and len(all_probes_preds) > 1:
                if all_probe_infos:
                    title = all_probe_infos[pidx].dataset_name
                    if truncate_title > 0:
                        title = title[:truncate_title] + ("..." if len(title) > truncate_title else "")
                    ax.set_title(title)
                elif len(all_probes_preds) != 0:  # sae
                    ax.set_title(f"Feature {sae_feature_offset + pidx}")
            probe_out = probe_preds[probe_idx]
            try:
                gt_label = all_gt_labels[pidx][obs_idx]
            except IndexError:
                gt_label = None

            if not all_probe_infos:
                plt_obs_with_position_probe(probe_out, gt_label, ax, heatmap_color_range=heatmap_color_range)  # sae
            elif "cycle" == all_probe_infos[pidx].probe_type:
                title_prefix, last_player_pos = plt_obs_with_cycle_probe(
                    obs,
                    probe_preds[(probe_idx[0] - 1, probe_idx[1]) if show_internal_steps_until else probe_idx - 1],  # type: ignore
                    probe_preds[probe_idx],
                    gt_label,
                    last_player_pos,
                    show_dot=(obs_idx == 0) or (show_internal_steps_until > 0 and probe_idx[1] < repeats_per_step - 1),
                    ax=ax,
                )
            elif "position" == all_probe_infos[pidx].probe_type:
                plt_obs_with_position_probe(probe_out, gt_label, ax)
            elif "direction" in all_probe_infos[pidx].probe_type:
                plt_obs_with_direction_probe(probe_out, gt_label, ax, all_probe_infos[pidx].color_scheme)
            else:
                raise ValueError(f"Unknown probe type: {all_probe_infos[pidx].probe_type}")

            # Draw box and target labels. Aids colorblind people
            if box_labels is not None:
                plt_obs_with_box_labels(box_labels[i], ax)
            if target_labels is not None:
                plt_obs_with_target_labels(target_labels, ax)

        internal_step_suffix = ": Internal Step " + str(i % repeats_per_step) if i < total_internal_steps else ""
        plt.suptitle(f"Step {obs_idx}{internal_step_suffix}" + title_prefix, y=0.99)
        return (fig,)

    anim = animation.FuncAnimation(
        fig,
        update_frame,  # type: ignore
        save_count=total_frames,
        repeat=False,
    )
    dpi = np.ceil(720 / max_fig_dim).astype(int)
    dpi = dpi if dpi % 2 == 0 else dpi + 1
    assert anim is not None
    full_path = os.path.join(base_dir, filename)
    os.makedirs(os.path.dirname(full_path), exist_ok=True)
    t0 = time.time()
    anim.save(full_path, fps=2, writer="ffmpeg")
    print(f"Saved video to {full_path} in {time.time() - t0:.2f} seconds.")
    return full_path


def save_video_sae(
    filename,
    all_obs,
    all_probes_preds=[],
    show_internal_steps_until=0,
    sae_feature_indices: int | list[int] = 0,
    base_dir="videos",
    box_labels=None,
    target_labels=None,
):
    """Save the video of the level given by all_obs. Video will be saved in the folder videos_{probe_type}.

    Args:
        filename (str): Name of the video file (with extension).
        all_obs (np.ndarray): observations of the level of shape (num_steps, 3, 10, 10).
        all_probes_preds (Optional[list[np.ndarray]]): list of predictions from multiple probes.
            The np arrays can be of the shape (num_steps,) or (num_steps, 10, 10) depending on the `probe_type`.
            Default is None.
        show_internal_steps_until (int): Number of internal steps to show. Default is 0.
        box_labels (np.ndarray): labels of the boxes in the level of shape (num_steps, 10, 10).
        target_labels (np.ndarray): labels of the targets of shape (10, 10).
    """
    plt.rcParams.update({"font.size": 18})
    max_len = len(all_obs)
    repeats_per_step = all_probes_preds[0].shape[1] if show_internal_steps_until else 1
    global last_player_pos
    last_player_pos = None
    os.makedirs(base_dir, exist_ok=True)
    title_prefix = ""

    total_subplots = len(all_probes_preds) + 1
    rows, cols = np.ceil(total_subplots / 4).astype(int), min(4, total_subplots)
    figsize = (2 * cols + 1, 2 * rows + 1)
    max_fig_dim = max(figsize)
    dpi = np.ceil(720 / max_fig_dim).astype(int)
    dpi = dpi if dpi % 2 == 0 else dpi + 1
    fig, axs = plt.subplots(rows, cols, figsize=figsize)
    plt.subplots_adjust(left=0.05, top=0.9, right=1.05, bottom=0.05, hspace=0.5, wspace=0.5)  # manually fine-tuned
    try:
        axs = axs.flatten()
    except AttributeError:
        axs = [axs]

    heatmap_color_range = (all_probes_preds.min(), all_probes_preds.max())
    norm = plt.Normalize(vmin=heatmap_color_range[0], vmax=heatmap_color_range[1])
    fig.colorbar(cm.ScalarMappable(cmap="viridis", norm=norm), ax=axs)
    all_obs = np.transpose(all_obs, (0, 2, 3, 1))
    imshow_outs = [axs[0].imshow(all_obs[0])]
    imshow_outs += [
        plt_obs_with_position_probe(all_probes_preds[i, 0, 0], None, ax, heatmap_color_range=heatmap_color_range)
        for i, ax in enumerate(axs[1:])
    ]

    total_internal_steps = repeats_per_step * show_internal_steps_until
    total_frames = total_internal_steps + max_len - show_internal_steps_until

    def ft_idx(idx):
        return sae_feature_indices + idx - 1 if isinstance(sae_feature_indices, int) else sae_feature_indices[idx - 1]

    [ax.set_title("Observation" if i == 0 else f"F{ft_idx(i)}") for i, ax in enumerate(axs)]
    imshow_outs.append(plt.suptitle("", fontsize=18, y=0.99))

    def update_frame(i, title_prefix=title_prefix):
        global last_player_pos
        if i < total_internal_steps:
            obs_idx = i // repeats_per_step
            probe_idx = (obs_idx, i % repeats_per_step)
        else:
            obs_idx = show_internal_steps_until + i - total_internal_steps
            # probe_idx = repeats_per_step * (obs_idx + 1) - 1 if show_internal_steps_until else obs_idx
            probe_idx = (obs_idx, repeats_per_step - 1) if show_internal_steps_until else obs_idx
        obs = all_obs[obs_idx]
        imshow_outs[0].set_data(obs)
        for pidx, probe_preds in enumerate(all_probes_preds):
            probe_out = probe_preds[probe_idx]
            plt_obs_with_position_probe(probe_out, None, imshow_outs[pidx + 1], heatmap_color_range=heatmap_color_range)  # sae
        internal_step_suffix = ": Internal Step " + str(i % repeats_per_step) if i < total_internal_steps else ""
        imshow_outs[-1].set_text(f"Step {obs_idx}{internal_step_suffix}" + title_prefix)

        return imshow_outs

    anim = animation.FuncAnimation(
        fig,
        update_frame,  # type: ignore
        save_count=total_frames,
        repeat=False,
    )

    assert anim is not None
    full_path = os.path.join(base_dir, filename)
    os.makedirs(os.path.dirname(full_path), exist_ok=True)
    t0 = time.time()
    anim.save(full_path, fps=2, writer="ffmpeg")
    print(f"Saved video to {full_path} in {time.time() - t0:.2f} seconds.")
    return full_path


def plotly_feature_vis(feature_acts, obs, feature_labels):
    """Feature activations visualized with observations along with time slider.

    Args:
        feature_acts (np.ndarray): Activations of top features. Shape: (time, num_features, height, width).
        obs (np.ndarray): Observations. Shape: (time, channels, height, width).
        feature_labels (list[str]): Labels for the features. Shape: (num_features,).
    """
    cmap = plt.get_cmap("viridis")
    normed = (feature_acts - feature_acts.min()) / (feature_acts.max() - feature_acts.min())

    repeated_obs = np.transpose(obs, (0, 2, 3, 1))[:, None, :, :, :]
    to_plot = np.concatenate([repeated_obs[: len(normed)], cmap(normed)[..., :3] * 255], axis=1)
    labels = ["Observation"] + feature_labels
    fig = px.imshow(
        to_plot[:, :],
        facet_col=1,
        animation_frame=0,
        facet_col_wrap=8,
        binary_string=True,
    )
    fig.for_each_annotation(lambda a: a.update(text=labels[int(a.text.split("=")[-1])]))
    fig.show()
