import os

slurm_prefix = """#!/bin/sh
#SBATCH --ntasks-per-node=4
#SBATCH --job-name=compute-hook-outputs-gpt
#SBATCH -t 24:00:0
#SBATCH --mem=128G
#SBATCH --gres=gpu:a100:1

source /mindhive/nklab3/users/XXXX-1/conda_stuff/bin/activate
conda activate neuro
"""

slurm_commands_folder = "./slurm"
os.system(f"rm {slurm_commands_folder}/*.sh")
hook_outputs_folder = "./hook_outputs"

checkpoint_root = "../../../training/gpt_neo_125m/checkpoints/"
run_names = {
    # "ours_scale_3000": "apply_nesim_every_n_steps_10_nesim_config_scale_3000_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    # "ours_scale_300": "apply_nesim_every_n_steps_10_nesim_config_scale_300_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    # "ours_scale_600": "apply_nesim_every_n_steps_10_nesim_config_scale_600_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    # "ours_scale_150": "apply_nesim_every_n_steps_10_nesim_config_scale_150_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    "ours_scale_1": "apply_nesim_every_n_steps_10_nesim_config_scale_1_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    # "ours_scale_10": "apply_nesim_every_n_steps_10_nesim_config_scale_10_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    # "ours_scale_25": "apply_nesim_every_n_steps_10_nesim_config_scale_25_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    # "ours_scale_50": "apply_nesim_every_n_steps_10_nesim_config_scale_50_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
    "baseline": "apply_nesim_every_n_steps_10_nesim_config_baseline_shrink_factor_[9.0]_layer_names_all_layers_c_proj_checkpoint_every_n_steps_100_num_warmup_steps_3000_batch_size_128_context_length_1024",
}

checkpoint_step = {
    # "ours_scale_3000": 4200,
    # "ours_scale_300": 4200,
    # "ours_scale_600": 4200,
    # "ours_scale_150": 4200,
    "ours_scale_1": 61000,
    # "ours_scale_10": 61000,
    # "ours_scale_25": 61000,
    # "ours_scale_50": 4200,
    # "baseline": 4200
}

hook_output_folders = {"pretrained": os.path.join(hook_outputs_folder, "pretrained")}

for key in run_names:
    folder = os.path.join(hook_outputs_folder, key)
    os.system(f"mkdir -p {folder}")
    hook_output_folders[key] = folder

checkpoint_filenames = {"pretrained": "pretrained"}

for key in checkpoint_step:
    filename = os.path.join(
        checkpoint_root,
        run_names[key],
        f"checkpoint-{checkpoint_step[key]}",
        "pytorch_model.bin",
    )
    checkpoint_filenames[key] = filename

result_filenames = {
    "pretrained": "results_pretrained.json",
}

for key in checkpoint_filenames:
    result_filenames[key] = f"results/result_{key}.json"

layer_names_filename = "layer_names.json"

all_commands = []
for run_name in checkpoint_filenames.keys():
    assert os.path.exists(hook_output_folders[run_name])

    if run_name != "pretrained":
        assert os.path.exists(
            checkpoint_filenames[run_name]
        ), f"{checkpoint_filenames[run_name]}"

    command = f"python3 obtain_hook_outputs.py "

    if run_name != "pretrained":
        command += f"--checkpoint-filename {checkpoint_filenames[run_name]} "
    else:
        command += f"--checkpoint-filename pretrained "

    command += f"--hook-output-folder {hook_output_folders[run_name]} "
    command += f"--layer-names-json {layer_names_filename}"

    command += f"\necho 'hook outputs saved, will now start computing effective dimensionality...'"

    command += f"\npython3 compute_effective_dimensionality.py "
    command += f"--hook-output-folder {hook_output_folders[run_name]} "
    command += f"--result-filename {result_filenames[run_name]} "
    command += f"--layer-names-json {layer_names_filename}"
    command += f"\necho 'done!'"

    all_commands.append(command)


slurm_shell_files = []
os.system(f"rm {slurm_commands_folder}/*.sh")
for command_index, command in enumerate(all_commands):
    shell_commands = slurm_prefix + command
    filename = os.path.join(slurm_commands_folder, f"{command_index}.sh")
    with open(filename, "w") as file:
        file.write(shell_commands)
    slurm_shell_files.append(filename)

print(f"Saved {len(slurm_shell_files)} shell files")
