import os
from nesim.lightning.imagenet import ConvertedImagenetDataset
from adversarial_robustness_eval import AdversarialRobustnessEval
from nesim.utils.json_stuff import dict_to_json
import torchvision.models as models
from nesim.utils.checkpoint import load_and_filter_state_dict_keys

device = "cuda:0"
validation_dataset = ConvertedImagenetDataset(
    slice_name="validation",
    # folder="/om2/user/mayukh09/datasets/imagenet_converted/validation",  ## openmind
    folder="/research/datasets/imagenet_converted/validation" # barlow
)

validation_dataset.labels = validation_dataset.labels[:1000]

attack_names = [
    "FGSM",
    "PGD",
    "OnePixel",
    "Square",
    "Jitter",
    # "Pixle"
]

checkpoint_map = {
    "ours": "/research/XXXX-1/nesim_old/training/imagenet/resnet18/checkpoints/imagenet/shrink_factor_[5.0]_loss_scale_150_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_20_steps/best/best_model.ckpt",
    "pretrained": "pretrained",
    "random_weights": None,
}

all_results = {}

for checkponint_name, checkpoint_filename in checkpoint_map.items():
    if checkpoint_filename == "pretrained":
        print(f"Loading pretrained model")
        model = models.resnet18(weights="DEFAULT")
    elif checkpoint_filename is None:
        print(f"Loading with random weights init")
        model = models.resnet18(weights=None)
    else:
        assert os.path.exists(
            checkpoint_filename
        ), f"Invalid path: {checkpoint_filename}"
        print(f"Loading checkpoint: {checkpoint_filename}")
        model = models.resnet18(weights=None)
        model.load_state_dict(load_and_filter_state_dict_keys(checkpoint_filename))

    model.to(device)
    model = model.eval()

    all_results_for_single_checkpoint = {}

    for attack_name in attack_names:
        eval = AdversarialRobustnessEval(
            model=model,
            attack_name=attack_name,
            device="cuda:0",
            num_classes=1000,
            batch_size=128,
        )
        result = eval.run(dataset=validation_dataset)
        all_results_for_single_checkpoint[attack_name] = result

    all_results[checkponint_name] = all_results_for_single_checkpoint

dict_to_json(all_results, filename="results.json")
from pprint import pprint
pprint(all_results)