from itertools import product
import os 
import subprocess 
import shutil
import itertools
WD = [1e-1, 1e-3,1e-5]
LR = [1e-6, 1e-5,1e-4]
EXPERIMENT_GROUP_NAME = "test_preft_zeroshot"
LAYERS =[6]
DATA_FILES = ["composition_both_hop.2000.200.18.0"]
out_path = f"/data/user_data/gghosal/grokkreason/results/two_hop_cf/{DATA_FILES[0]}/checkpoint-31000/finetune/"
os.makedirs(out_path, exist_ok = True)
def clean_results_file(path):
    for file in os.listdir(out_path):
        if os.path.isdir(os.path.join(path, file)):
            print(f"removing {file}")
            shutil.rmtree(os.path.join(path, file))
clean_results_file(out_path)
args = []
for wd, lr, layers,dfile in itertools.product(WD, LR, LAYERS, DATA_FILES):
    args.append(f"""--data_dir ./data/{dfile}\
 --model_name_or_path /data/user_data/gghosal/grokkreason/results/two_hop_cf/{dfile}/checkpoint-31000/\
 --weight_decay {wd}\
 --output_dir {out_path}/{dfile}/\
 --max_seq_length 10\
 --max_length 10\
 --block_size 10\
 --train_batch_size 16\
 --eval_batch_size 512\
 --learning_rate {lr}
 --gradient_accumulation_steps 1\
 --evaluate_during_training\
 --save_step 500\
 --save_step_dense 400\
 --max_steps 0\
 --scheduler constant_schedule_with_warmup\
 --num_train_epochs 3 \
 --fp16\
 --do_eval\
 --do_train\
 --predict_during_training\
 --overwrite_output_dir\
 --add_tokens\
 --n_layer {layers}""")
num_args = len(args)
args_string = "ARGS=("
for arg in args[:-1]:
    args_string+=f""""{arg}" """
args_string+=f""""{args[-1]}")"""
print(args_string)
SLURMSCRIPT = "#!/bin/bash" \
                    + f"\n#SBATCH --job-name=grokfinetune" \
                    +"\n#SBATCH --output=./logs/test_ft_%a.out" \
                    +"\n#SBATCH --error=./logs/test_ft_%a.err" \
                    +"\n#SBATCH --partition=general" \
                    +f"\n#SBATCH --array=0-{num_args-1}%8" \
                    +"\n#SBATCH --time=7:00:00" \
                    +"\n#SBATCH --gres=gpu:A6000:1" \
                    +"\n#SBATCH --mem=64G" \
                    +"\n#SBATCH --nodes=1" \
                    +"\n#SBATCH --requeue"\
                    +"\n#SBATCH --exclude=babel-7-5,babel-1-31"\
                    +"\nsource ../../miniconda3/etc/profile.d/conda.sh"\
                    +"\nconda activate grokkedtransformer"\
                    +f"\n{args_string}"\
                    +f"\nnvidia-smi" \
                    +"\nexport NCCL_DEBUG=INFO"\
                    +"\nexport NCCL_P2P_DISABLE=1" \
                    +"\nCUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port 1234$SLURM_ARRAY_TASK_ID CF_FT.py ${ARGS[$SLURM_ARRAY_TASK_ID]}" 
f = open("jobarrayscriptcf.sbatch", "w")
f.write(SLURMSCRIPT)
f.close()
subprocess.run(f'sbatch jobarrayscriptcf.sbatch', shell=True)