
from nesim.utils.json_stuff import load_json_as_dict
from nesim.eval.resnet import ResnetEfficiencyEval
import argparse

def get_conv_layer_names(model):
    conv_layer_names = []
    all_keys = list(dict(model.named_parameters()).keys())

    for key in all_keys:
        if "conv" in key:
            name = key.replace(".weight", "").replace(".bias", "")
            if name not in conv_layer_names:
                conv_layer_names.append(name)
    return conv_layer_names


# Initialize the parser
parser = argparse.ArgumentParser(description="Process model name")

# Add the --model-name argument
parser.add_argument('--model-name', type=str, required=True, help='Name of the model')
parser.add_argument('--layers', type=str, required=True, help='Name of the model')

# Parse the arguments
args = parser.parse_args()

# Access the model name
print(f'Model Name: {args.model_name}')

MAX_NUM_BATCHES = None

possible_model_names = [
    "end_topo_scale_1_shrink_factor_3.0",
    "all_topo_scale_1_shrink_factor_3.0",
    "baseline_scale_None_shrink_factor_3.0",
    # "pretrained"
]


assert args.model_name in possible_model_names, f"Invalid model_name: {args.model_name}"

layer_names = load_json_as_dict(
     "../../../../training/imagenet/resnet50/layer_names.json"
)
epoch = 'final'
downsample_factors = range(1, 21)

resnet_eval = ResnetEfficiencyEval(
    val_dataset_path = "/research/datasets/imagenet_ffcv/val_500_0.50_90.ffcv", 
    batch_size = 256, 
    checkpoints_folder = "/research/XXXX-1/toponets_resnet50_imagenet_checkpoints",
    mode="resnet50"
)

assert args.layers in ["all", "end"]

if args.layers == "all":
    if args.model_name == "pretrained":
        import torchvision.models as models
        topo_layer_names = get_conv_layer_names(model=models.resnet50())
        topo_layer_names.remove("conv1")
    else:
        topo_layer_names = layer_names["all_conv_layers_except_first"]

else:
    topo_layer_names = layer_names["last_conv_layers_in_each_block"]

resnet_eval.run(
    model_name=args.model_name,
    downsample_factors=downsample_factors,
    topo_layer_names=topo_layer_names,
    output_json_filename = f"results/{args.model_name}_sparsify_{args.layers}.json",
    epoch = "final",
    max_num_batches = None
)