import json
import toml
import os
import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
from src.utils import median_index, mpl_setup
mpl_setup() # sets some default params for the matplotlib figures

argparser = ArgumentParser()
argparser.add_argument("-d", "--dir", type=str, help="Output directory of the experiment", required=True)
args = vars(argparser.parse_args())

results_path = os.path.join(args["dir"], "results.json")
results_file = open(results_path)
results = json.load(results_file)
results_file.close()
model_out_dir_path = os.path.join(args["dir"], "models")

config_path = os.path.join(args["dir"], "config.toml")
config_file = open(config_path)
config = toml.load(config_file)
config_file.close()


print("Configuration:")
print(json.dumps(config, indent=4))
print('-'*40)

keys = []
all_train_mse, all_test_mse, all_train_time = {}, {}, {}
for result in results:
    for key, value in result.items():
        if not key in [ "run_id", "data_seed", "model_seed", "sampling_seed" ]:
            if not key in keys:                 keys.append(key)
            if not key in all_train_mse:        all_train_mse[key] = []
            if not key in all_test_mse:         all_test_mse[key] = []
            if not key in all_train_time:       all_train_time[key] = []

            train_mse = value[f"{key}-train-mse"]
            test_mse = value[f"{key}-test-mse"]
            train_time = value[f"{key}-train-time"]

            all_train_mse[key].append(train_mse)
            all_test_mse[key].append(test_mse)
            all_train_time[key].append(train_time)

for key in keys:
    max_train_mse = np.nanmax(all_train_mse[key])
    mean_train_mse = np.nanmean(all_train_mse[key])
    min_train_mse = np.nanmin(all_train_mse[key])

    max_test_mse = np.nanmax(all_test_mse[key])
    mean_test_mse = np.nanmean(all_test_mse[key])
    min_test_mse = np.nanmin(all_test_mse[key])

    max_train_time = np.nanmax(all_train_time[key])
    mean_train_time = np.nanmean(all_train_time[key])
    min_train_time = np.nanmin(all_train_time[key])

    # to see whether the optimmization completely failed for this key..
    nan_count = np.count_nonzero(np.isnan(all_test_mse[key]))

    # print(f"Optimizer {key}")
    # print(f"-> Train MSE            {mean_train_mse:.2e}")
    # print(f"-> Test  MSE            {mean_test_mse:.2e}")
    # print(f"-> Train time [seconds] {mean_train_time:.4f} or {mean_train_time:4e}")

    print()
    print("-"*40)
    print(f"Optimizer = {key}           MIN                     MEAN                                MAX")
    print("-"*40)
    print(f"-> Train MSE            :   {min_train_mse:.2e}                {mean_train_mse:.2e}                         {max_train_mse:.2e}")
    print(f"-> Test  MSE            :   {min_test_mse:.2e}                {mean_test_mse:.2e}                         {max_test_mse:.2e}")
    print(f"-> Train time [seconds] :   {min_train_time:.2f} or {min_train_time:.2e}       {mean_train_time:.2f} or {mean_train_time:.2e}               {max_train_time:.2f} or {max_train_time:.2e}")
    print(f"-> Nan Count            :   {nan_count}")

exit(0)
