import subprocess
import json
import argparse

def run_command(
        partition, model, n_params, dataset, n_prompt, splits=1, batch_size=32,
        program="dump_mwe_logprobs.py", just_print=False, use_sysprompt=False, no_formatting=True, debug=False,
        multi_gpu=False, reversed=False, only_alpha_1=False
):
    cd_slurm = "cd /users/<name>/code_remote/pretraining-attribution/slurm"
    sbatch = f"sbatch submit_{partition}.bash"
    python_command = f"python {program} --model {model} --n_params {n_params} --dataset {dataset} --batch_size {batch_size}"
    if n_prompts is not None:
        python_command +=  f" --n_prompts {n_prompt}"
    if not no_formatting:
        python_command += f" --formatting"

    if no_formatting and use_sysprompt:
        python_command += f" --use_sysprompt"

    if only_alpha_1:
        python_command += f" --only_alpha_1 "

    if debug:
        python_command += f" --debug --n_prompts {3*batch_size}"

    if multi_gpu:
        python_command += f" --multi_gpu"
    
    for i in range(splits):
        if debug and i > 0:
            return
        idx = i if not reversed else splits-i-1
        command = f"{cd_slurm}; {sbatch} \"{python_command} --idx {idx} --n_splits {splits}\"; cd .."
        print(command)
        if not just_print:
            subprocess.run(command, shell=True)


parser = argparse.ArgumentParser(description="Submit jobs for multiple datasets")
parser.add_argument("--program", default="dump_mwe_logprobs.py", help="Cluster partition name")
parser.add_argument("--dataset", default="big_bench", help="Cluster partition name")
parser.add_argument("--partition", help="Cluster partition name")
parser.add_argument("--model", help="Model name ('llama' or 'gemma')")
parser.add_argument("--n_params", help="Number of parameters")
parser.add_argument("--n_prompts", default=None, help="Number of prompts")
parser.add_argument("--n_splits", type=int, default=1, help="Number of prompts")
parser.add_argument("--batch_size", type=int, default=32, help="Number of prompts")
# parser.add_argument('--no_formatting', action="store_true", help='Disable model-native chat formatting')
parser.add_argument('--formatting', action="store_true", help='Disable model-native chat formatting')
parser.add_argument('--use_sysprompt', action='store_true', help='Use system prompt if available')
parser.add_argument("--just_print", action='store_true', help="Print commands without executing them")
parser.add_argument("--reversed", action='store_true', help="Print commands without executing them")
parser.add_argument("--debug", action='store_true', help="Debug")
parser.add_argument("--multi_gpu", action='store_true', help="Debug")
parser.add_argument('--only_alpha_1', action='store_true', help='Only use alpha = 1') 

args = parser.parse_args()

print(json.dumps(args.__dict__, indent=4)) 

datasets = {
    "big_bench": 1000,
    "mmlu": None,
    "fewshot_mmlu": None,
    "fewshot_cot_gsm8k": None,
    "psycho_msj": None,
    "psycho_msj_small": None,
    "ends_justify_means_msj_small": None,
    "machiavellianism_msj_small": None,
    "narcissism_msj_small": None,
    "resource-acquisition_msj_small": None,
    "coqa_icl": None,
    "logiqa_icl": None,
    "winogrande_icl": None,
    "mmlu_icl": None,
    "boolean_icl": None,
    "quickexp_boolean_icl_raw": 200,
    "quickexp_boolean_icl_bare_model_native": 200,
    "quickexp_boolean_icl_conversational_model_native": 200
}

# for dataset, n_prompts in datasets.items():
dataset = args.dataset
n_prompts = datasets[dataset]

if args.n_prompts is not None:
    n_prompts = args.n_prompts
run_command(
    args.partition, 
    args.model, 
    args.n_params, 
    dataset, 
    n_prompts, 
    program=args.program,
    just_print=args.just_print,
    no_formatting=not args.formatting,
    use_sysprompt=args.use_sysprompt,
    splits=args.n_splits,
    debug=args.debug,
    batch_size=args.batch_size,
    multi_gpu=args.multi_gpu,
    reversed = args.reversed,
    only_alpha_1=args.only_alpha_1,
)


