# %%
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import torch as th

from learned_planners.interp.plot import save_video
from learned_planners.interp.train_probes import ActivationsDataset, DatasetStore, set_seed  # noqa: F401  # pyright: ignore

# %%


SAVE_PATH = Path("/training/activations_dataset/hard/0_think_step") / "pca.pt"

if SAVE_PATH.exists() and False:
    all_pca = th.load(SAVE_PATH, weights_only=True)
else:
    keys = [".*hook_i$", ".*hook_j$", ".*hook_f$", ".*hook_o$"] + [".*hook_h$", ".*hook_c$"]
    acts_ds = ActivationsDataset(
        Path("/training/activations_dataset/hard/0_think_step"),
        labels_type="reward",
        keys=keys,
        num_data_points=int(1e5),
        fetch_all_boxing_data_points=True,
        gamma_value=0.99,
        balance_classes=False,
        skip_first_n=0,
        skip_walls=False,
        multioutput=False,
    )

    # all_pca = {}
    # for key in acts_ds.data[0].keys():
    #     print("Dealing with", key)
    #     hook_h = np.asarray([d[key] for d in acts_ds.data])
    #     cov = np.einsum("nchw,ndhw->cd", hook_h, hook_h) / (np.size(hook_h) / hook_h.shape[1])
    #     vals, vecs = np.linalg.eigh(cov)
    #     all_pca[key] = vals, vecs

    # th.save({k: tuple(map(th.from_numpy, v)) for k, v in all_pca.items()}, SAVE_PATH)


# pca = sklearn.decomposition.PCA(n_components=hook_h.shape[1], copy=True, svd_solver="covariance_eigh")
# pca.fit(hook_h)

# %%
h_all_layers = []
for key in acts_ds.data[0].keys():
    if "hook_o" not in key:
        continue
    print("Dealing with", key)
    hook_h = np.asarray([d[key] for d in acts_ds.data])
    h_all_layers.append(hook_h)
    # cov = np.einsum("nchw,ndhw->cd", hook_h, hook_h) / (np.size(hook_h) / hook_h.shape[1])
    # vals, vecs = np.linalg.eigh(cov)
    # all_pca[key] = vals, vecs

h_all_layers = np.concatenate(h_all_layers, axis=1)
# Do PCA after mean centering
mean_h = np.mean(h_all_layers, axis=(0, 2, 3), keepdims=True)
std_h = np.std(h_all_layers, axis=(0, 2, 3), keepdims=True)
h_all_layers -= mean_h
h_all_layers /= std_h
cov = np.einsum("nchw,ndhw->cd", h_all_layers, h_all_layers) / (np.size(h_all_layers) / h_all_layers.shape[1])
vals, vecs = np.linalg.eigh(cov)
vals = vals[::-1]
vecs = vecs[:, ::-1]

# %%
plt.plot(vals)
plt.show()


# %%
def project(cache):
    h_all_layers = []
    for key in cache.keys():
        if "hook_o" not in key:
            continue
        hook_h = cache[key]
        h_all_layers.append(hook_h)
    h_all_layers = np.concatenate(h_all_layers, axis=1)
    # Do PCA after mean centering
    h_all_layers -= mean_h
    h_all_layers /= std_h
    projected = np.einsum("nchw,cd->ndhw", h_all_layers, vecs)
    return projected


# key = "features_extractor.cell_list.2.hook_f"
level = DatasetStore.load(SAVE_PATH.parent / "idx_1455.pkl")
# projected = np.einsum("nchw,cd->ndhw", level.model_cache[key], all_pca[key][1])
# projected = level.model_cache[key]
projected = project(level.model_cache)

time_slice = slice(None, None, None)
cmap = plt.get_cmap("viridis")
projected = projected[time_slice]
normed = (projected - projected.min(axis=(0, 2, 3), keepdims=True)) / (
    projected.max(axis=(0, 2, 3), keepdims=True) - projected.min(axis=(0, 2, 3), keepdims=True)
)

repeated_obs = np.repeat(np.transpose(level.obs, (0, 2, 3, 1))[:, None, :, :, :], 3, axis=0)[time_slice, ...]
to_plot = np.concatenate([repeated_obs[: len(normed)], cmap(normed)[..., :3] * 255], axis=1)
print(to_plot.shape)


px.imshow(
    to_plot[:, :16],
    facet_col=1,
    animation_frame=0,
    facet_col_wrap=8,
    binary_string=True,
).show()

# %%
fig, axes = plt.subplots(4, 4, figsize=(10, 10))

ax = axes[0, 0]
ax.imshow(level.obs)


timesteps = 4
anim = save_video(
    "plot/interp/outputs/out1.mp4",
    level.obs[:timesteps],
    np.reshape(projected, (level.obs.shape[0], 3, *projected.shape[1:]))[:timesteps, :, :10],
    overlapped=False,
    show_internal_steps_until=10,
    # sae_feature_offset=0,
)
