from itertools import product
import os 
import subprocess 
import shutil
import itertools
WITH_EX_DROPOUT = [1]
WD = [1e-3]
LR = [1e-4]
BSIZE=[64]
TRAIN_SIZE = [64*10000]
EPOCHS = [10]
EXPERIMENT_GROUP_NAME = "finetune_cot"
MODEL_NAMES = ["/data/user_data/gghosal/grokkreason/results/cf_ft/pt_ft_cf_ood/cf_ood_0.25.2000.200.18.0/checkpoint-9500/"]
DATA_FILES = ["./data/procedural/cf_ood_0.25.2000.200.18.0/"]    
args = []
for modelname,dfile, wd, lr,bsize,ts,ep in itertools.product(MODEL_NAMES,DATA_FILES, WD, LR, BSIZE,TRAIN_SIZE, EPOCHS):
    args.append(f"""--model_name_or_path {modelname} --data_dir {dfile} --wd {wd} --n_epochs {ep} --lr {lr} --batch_size {bsize} --train_size {ts} --wandb_run_id {EXPERIMENT_GROUP_NAME}_lr_{lr}_wd_{wd}_batch_{bsize}_data_{ts}""")
num_args = len(args)
args_string = "ARGS=("
for arg in args[:-1]:
    args_string+=f""""{arg}" """
args_string+=f""""{args[-1]}")"""
print(args_string)
SLURMSCRIPT = "#!/bin/bash" \
                    + f"\n#SBATCH --job-name=counterfactual_ft" \
                    +"\n#SBATCH --output=./logs/counterfactual_ft_%a.out" \
                    +"\n#SBATCH --error=./logs/counterfactual_ft_%a.err" \
                    +"\n#SBATCH --partition=general" \
                    +f"\n#SBATCH --array=0-{num_args-1}%8" \
                    +"\n#SBATCH --time=7:00:00" \
                    +"\n#SBATCH --gres=gpu:1" \
                    +"\n#SBATCH --mem=64G" \
                    +"\n#SBATCH --nodes=1" \
                    +"\n#SBATCH --exclude=babel-3-5"\
                    +"\n#SBATCH --requeue"\
                    +"\nsource ../../miniconda3/etc/profile.d/conda.sh"\
                    +"\nconda activate grokkedtransformer"\
                    +f"\n{args_string}"\
                    +f"\nnvidia-smi" \
                    +"\nexport NCCL_DEBUG=INFO"\
                    +"\nexport NCCL_P2P_DISABLE=1" \
                    +"\nCUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --nproc_per_node=1 --master_port 1234$SLURM_ARRAY_TASK_ID main_ft.py ${ARGS[$SLURM_ARRAY_TASK_ID]}" 
f = open("jobarrayscriptft.sbatch", "w")
f.write(SLURMSCRIPT)
f.close()
subprocess.run(f'sbatch jobarrayscriptft.sbatch', shell=True)