import argparse

from lib_dl.ops.run_jobs import EnvConfig, MachineConfig, TaskConfig, run


TASKS = {
    "im": TaskConfig(
        experiment="im",
        config_names=[
            "pyt-70m_by",
            "pyt-160m_by",
            "pyt-410m_by",
            "pyt-1b_by",
            "pyt-1.4b_by",
            "pyt-2.8b_by",
            "pyt-6.9b_by",
            "pyt-12b_by",
        ],
        seed_ids=[0],
        experiment_options=[],
    ),
    "mhpr": TaskConfig(
        experiment="mhpr",
        config_names=(
            # [
            #     f"{model_id}_ns-{num_sequences}_lr-{learning_rate}"
            #     for model_id in ["pyt-70m", "pyt-1b"]
            #     # for model_id in ["pyt-70m"]
            #     # for model_id in ["pyt-1b"]
            #     for num_sequences in [1, 16, 256]
            #     # for learning_rate in [1e-4, 1e-5, 1e-6]
            #     for learning_rate in [5e-6, 5e-7]
            #     # 1 1e-6
            #     # 16 all
            #     # 256 1e-6
            # ]
            # +
            # [
            #     f"{model_id}_lrs-{schedule}_lr-{learning_rate}"
            #     for model_id in ["pyt-70m", "pyt-1b"]
            #     for learning_rate in [1e-4, 1e-5, 1e-6]
            #     for schedule in ["const", "lin", "cos"]
            # ]
            # +
            # [
            #     f"{model_id}_ws-{warmup_steps}_lr-{learning_rate}"
            #     # for model_id in ["pyt-70m", "pyt-1b"]
            #     for model_id in ["pyt-1b"]
            #     # for learning_rate in [1e-5, 1e-6]
            #     for learning_rate in [1e-5]
            #     for warmup_steps in [0, 10, 50, 100]
            # ]
            # +
            [
                f"{model_id}_sl-{sequence_length}"
                # for model_id in ["pyt-70m"]#, "pyt-1b"]
                # for model_id in ["pyt-70m"]#, "pyt-1b"]
                for model_id in ["pyt-1b"]
                for sequence_length in [4, 16, 64, 128, 256, 1024]
                # for sequence_length in [1024]
            ]
            # +
            # [
            #     f"{model_id}_dt-{data_type}_tok-{tokenization_id}"
            #     for model_id in ["pyt-70m"]#, "pyt-1b"]
            #     for data_type in ["rand", "rand-names", "sci-names", "wiki"]
            #     # for data_type in ["wiki"]
            #     for tokenization_id in ["char", "def"]
            # ]
        ),
        seed_ids=[0],
        experiment_options=[],
    ),
}

MACHINES = [
    <redacted>
]
ENV = EnvConfig(
    username=<redacted>,
    experiment_dir=<redacted>,
    path_expansion=<redacted>,
)


def main():
    parser = argparse.ArgumentParser(description="Run experiment tasks")
    parser.add_argument("task", type=str, help="Task id to run")
    args = parser.parse_args()

    task = TASKS[args.task]
    run(ENV, MACHINES, task)


if __name__ == "__main__":
    main()
