import optuna
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import sys

for sn in sys.argv[1:]:
    study = optuna.load_study(study_name=sn, storage=os.getenv("SQLURI"))

    fig = optuna.visualization.plot_parallel_coordinate(study)
    fig.show()

    df = study.trials_dataframe()
    fig = plt.figure()

    print('best val:', round(df.value.max(), 4))
    a = sns.lineplot(x=df.index, y=df.value.cummax())
    a.set_xlabel('trial number')
    sns.scatterplot(x=df.index, y=df.value, color='red')
    a.set_ylabel('score')
    a.legend(['best value', "trial's value"])

    #copy past in same order as learn_policy_optuna
    keys = {   'lr_decay': None,
                'pred_weight': None,
                'entropy_beta': None,
                'entropy_beta_decay': None,
                'dlm_noise': None,
                'curriculum_thresh': None,
                'stoch_decay': None,
                'gumbel_noise_begin': None,
                'dropout_prob_begin': None,
                'tau_begin': None,
                'last_tau_begin': None
            }

    # keys = ["params_"+f for f in keys.keys()]

    df = df.sort_values('value', ascending=False)
    for i in range(5):
        # print("_".join([str(q) for q in df[keys].values[i]]))
        print(i, np.array(df)[i, :])

    print()
    print()

    print(df)
plt.show()
