# %%
import dataclasses
import re
from copy import deepcopy
from functools import partial

import numpy as np
import plotly.io as pio
import sklearn.linear_model
import torch as th
from cleanba.environments import BoxobanConfig
from matplotlib import pyplot as plt
from stable_baselines3.common.distributions import CategoricalDistribution
from transformer_lens.hook_points import HookPoint

from learned_planners import BOXOBAN_CACHE, LP_DIR
from learned_planners.interp.channel_group import get_group_channels, get_group_connections, layer_groups
from learned_planners.interp.collect_dataset import join_cache_across_steps
from learned_planners.interp.offset_fns import offset_yx
from learned_planners.interp.plot import plotly_feature_vis
from learned_planners.interp.train_probes import TrainOn
from learned_planners.interp.utils import load_jax_model_to_torch, load_probe, parse_level, play_level
from learned_planners.interp.weight_utils import visualize_top_conv_inputs
from learned_planners.policies import download_policy_from_huggingface

# set_plotly_renderer("emacs")
pio.renderers.default = "notebook"
th.set_printoptions(sci_mode=False, precision=2)

# %%

MODEL_PATH_IN_REPO = "drc33/bkynosqi/cp_2002944000"  # DRC(3, 3) 2B checkpoint
MODEL_PATH = download_policy_from_huggingface(MODEL_PATH_IN_REPO)

# boxes_direction_probe_file = Path(
#     "/training/TrainProbeConfig/05-probe-boxes-future-direction/wandb/run-20240813_184417-vb6474rg/local-files/probe_l-all_x-all_y-all_c-all.pkl"
# )

boxo_cfg = BoxobanConfig(
    cache_path=BOXOBAN_CACHE,
    num_envs=1,
    max_episode_steps=200,
    min_episode_steps=200,
    asynchronous=False,
    tinyworld_obs=True,
    split=None,
    difficulty="hard",
)
model_cfg, model = load_jax_model_to_torch(MODEL_PATH, boxo_cfg)

orig_state_dict = deepcopy(model.state_dict())


envs = boxo_cfg.make()


def restore_model():
    model.load_state_dict(orig_state_dict)


# %% Print model's hook points
def recursive_children(model):
    for c in model.children():
        yield c
        yield from recursive_children(c)


[c.name for c in recursive_children(model) if isinstance(c, HookPoint)]

# %%
# probe = pd.read_pickle("probe-future-directions.pkl")
probe, _ = load_probe("probes/best/boxes_future_direction_map_l-all.pkl")
try:
    new_coef = np.concatenate([probe.estimators_[i].coef_ for i in range(len(probe.estimators_))], 0)
    new_intercept = np.concatenate([probe.estimators_[i].intercept_ for i in range(len(probe.estimators_))], 0)

    probe = sklearn.linear_model.LogisticRegression()
    probe.coef_ = new_coef
    probe.intercept_ = new_intercept
    probe.classes_ = np.arange(new_coef.shape[0])
    probe.n_features_in_ = new_coef.shape[1]
except AttributeError:
    print("Skipping constructing the probe as it is already the right kind")

probe_info = TrainOn(dataset_name="boxes_future_direction")


# %%

combined_probe, combined_intercepts = th.load(LP_DIR / "learned_planners/notebooks/action_l2_probe.pt", weights_only=True)
aggregation_weight, aggregation_bias = th.load(LP_DIR / "learned_planners/notebooks/aggregation.pt", weights_only=True)


def bigger_levels_get_distribution(self, obs, carry, episode_starts, feature_extractor_kwargs=None):
    _, new_carry = model._recurrent_extract_features(obs, carry, episode_starts)
    new_h, new_c = new_carry[2]
    probe_in = th.cat([new_h, new_c], dim=2).squeeze(0)
    # Get the channels which represent actions
    # actions_per_location = new_h[:, [29, 8, 27, 3], :, :].squeeze(0)
    actions_per_location = th.einsum("nchw,oc->nohw", probe_in, combined_probe) + combined_intercepts[None, :, None, None]

    # Aggregate
    num_action1 = actions_per_location.mean((2, 3))
    num_action2 = actions_per_location.max(dim=2, keepdim=False).values.max(dim=2, keepdim=False).values
    num_action3 = (actions_per_location > 0).float().mean((2, 3))
    actions = (
        num_action1 * aggregation_weight[0]
        + num_action2 * aggregation_weight[1]
        + num_action3 * aggregation_weight[2]
        + aggregation_bias
    )
    return CategoricalDistribution(actions.shape[-1]).proba_distribution(action_logits=actions), new_carry


def bigger_recurrent_initial_state(self, dim_room, N, device=None):
    return [(th.zeros([N, 1, 32, *dim_room], device=device), th.zeros([N, 1, 32, *dim_room], device=device)) for _ in range(3)]


# %%


def long_level_for_speed(pushes=1, direction=1):
    assert pushes > 4 and pushes % 2 == 1, "pushes must be odd and greater than 4"
    num_rows = 5 + pushes

    side_walls = "#" * ((num_rows - 8) // 2)
    level_map = (
        f"""
{side_walls}########{side_walls}
{side_walls}  @##   {side_walls}
{side_walls}  $##.  {side_walls}
"""
        + f"{side_walls}   ##   {side_walls}\n" * (pushes - 1)
        + f"""{side_walls}        {side_walls}
{side_walls}        {side_walls}
{side_walls}########{side_walls}
"""
    ).strip()

    y, x = -3, (map_dim // 2) - 1
    # print(y, x)
    if direction == 0:
        level_map = "\n".join(level_map.split("\n")[::-1])
        y = 2
    elif direction == 2:
        lines = level_map.strip().split("\n")
        # Pad lines to make them all equal length
        max_len = max(len(line) for line in lines)
        lines = [line.ljust(max_len) for line in lines]
        # Rotate 90 degrees clockwise
        rotated_lines = ["".join(lines[len(lines) - j - 1][i] for j in range(len(lines))) for i in range(max_len)]
        level_map = "\n".join(rotated_lines)
        y, x = x, 2
    elif direction == 3:
        lines = level_map.strip().split("\n")
        # Pad lines to make them all equal length
        max_len = max(len(line) for line in lines)
        lines = [line.ljust(max_len) for line in lines]
        # Rotate 90 degrees counter-clockwise
        rotated_lines = ["".join(lines[j][max_len - i - 1] for j in range(len(lines))) for i in range(max_len)]
        level_map = "\n".join(rotated_lines[::-1])
        y, x = x, y
        # print(y, x)

    rows = level_map.strip().split("\n")
    cols = len(rows[0])
    assert len(rows) == num_rows and cols == num_rows, f"Expected {num_rows}x{num_rows} but got {len(rows)}x{cols}"

    # return level_map, num_rows, [(y, x), (y + 1, x) if direction > 1 else (y, x + 1)]
    horz = int(direction > 1)
    vert = int(direction <= 1)
    # print(y, x)
    return level_map, num_rows, [y - horz, y, y + horz], [x - vert, x, x + vert]


# %%

# print_hook = True


# def strengthen_act_hook(inp, hook):
#     global print_hook
#     if print_hook:
#         print("strengthen hook applied")
#         print_hook = False
#     # inp[inp > 0.5] = 1
#     # inp[inp < -0.5] = -1
#     inp[inp > 0.45] += 0.3
#     inp[inp < -0.45] -= 0.3
#     inp = inp.clamp(-0.9, 0.9)
#     return inp


# fwd_hooks = [
#     (
#         f"features_extractor.cell_list.{layer}.{hook_type}.{pos}.{int_step}",
#         strengthen_act_hook,
#     )
#     for layer in [0, 1, 2]
#     for pos in [0]
#     for int_step in [0, 1, 2]
#     for hook_type in ["hook_h"]
# ]


def is_connected(lc1, lc2):
    """Checks if lc1 is fed as input to lc2"""
    l1, c1 = lc1
    l2, c2 = lc2

    return l1 == l2 or ((l1 + 1) % 3 == l2)


layer_values = {}

# for k, v in toy_cache.items():
#     if m := re.match("^.*([0-9]+)\\.hook_([h])$", k):
#         layer_values[int(m.group(1))] = v

desired_groups_box = ["B up", "B down", "B left", "B right"]
box_channels = get_group_channels(desired_groups_box)

desired_groups_agent = ["A up", "A down", "A left", "A right"]
agent_channels = get_group_channels(desired_groups_agent)
# combine box and agent channels
group_channels = [b + a for b, a in zip(box_channels, agent_channels)]
group_channels += get_group_channels(["Misc plan", "T", "No label"])
# inp,out
group_connections = get_group_connections(group_channels)
restore_model()
# ablate connections from B down to B right
factor = 1.3
for g1 in group_connections:
    for g2 in g1:
        for inplc, outlc in g2:
            inpl, inc = inplc
            outl, outc = outlc
            outc_ijfo = [idx * 32 + outc for idx in range(4)]
            if inpl == outl:
                model.features_extractor.cell_list[outl].conv_hh.weight.data[outc_ijfo, inc] *= factor
            else:
                model.features_extractor.cell_list[outl].conv_ih.weight.data[outc_ijfo, inc + 32] *= factor

# for layer in range(3):
#     model.features_extractor.cell_list[layer].conv_ih.weight.data[:, 32:64] *= factor
#     # model.features_extractor.cell_list[layer].conv_ih.weight.data[:, 32:64] *= factor
#     model.features_extractor.cell_list[layer].conv_hh.weight.data *= factor


# %%
restore_model()
# %%
pushes, direction = 15, 3
level_map, map_dim, ys, xs = long_level_for_speed(pushes=pushes, direction=direction)
print(level_map, ys, xs)
print(len(level_map.split("\n")), len(level_map.split("\n")[0]))


def flatten(l):
    return [flatten(item) for sublist in l for item in sublist] if isinstance(l, list) else l


thinking_steps = 0
fwd_hooks = []
max_steps = thinking_steps + map_dim * 3
dim_room = (map_dim, map_dim)
internal_steps = False

if getattr(model.get_distribution, "__name__", None) == "get_distribution":
    old_get_distribution = model.get_distribution
model.get_distribution = partial(bigger_levels_get_distribution, model)  # type: ignore

envs = dataclasses.replace(boxo_cfg, dim_room=dim_room).make()
if getattr(model.recurrent_initial_state, "__name__", None) == "recurrent_initial_state":
    old_recurrent_initial_state = model.recurrent_initial_state
model.recurrent_initial_state = partial(bigger_recurrent_initial_state, model, dim_room)


reset_opts = parse_level(level_map)
toy_out = play_level(
    envs,
    model,
    # reset_opts=dict(walls=walls, boxes=boxes, targets=targets, player=player),
    reset_opts=reset_opts,
    # reset_opts=level_reset_opt,
    thinking_steps=thinking_steps,
    fwd_hooks=fwd_hooks,
    max_steps=max_steps,
    hook_steps=list(range(thinking_steps, max_steps)) if thinking_steps > 0 else -1,
    probes=[probe],
    probe_train_ons=[probe_info],
    probe_logits=True,
    internal_steps=internal_steps,
)

# relevant_channel = [(0, 13), (1, 17), (2, 20), (2, 9)][direction]
relevant_channel = [(2, 9), (2, 9), (1, 17), (1, 17)][direction]

toy_cache = toy_out.cache
toy_cache = {k: v.squeeze(1) for k, v in toy_cache.items() if len(v.shape) == 5}
toy_obs = toy_out.obs.squeeze(1)
toy_obs = toy_out.obs.squeeze(1).numpy()

channel_acts = toy_cache[f"features_extractor.cell_list.{relevant_channel[0]}.hook_h"][:, relevant_channel[1]]


y_and_x = np.array([offset_yx(y, x, [relevant_channel[1]], relevant_channel[0]) for y, x in zip(ys, xs)])
ys, xs = y_and_x[:, 0], y_and_x[:, 1]
channel_acts = channel_acts[:, ys, xs].sum(1)
if direction > 1:
    channel_acts = -1 * channel_acts

plt.plot(channel_acts[: int(dim_room[0] * 3 * 1.1)])
plt.show()

repeats = 3 if internal_steps else 1
probe_outs = toy_out.probe_outs[0]
probe_outs = np.moveaxis(np.reshape(probe_outs, (len(probe_outs) * repeats, *probe_outs.shape[-3:])), -1, 1)
# assert probe_outs.shape == (45, 5, 10, 10)
fig = plotly_feature_vis(probe_outs, np.repeat(toy_obs, repeats, 0), "alternative box dirs probe").show()

first_active_time = np.where(channel_acts[1:] - channel_acts[:-1] > 0.3)[0][0] + 1
speed = pushes / first_active_time
print(f"Speed: {speed:.2f} squares per tick", "sqs", pushes, "ticks", first_active_time)


# %% Visualize all channels


for k, v in toy_cache.items():
    if m := re.match("^.*hook_([h])$", k):
        fig = plotly_feature_vis(v, np.repeat(toy_obs, 3, 0), k, m.group(1).upper())
        fig.update_layout(height=800)
        fig.show()

# %% Visualize features grouped by type

layer_values = {}

for k, v in toy_cache.items():
    if m := re.match("^.*([0-9]+)\\.hook_([c])$", k):
        layer_values[int(m.group(1))] = v

desired_groups = ["B up", "B down", "B left", "B right"]

channels = []
labels = []

for group in desired_groups:
    for layer in layer_groups[group]:
        channels.append(layer_values[layer["layer"]][:, layer["idx"], :, :])
        labels.append(f"{group} L{layer['layer']}H{layer['idx']}")

channels = np.stack(channels, 1)

fig = plotly_feature_vis(channels, np.repeat(toy_obs, 3, 0), feature_labels=labels)
fig.update_layout(height=800)
fig.show()


# %%
plot_layer, plot_channel = 2, 9
# plot_layer, plot_channel = 0, 17
# tick = reps - 1
tick = 0
show_ticks = True
# keys = [f"features_extractor.cell_list.{plot_layer}.hook_{k.lower()}" for k in ["H", "C", "I", "J", "F", "O"]]
keys = [f"features_extractor.cell_list.{plot_layer}.hook_{k.lower()}" for k in ["H", "J", "O"]]

toy_all_channels_for_lc = np.stack([toy_cache[key][:, plot_channel] for key in keys], axis=1)
# if not show_ticks:
#     toy_all_channels_for_lc = toy_all_channels_for_lc[tick::3]

# repeat obs 3 along first dimension
fig = plotly_feature_vis(
    toy_all_channels_for_lc,
    np.repeat(toy_obs, 3, 0),
    feature_labels=[k.rsplit(".")[-1] for k in keys],
)
fig.show()

# %% conv ih and hh

plot_layer, plot_channel = 2, 9

keys = [f"features_extractor.cell_list.{plot_layer}.hook_{k.lower()}" for k in ["conv_ih", "conv_hh"]]

toy_all_channels_for_lc = np.stack([toy_cache[key][:, 32 * ijo + plot_channel] for key in keys for ijo in [0, 1, 3]], axis=1)
# repeat obs 3 along first dimension
fig = plotly_feature_vis(
    toy_all_channels_for_lc,
    np.repeat(toy_obs, 3, 0),
    feature_labels=[k.rsplit(".")[-1][5:] + "_" + "ijfo"[ijo] for k in keys for ijo in [0, 1, 3]],
)
fig.show()

# %% Visualize

plot_layer, plot_channel, ih, ijfo, inp_types = 2, 9, False, "o", "lh"


def ijfo_idx(ijfo):
    return ["i", "j", "f", "o"].index(ijfo)


toy_all_channels_for_lc, top_channels, values = visualize_top_conv_inputs(
    plot_layer,
    plot_channel,
    out_type=ijfo,
    model=model,
    cache=toy_cache,
    ih=ih,
    num_channels=6 + 1 * 8,
    inp_types=inp_types,
    top_channel_sum=True,
)
plot_channel = 32 * ijfo_idx(ijfo) + plot_channel
toy_all_channels_for_lc = toy_all_channels_for_lc.numpy()
fig = plotly_feature_vis(
    toy_all_channels_for_lc,
    np.repeat(toy_obs, 3, 0),
    feature_labels=[f"{c}: {v:.2f}" for c, v in zip(top_channels, values)],  # + ["ih" if ih else "hh"],
    common_channel_norm=True,
)
fig.show()

# %% Activation patching


def abs_hook(inp, hook):
    inp[:, 17, 4:6, 7:11] = inp[:, 17, 4:6, 7:11].abs()
    return inp


def abs_hook_nine(inp, hook):
    inp[:, 9, 4:6, 7:11] = inp[:, 9, 4:6, 7:11].abs()
    return inp


def abs_hook_l1(inp, hook):
    inp[:, 13, 5:6, 7:11] = inp[:, 13, 5:6, 7:11].abs()
    return inp


fwd_hooks = [
    (
        f"features_extractor.cell_list.{layer}.{hook_type}.{pos}.{int_step}",
        abs_hook,
    )
    for layer in [0]
    for pos in [0]
    for int_step in [0, 1, 2]
    for hook_type in ["hook_h"]
]
fwd_hooks += [
    (
        f"features_extractor.cell_list.{layer}.{hook_type}.{pos}.{int_step}",
        abs_hook_nine,
    )
    for layer in [2]
    for pos in [0]
    for int_step in [0, 1, 2]
    for hook_type in ["hook_h"]
]

fwd_hooks += [
    (
        f"features_extractor.cell_list.{layer}.{hook_type}.{pos}.{int_step}",
        abs_hook_l1,
    )
    for layer in [1]
    for pos in [0]
    for int_step in [0, 1, 2]
    for hook_type in ["hook_h"]
]

toy_out = play_level(
    envs,
    model,
    reset_opts=reset_opts,
    thinking_steps=thinking_steps,
    fwd_hooks=fwd_hooks,
    max_steps=max_steps,
    hook_steps=list(range(4, 100)),
    probes=[probe],
    probe_train_ons=[probe_info],
    probe_logits=True,
    internal_steps=internal_steps,
)
toy_cache = join_cache_across_steps([toy_out.cache])
toy_cache = {
    k: np.transpose(v.squeeze(2), (1, 0, 2, 3, 4)).reshape(-1, *v.shape[-3:])
    for k, v in toy_cache.items()
    if len(v.shape) == 6
}
toy_obs = toy_out.obs.squeeze(1)
toy_obs = toy_out.obs.squeeze(1).numpy()


repeats = 3 if internal_steps else 1
probe_outs = toy_out.probe_outs[0]
probe_outs = np.moveaxis(np.reshape(probe_outs, (len(probe_outs) * repeats, *probe_outs.shape[-3:])), -1, 1)
# assert probe_outs.shape == (45, 5, 10, 10)
fig = plotly_feature_vis(probe_outs, np.repeat(toy_obs, repeats, 0), "alternative box dirs probe").show()
