from itertools import product
import os 
import subprocess 
import shutil
import itertools
WITH_EX_DROPOUT = [1]
WD = [1e-1]
EXPERIMENT_GROUP_NAME = "ood_cot_pt_wfact_5x"
LAYERS =[12]
WANDB_PROJ = "PARAMETRIC_CF"
DATA_FILES = ["ood_cot_rep_5_0.9_w_fact.2000.200.18.0"]
out_path = f"/data/user_data/gghosal/grokkreason/results/cf_ft/{EXPERIMENT_GROUP_NAME}/"
os.makedirs(out_path, exist_ok = True)
def clean_results_file(path):
    for file in os.listdir(out_path):
        if os.path.isdir(os.path.join(path, file)):
            print(f"removing {file}")
            shutil.rmtree(os.path.join(path, file))
clean_results_file(out_path)
    
args = []
for wd, layers,dfile in itertools.product(WD, LAYERS, DATA_FILES):
    args.append(f"""--data_dir ./data/procedural/{dfile}\
 --model_name_or_path gpt2\
 --weight_decay {wd}\
 --output_dir {out_path}/{dfile}/\
 --max_seq_length 25\
 --max_length 25\
 --block_size 25\
 --train_batch_size 512\
 --eval_batch_size 512\
 --learning_rate 1e-4\
 --gradient_accumulation_steps 1\
 --save_step 500\
 --save_step_dense 400\
 --max_steps 20000\
 --do_train\
 --scheduler constant_schedule_with_warmup\
 --fp16\
 --wandb_proj {WANDB_PROJ}\
 --wandb_run_id {EXPERIMENT_GROUP_NAME}_{wd}_{layers}_{dfile}\
 --evaluate_during_training\
 --predict_during_training\
 --overwrite_output_dir\
 --init_weights\
 --add_tokens\
 --n_layer {layers}""")
num_args = len(args)
args_string = "ARGS=("
for arg in args[:-1]:
    args_string+=f""""{arg}" """
args_string+=f""""{args[-1]}")"""
print(args_string)
SLURMSCRIPT = "#!/bin/bash" \
                    + f"\n#SBATCH --job-name=grokkdtf_layers_no_unrel" \
                    +"\n#SBATCH --output=./logs/grokkdtf_pt_layers_nounrel_%a.out" \
                    +"\n#SBATCH --error=./logs/grokkdtf_pt_layers_nounrel_%a.err" \
                    +"\n#SBATCH --partition=general" \
                    +f"\n#SBATCH --array=0-{num_args-1}%8" \
                    +"\n#SBATCH --time=7:00:00" \
                    +"\n#SBATCH --gres=gpu:1" \
                    +"\n#SBATCH --mem=64G" \
                    +"\n#SBATCH --exclude=babel-1-31,babel-5-19" \
                    +"\n#SBATCH --nodes=1" \
                    +"\n#SBATCH --requeue"\
                    +"\nsource ../../miniconda3/etc/profile.d/conda.sh"\
                    +"\nconda activate grokkedtransformer"\
                    +f"\n{args_string}"\
                    +f"\nnvidia-smi" \
                    +"\nexport NCCL_DEBUG=INFO"\
                    +"\nexport NCCL_P2P_DISABLE=1" \
                    +"\nCUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port 1234$SLURM_ARRAY_TASK_ID main.py ${ARGS[$SLURM_ARRAY_TASK_ID]}" 
f = open("jobarrayscript.sbatch", "w")
f.write(SLURMSCRIPT)
f.close()
subprocess.run(f'sbatch jobarrayscript.sbatch', shell=True)