import os

from utils import ENVLIST, Runner

GPUS = [0, 1, 3, 5]
# GPUS = [0]
MODE = "train"
SCRIPT = "./run.sh"


def get_cifar_runs() -> ENVLIST:
    envs: ENVLIST = []
    for model in ["wide-sn-resnet-28-10-cifar"]:
        for ds, batch in zip(["cifar10", "cifar100"], [128, 128]):
            for run in range(5):
                # GPU and DATADIR should be set elsewhere
                env = os.environ.copy()
                env["DATASET"] = ds
                env["RUN"] = str(run)
                env["BATCH_SIZE"] = str(batch)
                env["OOD_TEST"] = str(False)
                env["CORRUPT_TEST"] = str(False)
                env["SAVE_BEST_VAL"] = str(True)
                env["MODE"] = MODE
                env["MODEL"] = model
                env["P"] = str(0.1)
                env["FILTERWISE_DROPOUT"] = str(True)

                envs.append(env)
    return envs


if __name__ == "__main__":
    runner = Runner(SCRIPT, GPUS, get_cifar_runs())
    runner.run()
