import itertools
import os
import sys

import numpy as np

np.random.seed(3545)
cluster = "false"
retrainings = False

if retrainings:
    folder_prefix = "Retraining"
else:
    folder_prefix = "ModelSelection"

for which_example, folder_name in zip(["advection",
                                       "burgers",
                                       "shocktube",
                                       "riemann"],
                                      [folder_prefix + "AdvectionSDON",
                                       folder_prefix + "BurgersSDON",
                                       folder_prefix + "ShocktubeSDON",
                                       folder_prefix + "RiemannSDON"]):
    if retrainings:
        # Model selection params
        if which_example == "advection":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
                "num_sensor": [512],
                "n_out": [512]
            }
            branch_architecture_ = {
                "n_hidden_layers_b": [3],
                "neurons_b": [256],
                "act_string_b": ["leaky_relu"],
                "dropout_rate_b": [0.0],
                "kernel_size": [3],
            }
            trunk_architecture_ = {
                "n_hidden_layers_t": [6],
                "neurons_t": [256],
                "act_string_t": ["leaky_relu"],
                "dropout_rate_t": [0.0],
                "n_basis": [50]
            }
        if which_example == "burgers":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
                "num_sensor": [512],
                "n_out": [512]
            }
            branch_architecture_ = {
                "n_hidden_layers_b": [4],
                "neurons_b": [256],
                "act_string_b": ["leaky_relu"],
                "dropout_rate_b": [0.0],
                "kernel_size": [3],
            }
            trunk_architecture_ = {
                "n_hidden_layers_t": [6],
                "neurons_t": [256],
                "act_string_t": ["leaky_relu"],
                "dropout_rate_t": [0.0],
                "n_basis": [200]
            }
        if which_example == "shocktube":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
                "num_sensor": [512],
                "n_out": [512]
            }
            branch_architecture_ = {
                "n_hidden_layers_b": [4],
                "neurons_b": [256],
                "act_string_b": ["leaky_relu"],
                "dropout_rate_b": [0.0],
                "kernel_size": [3],
            }
            trunk_architecture_ = {
                "n_hidden_layers_t": [6],
                "neurons_t": [256],
                "act_string_t": ["leaky_relu"],
                "dropout_rate_t": [0.0],
                "n_basis": [100]
            }
        if which_example == "riemann":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
                "num_sensor": [64],
                "n_out": [128 * 128]
            }
            branch_architecture_ = {
                "n_hidden_layers_b": [3],
                "neurons_b": [32],
                "act_string_b": ["leaky_relu"],
                "dropout_rate_b": [0.0],
                "kernel_size": [3],
            }
            trunk_architecture_ = {
                "n_hidden_layers_t": [6],
                "neurons_t": [256],
                "act_string_t": ["leaky_relu"],
                "dropout_rate_t": [0.0],
                "n_basis": [50]
            }
    else:
        if which_example != "riemann":
            # Model selection params
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 2),
                "num_sensor": [128, 256, 512],
                "n_out": [128, 256, 512],
            }

            branch_architecture_ = {
                "n_hidden_layers_b": [3, 4],
                "neurons_b": [256],
                "act_string_b": ["leaky_relu", "softsign"],
                "dropout_rate_b": [0.0],
                "kernel_size": [3],
            }

            trunk_architecture_ = {
                "n_hidden_layers_t": [4, 6],
                "neurons_t": [256],
                "act_string_t": ["leaky_relu", "softsign"],
                "dropout_rate_t": [0.0],
                "n_basis": [50, 100, 200]
            }
        else:
            # Model selection params
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 2),
                "num_sensor": [64, 128, 256],
                "n_out": [64 * 64, 128 * 128, 256 * 256]
            }
            branch_architecture_ = {
                "n_hidden_layers_b": [3],
                "neurons_b": [256],
                "act_string_b": ["leaky_relu", "softsign"],
                "dropout_rate_b": [0.0],
                "kernel_size": [3],
            }
            trunk_architecture_ = {
                "n_hidden_layers_t": [6],
                "neurons_t": [256],
                "act_string_t": ["leaky_relu", "softsign"],
                "dropout_rate_t": [0.0],
                "n_basis": [50, 100]
            }

    ndic = {**training_properties_,
            **branch_architecture_,
            **trunk_architecture_}

    if not os.path.isdir(folder_name):
        os.mkdir(folder_name)
    settings = list(itertools.product(*ndic.values()))

    i = 0
    for setup in settings:
        print(setup)

        folder_path = "\'" + folder_name + "/Setup_" + str(i) + "\'"
        print("###################################")
        training_properties_ = {
            "epochs": setup[0],
            "batch_size": setup[1],
            "learning_rate": setup[2],
            "retrain": setup[3],
            "num_sensor": setup[4],
            "n_out": setup[5]
        }

        branch_architecture_ = {
            "n_hidden_layers": setup[6],
            "neurons": setup[7],
            "act_string": setup[8],
            "dropout_rate": setup[9],
            "kernel_size": setup[10]
        }

        trunk_architecture_ = {
            "n_hidden_layers": setup[11],
            "neurons": setup[12],
            "act_string": setup[13],
            "dropout_rate": setup[14],
            "n_basis": setup[15]
        }

        arguments = list()
        arguments.append(folder_path)
        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            arguments.append("\'" + str(training_properties_).replace("\'", "\"") + "\'")
        else:
            arguments.append(str(training_properties_).replace("\'", "\""))

        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            arguments.append("\'" + str(branch_architecture_).replace("\'", "\"") + "\'")
        else:
            arguments.append(str(branch_architecture_).replace("\'", "\""))

        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            arguments.append("\'" + str(trunk_architecture_).replace("\'", "\"") + "\'")
        else:
            arguments.append(str(trunk_architecture_).replace("\'", "\""))

        arguments.append(which_example)

        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            if cluster == "true":
                if which_example == "riemann":
                    string_to_exec = "bsub -W 48:00 -R \'rusage[mem=8192]\' -R \'rusage[ngpus_excl_p=1]\' python3 TrainSDON2D.py"
                else:
                    string_to_exec = "bsub -W 24:00 -R \'rusage[mem=8192]\' -R \'rusage[ngpus_excl_p=1]\' python3 TrainSDON.py"
            else:
                if which_example == "riemann":
                    string_to_exec = "python3 TrainSDON.py "
                else:
                    string_to_exec = "python3 TrainSDON2D.py "
            for arg in arguments:
                string_to_exec = string_to_exec + " " + arg
            print(string_to_exec)
            os.system(string_to_exec)
        i = i + 1
