import json

import torch

from adversarial_superposition.constants import DEVICE, MODEL_DIR, RESULTS_DIR
from adversarial_superposition.modulo.utils.helpers import (
    plot_eigenvalues,
    plot_pca_components,
)
from adversarial_superposition.modulo.utils.utils import Config, get_model

experiment_key = "acc55bfa"

with open(RESULTS_DIR / f"toy_models/{experiment_key}/config.json", "r") as f:
    config = json.load(f)
    config = Config().from_dict(config)
    print(
        f"Using the model config from: {RESULTS_DIR / f'toy_models/{experiment_key}/config.json'}"
    )

model = get_model(config)

model.load_state_dict(
    torch.load(
        MODEL_DIR / f"toy_models/{experiment_key}/last_run_saved_model_checkpoints.pt",
        map_location=DEVICE,
    )[9950]
)

p = 113
weights_p1 = model.layers[0].weight.detach().cpu()[:, :p]

weights_p1 = weights_p1.T

pca, pca_components = plot_pca_components(weights_p1, components=(0, 1))
plot_eigenvalues(pca)
