from nesim.vis.weights_magnitude import WeightsMagnitudeViewer
from nesim.vis.video import generate_video
import torch.nn as nn
import os

import os
import argparse

from nesim.losses.nesim_loss import (
    NesimConfig,
)
from nesim.experiments.mnist import MNISTHyperParams, MnistTrainingConfig
from utils import get_run_name

videos_save_dir = "./videos/pca"

parser = argparse.ArgumentParser(
    description="Trains a simple 3 layer MLP on the MNIST dataset"
)
parser.add_argument(
    "--nesim-config", type=str, required=True, help="Path to the nesim config json file"
)
parser.add_argument(
    "--hidden-size", type=int, required=True, help="hidden size of the model"
)
parser.add_argument(
    "--nesim-apply-after-n-steps",
    type=int,
    required=True,
    help="number of steps after which we apply nesim",
)
args = parser.parse_args()

run_name = get_run_name(
    nesim_config=args.nesim_config,
    hidden_size=args.hidden_size,
    nesim_apply_after_n_steps=args.nesim_apply_after_n_steps,
)

checkpoint_dir = f"./checkpoints/mnist/{run_name}"

hyperparams = MNISTHyperParams(
    lr=1e-3,
    batch_size=64,
    weight_decay=1e-5,
    save_checkpoint_every_n_steps=20,
    apply_nesim_every_n_steps=args.nesim_apply_after_n_steps,
)
nesim_config = NesimConfig.from_json(args.nesim_config)

# main config
experiment_config = MnistTrainingConfig(
    hyperparams=hyperparams,
    nesim_config=nesim_config,
    hidden_size=args.hidden_size,
    checkpoint_dir=checkpoint_dir,
    wandb_log=False,
    max_epochs=10,
    data_dir="./data",
)

model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(28 * 28, experiment_config.hidden_size),  ## 1
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(experiment_config.hidden_size, experiment_config.hidden_size),  ## 4
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(experiment_config.hidden_size, 10),  ## 7
)


checkpoint_filenames = [
    os.path.join(checkpoint_dir, "all", f"train_step_idx_{train_step_idx}.pth")
    for train_step_idx in range(0, 9340, hyperparams.apply_nesim_every_n_steps)
]


weights_pca = WeightsMagnitudeViewer(
    model=model,
    checkpoint_filenames=checkpoint_filenames,
    layer_name="4",
    device="cuda:0",
)
print(f"writing: {len(weights_pca)} frames")
generate_video(
    list_of_pil_images=[weights_pca[idx] for idx in range(len(weights_pca))],
    framerate=60,
    filename=os.path.join(videos_save_dir, "mnist_magnitude.mp4"),
    size=(512, 512),
)
