
from nesim.utils.json_stuff import load_json_as_dict
from nesim.eval.resnet import ResnetEfficiencyEval
import argparse
import numpy as np
# Initialize the parser
parser = argparse.ArgumentParser(description="Process model name")

# Add the --model-name argument
parser.add_argument('--model-name', type=str, required=True, help='Name of the model')
parser.add_argument('--layers', type=str, required=True, help='Name of the model')

# Parse the arguments
args = parser.parse_args()

# Access the model name
print(f'Model Name: {args.model_name}')

MAX_NUM_BATCHES = None

possible_model_names = [
    ## all_topo
    # "end_topo_scale_1.0_shrink_factor_3.0",
    # "end_topo_scale_5.0_shrink_factor_3.0",
    # "end_topo_scale_10.0_shrink_factor_3.0",
    # "end_topo_scale_50.0_shrink_factor_3.0",
    # ## end_topo
    "all_topo_scale_0.5_shrink_factor_3.0",
    "all_topo_scale_1_shrink_factor_3.0",
    "all_topo_scale_5_shrink_factor_3.0",
    "all_topo_scale_10.0_shrink_factor_3.0",
    "all_topo_scale_20.0_shrink_factor_3.0",
    "all_topo_scale_50.0_shrink_factor_3.0",
    ## baseline
    "baseline_scale_None_shrink_factor_3.0",
]

assert args.model_name in possible_model_names, f"Invalid model_name: {args.model_name}"

layer_names = load_json_as_dict(
     "../../../../training/imagenet/resnet18/layer_names.json"
)
epoch = 'final'

fractions_of_masked_weights = [0,0.2,0.4,0.6, 0.8]
def get_downsample_factor_from_fraction(fraction):
    return 1/(1-fraction)

downsample_factors = [
    get_downsample_factor_from_fraction(f)
    for f in fractions_of_masked_weights
]

resnet_eval = ResnetEfficiencyEval(
    val_dataset_path = "/research/datasets/imagenet_ffcv/val_500_0.50_90.ffcv", 
    batch_size = 256, 
    checkpoints_folder = "/home/XXXX-4/repos/nesim/training/imagenet/resnet18/checkpoints"
)

assert args.layers in ["all", "end"]

if args.layers == "all":
    topo_layer_names = layer_names["all_conv_layers_except_first"]
else:
    topo_layer_names = layer_names["last_conv_layers_in_each_block"]

resnet_eval.run(
    model_name=args.model_name,
    downsample_factors=downsample_factors,
    topo_layer_names=topo_layer_names,
    output_json_filename = f"results/{args.model_name}_sparsify_{args.layers}.json",
    epoch = "final",
    max_num_batches = None
)