# %%
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import torch as th

from learned_planners.interp.train_probes import ActivationsDataset, DatasetStore, set_seed  # noqa: F401  # pyright: ignore
from learned_planners.interp.utils import save_video

# %%


SAVE_PATH = Path("/training/activations_dataset/hard/0_think_step") / "pca.pt"

if SAVE_PATH.exists():
    all_pca = th.load(SAVE_PATH, weights_only=True)
else:
    keys = [".*hook_i$", ".*hook_j$", ".*hook_f$", ".*hook_o$"] + [".*hook_h$", ".*hook_c$"]
    acts_ds = ActivationsDataset(
        Path("/training/activations_dataset/hard/0_think_step"),
        labels_type="reward",
        keys=keys,
        num_data_points=int(1e5),
        fetch_all_boxing_data_points=True,
        gamma_value=0.99,
        balance_classes=False,
        skip_first_n=0,
        skip_walls=False,
        multioutput=False,
    )

    all_pca = {}
    for key in acts_ds.data[0].keys():
        print("Dealing with", key)
        hook_h = np.asarray([d[key] for d in acts_ds.data])
        cov = np.einsum("nchw,ndhw->cd", hook_h, hook_h) / (np.size(hook_h) / hook_h.shape[1])
        vals, vecs = np.linalg.eigh(cov)
        all_pca[key] = vals, vecs

    th.save({k: tuple(map(th.from_numpy, v)) for k, v in all_pca.items()}, SAVE_PATH)


# pca = sklearn.decomposition.PCA(n_components=hook_h.shape[1], copy=True, svd_solver="covariance_eigh")
# pca.fit(hook_h)


# %%
key = "features_extractor.cell_list.2.hook_f"
level = DatasetStore.load(SAVE_PATH.parent / "idx_1455.pkl")
# projected = np.einsum("nchw,cd->ndhw", level.model_cache[key], all_pca[key][1])
projected = level.model_cache[key]

time_slice = slice(None, None, None)
cmap = plt.get_cmap("viridis")
projected = projected[time_slice]
normed = (projected - projected.min()) / (projected.max() - projected.min())

repeated_obs = np.repeat(np.transpose(level.obs, (0, 2, 3, 1))[:, None, :, :, :], 3, axis=0)[time_slice, ...]
to_plot = np.concatenate([repeated_obs[: len(normed)], cmap(normed)[..., :3] * 255], axis=1)
print(to_plot.shape)


px.imshow(
    to_plot[:, :],
    facet_col=1,
    animation_frame=0,
    facet_col_wrap=8,
    binary_string=True,
).show()

# %%
fig, axes = plt.subplots(4, 4, figsize=(10, 10))

ax = axes[0, 0]
ax.imshow(level.obs)


timesteps = 4
anim = save_video(
    "plot/interp/outputs/out1.mp4",
    level.obs[:timesteps],
    np.reshape(projected, (level.obs.shape[0], 3, *projected.shape[1:]))[:timesteps, :, :10],
    overlapped=False,
    show_internal_steps_until=10,
    # sae_feature_offset=0,
)
