import os
from nesim.utils.json_stuff import load_json_as_dict

slurm_prefix = """#!/bin/sh
#SBATCH --ntasks-per-node=4
#SBATCH --job-name=orientation-map
#SBATCH -t 2:00:0
#SBATCH --mem=16G
#SBATCH --gres=gpu:a100:1

source /mindhive/nklab3/users/XXXX-1/conda_stuff/bin/activate
conda activate neuro
"""
slurm_commands_folder = "./slurm"
os.system(f"rm {slurm_commands_folder}/*.sh")


checkpoints_root = "../../../training/imagenet/resnet18/checkpoints/imagenet/"
run_name = "bigger_cortical_sheet_torchvision_recipe_shrink_factor_[5.0]_loss_scale_7.248294_layers_all_conv_layers_except_conv1__bimt_scale_None_from_pretrained_False_apply_every_1_steps_apply_sorted_weights_init_filename_None"
filename = "best/best_model.ckpt"
outputs_folder = "./maps/"

# os.system(f"rm {outputs_folder}/*.jpg")
checkpoint_filename = os.path.join(checkpoints_root, run_name, filename)
assert os.path.exists(checkpoint_filename)

layer_names = [
    "layer1.0.conv1",
    "layer1.0.conv2",
    "layer1.1.conv1",
    "layer1.1.conv2",
    "layer2.0.conv1",
    "layer2.0.conv2",
    "layer2.1.conv1",
    "layer2.1.conv2",
]

count = 0
for command_index, name in enumerate(layer_names):
    filename = os.path.join(slurm_commands_folder, f"{command_index}.sh")

    print(f"Generating {count + 1}/{len(layer_names)}")
    save_filename = os.path.join(outputs_folder, f"{name}_{run_name}.jpg")

    command = f"""python3 generate_orientation_map.py \
    --checkpoint-filename {checkpoint_filename} \
    --layer-name {name} \
    --output-filename {save_filename}"""

    shell_command = slurm_prefix + command
    with open(filename, "w") as file:
        file.write(shell_command)
    count += 1

print("Done!")
