"""Scatter plot of activations of channels in a group against each other.

Also, computes correlation matrix of activations of channels in a group.
"""
# %%
import dataclasses
import re
from copy import deepcopy
from pathlib import Path

import numpy as np
import plotly.express as px
import plotly.io as pio
import torch as th
from cleanba.environments import BoxobanConfig
from plotly.subplots import make_subplots

from learned_planners import BOXOBAN_CACHE
from learned_planners.interp.channel_group import get_group_channels
from learned_planners.interp.collect_dataset import join_cache_across_steps
from learned_planners.interp.offset_fns import INV_CHANNEL_OFFSET_FNS
from learned_planners.interp.plot import plotly_feature_vis
from learned_planners.interp.train_probes import TrainOn
from learned_planners.interp.utils import load_jax_model_to_torch, load_probe
from learned_planners.notebooks.emacs_plotly_render import set_plotly_renderer
from learned_planners.policies import download_policy_from_huggingface

set_plotly_renderer("emacs")
pio.renderers.default = "notebook"
th.set_printoptions(sci_mode=False, precision=2)


# %%
MODEL_PATH_IN_REPO = "drc33/bkynosqi/cp_2002944000"  # DRC(3, 3) 2B checkpoint
MODEL_PATH = download_policy_from_huggingface(MODEL_PATH_IN_REPO)

boxes_direction_probe_file = Path(
    "/training/TrainProbeConfig/05-probe-boxes-future-direction/wandb/run-20240813_184417-vb6474rg/local-files/probe_l-all_x-all_y-all_c-all.pkl"
)

boxo_cfg = BoxobanConfig(
    cache_path=BOXOBAN_CACHE,
    num_envs=2,
    max_episode_steps=120,
    min_episode_steps=120,
    asynchronous=False,
    tinyworld_obs=True,
    # split=None,
    # difficulty="hard",
    split="valid",
    difficulty="unfiltered",
)
model_cfg, model = load_jax_model_to_torch(MODEL_PATH, boxo_cfg)

orig_state_dict = deepcopy(model.state_dict())


def restore_model():
    model.load_state_dict(orig_state_dict)


probe, _ = load_probe("probes/best/boxes_future_direction_map_l-all.pkl")
probe_info = TrainOn(dataset_name="boxes_future_direction_map")
# %%


def run_policy_reset(num_steps, envs, policy):
    new_obs, _ = envs.reset()
    obs = [new_obs]
    carry = policy.recurrent_initial_state(envs.num_envs)
    all_false = th.zeros(envs.num_envs, dtype=th.bool)
    eps_done = np.zeros(envs.num_envs, dtype=bool)
    eps_solved = np.zeros(envs.num_envs, dtype=bool)
    episode_lengths = np.zeros(envs.num_envs, dtype=np.int32)
    for _ in range(num_steps - 1):
        action, _value, something, carry = policy(th.as_tensor(new_obs), carry, all_false)
        new_obs, _, term, trunc, _ = envs.step(action.detach().cpu().numpy())
        # assert not (np.any(term) or np.any(trunc))
        obs.append(new_obs)
        eps_done |= term | trunc
        episode_lengths[~eps_solved] += 1
        eps_solved |= term
        if np.all(eps_done):
            break
    episode_lengths = th.as_tensor(episode_lengths)
    obs = th.as_tensor(np.stack(obs))[:, episode_lengths == num_steps - 1]
    print("Number of envs:", obs.shape[1])
    return obs, episode_lengths


seq_len = 30

envs = dataclasses.replace(boxo_cfg, num_envs=512).make()


# %%
clean_obs, clean_length = run_policy_reset(seq_len, envs, model)
num_eps = clean_obs.shape[1]
zero_carry = model.recurrent_initial_state(num_eps)
eps_start = th.zeros((seq_len, num_eps), dtype=th.bool)
eps_start[0, :] = True

(clean_actions, clean_values, clean_log_probs, _), clean_cache = model.run_with_cache(clean_obs, zero_carry, eps_start)


# %% SINGLE ENV

# envs = dataclasses.replace(boxo_cfg, num_envs=1, difficulty="unfiltered").make()
envs = dataclasses.replace(boxo_cfg, num_envs=1).make()

# %% Visualize all layer and channels
env_idx = 1
toy_cache = join_cache_across_steps([{k: v[env_idx] for k, v in clean_cache.items()}])
toy_obs = np.repeat(clean_obs[:, env_idx], 3, axis=0)
for k, v in toy_cache.items():
    if m := re.match("^.*hook_([h])$", k):
        fig = plotly_feature_vis(v, toy_obs, k)
        fig.update_layout(height=800)
        fig.show()

# %%

grouped_channels = get_group_channels("box")
start_pos = 0
last_tick = False
activations_by_group = []
for group in grouped_channels:
    activations_by_group.append([])
    for layer, channel in group:
        # Stack all positions and internal positions for this channel
        acts = th.stack(
            [
                clean_cache[f"features_extractor.cell_list.{layer}.hook_h.{pos}.{int_pos}"][:, channel]
                for pos in range(start_pos, seq_len)
                # for int_pos in range(3)
                for int_pos in ([2] if last_tick else range(3))
            ],
            dim=0,
        )
        acts = INV_CHANNEL_OFFSET_FNS[layer][channel](acts, last_dim_grid=True)
        activations_by_group[-1].append(acts.flatten().numpy())

# Convert to numpy array and compute correlation matrix in one step
move_dir = 1
move_dir_title = ["UP", "DOWN", "LEFT", "RIGHT"][move_dir]
grp_size = len(grouped_channels[move_dir])
down_activations = np.array(activations_by_group[move_dir])
correlation_matrix = np.abs(np.corrcoef(down_activations))
# np.fill_diagonal(correlation_matrix, 0)

channel_titles = [f"L{layer}H{channel}" for layer, channel in grouped_channels[move_dir]]

# %%
correlation_matrix[np.triu_indices(correlation_matrix.shape[0], k=1)] = 0

fig = px.imshow(
    np.abs(correlation_matrix[1:, :-1]),
    x=channel_titles[:-1],
    y=channel_titles[1:],
)
fig.show()

# %% Scatter plot of activations

fig = make_subplots(rows=grp_size - 1, cols=grp_size - 1, row_titles=channel_titles[1:], column_titles=channel_titles[:-1])
fig.update_layout(title=move_dir_title + " Group")
rng = np.random.default_rng(42)
sample_indices = rng.choice(down_activations.shape[1], min(down_activations.shape[1], 10000), replace=False)
down_activations_sampled = down_activations[:, sample_indices]
for i in range(grp_size):
    for j in range(grp_size):
        if i <= j:
            continue
        fig.add_trace(
            px.scatter(x=down_activations_sampled[i], y=down_activations_sampled[j]).data[0],
            row=i,
            col=j + 1,
        )
fig.show()

# %% # %% Visualize group activations
env_idx = 7
toy_cache = join_cache_across_steps([{k: v[env_idx] for k, v in clean_cache.items()}])

reps = 1 if last_tick else 3
rep_seq_len = (seq_len - start_pos) * reps
toy_obs = np.repeat(clean_obs[start_pos:seq_len, env_idx], reps, axis=0)

down_activations_reshaped = down_activations.reshape(down_activations.shape[0], rep_seq_len, num_eps, 10, 10)
down_activations_reshaped = np.transpose(down_activations_reshaped[:, :, env_idx], (1, 0, 2, 3))

fig = plotly_feature_vis(down_activations_reshaped, toy_obs, feature_labels=channel_titles)
fig.update_layout(title=move_dir_title + " Group")
fig.show()
