import argparse

from lib_dl.ops.run_jobs import EnvConfig, MachineConfig, TaskConfig, run


TASKS = {
    "im": TaskConfig(
        experiment="im",
        config_names=[
            "pyt-70m_by",
            "pyt-160m_by",
            "pyt-410m_by",
            "pyt-1b_by",
            "pyt-1.4b_by",
            "pyt-2.8b_by",
            "pyt-6.9b_by",
            "pyt-12b_by",
        ],
        seed_ids=[0],
        experiment_options=[],
    ),
    "mhpr": TaskConfig(
        experiment="mhpr",
        config_names=(
            # [
            #     f"{model_id}_ns-{num_sequences}_lr-{learning_rate}"
            #     for model_id in ["pyt-70m", "pyt-1b"]
            #     # for model_id in ["pyt-70m"]
            #     # for model_id in ["pyt-1b"]
            #     for num_sequences in [1, 16, 256]
            #     # for learning_rate in [1e-4, 1e-5, 1e-6]
            #     for learning_rate in [5e-6, 5e-7]
            #     # 1 1e-6
            #     # 16 all
            #     # 256 1e-6
            # ]
            # +
            # [
            #     f"{model_id}_lrs-{schedule}_lr-{learning_rate}"
            #     for model_id in ["pyt-70m", "pyt-1b"]
            #     for learning_rate in [1e-4, 1e-5, 1e-6]
            #     for schedule in ["const", "lin", "cos"]
            # ]
            # +
            # [
            #     f"{model_id}_ws-{warmup_steps}_lr-{learning_rate}"
            #     # for model_id in ["pyt-70m", "pyt-1b"]
            #     for model_id in ["pyt-1b"]
            #     # for learning_rate in [1e-5, 1e-6]
            #     for learning_rate in [1e-5]
            #     for warmup_steps in [0, 10, 50, 100]
            # ]
            # +
            [
                f"{model_id}_sl-{sequence_length}"
                # for model_id in ["pyt-70m"]#, "pyt-1b"]
                # for model_id in ["pyt-70m"]#, "pyt-1b"]
                for model_id in ["pyt-1b"]
                for sequence_length in [4, 16, 64, 128, 256, 1024]
                # for sequence_length in [1024]
            ]
            # +
            # [
            #     f"{model_id}_dt-{data_type}_tok-{tokenization_id}"
            #     for model_id in ["pyt-70m"]#, "pyt-1b"]
            #     for data_type in ["rand", "rand-names", "sci-names", "wiki"]
            #     # for data_type in ["wiki"]
            #     for tokenization_id in ["char", "def"]
            # ]
        ),
        seed_ids=[0],
        experiment_options=[],
    ),
}

MACHINES = [
    # MachineConfig(
    #     name="volta12",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="volta11",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="volta10",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="volta09",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="volta08",
    #     device_ids=[1],
    # ),
    # MachineConfig(
    #     name="volta07",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="volta06",
    #     device_ids=[0],
    # ),
    # MachineConfig(
    #     name="volta05",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="volta04",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="volta03",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="volta02",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="volta01",
    #     device_ids=[0, 1],
    # ),
    MachineConfig(
        name="ANONYMOUS-2a40-01",
        device_ids=[0, 1],
    ),
    # MachineConfig(
    #     name="ANONYMOUS-2a40-02",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="ANONYMOUS-2a40-03",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="ANONYMOUS-2a40-04",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="ANONYMOUS-2a40-05",
    #     device_ids=[0, 1],
    # ),
    MachineConfig(
        name="ANONYMOUS-2a40-06",
        device_ids=[0, 1],
    ),
    MachineConfig(
        name="ANONYMOUS-2a40-07",
        device_ids=[0, 1],
    ),
    # MachineConfig(
    #     name="ANONYMOUS-2a40-08",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="ANONYMOUS-2a100-01",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="ANONYMOUS-2a100-02",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="ANONYMOUS-2a100-03",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="ANONYMOUS-2a100-04",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="ANONYMOUS-2a100-05",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="ANONYMOUS-2a100-06",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="ANONYMOUS-2a100-07",
    #     device_ids=[0, 1],
    # ),
    # MachineConfig(
    #     name="ANONYMOUS-2a100-08",
    #     device_ids=[0, 1],
    # ),
]
ENV = EnvConfig(
    username="ANONYMOUS",
    experiment_dir="/ANONYMOUS/llm-1/work/ANONYMOUS/llms",
    path_expansion="/home/ANONYMOUS/.poetry/bin:/home/ANONYMOUS/.local/bin",
)


def main():
    parser = argparse.ArgumentParser(description="Run experiment tasks")
    parser.add_argument("task", type=str, help="Task id to run")
    args = parser.parse_args()

    task = TASKS[args.task]
    run(ENV, MACHINES, task)


if __name__ == "__main__":
    main()
