import argparse
import os
from multiprocessing import Pool, Queue, current_process
from subprocess import call

import torch

AVAILABLE_DATASETS = (
    []
    + ["assist2009", "algebra2005", "bridge2algebra2006", "nips_task34"]
    + ["statics2011", "assist2015", "poj", "ednet"]
)


def train(config: tuple) -> None:

    gpu_id = queue.get()
    process_id = current_process().ident
    cfg, dataset, fold = config

    # Build command to start model training
    cmd = f"python {os.getcwd()}/train.py +with=ray_single_gpu +{cfg} data={dataset} data.val_fold_idx={fold}"

    # Start computation in subprocess
    try:
        print(f"{process_id}: starting process on GPU {gpu_id}")
        print(f"{process_id}: {cmd}")

        call(
            cmd.split(" "),
            cwd=os.getcwd(),
            env=dict(os.environ)
            | {"CUDA_VISIBLE_DEVICES": f"{gpu_id}", "HYDRA_FULL_ERROR": "1"},
        )

        print(f"{process_id}: finished")
    finally:
        queue.put(gpu_id)


if __name__ == "__main__":

    assert torch.cuda.is_available()
    device_count = torch.cuda.device_count()

    description = f"""Sweep hyperparameters.
    Available datasets: {AVAILABLE_DATASETS}\n\n"""

    parser = argparse.ArgumentParser(
        description=description,
        formatter_class=argparse.RawTextHelpFormatter,
    )
    parser.add_argument(
        "config",
        type=str,
        help="name of config to execute (example: sweep/ktst=benchmark_mean)",
    )
    parser.add_argument(
        "--datasets",
        type=str,
        default="",
        help="comma-separated dataset names (default: all datasets)",
    )
    parser.add_argument(
        "--folds",
        type=str,
        default="",
        help="comma-separated list of folds to compute (default: 0,1,2,3,4)",
    )
    parser.add_argument(
        "--devices",
        type=str,
        default="",
        help="comma-separated list of CUDA device ids that can be used for computations (default: all available devices)",
    )
    parser.add_argument(
        "--dry",
        type=bool,
        default=False,
        action=argparse.BooleanOptionalAction,
        help="list computations that configuration would trigger (default: --no-dry)",
    )

    args = parser.parse_args()

    # Parse comma separated inputs
    config = args.config
    datasets = args.datasets.split(",") if args.datasets else AVAILABLE_DATASETS
    folds = [int(x) for x in (args.folds.split(",") if args.folds else range(5))]
    devices = [
        int(x)
        for x in (args.devices.split(",") if args.devices else range(device_count))
    ]

    # Validate input
    assert all([d in AVAILABLE_DATASETS for d in datasets])
    assert all([f in [x for x in range(5)] for f in folds])
    assert all([d in [x for x in range(device_count)] for d in devices])

    # Determine list of outstanding compute tasks
    tasks = []
    for d in datasets:
        for f in folds:
            tasks.append((config, d, f))

    if args.dry:
        for i, task in enumerate(tasks):
            print(f"{i:3}: {task}")
    else:
        # Initialize the queue with the device ids
        queue = Queue()
        for d in devices:
            queue.put(d)

        # Initialize pool of workers, add tasks and start computation
        pool = Pool(processes=len(devices))
        for _ in pool.imap_unordered(train, tasks):
            pass
        pool.close()
        pool.join()
