"""
The goal is to resume a crashed training
"""

import os

slurm_prefix = """#!/bin/sh
#SBATCH -n 64
#SBATCH --ntasks-per-node=64
#SBATCH --job-name=gpt-neo-125m-training
#SBATCH -t 48:00:0
#SBATCH --mem=32G
#SBATCH --gres=gpu:a100:1

source /mindhive/nklab3/users/XXXX-1/conda_stuff/bin/activate
conda activate neuro
export filepath=/mindhive/nklab3/users/XXXX-1/repos/nesim/training/gpt_neo_125m/train.py
export gpu=0
"""
resume_filename = "resume_training.sh"

nesim_config = "nesim_configs/scale_0.06_shrink_factor_[9.0]_layer_names_index_10.json"

checkpoint_every_n_steps = 100
num_warmup_steps = 100
batch_size = 128
context_length = 256
apply_nesim_every_n_steps = 10

## IMPORTANT
resume_from_checkpoint = "checkpoints/apply_nesim_every_n_steps_10_nesim_config_scale_0.06_shrink_factor_\[9.0\]_layer_names_index_10_checkpoint_every_n_steps_100_num_warmup_steps_7000_batch_size_128_context_length_256/checkpoint-9300"
resume_wandb_id = "airgco4j"

command = f"python3 train.py "
command += f"--nesim-config {nesim_config} "
command += f"--checkpoint-every-n-steps {checkpoint_every_n_steps} "
command += f"--num-warmup-steps {num_warmup_steps} "
command += f"--batch-size {batch_size} "
command += f"--context-length {context_length} "
command += f"--resume-from-checkpoint {resume_from_checkpoint} "
command += f"--resume-wandb-id {resume_wandb_id} "
command += f"--apply-nesim-every-n-steps {apply_nesim_every_n_steps}"

os.system(f"rm {resume_filename}")
shell_command = slurm_prefix + command
with open(resume_filename, "w") as file:
    file.write(shell_command)

print("Now run the following command:")
print(f"sbatch {resume_filename}")
