import pickle, numpy as np
import matplotlib.pyplot as plt

if __name__ == '__main__':
    with open('results/rbm/results.pkl', 'rb') as handle:
        data = pickle.load(handle)
    log_mmds = data['log_mmds']
    ess_raw = [data['ess']]
    for i in range(1, 7):
        with open(f'results/rbm/results{i}.pkl', 'rb') as handle:
            data = pickle.load(handle)
        ess_raw.append(data['ess'])


    fig, ax = plt.subplots(1, 2, figsize=(16, 5))
    methods = ['bg-1', 'bg-2', 'hb-10-1', 'gwg', 'gwg-3', 'gwg-5', 'mscorrect-3', 'mscorrect-5']
    names = ['Gibbs-1', 'Gibbs-2', 'HB-10-1', 'GWG-1', 'GWG-3', 'GWG-5', 'PAFS-3', 'PAFS-5']
    color = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink',
             'tab:grey', 'tab:olive', 'tab:cyan']

    ess = {}
    for key in methods:
        ess[key] = np.stack([ess_raw[i][key] for i in range(7)], axis=0)
    for i in range(len(methods)):
        ax[0].plot(np.arange(4000) * 10, log_mmds[methods[i]], label=names[i], color=color[i])
        ax[1].bar(i, np.mean(ess[methods[i]]), yerr=np.std(np.mean(ess[methods[i]], axis=1)), color=color[i], label=names[i])
    for j in range(2):
        ax[j].legend(fontsize=12)
        ax[j].tick_params(labelsize=12)
        ax[j].tick_params(labelsize=12)
        ax[j].grid()
    ax[0].set_title(f"RBM log(MMD)", fontsize=18)
    ax[1].set_title(f"RBM ESS", fontsize=18)
    ax[1].set_yscale('log')
    plt.savefig('figs/rbm.pdf')
    # plt.show()