import os, pickle
import matplotlib.pyplot as plt
import numpy as np
from statsmodels.tsa.stattools import acf

class MainPlot(object):
    def __init__(self, info):
        self.model_info = info
        self.color = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink',
                      'tab:grey', 'tab:olive', 'tab:cyan']

    def ess(self, auto_cor):
        rho = 0
        for i in range(len(auto_cor)):
            if auto_cor[i] < 0:
                break
            rho += auto_cor[i]
        return len(auto_cor) / (1 + 2 * rho)

    def run(self, methods, names, count, N=5, interval=2000):
        if not os.path.isdir(f'figs/{self.model_info}'):
            os.makedirs(f'figs/{self.model_info}')
        fig, ax = plt.subplots(1, 2, figsize=(16, 4))
        # ax[0].set_title(f"Burn In", fontsize=16)
        ax[0].set_title(f"Effective Sample Size", fontsize=16)
        ax[1].set_title(f"Normalized Effective Sample Size", fontsize=16)
        for i, method in enumerate(methods):
            with open(f'results/{self.model_info}/{method}.pkl', 'rb') as handle:
                data = pickle.load(handle)
            Burn_in, ESS, NESS = [], [], []
            for n in range(N):
                logp, trace, elapse, succ = data[n]
                T = len(trace)
                auto_cor = acf(trace[T // 2:], nlags = T // 4, fft=True)
                ess = self.ess(auto_cor)
                # Burn_in.append(logp[:interval])
                ESS.append(ess)
                NESS.append(ess / count[i])
            # Burn_in = np.array(Burn_in).squeeze()
            ESS = np.array(ESS)
            NESS = np.array(NESS)
            # ax[0].plot(np.arange(interval), np.mean(Burn_in, axis=0), color=self.color[i], label=names[i])
            # ax[0].fill_between(np.arange(interval), np.mean(Burn_in, axis=0) - np.std(Burn_in, axis=0),
            #                    np.mean(Burn_in, axis=0) + np.std(Burn_in, axis=0), color=self.color[i], alpha=0.3)
            ax[0].bar(i, np.mean(ESS), yerr=np.std(ESS), color=self.color[i], ecolor='black', label=names[i])
            ax[1].bar(i, np.mean(NESS), yerr=np.std(NESS), color=self.color[i], ecolor='black', label=names[i])
        for idx in range(2):
            ax[idx].legend()
            ax[idx].grid()
            ax[idx].set_yscale('log')
        # plt.show()
        plt.savefig(f'figs/{self.model_info}.pdf')

if __name__ == '__main__':
    methods = ['Gibbs_R-1', 'Gibbs_R-5', 'RW_R-1', 'RW_R-10', 'LB_R-1' , 'GWG_R-1', 'GWG_R-10', 'MSA_R-10', 'MSF_R-10']
    names = ['Gibbs-1', 'Gibbs-5', 'RW-1', 'RW-10', 'LB-1', 'GWG-1', 'GWG-10', 'PAS-10', 'PASF-10']
    count = [2, 32, 2, 2, 4, 4, 4, 13, 4]
    # infos = ['ising/p-50_mu-2.0_sigma-3.0_lamda-1.0']
    infos = ['ising/p-50_mu-2.0_sigma-3.0_lamda-1.0', 'ising/p-100_mu-2.0_sigma-3.0_lamda-1.0',
             'ising/p-150_mu-2.0_sigma-3.0_lamda-1.0', 'ising/p-200_mu-2.0_sigma-3.0_lamda-1.0']
    for info in infos:
        ploter = MainPlot(info=info)
        ploter.run(methods, names, count, interval=2000)