import os
import subprocess
import time
import random
import datetime
import itertools
import argparse
import string

MAX_JOBS = 10
SLEEP_TIME = 100
JOB_TIME = "4:00:00"
GRES_CONFIG = "gpu:rtxa6000:1" #"gpu:rtx2080ti:2"
CPUS_PER_TASK = int(GRES_CONFIG.split(":")[-1]) * 4
MEMORY_PER_NODE = int(GRES_CONFIG.split(":")[-1]) * 32
DELETE_LOG = False
SLURM_NUM_JOB_OFFSET = 1
JOB_ID_OFFSET = 1 * 100004


COMMAND_TEMPLATE = """
accelerate launch --num_processes={num_gpus} --num_machines=1 --mixed_precision=bf16 --dynamo_backend=no --main_process_port={random_port} main.py\
    --model_name="{model_name}"\
    --strong_model_name="{strong_model_name}"\
    --is_easy_to_hard="{is_easy_to_hard}"\
    --dataset_name="{dataset_name}"\
    --adaboost_rounds=5\
    --num_epochs=5\
    --learning_rate=5e-5\
    --train_batch_size={train_batch_size}\
    --pred_batch_size={pred_batch_size}\
    --num_proc={num_cpus}\
    --model_max_length=512\
    --is_token_based_error=yes\
    --is_weight_by_token={is_weight_by_token}\
    --is_completion_only={is_completion_only}\
    --probability_bias={probability_bias}\
    --token_prob_window_size={token_prob_window_size}\
    --logits_top_k={logits_top_k}\
    --is_combine_probs={is_combine_probs}\
    --is_top_k_pooling={is_top_k_pooling}\
    --temp_folder={temp_folder}\
"""

ARGS_TEMPLATE = {
    "model_name": [
        "EleutherAI/pythia-70m-deduped",
        "EleutherAI/pythia-160m-deduped",
        # "EleutherAI/pythia-410m-deduped",
        # "EleutherAI/pythia-1b-deduped",
        # "EleutherAI/pythia-1.4b-deduped",
        # "EleutherAI/pythia-2.8b-deduped",
        # "EleutherAI/pythia-6.9b-deduped",
    ],
    "strong_model_name": [
        "EleutherAI/pythia-160m-deduped",
        "EleutherAI/pythia-410m-deduped",
        "EleutherAI/pythia-1b-deduped",
        "EleutherAI/pythia-1.4b-deduped",
        # "EleutherAI/pythia-2.8b-deduped",
        # "EleutherAI/pythia-6.9b-deduped",
    ],
    "is_easy_to_hard": [True, False],
    "dataset_name": [
        "quartz", "arc"
    ],  # ["sciq", "quartz", "quail", "hellaswag", "winogrande"],
    "is_weight_by_token": [True],  # [True, False],
    "is_completion_only": [True],  # [True, False],
    "probability_bias": [0.0],  # [0.0, 0.1, 0.2, 0.3, 0.4],
    "token_prob_window_size": [1],  # [1, 2, 4, 8, 16],
    "logits_top_k": [64],  # [1, 4, 16, 64, 256],
    "is_combine_probs": [True],  # [True, False],
    "is_top_k_pooling": [False],  # [True, False],
}

ALL_ARGS = [
    dict(zip(ARGS_TEMPLATE, v)) for v in itertools.product(*ARGS_TEMPLATE.values())
]

BATCH_SIZE_DICT = {
    ("EleutherAI/pythia-70m-deduped", "arc"): (8, 4),
    ("EleutherAI/pythia-70m-deduped", "sciq"): (8, 4),
    ("EleutherAI/pythia-70m-deduped", "quartz"): (8, 4),
    ("EleutherAI/pythia-160m-deduped", "arc"): (8, 4),
    ("EleutherAI/pythia-160m-deduped", "sciq"): (8, 4),
    ("EleutherAI/pythia-160m-deduped", "quartz"): (8, 4),
    ("EleutherAI/pythia-410m-deduped", "sciq"): (6, 4),
}

SCRIPT_HEADER = """
#!/bin/zsh
source /etc/profile.d/modules.sh
source ~/.zshrc
source ./venv/bin/activate
mkdir {temp_folder}
export TORCH_DISTRIBUTED_DEBUG=INFO
export NCCL_DEBUG=INFO
export CUDA_VISIBLE_DEVICES={gpu_device_ids}
"""

SCRIPT_SUCCESS_CHECKING = """
if [ $? -ne 0 ]; then
    echo "Job exited with an error. Aborting script."
    exit 1
fi
"""


def get_slurm_job_count():
    result = subprocess.run(
        "squeue --me -t R | wc -l; squeue --me -t PD | wc -l",
        shell=True,
        executable="/bin/bash",
        text=True,
        capture_output=True,
    )
    running_jobs, pending_jobs = map(int, result.stdout.strip().split("\n"))
    return running_jobs + pending_jobs - SLURM_NUM_JOB_OFFSET - 1


def create_temp_shell_script(index, jobname):
    random_string = "".join(
        random.choices(string.ascii_lowercase + string.digits, k=10)
    )
    scriptpath = f"./jobs/srun_{jobname}.sh"
    with open(scriptpath, "w") as f:
        f.write(
            SCRIPT_HEADER.format(
                temp_folder=f"/tmp/{random_string}",
                gpu_device_ids=",".join(
                    [str(i) for i in range(int(GRES_CONFIG.split(":")[-1]))]
                ),
            ).strip()
            + "\n"
        )
        f.write(
            " ".join(
                COMMAND_TEMPLATE.format(
                    **ALL_ARGS[index],
                    random_port=random.randint(10000, 49512),
                    num_gpus=int(GRES_CONFIG.split(":")[-1]),
                    num_cpus=CPUS_PER_TASK,
                    temp_folder=f"/tmp/{random_string}",
                    train_batch_size=BATCH_SIZE_DICT[
                        ALL_ARGS[index]["model_name"], ALL_ARGS[index]["dataset_name"]
                    ][0],
                    pred_batch_size=BATCH_SIZE_DICT[
                        ALL_ARGS[index]["model_name"], ALL_ARGS[index]["dataset_name"]
                    ][1],
                )
                .replace("\t", " ")
                .split()
            ).strip()
            + "\n"
        )
        f.write(SCRIPT_SUCCESS_CHECKING.strip() + "\n")
        if DELETE_LOG:
            f.write(f"rm {scriptpath[:-3]}.log\n")
            f.write(f"rm {scriptpath}\n")
    subprocess.run(
        f"chmod +x {scriptpath}",
        shell=True,
        executable="/bin/bash",
        text=True,
        capture_output=True,
    )
    return scriptpath


def submit_slurm_job(index):
    jobname = f"{JOB_ID_OFFSET + index: 07d}".strip()
    scriptpath = create_temp_shell_script(index, jobname)
    subprocess.run(
        f'srun --job-name={jobname} --qos=scavenger --partition=scavenger --account=scavenger --time={JOB_TIME} {"--gres=" + GRES_CONFIG if GRES_CONFIG is not None else ""} --nodes=1 --ntasks=1 --cpus-per-task={CPUS_PER_TASK} --mem={MEMORY_PER_NODE}G "{scriptpath}" > {scriptpath[:-3]}.log 2>&1 &',
        shell=True,
        executable="/bin/bash",
        text=True,
        capture_output=True,
    )


def main(num_shards, shard_index, real_run):
    # Create necessary directories
    [
        os.makedirs(d, exist_ok=True)
        for d in ["cache", "jobs", "models", "tables", "results"]
    ]

    # Count the number of combinations
    print(f"Total {len(ALL_ARGS)} combinations.")
    selected_indices = list(
        range(
            shard_index * len(ALL_ARGS) // num_shards,
            (shard_index + 1) * len(ALL_ARGS) // num_shards,
        )
    )
    for idx in selected_indices:
        if (
            ALL_ARGS[idx]["model_name"],
            ALL_ARGS[idx]["dataset_name"],
        ) not in BATCH_SIZE_DICT:
            print(
                f"Combination ({ALL_ARGS[idx]['model_name']}, {ALL_ARGS[idx]['dataset_name']}) not found in BATCH_SIZE_DICT."
            )
            return
    print(f"Running shard {shard_index} out of {num_shards}.")
    print(f"Selected {len(selected_indices)} combinations.")
    if not real_run:
        print("Dry run, exiting. Use --real_run to really run.")
        return

    # Slurm job submission loop
    print("Submitting slurm jobs.")
    while True:
        if len(selected_indices) == 0:
            print(
                f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Info: No remaining tasks, existing.'
            )
            break
        print(
            f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Info: Found {len(selected_indices)} remaining tasks.'
        )
        available_num_jobs = 1 + MAX_JOBS - get_slurm_job_count()
        if available_num_jobs <= 0:
            print(
                f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Info: Too many slurm jobs running. Wait for {SLEEP_TIME} seconds.'
            )
            time.sleep(SLEEP_TIME)
            continue
        counter = 0
        for index in list(selected_indices)[:available_num_jobs]:
            print(
                f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Info: Submitting slurm job for index {index}.'
            )
            submit_slurm_job(index)
            selected_indices.remove(index)
            counter += 1
            time.sleep(SLEEP_TIME / 5)
        print(
            f'[{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}] Info: Submitted {counter} slurm jobs.'
        )
        time.sleep(SLEEP_TIME)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--num_shards",
        type=int,
        default=1,
        help="Number of shards to run",
    )
    parser.add_argument(
        "--shard_index",
        type=int,
        default=0,
        help="Current shard index",
    )
    parser.add_argument(
        "--real_run",
        action="store_true",
        default=False,
        help="whether to really run things",
    )
    args = parser.parse_args()
    main(args.num_shards, args.shard_index, args.real_run)
