import subprocess
import argparse

def run_command(partition, model, n_params, dataset, n_prompt, program="run.py", just_print=False, use_sysprompt=False, no_formatting=False, splits=1):
    cd_slurm = "cd /users/<name>/code_remote/pretraining-attribution/slurm"
    sbatch = f"sbatch submit_{partition}.bash"
    python_command = f"python {program} --model {model} --n_params {n_params} --dataset {dataset}"
    if n_prompts is not None:
        python_command +=  f" --n_prompts {n_prompt}"
    if no_formatting:
        python_command += f" --no_formatting"
        if use_sysprompt:
            python_command += f" --use_sysprompt"

    if program == "model_generations.py":
        formatting_str = 'no_formatting' if no_formatting else 'with_formatting'
        sysprompt_str = '_use_sysprompt' if use_sysprompt else ''
        python_command += f" >> example_generations/{model}_{n_params}_{dataset}_{formatting_str}{sysprompt_str}.txt"
    
    command = f"{cd_slurm}; {sbatch} \"{python_command}\"; cd .."
    print(command)
    if not just_print:
        subprocess.run(command, shell=True)


parser = argparse.ArgumentParser(description="Submit jobs for multiple datasets")
parser.add_argument("--program", default="run.py", help="Cluster partition name")
parser.add_argument("--partition", help="Cluster partition name")
parser.add_argument("--model", help="Model name ('llama' or 'gemma')")
parser.add_argument("--n_params", help="Number of parameters")
parser.add_argument("--n_prompts", default=None, help="Number of prompts")
parser.add_argument('--no_formatting', action='store_true', help='Disable model-native chat formatting')
parser.add_argument('--use_sysprompt', action='store_true', help='Use system prompt if available')
parser.add_argument("--just_print", action='store_true', help="Print commands without executing them")

args = parser.parse_args()

datasets = {
    "conjugate_prompting": 520,
    "honest_llama": None,
    "natural_qa": 1000,
    "hh_rlhf": 1000,
    "openwebtext": 1000,
}

if args.program =="dump_mwe_logprobs.py":
    datasets = {"model_written_evals": None}

for dataset, n_prompts in datasets.items():
    if args.n_prompts is not None:
        n_prompts = args.n_prompts
    run_command(
        args.partition, 
        args.model, 
        args.n_params, 
        dataset, 
        n_prompts, 
        program=args.program,
        just_print=args.just_print,
        no_formatting=args.no_formatting,
        use_sysprompt=args.use_sysprompt,
    )


