import os, pickle
import matplotlib.pyplot as plt
from SIP.model import Ising, BMM, FHMM
from SIP.sampler import *

class MainExp(object):
    def __init__(self, seed=cmd_args.seed):
        self.model = None
        self.rng = np.random.RandomState(seed)
        self.color = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink',
                      'tab:grey', 'tab:olive', 'tab:cyan']

    def init_fhmm(self, L=1000, K=10, s=0.5, a=0.1, b=0.05, seed=0, device=torch.device('cpu')):
        self.model = FHMM(L=L, K=K, sigma2=s, alpha=a, beta=b, seed=seed, device=device)
        self.model_info = f"fhmm"

    def _run(self, method='gibbs_r-1', T=1000, *args, **kwargs):
        config = method.split(sep='_')
        if config[0] == 'RW':
            sampler = RWSampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        elif config[0] == 'Gibbs':
            sampler = GibbsSampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        elif config[0] == 'LB':
            sampler = LBSampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        elif config[0] == 'MT':
            sampler = MTSampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        elif config[0] == 'GWG':
            sampler = GWGSampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        elif config[0] == 'MSA':
            sampler = MSASampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        elif config[0] == 'PAFS':
            sampler = MSFSampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        else:
            raise NotImplementedError

    def run(self, l, k, s, a, b):
        myExp.init_fhmm(L=l, K=k, s=s, a=a, b=b, seed=cmd_args.seed, device=device)
        fig, ax = plt.subplots(1, 2, figsize=(16,6))
        ax[0].set_title(f"Energy (Negative Log Joint Density)", fontsize=16)
        ax[1].set_title(f"Log of Reconstruction Error", fontsize=16)
        names = ['Gibbs-1', 'Gibbs-3', 'RW-1', 'RW-3', 'GWG-1', 'GWG-3', 'PAFS-3']
        methods = ['Gibbs_R-1', 'Gibbs_R-3', 'RW_R-1', 'RW_R-3', 'GWG_R-1', 'GWG_R-3', 'PAFS_R-3']
        # T = [1000, 250, 1000, 1000, 500, 500, 500]
        T = [5000, 1250, 5000, 5000, 2500, 2500, 2500]
        size = [2, 8, 2, 2, 4, 4, 4]
        for i, method in enumerate(methods):
            Error = []
            Energy = []
            for _ in range(5):
                logp, trace, elapse, succ = self._run(method, T=T[i])
                Error.append(trace)
                Energy.append(logp)
            Error = np.stack(Error, axis=0).squeeze()
            Energy = np.stack(Energy, axis=0).squeeze()
            ax[0].plot(np.arange(T[i]) * size[i], np.mean(Energy, axis=0), label=names[i], color=self.color[i], alpha=0.8)
            ax[0].fill_between(np.arange(T[i]) * size[i], np.mean(Energy, axis=0) - np.std(Energy, axis=0),
                               np.mean(Energy, axis=0) + np.std(Energy, axis=0),
                               color=self.color[i], alpha=0.3)
            ax[1].plot(np.arange(T[i]) * size[i], np.mean(Error, axis=0), label=names[i], color=self.color[i], alpha=0.8)
            ax[1].fill_between(np.arange(T[i]) * size[i], np.mean(Error, axis=0) - np.std(Error, axis=0),
                               np.mean(Error, axis=0) + np.std(Error, axis=0),
                               color=self.color[i], alpha=0.3)
        energy = self.model.energy(self.model.X).item()
        ax[0].plot(np.arange(T[0] * size[0]), [energy] * (T[0] * size[0]), ls='--', label='GroundTruth', color=self.color[i + 1],
                   alpha=0.5)

        ax[0].legend(fontsize=14, loc='upper right')
        ax[0].set_xlabel("Energy Functions Evaluations", fontsize=16)
        ax[0].grid()
        ax[0].tick_params(labelsize=12)
        ax[1].legend(fontsize=14, loc='upper right')
        ax[1].set_xlabel("Energy Functions Evaluations", fontsize=16)
        ax[1].grid()
        ax[1].tick_params(labelsize=12)
        plt.savefig(f'figs/fhmm.pdf')# L-{l}_K-{k}_S-{s}_A-{a}_B-{b}.pdf')
        # plt.show()

if __name__ == '__main__':
    device = torch.device(f"cuda:{cmd_args.device}" if torch.cuda.is_available() else "cpu")
    myExp = MainExp()
    L = [1000]
    K = [10]
    alpha = [0.05] # , 0.1, 0.15]
    beta = [0.85] # , 0.9, 0.95]
    sigma2 = [0.25] # , 0.5, 1]
    for l in L:
        for k in K:
            for s in sigma2:
                for a in alpha:
                    for b in beta:
                        myExp.run(l, k, s, a, b)





