import os

from utils import ENVLIST, Runner

GPUS = [2, 3, 4, 5, 6, 7]
MODE = "test"
SCRIPT = "./run.sh"
CORRUPT_TEST = True
OOD_TEST = False


def get_miniimagenet_runs() -> ENVLIST:
    envs: ENVLIST = []
    for filters in [32, 64]:
        # for norm_type in ["transductive", "reptile-norm"]:
        for norm_type in ["transductive"]:
            for shot, inner_eval_batch in [(1, 5), (5, 15)]:
                for run in range(5):
                    # GPU and DATADIR should be set elsewhere
                    env = os.environ.copy()
                    env["DATASET"] = "miniimagenet"
                    env["RUN"] = str(run)
                    env["WAY"] = str(5)
                    env["SHOT"] = str(shot)
                    env["FILTERS"] = str(filters)
                    env["TRAIN_QUERY_SHOTS"] = str(15)
                    env["VAL_QUERY_SHOTS"] = str(shot)
                    env["INNER_TRAIN_STEPS"] = str(8)
                    env["INNER_VAL_STEPS"] = str(50)
                    env["INNER_TRAIN_BATCH_SIZE"] = str(10)
                    env["INNER_EVAL_BATCH_SIZE"] = str(inner_eval_batch)
                    env["INNER_LR"] = str(1e-3)
                    env["BATCH_SIZE"] = str(5)
                    env["METATRAIN_ITERS"] = str(100000)
                    env["MODE"] = MODE
                    env["NORM_TYPE"] = norm_type
                    env["OOD_TEST"] = str(OOD_TEST)
                    env["CORRUPT_TEST"] = str(CORRUPT_TEST)
                    env["SAVE_BEST_VAL"] = str(True)

                    envs.append(env)
    return envs


def get_omniglot_runs() -> ENVLIST:
    envs: ENVLIST = []
    # for norm_type in ["transductive", "reptile-norm"]:
    for norm_type in ["transductive"]:
        tup = ((5, 5, 5, 50, 1e-3, 100000, 10, 5), (20, 5, 10, 50, 5e-4, 200000, 20, 10))
        for way, batch, trainstep, valstep, inner_lr, metatrain_iters, inner_train_batch, inner_eval_batch in tup:
            for shot in [1, 5]:
                for run in range(5):
                    env = os.environ.copy()
                    env["DATASET"] = "omniglot"
                    env["RUN"] = str(run)
                    env["WAY"] = str(way)
                    env["SHOT"] = str(shot)
                    env["FILTERS"] = str(64)
                    env["TRAIN_QUERY_SHOTS"] = str(10)
                    env["VAL_QUERY_SHOTS"] = str(shot)
                    env["INNER_TRAIN_STEPS"] = str(trainstep)
                    env["INNER_VAL_STEPS"] = str(valstep)
                    env["INNER_TRAIN_BATCH_SIZE"] = str(inner_train_batch)
                    env["INNER_EVAL_BATCH_SIZE"] = str(inner_eval_batch)
                    env["INNER_LR"] = str(inner_lr)
                    env["BATCH_SIZE"] = str(batch)
                    env["METATRAIN_ITERS"] = str(metatrain_iters)
                    env["MODE"] = MODE
                    env["NORM_TYPE"] = norm_type
                    env["OOD_TEST"] = str(OOD_TEST)
                    env["CORRUPT_TEST"] = str(CORRUPT_TEST)
                    env["SAVE_BEST_VAL"] = str(False)

                    envs.append(env)
    return envs


if __name__ == "__main__":
    omniglot_runs = get_omniglot_runs()
    miniimagenet_runs = get_miniimagenet_runs()
    runs = omniglot_runs + miniimagenet_runs
    runner = Runner(SCRIPT, GPUS, runs)
    runner.run()
