import os

from utils import ENVLIST, Runner

GPUS = [0, 1, 2, 3, 5, 6, 7]
MODE = "test"
SCRIPT = "./run.sh"
CORRUPT_TEST = True
OOD_TEST = False
COV_EXPERIMENT = False
METATRAIN_ITERS = 60000
# METATRAIN_ITERS = 100


def get_baseline_miniimagenet_runs() -> ENVLIST:
    envs: ENVLIST = []
    for run in range(5):
        for shot, batch in [(1, 4), (5, 2)]:
            for model in ["protonet", "protonet-sn", "proto-ddu", "proto-sngp"]:
                # GPU and DATADIR should be set elsewhere
                env = os.environ.copy()
                env["DATASET"] = "miniimagenet"
                env["METATRAIN_ITERS"] = str(METATRAIN_ITERS)
                env["INFERENCE_STYLE"] = "distance"
                env["RUN"] = str(run)
                env["WAY"] = str(5)
                env["SHOT"] = str(shot)
                env["TRAIN_QUERY_SHOTS"] = str(15)
                env["VAL_QUERY_SHOTS"] = str(shot)
                env["BATCH_SIZE"] = str(batch)
                env["MODE"] = MODE
                env["MODEL"] = model
                env["FORWARD_TYPE"] = "softmax"
                env["BETA_BIAS"] = str(False)
                env["OOD_TEST"] = str(OOD_TEST)
                env["COV_EXPERIMENT"] = str(COV_EXPERIMENT)
                env["CORRUPT_TEST"] = str(CORRUPT_TEST)
                env["SAVE_BEST_VAL"] = str(True)

                # not used in these models, but still needs to be in the environment for the script
                env["ENCODER_TYPE"] = "diag"
                env["PMA_TYPE"] = "no-residual"
                env["RANK"] = str(1)
                env["T"] = str(0.05)

                envs.append(env)
    return envs


def get_mahalanobis_miniimagenet_runs() -> ENVLIST:
    envs: ENVLIST = []
    for run in range(5):
        for shot, batch in [(1, 4), (5, 2)]:
            for (forward_type, beta_bias, inference_style) in zip(["softmax"], [False], ["softmax-sample"]):
                for (encoder, rank) in zip(["diag", "low-rank", "low-rank", "low-rank", "low-rank"], [1, 1, 2, 4, 8]):
                    # GPU and DATADIR should be set elsewhere
                    env = os.environ.copy()
                    env["DATASET"] = "miniimagenet"
                    env["METATRAIN_ITERS"] = str(METATRAIN_ITERS)
                    env["INFERENCE_STYLE"] = inference_style
                    env["RUN"] = str(run)
                    env["WAY"] = str(5)
                    env["SHOT"] = str(shot)
                    env["TRAIN_QUERY_SHOTS"] = str(15)
                    env["VAL_QUERY_SHOTS"] = str(shot)
                    env["BATCH_SIZE"] = str(batch)
                    env["MODE"] = MODE
                    env["MODEL"] = "proto-mahalanobis"
                    env["FORWARD_TYPE"] = forward_type
                    env["BETA_BIAS"] = str(beta_bias)
                    env["OOD_TEST"] = str(OOD_TEST)
                    env["COV_EXPERIMENT"] = str(COV_EXPERIMENT)
                    env["CORRUPT_TEST"] = str(CORRUPT_TEST)
                    env["SAVE_BEST_VAL"] = str(True)

                    env["ENCODER_TYPE"] = encoder
                    env["PMA_TYPE"] = "no-residual"
                    env["RANK"] = str(rank)
                    env["T"] = str(1.0)
                    envs.append(env)

    return envs


def get_mahalanobis_omniglot_runs() -> ENVLIST:
    envs: ENVLIST = []
    for run in range(5):
        for way, batch, trainstep, valstep, inner_lr in [(5, 32, 1, 3, 0.4), (20, 16, 5, 5, 0.1)]:
            for shot in [1, 5]:
                for (forward_type, beta_bias, inference_style) in zip(["softmax"], [False], ["softmax-sample"]):
                    for (encoder, rank) in zip(["diag", "low-rank", "low-rank", "low-rank", "low-rank"], [1, 1, 2, 4, 8]):
                        env = os.environ.copy()
                        env["DATASET"] = "omniglot"
                        env["METATRAIN_ITERS"] = str(METATRAIN_ITERS)
                        env["RUN"] = str(run)
                        env["WAY"] = str(way)
                        env["INFERENCE_STYLE"] = inference_style
                        env["SHOT"] = str(shot)
                        env["TRAIN_QUERY_SHOTS"] = str(shot)
                        env["VAL_QUERY_SHOTS"] = str(15)
                        env["MODEL"] = "proto-mahalanobis"
                        env["FORWARD_TYPE"] = forward_type
                        env["BETA_BIAS"] = str(beta_bias)
                        env["BATCH_SIZE"] = str(batch)
                        env["MODE"] = MODE
                        env["OOD_TEST"] = str(OOD_TEST)
                        env["COV_EXPERIMENT"] = str(COV_EXPERIMENT)
                        env["CORRUPT_TEST"] = str(CORRUPT_TEST)
                        env["SAVE_BEST_VAL"] = str(False)

                        env["ENCODER_TYPE"] = encoder
                        env["PMA_TYPE"] = "no-residual"
                        env["RANK"] = str(rank)
                        env["T"] = str(1.0)

                        envs.append(env)
    return envs


def get_baseline_omniglot_runs() -> ENVLIST:
    envs: ENVLIST = []
    for run in range(5):
        for way, batch, trainstep, valstep, inner_lr in [(5, 32, 1, 3, 0.4), (20, 16, 5, 5, 0.1)]:
            for shot in [1, 5]:
                for model in "protonet", "protonet-sn", "proto-ddu", "proto-sngp":
                    env = os.environ.copy()
                    env["DATASET"] = "omniglot"
                    env["METATRAIN_ITERS"] = str(METATRAIN_ITERS)
                    env["RUN"] = str(run)
                    env["WAY"] = str(way)
                    env["INFERENCE_STYLE"] = "distance"
                    env["SHOT"] = str(shot)
                    env["TRAIN_QUERY_SHOTS"] = str(shot)
                    env["VAL_QUERY_SHOTS"] = str(15)
                    env["MODEL"] = model
                    env["FORWARD_TYPE"] = "softmax"
                    env["BETA_BIAS"] = str(False)
                    env["BATCH_SIZE"] = str(batch)
                    env["MODE"] = MODE
                    env["OOD_TEST"] = str(OOD_TEST)
                    env["CORRUPT_TEST"] = str(CORRUPT_TEST)
                    env["COV_EXPERIMENT"] = str(COV_EXPERIMENT)
                    env["SAVE_BEST_VAL"] = str(False)

                    # not used in these models, but still needs to be in the environment for the script
                    env["ENCODER_TYPE"] = "diag"
                    env["PMA_TYPE"] = "no-residual"
                    env["RANK"] = str(1)
                    env["T"] = str(1.0)

                    envs.append(env)
    return envs


if __name__ == "__main__":
    baseline_omniglot_runs = get_baseline_omniglot_runs()
    baseline_miniimagenet_runs = get_baseline_miniimagenet_runs()

    mahalanobis_omniglot_runs = get_mahalanobis_omniglot_runs()
    mahalanobis_miniimagenet_runs = get_mahalanobis_miniimagenet_runs()

    # runs = mahalanobis_miniimagenet_runs + baseline_miniimagenet_runs + mahalanobis_omniglot_runs + baseline_omniglot_runs
    runs = baseline_miniimagenet_runs + baseline_omniglot_runs
    runner = Runner(SCRIPT, GPUS, runs)
    runner.run()
