from nesim.eval.resnet import EvalSuite
import os
import torch
from nesim.utils.json_stuff import load_json_as_dict, dict_to_json
from nesim.experiments.resnet import create_model_and_scaler, create_val_loader
from lightning import seed_everything
from nesim.eval.resnet import load_resnet18_checkpoint

seed_everything(0)

val_dataloader = create_val_loader(
        val_dataset="/research/datasets/imagenet_ffcv/val_500_0.50_90.ffcv",
        num_workers=16,
        batch_size=128,
        resolution=224, 
        distributed=False, 
        gpu = 0
)
eval_suite = EvalSuite(
    dataloader=val_dataloader,
)

model_names = [
    "baseline_scale_None_shrink_factor_3.0",

    "all_topo_scale_0.5_shrink_factor_3.0",
    "all_topo_scale_1_shrink_factor_3.0",
    "all_topo_scale_5_shrink_factor_3.0",
    "all_topo_scale_10.0_shrink_factor_3.0",
    "all_topo_scale_20.0_shrink_factor_3.0",
    "all_topo_scale_50.0_shrink_factor_3.0",

    "end_topo_scale_0.5_shrink_factor_3.0",
    "end_topo_scale_1.0_shrink_factor_3.0",
    "end_topo_scale_5.0_shrink_factor_3.0",
    "end_topo_scale_10.0_shrink_factor_3.0",
    "end_topo_scale_20.0_shrink_factor_3.0",
    "end_topo_scale_50.0_shrink_factor_3.0",

    "eshed_layers_scale_5.0_shrink_factor_3.0",
    "eshed_layers_scale_20.0_shrink_factor_3.0"
]
layer_names = load_json_as_dict(
     "../../../../training/imagenet/resnet18/layer_names.json"
)
results = {}
plot_data = {}

plot_data["resnet18_end_topo"] = {
    "acc": [],
    "tau": []
}
plot_data["resnet18_all_topo"] = {
    "acc": [],
    "tau": []
}

for model_name in model_names:

    model = load_resnet18_checkpoint(
        checkpoints_folder= "/home/XXXX-4/repos/nesim/training/imagenet/resnet18/checkpoints",
        model_name=model_name,
        epoch="final"
    )
    model.eval()

    val_acc = eval_suite.compute_accuracy(
        model=model,
        max_num_batches=None,
    )

    if model_name.startswith("end_topo"):
        plot_data["resnet18_end_topo"]["acc"].append(val_acc)
        plot_data["resnet18_end_topo"]["tau"].append(float(model_name.split("_")[3]))

    elif model_name.startswith("all_topo"):
        plot_data["resnet18_all_topo"]["acc"].append(val_acc)
        plot_data["resnet18_all_topo"]["tau"].append(float(model_name.split("_")[3]))

    elif model_name.startswith("baseline"):
        assert len(plot_data["resnet18_end_topo"]["acc"]) == 0
        assert len(plot_data["resnet18_all_topo"]["acc"]) == 0

        plot_data["resnet18_end_topo"]["acc"].append(val_acc)
        plot_data["resnet18_all_topo"]["acc"].append(val_acc)

        plot_data["resnet18_end_topo"]["tau"].append('Baseline')
        plot_data["resnet18_all_topo"]["tau"].append('Baseline')

    print(f"Model: {model_name} Val acc Top 1: {val_acc}")
    print(f"\n")
    results[model_name] = val_acc
    
dict_to_json(dictionary=plot_data, filename="results.json")
dict_to_json(dictionary=results, filename="result_simple.json")