# %%
import dataclasses
import re
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict

import numpy as np
import torch as th
from cleanba.environments import BoxobanConfig
from matplotlib import pyplot as plt
from tqdm import tqdm
from transformer_lens.hook_points import HookPoint

from learned_planners.interp.utils import load_jax_model_to_torch
from learned_planners.notebooks.emacs_plotly_render import set_plotly_renderer
from learned_planners.policies import download_policy_from_huggingface

set_plotly_renderer("emacs")

# %%
MODEL_PATH_IN_REPO = "drc33/bkynosqi/cp_2002944000"  # DRC(3, 3) 2B checkpoint
MODEL_PATH = download_policy_from_huggingface(MODEL_PATH_IN_REPO)

# try:
#     BOXOBAN_CACHE = Path(__file__).parent.parent.parent / ".sokoban_cache"
# except NameError:
#     BOXOBAN_CACHE = Path(os.getcwd()) / ".sokoban_cache"
BOXOBAN_CACHE = Path("/opt/sokoban_cache")

boxes_direction_probe_file = Path(
    "/training/TrainProbeConfig/05-probe-boxes-future-direction/wandb/run-20240813_184417-vb6474rg/local-files/probe_l-all_x-all_y-all_c-all.pkl"
)

boxo_cfg = BoxobanConfig(
    cache_path=BOXOBAN_CACHE,
    num_envs=2,
    max_episode_steps=120,
    min_episode_steps=120,
    asynchronous=False,
    tinyworld_obs=True,
    split="train",
    difficulty="medium",
)
model_cfg, model = load_jax_model_to_torch(MODEL_PATH, boxo_cfg)


orig_state_dict = deepcopy(model.state_dict())


def restore_model():
    model.load_state_dict(orig_state_dict)


# %%


def run_policy_reset(num_steps, envs, policy):
    new_obs, _ = envs.reset()
    obs = [new_obs]
    carry = policy.recurrent_initial_state(envs.num_envs)
    all_false = th.zeros(envs.num_envs, dtype=th.bool)
    for _ in range(num_steps - 1):
        action, _value, something, carry = policy(th.as_tensor(new_obs), carry, all_false)
        new_obs, _, term, trunc, _ = envs.step(action.detach().cpu().numpy())
        assert not (np.any(term) or np.any(trunc))
        obs.append(new_obs)
    return th.as_tensor(np.stack(obs))


seq_len = 10

envs = dataclasses.replace(boxo_cfg, num_envs=512).make()
clean_obs = run_policy_reset(seq_len, envs, model)
corrupted_obs = run_policy_reset(seq_len, envs, model)

# %%

zero_carry = model.recurrent_initial_state(envs.num_envs)
eps_start = th.zeros((seq_len, envs.num_envs), dtype=th.bool)
eps_start[0, :] = True

(clean_actions, clean_values, clean_log_probs, _), clean_cache = model.run_with_cache(
    clean_obs, zero_carry, eps_start, deterministic=True
)
# Create corrupted activations which are other random levels. The hope is that, over a large data set, they will change
# the output enough times and in different enough ways that we can correctly attribute to things in the latest layer
_, corrupted_cache = model.run_with_cache(corrupted_obs, zero_carry, eps_start)

# %%
# key_pattern = [rf".*hook_{v}\.\d\.\d$" for v in ["h", "c", "i", "j", "f", "o"]]
key_pattern = [rf".*hook_{v}\.\d\.\d$" for v in ["i", "j", "f", "o"]]

# %%


def interpolate_rnn_inputs(alpha, am1_tuple, a_tuple):
    am1, _, _ = am1_tuple
    a, _, _ = a_tuple

    return ((1 - alpha) * am1 + alpha * a, *am1_tuple[1:])


def add_attributions(
    attributions: Dict[str, th.Tensor],
    loss_fn,
    model,
    clean_inputs: Any,
    corrupted_inputs: Any,
    clean_cache: Dict[str, th.Tensor],
    corrupted_minus_clean_cache: Dict[str, th.Tensor],
    *,
    ablate_at_every_hook: bool = False,
    n_gradients_to_integrate: int = 5,
) -> Dict[str, th.Tensor]:
    """Attributes the output to every parameter in `clean_cache`. The attribution strength is added go `attributions`
    and returned. The `attributions` parameter is useful for computing the attribution in several minibatches.

    """
    assert n_gradients_to_integrate >= 1

    def set_corrupted_hook(inputs: th.Tensor, hook: HookPoint):
        nonlocal alpha
        if ablate_at_every_hook:
            desired = clean_cache[str(hook.name)] + alpha * corrupted_minus_clean_cache[str(hook.name)]
            return inputs + (desired - inputs).detach()
        return None

    def save_gradient_hook(grad: th.Tensor, hook: HookPoint):
        nonlocal attributions
        attr = (grad.detach() * corrupted_minus_clean_cache[str(hook.name)]).sum(0).detach() / n_gradients_to_integrate
        try:
            attributions[str(hook.name)].add_(attr)
        except KeyError:
            attributions[str(hook.name)] = attr

    keys = clean_cache.keys()
    keys = [fk for fk in keys if any(re.match(k, fk) for k in key_pattern)]

    fwd_hooks = [(k, set_corrupted_hook) for k in keys]
    bwd_hooks = [(k, save_gradient_hook) for k in keys]
    with model.input_dependent_hooks_context(*clean_inputs, fwd_hooks=fwd_hooks, bwd_hooks=bwd_hooks):
        with model.hooks(fwd_hooks=fwd_hooks, bwd_hooks=bwd_hooks):
            for k in range(n_gradients_to_integrate):
                alpha = k / max(1, n_gradients_to_integrate - 1)

                # Check that computation is right
                assert 0.0 <= alpha <= 1.0
                if k == 0:
                    assert alpha == 0.0
                if n_gradients_to_integrate > 1 and alpha == n_gradients_to_integrate - 1:
                    assert alpha == 1.0

                loss_fn(interpolate_rnn_inputs(alpha, clean_inputs, corrupted_inputs)).backward()
                model.zero_grad()
    return attributions


def loss_fn(inputs):
    obs, init_carry, eps_start = inputs
    dist, _carry = model.get_distribution(obs, init_carry, eps_start)
    logits = dist.distribution.log_prob(clean_actions)
    bw_tensor = logits.sum()
    return bw_tensor


# TODO: modify this call so caches only contain keys that we want to look at (hook_h, hook_c, hook_ijfo, etc.)
attributions = add_attributions(
    {},
    loss_fn,
    model,
    clean_inputs=(clean_obs, zero_carry, eps_start),
    corrupted_inputs=(corrupted_obs, zero_carry, eps_start),
    clean_cache={k: v.detach() for k, v in clean_cache.items()},
    corrupted_minus_clean_cache={k: (v - corrupted_cache[k]).detach() for k, v in clean_cache.items()},
    ablate_at_every_hook=False,
    n_gradients_to_integrate=2,
)

# %%
mean_attributions = {}
mean_attributions_channels = {}
prune_channels = {}
do_abs = True


def shorten_key(key):
    layer = int(key.split(".")[2])
    layer_type = key.split(".")[3][5:].upper()
    return f"L{layer}{layer_type}"


def unshorten_key(short_key):
    layer = int(short_key[1])
    layer_type = short_key[2].lower()
    return f"features_extractor.cell_list.{layer}.hook_{layer_type}"


for k, v in attributions.items():
    if k.rsplit(".", 2)[0] in mean_attributions:
        mean_attributions[k.rsplit(".", 2)[0]] += v.abs() if do_abs else v
    else:
        mean_attributions[k.rsplit(".", 2)[0]] = v.abs() if do_abs else v
for k, v in mean_attributions.items():
    for c in range(v.shape[0]):
        mean_attributions_channels[shorten_key(k) + f"{c}"] = v[c].mean().item()

mean_attributions_channels = sorted(mean_attributions_channels.items(), key=lambda x: x[1], reverse=True)
# %%
print(mean_attributions_channels[:10])
# %%

mean_activations = {k: v.mean(dim=(0, 2, 3), keepdims=True) for k, v in clean_cache.items() if len(v.shape) == 4}
# prune lowest activations by substituting them with the mean activations
# top_k_channels_list = np.array([32, 64, 128, 256, 300, 375, 425, 500, 576])
top_k_channels_list = np.linspace(0, len(mean_attributions_channels), 12, dtype=int)
accs, kls = [], []
for top_k_channels in tqdm(top_k_channels_list):
    prune_channels = {}
    for k, v in mean_attributions_channels[top_k_channels:]:
        unshorten_k = unshorten_key(k)
        if unshorten_k in prune_channels:
            prune_channels[unshorten_k].append(int(k[3:]))
        else:
            prune_channels[unshorten_k] = [int(k[3:])]

    def prune_hook(inputs: th.Tensor, hook: HookPoint):
        assert hook.name is not None
        hook_key = hook.name.rsplit(".", 2)[0]
        inputs[:, prune_channels[hook_key]] = mean_activations[hook.name][:, prune_channels[hook_key]]
        return inputs

    fwd_hooks = [(k, prune_hook) for k in clean_cache.keys() if k.rsplit(".", 2)[0] in prune_channels]
    with model.input_dependent_hooks_context(*clean_obs, fwd_hooks=fwd_hooks, bwd_hooks=[]):
        with model.hooks(fwd_hooks=fwd_hooks, bwd_hooks=[]):
            prune_actions, _, prune_log_probs, _ = model(clean_obs, zero_carry, eps_start, deterministic=True)
    # compute kl divergence using the log probs
    kl_div = th.nn.functional.kl_div(prune_log_probs, clean_log_probs, log_target=True)
    # accuracy of prune_actions
    accuracy = (prune_actions == clean_actions).float().mean().item()
    print(f"Accuracy: {accuracy:.2f}, KL Divergence: {kl_div.item():.2f}")
    accs.append(accuracy)
    kls.append(kl_div.item())

plt.plot(len(mean_attributions_channels) - top_k_channels_list, accs)
plt.plot(len(mean_attributions_channels) - top_k_channels_list, kls)
plt.legend(["Accuracy", "KL Divergence"])
plt.xlabel("Number of channels pruned (starting from lowest attributions)")
plt.ylabel("Value")
plt.show()

# %% Prune specific channels

mean_activations = {k: v.mean(dim=(0, 2, 3), keepdims=True) for k, v in clean_cache.items() if len(v.shape) == 4}

prune_channels_list = ["L0H31"]
prune_channels = {unshorten_key(k): [int(k[3:])] for k in prune_channels_list}


def prune_hook(inputs: th.Tensor, hook: HookPoint):
    assert hook.name is not None
    hook_key = hook.name.rsplit(".", 2)[0]
    inputs[:, prune_channels[hook_key]] = mean_activations[hook.name][:, prune_channels[hook_key]]
    return inputs


fwd_hooks = [(k, prune_hook) for k in clean_cache.keys() if k.rsplit(".", 2)[0] in prune_channels]
print("Fwd hooks", len(fwd_hooks))
with model.input_dependent_hooks_context(*clean_obs, fwd_hooks=fwd_hooks, bwd_hooks=[]):
    with model.hooks(fwd_hooks=fwd_hooks, bwd_hooks=[]):
        prune_actions, _, prune_log_probs, _ = model(clean_obs, zero_carry, eps_start, deterministic=True)
# compute kl divergence using the log probs
kl_div = th.nn.functional.kl_div(prune_log_probs, clean_log_probs, log_target=True)
# accuracy of prune_actions
accuracy = (prune_actions == clean_actions).float().mean().item()
print(f"Accuracy: {accuracy:.2f}, KL Divergence: {kl_div.item():.2f}")


# %%

inp_types, out_types = ["e", "lh", "ch"], ["i", "j", "f", "o"]
inp_type_to_hook = {"e": "hook_layer_input", "lh": "hook_prev_layer_hidden", "ch": "hook_pool_project"}


def get_conv_weights(layer, out, inp, out_type="o", inp_type="lh", ih=True):
    assert out_type in out_types
    if ih:
        assert inp_type in inp_types
        if isinstance(inp, int):
            inp += 32 * inp_types.index(inp_type)
        else:
            inp = slice(32 * inp_types.index(inp_type), 32 * (inp_types.index(inp_type) + 1))
    else:
        if not isinstance(inp, int):
            inp = slice(None)
    if isinstance(out, int):
        out += 32 * out_types.index(out_type)
    else:
        out = slice(32 * out_types.index(out_type), 32 * (out_types.index(out_type) + 1))
    comp = model.features_extractor.cell_list[layer]
    comp = comp.conv_ih if ih else comp.conv_hh
    return comp.weight.data[out, inp]


def top_weights(layer, out, out_type="o", inp_type="lh", ih=True):
    top_channels = get_conv_weights(layer, out, None, out_type, inp_type, ih).abs().max(dim=1).values.max(dim=1).values
    return top_channels.argsort(descending=True)


def top_weights_out(layer, inp, out_type="o", inp_type="lh"):
    next_layer = (layer + 1) % 3
    inp = 32 * inp_types.index(inp_type) + inp
    out_idx = out_types.index(out_type)
    top_channels = (
        model.features_extractor.cell_list[next_layer]
        .conv_ih.weight.data[32 * out_idx : 32 * (out_idx + 1), inp]
        .abs()
        .max(dim=1)
        .values.max(dim=1)
        .values
    )
    return top_channels.argsort(descending=True)


def get_top_conv_inputs(layer, out_type="o", inp_type="lh", ih=True, cache=clean_cache):
    conv_weights = get_conv_weights(layer, None, None, out_type, inp_type, ih)
    if ih:
        hook_type = inp_type_to_hook[inp_type]
        inputs = cache[f"features_extractor.cell_list.{layer}.{hook_type}"]
    else:
        inputs = cache[f"features_extractor.cell_list.{layer}.hook_input_h"]
    conv_kernels = conv_weights.reshape(conv_weights.shape[0] * conv_weights.shape[1], 1, 3, 3)
    if len(inputs.shape) == 5:
        inputs = inputs.reshape(inputs.shape[0] * inputs.shape[1], *inputs.shape[2:])
        # inputs = inputs[:, 0]
    inputs = th.tensor(inputs).repeat(1, len(conv_weights), 1, 1)
    conv_output = th.nn.functional.conv2d(inputs, conv_kernels, padding=1, groups=len(conv_kernels))
    conv_output = conv_output.reshape(
        conv_output.shape[0], conv_weights.shape[0], conv_weights.shape[1], *conv_output.shape[-2:]
    )
    values = conv_output.abs().max(dim=0).values.max(dim=2).values.max(dim=2).values
    top_channels = values.argsort(dim=1, descending=True)
    conv_output = th.take_along_dim(conv_output, top_channels[None, ..., None, None], dim=2)
    # top_channels = top_channels[:num_channels]

    # conv_output = th.cat([conv_output[:, top_channels], conv_output.sum(dim=1, keepdim=True)], dim=1)
    # values = values[top_channels].tolist() + [-1]
    # top_channels = top_channels.tolist() + [-1]
    # return conv_output, top_channels, values
    return conv_output, top_channels


# clean_cache = join_cache_across_steps([clean_cache])
important_channels = [
    [
        [
            [
                get_top_conv_inputs(layer, out_type=out_type, inp_type=inp_type, ih=ih, cache=clean_cache)
                for ih in [False, True]
            ]
            for inp_type in inp_types
        ]
        for out_type in out_types
    ]
    for layer in range(3)
]

# %%

dependencies = np.arange(64, 1, -10)
accs, kls = [], []

for num_dependencies in dependencies:
    new_sd = deepcopy(orig_state_dict)

    for k, v in new_sd.items():
        if "features_extractor.cell_list" in k and "weight" in k and "conv_ih" in k:
            top_channels = v.data[:, 32:].abs().max(dim=2).values.max(dim=2).values
            indices_to_zero = top_channels.argsort(axis=1)[:, :-num_dependencies]  # All except top 5 channels

            # Create mask using broadcasting
            batch_size, channels = indices_to_zero.shape[0], v.data[:, 32:].shape[1]
            channel_indices = th.arange(channels, device=v.data.device)
            mask = ~(channel_indices.view(1, -1) == indices_to_zero.unsqueeze(-1)).any(dim=1)
            mask = mask.view(batch_size, channels, 1, 1).expand_as(v.data[:, 32:])
            # Zero out channels in-place
            v.data[:, 32:] = v.data[:, 32:] * mask
    # break
    model.load_state_dict(new_sd)
    with th.no_grad():
        prune_actions, _, prune_log_probs, _ = model(clean_obs, zero_carry, eps_start, deterministic=True)
    # compute kl divergence using the log probs
    kl_div = th.nn.functional.kl_div(prune_log_probs, clean_log_probs, log_target=True)
    # accuracy of prune_actions
    accuracy = (prune_actions == clean_actions).float().mean().item()
    print(f"Threshold: {num_dependencies:.2f}, Accuracy: {accuracy:.2f}, KL Divergence: {kl_div.item():.2f}")
    accs.append(accuracy)
    kls.append(kl_div.item())

plt.plot(dependencies, accs)
plt.plot(dependencies, kls)
plt.legend(["Accuracy", "KL Divergence"])
plt.xlabel(
    "Number of input channels kept for (prev_h, pool) -> conv_ih connection\n (pruning channels with lowest abs weights)"
)
plt.ylabel("Value")
plt.show()


# %%

dependencies = np.arange(64, 1, -10)
accs, kls = [], []

for num_dependencies in dependencies:
    new_sd = deepcopy(orig_state_dict)
    for k, v in new_sd.items():
        if "features_extractor.cell_list" in k and "weight" in k and "conv_ih" in k:
            # print(k)
            # v.data[v.data.abs() < threshold] = 0.0
            top_channels = v.data[:, 32:].abs().max(dim=2).values.max(dim=2).values
            indices_to_zero = top_channels.argsort(axis=1)[:, :-num_dependencies]  # All except top 5 channels

            # Create mask using broadcasting
            batch_size, channels = indices_to_zero.shape[0], v.data[:, 32:].shape[1]
            channel_indices = th.arange(channels, device=v.data.device)
            mask = ~(channel_indices.view(1, -1) == indices_to_zero.unsqueeze(-1)).any(dim=1)
            mask = mask.view(batch_size, channels, 1, 1).expand_as(v.data[:, 32:])
            # Zero out channels in-place
            v.data[:, 32:] = v.data[:, 32:] * mask
    # break
    model.load_state_dict(new_sd)
    with th.no_grad():
        prune_actions, _, prune_log_probs, _ = model(clean_obs, zero_carry, eps_start, deterministic=True)
    # compute kl divergence using the log probs
    kl_div = th.nn.functional.kl_div(prune_log_probs, clean_log_probs, log_target=True)
    # accuracy of prune_actions
    accuracy = (prune_actions == clean_actions).float().mean().item()
    print(f"Threshold: {num_dependencies:.2f}, Accuracy: {accuracy:.2f}, KL Divergence: {kl_div.item():.2f}")
    accs.append(accuracy)
    kls.append(kl_div.item())

plt.plot(dependencies, accs)
plt.plot(dependencies, kls)
plt.legend(["Accuracy", "KL Divergence"])
plt.xlabel(
    "Number of input channels kept for (prev_h, pool) -> conv_ih connection\n (pruning channels with lowest abs weights)"
)
plt.ylabel("Value")
plt.show()
