import os

from utils import ENVLIST, Runner

GPUS = [2, 3, 4, 5, 6, 7]
MODE = "test"
SCRIPT = "./run.sh"
CORRUPT_TEST = False
OOD_TEST = False


def get_miniimagenet_runs() -> ENVLIST:
    envs: ENVLIST = []
    for first_order in [True, False]:
        for shot, batch in [(1, 4), (5, 2)]:
            for run in range(5):
                # GPU and DATADIR should be set elsewhere
                env = os.environ.copy()
                env["DATASET"] = "miniimagenet"
                env["RUN"] = str(run)
                env["WAY"] = str(5)
                env["SHOT"] = str(shot)
                env["FIRST_ORDER"] = str(first_order)
                env["TRAIN_QUERY_SHOTS"] = str(15)
                env["VAL_QUERY_SHOTS"] = str(shot)
                env["INNER_TRAIN_STEPS"] = str(5)
                env["INNER_VAL_STEPS"] = str(10)
                env["INNER_LR"] = str(0.01)
                env["BATCH_SIZE"] = str(batch)
                env["MODE"] = MODE
                env["OOD_TEST"] = str(OOD_TEST)
                env["CORRUPT_TEST"] = str(CORRUPT_TEST)
                env["SAVE_BEST_VAL"] = str(True)

                envs.append(env)
    return envs


def get_omniglot_runs() -> ENVLIST:
    envs: ENVLIST = []
    for first_order in [True, False]:
        for way, batch, trainstep, valstep, inner_lr in [(5, 32, 1, 3, 0.4), (20, 16, 5, 5, 0.1)]:
            for shot in [1, 5]:
                for run in range(5):
                    env = os.environ.copy()
                    env["DATASET"] = "omniglot"
                    env["RUN"] = str(run)
                    env["WAY"] = str(way)
                    env["SHOT"] = str(shot)
                    env["FIRST_ORDER"] = str(first_order)
                    env["TRAIN_QUERY_SHOTS"] = str(shot)
                    env["VAL_QUERY_SHOTS"] = str(15)
                    env["INNER_TRAIN_STEPS"] = str(trainstep)
                    env["INNER_VAL_STEPS"] = str(valstep)
                    env["INNER_LR"] = str(inner_lr)
                    env["BATCH_SIZE"] = str(batch)
                    env["MODE"] = MODE
                    env["OOD_TEST"] = str(OOD_TEST)
                    env["CORRUPT_TEST"] = str(CORRUPT_TEST)
                    env["SAVE_BEST_VAL"] = str(False)

                    envs.append(env)
    return envs


if __name__ == "__main__":
    omniglot_runs = get_omniglot_runs()
    miniimagenet_runs = get_miniimagenet_runs()
    runner = Runner(SCRIPT, GPUS, omniglot_runs + miniimagenet_runs)
    runner.run()
