import os


slurm_prefix = """#!/bin/sh
#SBATCH -n 64
#SBATCH --ntasks-per-node=64
#SBATCH --job-name=compute-hook-outputs
#SBATCH -t 24:00:0
#SBATCH --mem=128G
#SBATCH --gres=gpu:a100:1

source /mindhive/nklab3/users/XXXX-1/conda_stuff/bin/activate
conda activate neuro
"""

slurm_commands_folder = "./slurm"
os.system(f"rm {slurm_commands_folder}/*.sh")

hook_output_folders = {
    "ours": "hook_outputs/ours",
    "pretrained": "hook_outputs/pretrained",
    "eshed": "hook_outputs/eshed",
}

checkpoint_filenames = {
    "pretrained": "pretrained",
    "eshed": "eshed",
    "ours": "../../../training/imagenet/resnet18/checkpoints/imagenet/torchvision_recipe_shrink_factor_[5.0]_loss_scale_150_layers_all_conv_layers__bimt_scale_None_from_pretrained_False_apply_every_30_steps_apply_sorted_weights_init_filename_None/best/best_model-v1.ckpt",
}

result_filenames = {
    "ours": "results_ours.json",
    "pretrained": "results_pretrained.json",
    "eshed": "results_eshed.json",
}
layer_names_filename = "layer_names.json"

all_commands = []
for run_name in hook_output_folders.keys():
    assert os.path.exists(hook_output_folders[run_name])

    command = f"python3 obtain_hook_outputs.py "
    command += f"--checkpoint-filename {checkpoint_filenames[run_name]} "
    command += f"--hook-output-folder {hook_output_folders[run_name]} "
    command += f"--layer-names-json {layer_names_filename}"

    command += f"\necho 'hook outputs saved, will now start computing effective dimensionality...'"

    command += f"\npython3 compute_effective_dimensionality.py "
    command += f"--hook-output-folder {hook_output_folders[run_name]} "
    command += f"--result-filename {result_filenames[run_name]} "
    command += f"--layer-names-json {layer_names_filename}"
    command += f"\necho 'done!'"

    all_commands.append(command)


slurm_shell_files = []
os.system(f"rm {slurm_commands_folder}/*.sh")
for command_index, command in enumerate(all_commands):
    shell_commands = slurm_prefix + command
    filename = os.path.join(slurm_commands_folder, f"{command_index}.sh")
    with open(filename, "w") as file:
        file.write(shell_commands)
    slurm_shell_files.append(filename)

print(f"Saved {len(slurm_shell_files)} shell files")
