import os
import argparse
from nesim.losses.nesim_loss import (
    NesimConfig,
)
from nesim.experiments.mnist import get_untrained_model
from nesim.utils.checkpoint import load_and_filter_state_dict_keys
import lightning as L

from nesim.lightning.mnist import MNISTLightningModule, MNISTHyperParams
from lightning.pytorch import seed_everything
from torchvision.datasets import MNIST
import torchvision.transforms as transforms
from nesim.utils.json_stuff import dict_to_json
import torch

parser = argparse.ArgumentParser(description="evaluate model on MNIST data")

parser.add_argument("--checkpoint-path", type=str, help="Path to model checkpoint file")

parser.add_argument(
    "--lightning-checkpoint",
    action="store_true",
    help="if true, will use load_and_filter_state_dict_keys() to load model",
)


parser.add_argument(
    "--results-path", type=str, help="Path to save evaluation results as json"
)

args = parser.parse_args()

# run_name = get_run_name(
#     nesim_config=args.nesim_config,
#     hidden_size=args.hidden_size,
#     nesim_apply_after_n_steps=args.nesim_apply_after_n_steps,
# )

run_name = f"Evaluate MNIST on checkpoint {args.checkpoint_path}"

model = get_untrained_model(hidden_size=1024)

if args.lightning_checkpoint:
    model.load_state_dict(
        load_and_filter_state_dict_keys(
            args.checkpoint_path,
        ),
    )
else:
    model = torch.load(args.checkpoint_path)


seed_everything(0)

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

validation_dataset = MNIST(
    root="./data", train=False, transform=transform, download=True
)
train_dataset = MNIST(root="./data", train=True, transform=transform, download=True)

hyperparams = MNISTHyperParams(
    lr=1e-3,
    batch_size=64,
    weight_decay=1e-5,
    save_checkpoint_every_n_steps=10,
    apply_nesim_every_n_steps=1,
)

lightning_module = MNISTLightningModule(
    model=model,
    hyperparams=hyperparams,
    nesim_config=None,
    checkpoint_dir="./checkpoints/mnist",
    train_dataset=train_dataset,
    validation_dataset=validation_dataset,
    wandb_log=False,
    bimt_config=None,
)

trainer = L.Trainer(
    accelerator="auto",
    devices=1,
    max_epochs=1,
    logger=None,
    default_root_dir=lightning_module.checkpoint_dir,
    # saves top-K checkpoints based on "val_loss" metric
    # callbacks=[checkpoint_callback],
)

val_results = trainer.validate(lightning_module)[0]
val_results["checkpoint_path"] = args.checkpoint_path

dict_to_json(dictionary=val_results, filename=args.results_path)
