import subprocess
import argparse

def run_command(
        partition, model, n_params, dataset, n_prompt, splits=1, batch_size=32,
        program="dump_mwe_logprobs.py", just_print=False, use_sysprompt=False, no_formatting=False, debug=False,
        multi_gpu=False,
):
    cd_slurm = "cd /users/<name>/code_remote/pretraining-attribution/slurm"
    sbatch = f"sbatch submit_{partition}.bash"
    python_command = f"python {program} --model {model} --n_params {n_params} --dataset {dataset} --batch_size {batch_size}"
    if n_prompts is not None:
        python_command +=  f" --n_prompts {n_prompt}"
    if no_formatting:
        python_command += f" --no_formatting"
        if use_sysprompt:
            python_command += f" --use_sysprompt"

    if debug:
        python_command += f" --debug --n_prompts {max(20, 3*batch_size)}"
    
    if multi_gpu:
        python_command += f" --multi_gpu"
    
    for i in range(splits):
        if debug and i > 0:
            return
        command = f"{cd_slurm}; {sbatch} \"{python_command} --idx {i} --n_splits {splits}\"; cd .."
        print(command)
        if not just_print:
            subprocess.run(command, shell=True)


parser = argparse.ArgumentParser(description="Submit jobs for multiple datasets")
parser.add_argument("--program", default="dump_mwe_logprobs.py", help="Cluster partition name")
parser.add_argument("--partition", help="Cluster partition name")
parser.add_argument("--model", help="Model name ('llama' or 'gemma')")
parser.add_argument("--n_params", help="Number of parameters")
parser.add_argument("--n_prompts", default=None, help="Number of prompts")
parser.add_argument("--n_splits", type=int, default=1, help="Number of prompts")
parser.add_argument("--batch_size", type=int, default=32, help="Number of prompts")
parser.add_argument('--no_formatting', action='store_true', help='Disable model-native chat formatting')
parser.add_argument('--use_sysprompt', action='store_true', help='Use system prompt if available')
parser.add_argument("--just_print", action='store_true', help="Print commands without executing them")
parser.add_argument("--debug", action='store_true', help="Debug")
parser.add_argument("--multi_gpu", action='store_true', help="Debug")

args = parser.parse_args()

datasets = {"model_written_evals": None}

for dataset, n_prompts in datasets.items():
    if args.n_prompts is not None:
        n_prompts = args.n_prompts
    run_command(
        args.partition, 
        args.model, 
        args.n_params, 
        dataset, 
        n_prompts, 
        program=args.program,
        just_print=args.just_print,
        no_formatting=args.no_formatting,
        use_sysprompt=args.use_sysprompt,
        splits=args.n_splits,
        debug=args.debug,
        batch_size=args.batch_size,
        multi_gpu=args.multi_gpu
    )


