import os
from nesim.utils.json_stuff import load_json_as_dict

checkpoints_root = "../../../training/imagenet/resnet18/checkpoints/imagenet/"
run_name = "fake_run_name"
run_name = "fake_run_name"
filename = "/mindhive/nklab3/users/XXXX-1/repos/nesim/training/imagenet/resnet18/checkpoints/imagenet/mixed_shrink_factor_torchvision_recipe_loss_scale_300__bimt_scale_None_from_pretrained_False_apply_every_30_steps_apply_sorted_weights_init_filename_None/best/best_model.ckpt"
# filename = "/research/XXXX-3/repos/nesim/training/imagenet/resnet18/checkpoints/imagenet/" # use this for barlow
outputs_folder = "./orientation_maps/"

os.system(f"rm {outputs_folder}/*.jpg")
if not os.path.exists(outputs_folder):
    os.makedirs(outputs_folder)

checkpoint_filename = os.path.join(checkpoints_root, run_name, filename)
print(checkpoint_filename)
layer_names = list(load_json_as_dict(
    "../../../training/imagenet/resnet18/layer_wise_shrink_factors.json"
).keys())

count = 0
for name in reversed(layer_names):
    print(f"Generating {count + 1}/{len(layer_names)}")
    save_filename = os.path.join(outputs_folder, f"{name}_{run_name}.jpg")

    command = f"""python3 generate_orientation_map.py \
    --checkpoint-filename {checkpoint_filename} \
    --layer-name {name}"""

    print(command)
    os.system(command)
    count += 1

print("Done!")
