import os
from train import *

# datasets["num_work"] = 1

dataset_name = "abalone"
datasets[dataset_name]["epoch"] = 1000
datasets[dataset_name]["slow-star"] = True
datasets["decrease"] = False
datasets["slow"] = 200
datasets["stop_c"] = 0
datasets["inf"] = 0

datasets["weight_decay"] = 0
datasets["drop_out"] = 0
model_type = "RwR_mlp"
loss_types = ["RwR_loss_mae"]

metrics = ["RwR_Risk_Evaluation", "A_loss", "R_loss", "Reject_Rate", "R_A", "A_R"]

for datasets[dataset_name]["c"] in [3.0, 4.0, 5.0, 6.0]:
    for datasets[dataset_name]["optim_rate"] in [0.01, 0.001]:
        history = {}
        for loss in loss_types:
            history[loss] = {}
            for metric in metrics:
                history[loss][metric] = []

        for time in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]:
            load_data(dataset_name, time)
            for loss_type in loss_types:
                up_seed(seed)
                best_model = train_dataset_model(dataset_name, loss_type, model_type=model_type,
                                                 metrics=metrics,
                                                 print_show=True)
                for metric in metrics:
                    if metric in ["RwR_Risk_Evaluation", "A_loss", "R_loss", "Reject_Rate", "R_A", "A_R", "real"]:
                        loss_test = test_model(best_model["RwR_Risk_Evaluation"], dataset_name, model_type, [metric])
                    else:
                        loss_test = test_model(best_model[metric], dataset_name, model_type, [metric])
                    history[loss_type][metric].append(loss_test[metric])

                print("dataset_name = {},c={},optim_rate={},time={},loss_type = {} complete".format(dataset_name,
                                                                                                    datasets[
                                                                                                        dataset_name][
                                                                                                        "c"], datasets[
                                                                                                        dataset_name][
                                                                                                        "optim_rate"],
                                                                                                    time, loss_type))

            print("dataset_name = {},c={},optim_rate={},time={} complete".format(dataset_name,
                                                                                 datasets[dataset_name]["c"],
                                                                                 datasets[dataset_name]["optim_rate"],
                                                                                 time))

            for loss_type in loss_types:
                print("loss_type = {}:\t{}".format(loss_type, history[loss_type]))

        print("dataset_name = {},c={},optim_rate={} complete".format(dataset_name, datasets[dataset_name]["c"],
                                                                     datasets[dataset_name]["optim_rate"]))

        for loss in loss_types:
            for metric in metrics:
                all_loss = history[loss][metric]
                all_loss = np.array(all_loss)
                history[loss][metric] = [all_loss.mean(), all_loss.std()]

        for loss_type in loss_types:
            print("loss_type = {}:\t{}".format(loss_type, history[loss_type]))
    print("dataset_name = {},c={} complete".format(dataset_name, datasets[dataset_name]["c"]))
