## Import libraries.
import os
import torch
from Arguments import arguments
from plot_loss_inv import plot_loss_inv
from Learn import learn
from learning_result_and_plot import result_and_plot
from save_results import save
from save_interface_figure import save_interface_figure

if __name__ == '__main__':

    args = arguments()


    # if args.import_flag == "True" and args.import_title == 0:

    #     print("warning!")

    #     exit()

    exp_name='2024_PINN_neutral_test_%d' %args.title

    if torch.cuda.is_available():
        device = torch.device(f"cuda:{args.gpu_id}")
    elif torch.backends.mps.is_available():
        device = torch.device(f"mps:{args.gpu_id}")
    else:
        device = torch.device("cpu")

    PATH ="results/"
    if not os.path.exists(PATH):
        os.mkdir(PATH)

    if args.custom_coeffs == [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.1]:
        folder_name = 'spike1/'

    elif args.custom_coeffs == [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.1]:
        folder_name = 'spike2/'

    elif args.custom_coeffs == [0.0, 0.0, 0.1]:
        folder_name = 'square/'

    else:
        folder_name = args.PSI_list[args.arg_conformal]+"/"

    PATH = PATH+folder_name

    settings_name = args.args_type+"/"
    PATH = PATH + settings_name

    if not os.path.exists(PATH):
        os.mkdir(PATH)
    PATH = PATH+exp_name
    if not os.path.exists(PATH):
        os.mkdir(PATH)
    else:
        raise ValueError(f"Exist path: {PATH}")


    model, after_learning, learning_error = learn(args, device, PATH)

    plot_loss_inv(args, model, PATH, exp_name)

    save_interface_figure(args, after_learning, PATH, exp_name)

    dict_result_error = result_and_plot(model, after_learning, PATH, exp_name)
    
    prediction_error = learning_error["prediction_error"]
    Relative_L2_error = dict_result_error["Relative_L2_error"]
    Background_L2_error = dict_result_error["Background_L2_error"]

    errors_to_save = {
        "prediction_error" : prediction_error,
        "Relative_L2_error" : Relative_L2_error,
        "Background_L2_error" : Background_L2_error
    }
    
    save(args,PATH,exp_name,errors_to_save)

# Seed = ?


# for  i in range(0:10):
#     Seed = torch.rand()
#     run(Seed, Inputs)
