import itertools
import os
import sys

import numpy as np

np.random.seed(3545)
cluster = "false"
retrainings = False

if retrainings:
    folder_prefix = "Retraining"
else:
    folder_prefix = "ModelSelection"

for which_example, folder_name in zip(["advection",
                                       "burgers",
                                       "shocktube",
                                       "riemann"],
                                      [folder_prefix + "AdvectionConv",
                                       folder_prefix + "BurgersConv",
                                       folder_prefix + "ShocktubeConv",
                                       folder_prefix + "RiemannConv"]):
    if not retrainings:
        # Model Params Search
        if which_example != "riemann":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }

            net_architecture_ = {
                "activation": ["leaky_relu", "softsign", "sin"],
                "atrous": [0],
                "start": [8, 16, 32],
                "opt": ["adam"],
            }
        else:
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }

            net_architecture_ = {
                "activation": ["relu"],
                "atrous": [1],
                "start": [8, 16, 32],
                "opt": ["adam"],
            }

    else:
        if which_example == "advection":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }
            net_architecture_ = {
                "activation": ["leaky_relu"],
                "atrous": [0],
                "start": [8],
                "opt": ["adam"],
            }
        if which_example == "burgers":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }
            net_architecture_ = {
                "activation": ["leaky_relu"],
                "atrous": [0],
                "start": [16],
                "opt": ["adam"],
            }
        if which_example == "shocktube":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }
            net_architecture_ = {
                "activation": ["leaky_relu"],
                "atrous": [0],
                "start": [16],
                "opt": ["adam"],
            }

        if which_example == "riemann":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }
            net_architecture_ = {
                "activation": ["relu"],
                "atrous": [1],
                "start": [16],
                "opt": ["adam"],
            }

    ndic = {**training_properties_,
            **net_architecture_}

    if not os.path.isdir(folder_name):
        os.mkdir(folder_name)
    settings = list(itertools.product(*ndic.values()))

    i = 0
    for setup in settings:
        print(setup)

        folder_path = "\'" + folder_name + "/Setup_" + str(i) + "\'"
        print("###################################")
        training_properties_ = {
            "epochs": setup[0],
            "batch_size": setup[1],
            "learning_rate": setup[2],
            "retrain": setup[3],
        }

        net_architecture_ = {
            "activation": setup[4],
            "atrous": setup[5],
            "start": setup[6],
            "opt": setup[7],
        }

        arguments = list()
        arguments.append(folder_path)
        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            arguments.append("\'" + str(training_properties_).replace("\'", "\"") + "\'")
        else:
            arguments.append(str(training_properties_).replace("\'", "\""))

        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            arguments.append("\'" + str(net_architecture_).replace("\'", "\"") + "\'")
        else:
            arguments.append(str(net_architecture_).replace("\'", "\""))

        arguments.append(which_example)

        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            if cluster == "true":
                if which_example == "riemann":
                    string_to_exec = "bsub -W 48:00 -R \'rusage[mem=16384]\' -R \'rusage[ngpus_excl_p=1]\' python3 TrainConv.py"
                else:
                    string_to_exec = "bsub -W 24:00 -R \'rusage[mem=16384]\' -R \'rusage[ngpus_excl_p=1]\' python3 TrainConv.py"
            else:
                string_to_exec = "python3 TrainConv.py "
            for arg in arguments:
                string_to_exec = string_to_exec + " " + arg
            print(string_to_exec)
            os.system(string_to_exec)
        i = i + 1
