import multiprocessing as mp
import os

import fire
import torch


def run(gpu_queue, script_format: str, task_name: str):
    gpu_id = gpu_queue.get()
    print(f"Running {task_name} on GPU {gpu_id}.")
    command = script_format.format(gpu_id=gpu_id, task_name=task_name)
    os.system(command)
    gpu_queue.put(gpu_id)


def main(
        model_name: str,
        report_output_path: str,
        eval_mode: str = "all",
        override_model_path: str | None = None,
        save_individual_results_path: str | None = None,
):
    entrypoint_script_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), "eval.py")
    mp.set_start_method("spawn", force=True)

    if eval_mode == "prompt_harmfulness":
        task_names = [
            "wildguardtest_prompt",
            "harmbench",
            "openai_mod",
            "toxicchat",
            "aegis_safety_dataset",
            "simplesafetytests",
        ]
    elif eval_mode == "response_refusal":
        task_names = [
            "wildguardtest_refusal",
            "xstest_response_refusal",
        ]
    elif eval_mode == "response_harmfulness":
        task_names = [
            "wildguardtest_response",
            "harmbench:response",
            "beavertails",
            "saferlhf",
            "xstest_response_harm",
        ]
    elif eval_mode == "all":
        task_names = [
            "harmbench",
            "openai_mod",
            "toxicchat",
            "aegis_safety_dataset",
            "simplesafetytests",
            "harmbench:response",
            "beavertails",
            "saferlhf",
            "xstest_response_harm",
            "xstest_response_refusal",

            "wildguardtest_prompt",
            "wildguardtest_refusal",
            "wildguardtest_response",
        ]
    else:
        raise ValueError(f"Unknown eval_mode: {eval_mode}")

    # you can run in parallel.
    num_devices = torch.cuda.device_count()
    print(f"Number of devices: {num_devices}")
    print(f"List of tasks: {task_names}")

    with mp.Manager() as manager:
        gpu_queue = manager.Queue()
        for i in range(num_devices):
            gpu_queue.put(i)

        with mp.Pool(processes=num_devices) as pool:
            res = []
            script_format = (
                "CUDA_VISIBLE_DEVICES={gpu_id} \\\n"
                f"python {entrypoint_script_path} classifiers \\\n"
                f"  --model_name {model_name} \\\n"
                f"  --override_model_path {override_model_path} \\\n"
                "   --tasks {task_name} \\\n"
                f"  --report_output_path {report_output_path} \\\n"
                f"  --save_individual_results_path {save_individual_results_path}"
            )
            for task_name in task_names:
                res.append(
                    pool.apply_async(run, (gpu_queue, script_format, task_name)))

            for task_name, r in zip(task_names, res):
                try:
                    r.get()
                except Exception as e:
                    print(f"Error in {task_name}.")
                    print(e)


if __name__ == "__main__":
    fire.Fire(main)
