import itertools
import os
import sys

import numpy as np

np.random.seed(3545)
cluster = "false"
retrainings = False

if retrainings:
    folder_prefix = "Retraining"
else:
    folder_prefix = "ModelSelection"

for which_example, folder_name in zip(["advection",
                                       "burgers",
                                       "shocktube",
                                       "riemann"],
                                      [folder_prefix + "AdvectionFNO",
                                       folder_prefix + "BurgersFNO",
                                       folder_prefix + "ShocktubeFNO",
                                       folder_prefix + "RiemannFNO"]):

    if retrainings:
        if which_example == "advection":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }
            net_architecture_ = {
                "width": [64],
                "modes": [16],
                "n_layers": [2]
            }

        if which_example == "burgers":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }

            net_architecture_ = {
                "width": [32],
                "modes": [20],
                "n_layers": [4]
            }
        if which_example == "shocktube":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }

            net_architecture_ = {
                "width": [32],
                "modes": [8],
                "n_layers": [4]
            }
        if which_example == "riemann":
            training_properties_ = {
                "epochs": [10000],
                "batch_size": [10],
                "learning_rate": [5e-4],
                "retrain": np.random.randint(0, 1000, 5),
            }

            net_architecture_ = {
                "width": [64],
                "modes": [16],
                "n_layers": [4]
            }
    else:
        training_properties_ = {
            "epochs": [10000],
            "batch_size": [10],
            "learning_rate": [5e-4],
            "retrain": np.random.randint(0, 1000, 3),
        }

        net_architecture_ = {
            "width": [32, 64],
            "modes": [8, 16, 20],
            "n_layers": [2, 3, 4]
        }

    ndic = {**training_properties_,
            **net_architecture_}

    if not os.path.isdir(folder_name):
        os.mkdir(folder_name)
    settings = list(itertools.product(*ndic.values()))

    i = 0
    for setup in settings:
        print(setup)

        folder_path = "\'" + folder_name + "/Setup_" + str(i) + "\'"
        print("###################################")
        training_properties_ = {
            "epochs": setup[0],
            "batch_size": setup[1],
            "learning_rate": setup[2],
            "retrain": setup[3],
        }

        net_architecture_ = {
            "width": setup[4],
            "modes": setup[5],
            "n_layers": setup[6],
        }

        arguments = list()
        arguments.append(folder_path)
        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            arguments.append("\'" + str(training_properties_).replace("\'", "\"") + "\'")
        else:
            arguments.append(str(training_properties_).replace("\'", "\""))

        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            arguments.append("\'" + str(net_architecture_).replace("\'", "\"") + "\'")
        else:
            arguments.append(str(net_architecture_).replace("\'", "\""))

        arguments.append(which_example)

        if sys.platform == "linux" or sys.platform == "linux2" or sys.platform == "darwin":
            if cluster == "true":
                if which_example == "riemann":
                    string_to_exec = "bsub -W 48:00 -R \'rusage[mem=16384]\' -R \'rusage[ngpus_excl_p=1]\' python3 TrainFNO.py"
                else:
                    string_to_exec = "bsub -W 24:00 -R \'rusage[mem=16384]\' -R \'rusage[ngpus_excl_p=1]\' python3 TrainFNO.py"
            else:
                string_to_exec = "python3 TrainFNO.py "
            for arg in arguments:
                string_to_exec = string_to_exec + " " + arg
            print(string_to_exec)
            os.system(string_to_exec)
        i = i + 1
