#!/usr/bin/env python3

import argparse
import subprocess
from typing import Union


ALL_MODELS = {
    "pyt-70m": ("1a40", False, 1),
    "pyt-1b": ("1a40", False, 1),
    "pyt-12b": ("1a100-ol", True, 12),
    # "llama2-7b": ("1a100", True, 12),
    "llama2-7b": ("1a40-ol", True, 1),
    "llama2-13b": ("1a100-ol", True, 12),
    "gpt2-124m": ("1a40", False, 1),
    "gpt2-1.5b": ("1a100", False, 1),
    "phi-1.3b": ("1a40", False, 1),
    "phi-2.7b": ("1a100", False, 1),
    # "phi-2.7b": ("1a40-ol", True, 1),
    "opt-350m": ("1a40", False, 1),
}
CORE_MODELS = [
    "pyt-1b",
    "llama2-13b",
    "phi-2.7b",
]
NON_CORE_MODELS = [
    # "pyt-70m",
    # "pyt-12b",
    # "llama2-7b",
    # "phi-1.3b",
    "opt-350m"
]
# "gpt2-124m", # SANONYMOUS needs to run for 0, 3, 4
# "gpt2-1.5b",
ALPHABET_SIZES = [
    2,
    4,
    7,
    13,
    26,
]
ENTROPY_LEVELS = [
    2,
    4,
    7,
    13,
]
# SEED_IDS = [
#     # 0,
#     1,
# ]
# CPU_OFFLOAD_OPTIOANONYMOUS = "--devel --torchrun --env='CPU_OFFLOAD=1'"
CPU_OFFLOAD_OPTIOANONYMOUS = "--devel --env='CPU_OFFLOAD=1' torchrun"


def run_experiment_config(
    experiment_id: str,
    config_name: str,
    model_name: str,
    seed_ids: list[int],
    runtime_h: Union[int, str, None] = None,
) -> None:
    node_type, cpu_offload, min_runtime = ALL_MODELS[model_name]
    if runtime_h is None:
        runtime = f"{min_runtime}h"
    else:
        if isinstance(runtime_h, str):
            runtime = runtime_h
        else:
            runtime = f"{max(min_runtime, runtime_h)}h"
    sid_arg = ",".join([str(sid) for sid in seed_ids])
    if cpu_offload:
        launcher = CPU_OFFLOAD_OPTIOANONYMOUS
    else:
        # launcher = "python -u"
        launcher = "torchrun"
    # cpu_offload_option = f"{CPU_OFFLOAD_OPTIOANONYMOUS} " if cpu_offload else ""
    command = (
        f"./actions.py run-slurm -c {node_type} -t {runtime} "
        f"--sids={sid_arg} {launcher} src/main.py "
        f"+{experiment_id}={config_name}"
    )
    # print(command)
    subprocess.run(command, shell=True, check=True)


parser = argparse.ArgumentParser()
parser.add_argument(
    "experiment_id",
    choices=[
        "md",
        "md-nlat",
        "md-ut",
        "md-len",
        "md-split",
        "pm",
        "rt",
        "rs",
    ],
)
parser.add_argument(
    "seed_ids",
    nargs="+",
    type=int,
)
args = parser.parse_args()


if args.experiment_id == "md":
    # Memorization dynamics
    for model in NON_CORE_MODELS:
        # for model in CORE_MODELS:
        for alphabet_size in ALPHABET_SIZES:
            config_name = f"{model}_a-{alphabet_size}_t-1024"
            run_experiment_config(
                "md",
                config_name,
                model,
                seed_ids=args.seed_ids,
            )
    for model in CORE_MODELS + NON_CORE_MODELS:
        for entropy_level in ENTROPY_LEVELS:
            config_name = f"{model}_h-{entropy_level}_t-1024"
            run_experiment_config(
                "md",
                config_name,
                model,
                seed_ids=args.seed_ids,
            )

elif args.experiment_id == "md-nlat":
    # Non-latin alphabets
    for model in CORE_MODELS:
        for alphabet_size in ALPHABET_SIZES:
            config_name = f"{model}_a-nlat-{alphabet_size}"
            run_experiment_config(
                "md",
                config_name,
                model,
                seed_ids=args.seed_ids,
            )

elif args.experiment_id == "md-ut":
    # Non-pretrained/untrained models
    for model in CORE_MODELS:
        for alphabet_size in ALPHABET_SIZES:
            config_name = f"{model}_a-{alphabet_size}_ut"
            run_experiment_config(
                "md",
                config_name,
                model,
                seed_ids=args.seed_ids,
            )

elif args.experiment_id == "md-len":
    # Different string lengths
    for model in CORE_MODELS:
        for alphabet_size in [2]:  # [26]:
            for length in [16, 32, 64, 128, 256, 512]:
                config_name = f"{model}_a-{alphabet_size}_t-{length}"
                run_experiment_config(
                    "md",
                    config_name,
                    model,
                    seed_ids=args.seed_ids,
                )

elif args.experiment_id == "md-split":
    # Different splits/partitions of the same strings
    for model in CORE_MODELS:
        for alphabet_size in [2]:  # [26]:
            for partitions in [2, 4, 8, 16, 32, 64]:
                # for partitions in [64]:
                config_name = f"{model}_a-{alphabet_size}_t-1024_p-{partitions}"
                run_experiment_config(
                    "md",
                    config_name,
                    model,
                    seed_ids=args.seed_ids,
                )

elif args.experiment_id == "pm":
    # Prefix Mappings
    # Running for seed 0, 1, 2, 3, 4

    for model in CORE_MODELS:
        # runtime = "4d" if model == "llama2-13b" else "1d"
        runtime = "1d"
        for alphabet_size in ALPHABET_SIZES:
            config_name = f"{model}_a-{alphabet_size}_t-1024"
            run_experiment_config(
                "pm",
                config_name,
                model,
                seed_ids=args.seed_ids,
                runtime_h=runtime,
            )

        # Entropy computed for Pythia-1B for seeds 1, 2
        # for entropy_level in ENTROPY_LEVELS:
        #     config_name = f"{model}_h-{entropy_level}_t-1024"
        #     run_experiment_config(
        #         "pm",
        #         config_name,
        #         model,
        #         seed_ids=args.seed_ids,
        #         runtime_h=runtime,
        #     )

elif args.experiment_id == "rt":
    # Repeated Training
    for model in CORE_MODELS:
        runtime = "1d" if model == "llama2-13b" else "12h"
        for alphabet_size in [2, 26]:
            repetitions = 32 if alphabet_size == 2 else 16
            config_name = f"{model}_a-{alphabet_size}_t-1024_x{repetitions}"
            run_experiment_config(
                "rt",
                config_name,
                model,
                seed_ids=args.seed_ids,
                runtime_h=runtime,
            )

elif args.experiment_id == "rs":
    # Repeated Strings
    for model in CORE_MODELS:
        for alphabet_size in [2]:  # , 26]:
            for substring_length in [16, 32, 64, 128, 256, 512]:
                config_name = (
                    f"{model}_a-{alphabet_size}_sl-{substring_length}_"
                    "ns-1_plo-iterative"
                )
                run_experiment_config(
                    "rs",
                    config_name,
                    model,
                    seed_ids=args.seed_ids,
                )
