import itertools
import os
import sys

import numpy as np

np.random.seed(3545)
cluster = "false"
retrainings = False

if retrainings:
    folder_prefix = "Retraining"
else:
    folder_prefix = "ModelSelection"

for which_example, folder_name in zip(["advection",
                                       "burgers",
                                       "shocktube",
                                       "riemann"],
                                      [folder_prefix + "AdvectionFNN",
                                       folder_prefix + "BurgersFNN",
                                       folder_prefix + "ShocktubeFNN",
                                       folder_prefix + "RiemannFNN"]):

    if retrainings:
        # Linear Advection
        if which_example == "advection":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }
            net_architecture_ = {
                "n_hidden_layers_t": [4],
                "neurons_t": [128],
                "act_string_t": ["leaky_relu"],
                "dropout_rate_t": [0.0],
            }

        # Burgers
        if which_example == "burgers":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }
            net_architecture_ = {
                "n_hidden_layers_t": [8],
                "neurons_t": [256],
                "act_string_t": ["leaky_relu"],
                "dropout_rate_t": [0.0],
            }

        # LaxSod
        if which_example == "shocktube":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }
            net_architecture_ = {
                "n_hidden_layers_t": [8],
                "neurons_t": [256],
                "act_string_t": ["leaky_relu"],
                "dropout_rate_t": [0.0],
            }

        # Riemann
        if which_example == "riemann":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }
            net_architecture_ = {
                "n_hidden_layers_t": [8],
                "neurons_t": [256],
                "act_string_t": ["leaky_relu"],
                "dropout_rate_t": [0.0],
            }
    else:
        # Model selection params
        training_properties_ = {
            "epochs": [10000],
            "batch_size": [10],
            "learning_rate": [5e-4],
            "retrain": np.random.randint(0, 1000, 5),
        }
        net_architecture_ = {
            "n_hidden_layers_t": [4, 6, 8],
            "neurons_t": [128, 256],
            "act_string_t": ["leaky_relu", "tanh", "sin"],
            "dropout_rate_t": [0.0],
        }

    ndic = {**training_properties_,
            **net_architecture_}

    if not os.path.isdir(folder_name):
        os.mkdir(folder_name)
    settings = list(itertools.product(*ndic.values()))

    i = 0
    for setup in settings:
        print(setup)

        folder_path = "\'" + folder_name + "/Setup_" + str(i) + "\'"
        print("###################################")
        training_properties_ = {
            "epochs": setup[0],
            "batch_size": setup[1],
            "learning_rate": setup[2],
            "retrain": setup[3],
        }

        net_architecture_ = {
            "n_hidden_layers": setup[4],
            "neurons": setup[5],
            "act_string": setup[6],
            "dropout_rate": setup[7],
        }

        arguments = list()
        arguments.append(folder_path)
        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            arguments.append("\'" + str(training_properties_).replace("\'", "\"") + "\'")
        else:
            arguments.append(str(training_properties_).replace("\'", "\""))

        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            arguments.append("\'" + str(net_architecture_).replace("\'", "\"") + "\'")
        else:
            arguments.append(str(net_architecture_).replace("\'", "\""))
        arguments.append(which_example)
        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            if cluster == "true":
                string_to_exec = "bsub -W 48:00 -R \'rusage[mem=16384]\' -R \'rusage[ngpus_excl_p=1]\' python3 TrainFNN.py"
            else:
                string_to_exec = "python3 TrainFNN.py "
            for arg in arguments:
                string_to_exec = string_to_exec + " " + arg
            print(string_to_exec)
            os.system(string_to_exec)
        i = i + 1
