import os

from utils import ENVLIST, Runner

GPUS = [0, 1, 2, 4, 7]
MODE = "train"
SCRIPT = "./run.sh"


def get_miniimagenet_runs() -> ENVLIST:
    envs: ENVLIST = []
    for shot, batch in [(1, 4), (5, 2)]:
        for run in range(5):
            # GPU and DATADIR should be set elsewhere
            env = os.environ.copy()
            env["DATASET"] = "miniimagenet"
            env["RUN"] = str(run)
            env["WAY"] = str(5)
            env["SHOT"] = str(shot)
            env["TRAIN_QUERY_SHOTS"] = str(15)
            env["VAL_QUERY_SHOTS"] = str(shot)
            env["BATCH_SIZE"] = str(batch)
            env["MODE"] = MODE
            env["OOD_TEST"] = str(False)
            env["CORRUPT_TEST"] = str(False)
            env["SAVE_BEST_VAL"] = str(True)

            envs.append(env)
    return envs


def get_omniglot_runs() -> ENVLIST:
    envs: ENVLIST = []
    for way, batch, trainstep, valstep, inner_lr in [(5, 32, 1, 3, 0.4), (20, 16, 5, 5, 0.1)]:
        for shot in [1, 5]:
            for run in range(5):
                env = os.environ.copy()
                env["DATASET"] = "omniglot"
                env["RUN"] = str(run)
                env["WAY"] = str(way)
                env["SHOT"] = str(shot)
                env["TRAIN_QUERY_SHOTS"] = str(shot)
                env["VAL_QUERY_SHOTS"] = str(15)
                env["BATCH_SIZE"] = str(batch)
                env["MODE"] = MODE
                env["OOD_TEST"] = str(False)
                env["CORRUPT_TEST"] = str(False)
                env["SAVE_BEST_VAL"] = str(False)

                envs.append(env)
    return envs


if __name__ == "__main__":
    omniglot_runs = get_omniglot_runs()
    miniimagenet_runs = get_miniimagenet_runs()
    runner = Runner(SCRIPT, GPUS, omniglot_runs + miniimagenet_runs)
    runner.run()
