"""Runs the network by replacing the activations of a channel with another channel's activations.

This is used to check if the activations of a channel can be replaced with another channel's activations
from the same group without affecting the output of the network.
"""
# %%
import dataclasses
import re
from copy import deepcopy
from functools import partial
from pathlib import Path

import numpy as np
import plotly.express as px
import plotly.io as pio
import torch as th
from cleanba.environments import BoxobanConfig
from plotly.subplots import make_subplots
from transformer_lens.hook_points import HookPoint

from learned_planners import BOXOBAN_CACHE
from learned_planners.interp.channel_group import get_group_channels
from learned_planners.interp.collect_dataset import join_cache_across_steps
from learned_planners.interp.offset_fns import INV_CHANNEL_OFFSET_FNS
from learned_planners.interp.plot import plotly_feature_vis
from learned_planners.interp.train_probes import TrainOn
from learned_planners.interp.utils import get_cache_and_probs, load_jax_model_to_torch, load_probe, play_level
from learned_planners.notebooks.emacs_plotly_render import set_plotly_renderer
from learned_planners.policies import download_policy_from_huggingface

set_plotly_renderer("emacs")
pio.renderers.default = "notebook"
th.set_printoptions(sci_mode=False, precision=2)


# %%
MODEL_PATH_IN_REPO = "drc33/bkynosqi/cp_2002944000"  # DRC(3, 3) 2B checkpoint
MODEL_PATH = download_policy_from_huggingface(MODEL_PATH_IN_REPO)

boxes_direction_probe_file = Path(
    "/training/TrainProbeConfig/05-probe-boxes-future-direction/wandb/run-20240813_184417-vb6474rg/local-files/probe_l-all_x-all_y-all_c-all.pkl"
)

seq_len = 30
boxo_cfg = BoxobanConfig(
    cache_path=BOXOBAN_CACHE,
    num_envs=512,
    max_episode_steps=120,
    min_episode_steps=120,
    asynchronous=False,
    tinyworld_obs=True,
    split=None,
    difficulty="hard",
    # split="valid",
    # difficulty="unfiltered",
)
model_cfg, model = load_jax_model_to_torch(MODEL_PATH, boxo_cfg)

orig_state_dict = deepcopy(model.state_dict())


def restore_model():
    model.load_state_dict(orig_state_dict)


probe, _ = load_probe("probes/best/boxes_future_direction_map_l-all.pkl")
probe_info = TrainOn(dataset_name="boxes_future_direction_map")
# %%


def run_policy_reset(num_steps, envs, policy):
    new_obs, _ = envs.reset()
    obs = [new_obs]
    carry = policy.recurrent_initial_state(envs.num_envs)
    all_false = th.zeros(envs.num_envs, dtype=th.bool)
    eps_done = np.zeros(envs.num_envs, dtype=bool)
    eps_solved = np.zeros(envs.num_envs, dtype=bool)
    episode_lengths = np.zeros(envs.num_envs, dtype=np.int32)
    for _ in range(num_steps - 1):
        action, _value, something, carry = policy(th.as_tensor(new_obs), carry, all_false)
        new_obs, _, term, trunc, _ = envs.step(action.detach().cpu().numpy())
        # assert not (np.any(term) or np.any(trunc))
        obs.append(new_obs)
        eps_done |= term | trunc
        episode_lengths[~eps_solved] += 1
        eps_solved |= term
        if np.all(eps_done):
            break
    episode_lengths = th.as_tensor(episode_lengths)
    obs = th.as_tensor(np.stack(obs))[:, episode_lengths == num_steps - 1]
    print("Number of envs:", obs.shape[1])
    return obs, episode_lengths


envs = dataclasses.replace(boxo_cfg, num_envs=512).make()


# %%
clean_obs, clean_length = run_policy_reset(seq_len, envs, model)
num_eps = clean_obs.shape[1]
zero_carry = model.recurrent_initial_state(num_eps)
eps_start = th.zeros((seq_len, num_eps), dtype=th.bool)
eps_start[0, :] = True

# joined = False
# (clean_actions, clean_values, clean_log_probs, _), clean_cache = model.run_with_cache(clean_obs, zero_carry, eps_start)

joined = True
clean_cache, clean_probs = get_cache_and_probs(clean_obs, model)

# %%


def channel_replacement_hook(inputs: th.Tensor, hook: HookPoint, from_channel, to_channel, coefs):
    inputs[:, to_channel] = inputs[:, from_channel] * th.tensor(coefs, device=inputs.device)[None, :, None, None]
    return inputs


full_seq = False

fwd_hooks = [
    (
        f"features_extractor.cell_list.0.hook_h.{pos}.{tick}",
        partial(channel_replacement_hook, from_channel=2, to_channel=28, coefs=[-1]),
    )
    for pos in range(seq_len if full_seq else 1)
    for tick in range(3)
]
fwd_hooks += [
    (
        f"features_extractor.cell_list.1.hook_h.{pos}.{tick}",
        partial(channel_replacement_hook, from_channel=17, to_channel=19, coefs=[0.6]),
    )
    for pos in range(seq_len if full_seq else 1)
    for tick in range(3)
]
# %%
patched_cache, patched_probs = get_cache_and_probs(clean_obs, model, fwd_hooks)

# %%
acc = (clean_probs.argmax(-1) == patched_probs.argmax(-1)).float().mean()
print(f"Accuracy: {acc.item() * 100:.2f}%")

eps_wise_acc = (clean_probs.argmax(-1) == patched_probs.argmax(-1)).float().mean(0)
argsort_acc = th.argsort(eps_wise_acc)
print(f"Bottom 10 acc: {eps_wise_acc[argsort_acc[:10]]}")
print(f"Bottom 10 envs: {argsort_acc[:10]}")

# kl = th.nn.functional.kl_div(th.nn.functional.log_softmax(patched_probs, dim=-1), clean_probs, reduction="batchmean")
# print(f"KL Divergence: {kl.item()}")

# %%
num_envs = 256
envs = dataclasses.replace(boxo_cfg, num_envs=num_envs).make()
clean_play = play_level(envs, model, fwd_hooks=None)

print("Processed clean.")

envs = dataclasses.replace(boxo_cfg, num_envs=num_envs).make()
patched_play = play_level(envs, model, fwd_hooks=fwd_hooks)

assert all([(v == patched_play.info[k]).all() for k, v in clean_play.info.items()]), "Different levels played"

print(f"Clean solved: {clean_play.solved.sum().item()}")
print(f"Patched solved: {patched_play.solved.sum().item()}")

# %%
# levels where clean solved but patched didn't
patch_unsolved = clean_play.solved & ~patched_play.solved
print(f"Num solve->unsolved: {patch_unsolved.sum().item()}")
print(f"Num unsolved->solved: {(~clean_play.solved & patched_play.solved).sum().item()}")

# %% Play level


# %% Visualize all layer and channels
env_idx = 1
toy_cache = join_cache_across_steps([{k: v[env_idx] for k, v in clean_cache.items()}])
toy_obs = np.repeat(clean_obs[:, env_idx], 3, axis=0)
for k, v in toy_cache.items():
    if m := re.match("^.*hook_([h])$", k):
        fig = plotly_feature_vis(v, toy_obs, k)
        fig.update_layout(height=800)
        fig.show()

# %%
# %%
start_pos = 0
last_tick = False
clean = True

grouped_channels = get_group_channels("box")
move_dir = 1
channel_titles = [f"L{layer}H{channel}" for layer, channel in grouped_channels[move_dir]]
move_dir_title = ["UP", "DOWN", "LEFT", "RIGHT"][move_dir]
grp_size = len(grouped_channels[move_dir])

activations_by_group = []
for group in grouped_channels:
    activations_by_group.append([])
    for layer, channel in group:
        # Stack all positions and internal positions for this channel
        if joined:
            acts = np.stack(
                [
                    (clean_cache if clean else patched_cache)[f"features_extractor.cell_list.{layer}.hook_h"][
                        slice(2, None, 3) if last_tick else slice(None), :, channel
                    ]
                ],
                axis=0,
            )
        else:
            acts = th.stack(
                [
                    clean_cache[f"features_extractor.cell_list.{layer}.hook_h.{pos}.{int_pos}"][:, channel]
                    for pos in range(start_pos, seq_len)
                    # for int_pos in range(3)
                    for int_pos in ([2] if last_tick else range(3))
                ],
                dim=0,
            ).numpy()
        acts = INV_CHANNEL_OFFSET_FNS[layer][channel](acts, last_dim_grid=True)
        activations_by_group[-1].append(acts.flatten())

# Convert to numpy array and compute correlation matrix in one step
down_activations = np.array(activations_by_group[move_dir])
correlation_matrix = np.abs(np.corrcoef(down_activations))
# np.fill_diagonal(correlation_matrix, 0)


if clean:
    clean_down_activations = down_activations
else:
    patched_down_activations = down_activations


# %%
correlation_matrix[np.triu_indices(correlation_matrix.shape[0], k=1)] = 0

fig = px.imshow(
    np.abs(correlation_matrix[1:, :-1]),
    x=channel_titles[:-1],
    y=channel_titles[1:],
)
fig.show()

# %% Scatter plot of activations

fig = make_subplots(rows=grp_size - 1, cols=grp_size - 1, row_titles=channel_titles[1:], column_titles=channel_titles[:-1])
fig.update_layout(title=move_dir_title + " Group")
rng = np.random.default_rng(42)
sample_indices = rng.choice(down_activations.shape[1], min(down_activations.shape[1], 10000), replace=False)
down_activations_sampled = down_activations[:, sample_indices]
for i in range(grp_size):
    for j in range(grp_size):
        if i <= j:
            continue
        fig.add_trace(
            px.scatter(x=down_activations_sampled[i], y=down_activations_sampled[j]).data[0],
            row=i,
            col=j + 1,
        )
fig.show()

# %% # %% Visualize group activations
env_idx = 198
# toy_cache = join_cache_across_steps([{k: v[env_idx] for k, v in clean_cache.items()}])
toy_cache = {k: v[:, env_idx] for k, v in clean_cache.items()}

reps = 1 if last_tick else 3
rep_seq_len = (seq_len - start_pos) * reps
toy_obs = np.repeat(clean_obs[start_pos:seq_len, env_idx], reps, axis=0)

down_activations_reshaped = down_activations.reshape(down_activations.shape[0], rep_seq_len, num_eps, 10, 10)
down_activations_reshaped = np.transpose(down_activations_reshaped[:, :, env_idx], (1, 0, 2, 3))

fig = plotly_feature_vis(down_activations_reshaped, toy_obs, feature_labels=channel_titles)
fig.update_layout(title=move_dir_title + " Group")
fig.show()


# %% Visualize group activations from play level

env_idx = 1
toy_cache = join_cache_across_steps([{k: v for k, v in clean_play.cache.items()}])

reps = 1 if last_tick else 3
rep_seq_len = (seq_len - start_pos) * reps
toy_obs = np.repeat(clean_play.obs[:, env_idx], reps, axis=0)

feature_acts = []
for l, c in grouped_channels[move_dir]:
    acts = toy_cache[f"features_extractor.cell_list.{l}.hook_h"][:, env_idx, c]
    acts = INV_CHANNEL_OFFSET_FNS[l][c](acts, last_dim_grid=True)
    feature_acts.append(acts)
feature_acts = np.stack(feature_acts, axis=0)

fig = plotly_feature_vis(feature_acts, toy_obs, feature_labels=channel_titles)
fig.update_layout(title=move_dir_title + " Group")
fig.show()
