import os

from utils import ENVLIST, Runner

# GPUS = [0, 1, 2, 3, 4, 5, 6, 7]
GPUS = [3, 4, 5, 6, 7]
MODE = "test"
SCRIPT = "./run.sh"
CORRUPT_TEST = False
OOD_TEST = False
COV_EXPERIMENT = True
METATRAIN_ITERS = 60000
# METATRAIN_ITERS = 100


def get_baseline_miniimagenet_runs() -> ENVLIST:
    envs: ENVLIST = []
    for shot, batch in [(1, 4), (5, 2)]:
        for model in ["protonet", "protonet-sn", "proto-ddu", "proto-sngp"]:
            # GPU and DATADIR should be set elsewhere
            env = os.environ.copy()
            env["DATASET"] = "miniimagenet"
            env["METATRAIN_ITERS"] = str(METATRAIN_ITERS)
            env["INFERENCE_STYLE"] = "distance"
            env["RUN"] = str(0)
            env["WAY"] = str(5)
            env["SHOT"] = str(shot)
            env["TRAIN_QUERY_SHOTS"] = str(15)
            env["VAL_QUERY_SHOTS"] = str(shot)
            env["BATCH_SIZE"] = str(batch)
            env["MODE"] = MODE
            env["MODEL"] = model
            env["FORWARD_TYPE"] = "softmax"
            env["BETA_BIAS"] = str(False)
            env["OOD_TEST"] = str(OOD_TEST)
            env["COV_EXPERIMENT"] = str(COV_EXPERIMENT)
            env["CORRUPT_TEST"] = str(CORRUPT_TEST)
            env["SAVE_BEST_VAL"] = str(True)

            # not used in these models, but still needs to be in the environment for the script
            env["ENCODER_TYPE"] = "diag"
            env["PMA_TYPE"] = "no-residual"
            env["RANK"] = str(1)
            env["T"] = str(0.05)
            envs.append(env)
    return envs


def get_mahalanobis_miniimagenet_runs() -> ENVLIST:
    envs: ENVLIST = []
    for shot, batch in [(1, 4), (5, 2)]:
        for (forward_type, beta_bias, inference_style) in zip(["softmax"], [False], ["softmax-sample"]):
            for (encoder, rank) in zip(["diag", "low-rank", "low-rank", "low-rank", "low-rank"], [1, 1, 2, 4, 8]):
                # GPU and DATADIR should be set elsewhere
                env = os.environ.copy()
                env["DATASET"] = "miniimagenet"
                env["METATRAIN_ITERS"] = str(METATRAIN_ITERS)
                env["INFERENCE_STYLE"] = inference_style
                env["RUN"] = str(0)
                env["WAY"] = str(5)
                env["SHOT"] = str(shot)
                env["TRAIN_QUERY_SHOTS"] = str(15)
                env["VAL_QUERY_SHOTS"] = str(shot)
                env["BATCH_SIZE"] = str(batch)
                env["MODE"] = MODE
                env["MODEL"] = "proto-mahalanobis"
                env["FORWARD_TYPE"] = forward_type
                env["BETA_BIAS"] = str(beta_bias)
                env["OOD_TEST"] = str(OOD_TEST)
                env["COV_EXPERIMENT"] = str(COV_EXPERIMENT)
                env["CORRUPT_TEST"] = str(CORRUPT_TEST)
                env["SAVE_BEST_VAL"] = str(True)

                env["ENCODER_TYPE"] = encoder
                env["PMA_TYPE"] = "no-residual"
                env["RANK"] = str(rank)
                env["T"] = str(1.0)
                envs.append(env)

    return envs


def get_mahalanobis_omniglot_runs() -> ENVLIST:
    envs: ENVLIST = []
    for way, batch, trainstep, valstep, inner_lr in [(5, 32, 1, 3, 0.4), (20, 16, 5, 5, 0.1)]:
        for shot in [1, 5]:
            for (forward_type, beta_bias, inference_style) in zip(["softmax"], [False], ["softmax-sample"]):
                for (encoder, rank) in zip(["diag", "low-rank", "low-rank", "low-rank", "low-rank"], [1, 1, 2, 4, 8]):
                    env = os.environ.copy()
                    env["DATASET"] = "omniglot"
                    env["METATRAIN_ITERS"] = str(METATRAIN_ITERS)
                    env["RUN"] = str(0)
                    env["WAY"] = str(way)
                    env["INFERENCE_STYLE"] = inference_style
                    env["SHOT"] = str(shot)
                    env["TRAIN_QUERY_SHOTS"] = str(shot)
                    env["VAL_QUERY_SHOTS"] = str(15)
                    env["MODEL"] = "proto-mahalanobis"
                    env["FORWARD_TYPE"] = forward_type
                    env["BETA_BIAS"] = str(beta_bias)
                    env["BATCH_SIZE"] = str(batch)
                    env["MODE"] = MODE
                    env["OOD_TEST"] = str(OOD_TEST)
                    env["COV_EXPERIMENT"] = str(COV_EXPERIMENT)
                    env["CORRUPT_TEST"] = str(CORRUPT_TEST)
                    env["SAVE_BEST_VAL"] = str(False)

                    env["ENCODER_TYPE"] = encoder
                    env["PMA_TYPE"] = "no-residual"
                    env["RANK"] = str(rank)
                    env["T"] = str(1.0)

                    envs.append(env)
    return envs


def get_baseline_omniglot_runs() -> ENVLIST:
    envs: ENVLIST = []
    for way, batch, trainstep, valstep, inner_lr in [(5, 32, 1, 3, 0.4), (20, 16, 5, 5, 0.1)]:
        for shot in [1, 5]:
            for model in "protonet", "protonet-sn", "proto-ddu", "proto-sngp":
                env = os.environ.copy()
                env["DATASET"] = "omniglot"
                env["METATRAIN_ITERS"] = str(METATRAIN_ITERS)
                env["RUN"] = str(0)
                env["WAY"] = str(way)
                env["INFERENCE_STYLE"] = "distance"
                env["SHOT"] = str(shot)
                env["TRAIN_QUERY_SHOTS"] = str(shot)
                env["VAL_QUERY_SHOTS"] = str(15)
                env["MODEL"] = model
                env["FORWARD_TYPE"] = "softmax"
                env["BETA_BIAS"] = str(False)
                env["BATCH_SIZE"] = str(batch)
                env["MODE"] = MODE
                env["OOD_TEST"] = str(OOD_TEST)
                env["CORRUPT_TEST"] = str(CORRUPT_TEST)
                env["COV_EXPERIMENT"] = str(COV_EXPERIMENT)
                env["SAVE_BEST_VAL"] = str(False)

                # not used in these models, but still needs to be in the environment for the script
                env["ENCODER_TYPE"] = "diag"
                env["PMA_TYPE"] = "no-residual"
                env["RANK"] = str(1)
                env["T"] = str(1.0)

                envs.append(env)
    return envs


if __name__ == "__main__":
    baseline_omniglot_runs = get_baseline_omniglot_runs()
    baseline_miniimagenet_runs = get_baseline_miniimagenet_runs()

    mahalanobis_omniglot_runs = get_mahalanobis_omniglot_runs()
    mahalanobis_miniimagenet_runs = get_mahalanobis_miniimagenet_runs()

    runs = mahalanobis_miniimagenet_runs + baseline_miniimagenet_runs + mahalanobis_omniglot_runs + baseline_omniglot_runs
    runner = Runner(SCRIPT, GPUS, runs)
    runner.run()
