import json
import os
import sys
import tqdm

import GPUtil
import torch

import rgd
import testproblems


hps_oneshot = {
    "cm": {"lr": 1e-2, "momentum": 0.99},
    "adam": {"lr": 1e-3, "beta1": 0.9, "beta2": 0.999},
    "rgd": {"lr": 1e-2, "momentum": 0.99, "alpha": 1, "delta": 10},
    "rgd_eu": {"lr": 1e-2, "momentum": 0.99, "delta": 10},
}


def run_oneshot():
    opt = "rgd"
    problem = "imagenet_resnet50"
    results_dir = f"results/oneshot/{problem}/{opt}"
    testproblems.problems[problem].run(opt, hps_oneshot[opt], results_dir)


def run_jobs(job_file):
    with open(job_file) as f:
        contents = json.load(f)

    for run in tqdm.tqdm(contents):
        problem = run["testproblem"]
        opt = run["optimizer_name"]
        hps = run["hyperparams"]

        hp_string = "__".join(f"{k}__{v}" for k, v in hps.items())
        results_dir = f"results/small_budget/{problem}/{opt}/{hp_string}"

        testproblems.problems[problem].run(opt, hps, results_dir)


def main():
    devices = GPUtil.getAvailable(limit=float("inf"), maxLoad=0.05, maxMemory=0.05)
    os.environ["CUDA_VISIBLE_DEVICES"] = ", ".join([str(d) for d in devices])

    job = sys.argv[1]
    if job == "oneshot":
        run_oneshot()
    else:
        run_jobs(job)


if __name__ == "__main__":
    main()
