"""
Evaluates and creates figures from the Adam comparison experiment directory:
    - Plots training Adam (train and test MSE loss curves) and SWIM training result with variance!
    - Tests zero-shot learning by testing the system on larger NxN lattices. (use eval_model.py script!)
    - Demonstrates RF initialization + further training (use rf_adam.py)
"""
import json
import toml
import os
import numpy as np
import matplotlib.pyplot as plt
from argparse import ArgumentParser
from src.utils import median_index, mpl_setup, ADAM_HGN_TRAIN, ADAM_HGN, SWIM_RF_HGN_TRAIN, SWIM_RF_HGN
# mpl_setup() # sets some default params for the matplotlib figures

labelsize = 12
legend_fontsize = 10
# prefix = "4x4 3D-Lattice System"

argparser = ArgumentParser()
argparser.add_argument("-d", "--dir", type=str, help="Output directory of the experiment", required=True)
args = vars(argparser.parse_args())

results_path = os.path.join(args["dir"], "results.json")
results_file = open(results_path)
results = json.load(results_file)
results_file.close()
model_out_dir_path = os.path.join(args["dir"], "models")

config_path = os.path.join(args["dir"], "config.toml")
config_file = open(config_path)
config = toml.load(config_file)
config_file.close()

print("Configuration:")
print(config)
print('-'*40)

def get_loss_stats(key):
    lst = [result[key] for result in results]
    print("lst has lenght", len(lst))
    print("lst[0] has length", len(lst[0]))
    losses = np.vstack([ result[key] for result in results ]) # of shape (n_runs, )
    min = losses.min(axis=0)
    mean = losses.mean(axis=0)
    median = np.median(losses, axis=0)
    max = losses.max(axis=0)
    return min, mean, median, max

def get_arr(key):
    return np.asarray([ result[key] for result in results ])

swim_train_rel2 = get_arr('swim-train-rel2')
swim_train_mse = get_arr('swim-train-mse')
swim_test_rel2 = get_arr('swim-test-rel2')
swim_test_mse = get_arr('swim-test-mse')
swim_train_times = get_arr('swim-train-time')

adam_train_rel2 = get_arr('adam-train-rel2')
adam_train_mse = get_arr('adam-train-mse')
adam_test_rel2 = get_arr('adam-test-rel2')
adam_test_mse = get_arr('adam-test-mse')
adam_train_times = get_arr('adam-train-time')

# Adam best losses and loss history
adam_best_losses = get_arr('adam-best-test-loss')
adam_train_min, adam_train_mean, adam_train_median, adam_train_max = get_loss_stats('adam-train-loss')
adam_test_min, adam_test_mean, adam_test_median, adam_test_max = get_loss_stats('adam-test-loss')

x = np.arange(config["train"]["n_steps"]) + 1
print("Training step length:", len(x))

# Accuracy
print("==== Mean Train Accruacy")
print(f"(SWIM) RF-HGN   :   {swim_train_rel2.mean():.2e} Relative L2        {swim_train_mse.mean():.2e} MSE")
print(f"RELATIVE L2     :   min={swim_train_rel2.min():.2e}     max={swim_train_rel2.max():.2e}")
print()
# print(f"(median values) :   {np.median(swim_train_rel2, axis=0):.2e} Relative L2        {np.median(swim_train_mse, axis=0):.2e} MSE")
print(f"Adam-HGN        :   {adam_train_rel2.mean():.2e} Relative L2        {adam_train_mean[-1]:.2e} MSE")
print(f"RELATIVE L2     :   min={adam_train_rel2.min():.2e}     max={adam_train_rel2.max():.2e}")
# print(f"(median values) :   {np.median(adam_train_rel2, axis=0):.2e} Relative L2        {np.median(adam_train_mse, axis=0):.2e} MSE")
print()
print("==== Mean Test Accruacy")
print(f"(SWIM) RF-HGN   :   {swim_test_rel2.mean():.2e} Relative L2        {swim_test_mse.mean():.2e} MSE")
print(f"RELATIVE L2     :   min={swim_test_rel2.min():.2e}     max={swim_test_rel2.max():.2e}")
# print(f"(median values) :   {np.median(swim_test_rel2, axis=0):.2e} Relative L2        {np.median(swim_test_mse, axis=0):.2e} MSE")
print()
print(f"Adam-HGN        :   {adam_test_rel2.mean():.2e} Relative L2        {adam_test_mean[-1]:.2e} MSE")
print(f"RELATIVE L2     :   min={adam_test_rel2.min():.2e}     max={adam_test_rel2.max():.2e}")
# print(f"(median values) :   {np.median(adam_test_rel2, axis=0):.2e} Relative L2        {np.median(adam_test_mse, axis=0):.2e} MSE")
# Train times
print()
print("==== Train times [s]=[seconds]")
# print(f"(SWIM) RF-HGN   :   min={swim_train_times.min():.3f} [s], mean={swim_train_times.mean():.3f} [s], median={np.median(swim_train_times):.3f} [s], max={swim_train_times.max():.3f} [s]")
# print(f"Adam-HGN        :   min={adam_train_times.min():.3f} [s], mean={adam_train_times.mean():.3f} [s], median={np.median(adam_train_times):.3f} [s], max={adam_train_times.max():.3f} [s]")
print(f"(SWIM) RF-HGN   :   min={swim_train_times.min():.2e} [s], mean={swim_train_times.mean():.2e} [s], max={swim_train_times.max():.2e} [s]")
print(f"Adam-HGN        :   min={adam_train_times.min():.2e} [s], mean={adam_train_times.mean():.2e} [s], max={adam_train_times.max():.2e} [s]")
print()

# ==== PLOT TRAINING PLOT WITH VARIANCE

fig, (ax1) = plt.subplots(1, 1, figsize=(4,3), dpi=100)
ax1.set_yscale("log")
ax1.set_title(f"Loss")
# ax1.set_title(f"{prefix} Loss")
ax1.set_xlabel("Training step", fontsize=labelsize)
ax1.set_ylabel(r"MSE", fontsize=labelsize)

# Train loss
ax1.plot(x, adam_train_mean, c=ADAM_HGN_TRAIN, label="(Adam) HGN train", alpha=0.8)
ax1.fill_between(x, adam_train_min, adam_train_max, color=ADAM_HGN_TRAIN, alpha=0.2)
ax1.axhline(swim_train_mse.mean(), c=SWIM_RF_HGN_TRAIN, linestyle="dashed", label=f"(SWIM) RF-HGN train")
ax1.axhline(swim_train_mse.min(), c=SWIM_RF_HGN_TRAIN, linestyle="dotted", alpha=0.4)
ax1.axhline(swim_train_mse.max(), c=SWIM_RF_HGN_TRAIN, linestyle="dotted", alpha=0.4)
# ax1.fill_between(x, swim_train_mse.min(), swim_train_mse.max(), color="tab:green", alpha=0.2)

# Test loss
ax1.plot(x, adam_test_mean, c=ADAM_HGN, label="(Adam) HGN test", alpha=0.8)
ax1.fill_between(x, adam_test_min, adam_test_max, color=ADAM_HGN, alpha=0.2)
ax1.axhline(swim_test_mse.mean(), c=SWIM_RF_HGN, linestyle="dashed", label=f"(SWIM) RF-HGN test")
ax1.axhline(swim_test_mse.min(), c=SWIM_RF_HGN, linestyle="dotted", alpha=0.4)
ax1.axhline(swim_test_mse.max(), c=SWIM_RF_HGN, linestyle="dotted", alpha=0.4)
# ax1.fill_between(x, swim_test_mse.min(), swim_test_mse.max(), color="tab:red", alpha=0.2)

ax1.grid(True)
ax1.legend(fontsize=legend_fontsize)
fig.tight_layout()

plot_path = os.path.join(args["dir"], "adam_comparison")

fig.savefig(plot_path + ".pdf")
print(f"-> Saved under {plot_path + '.pdf'}")

fig.savefig(plot_path + ".png")
print(f"-> Saved under {plot_path + '.png'}")

midx = median_index(swim_test_mse)
print("median performance model index for swim is", midx)
midx = median_index(adam_test_mse)
print("median performance index for adam is", midx)
exit(0)
