import networkx as nx
import argparse
import random
import os, pickle
import matplotlib.pyplot as plt
from SIP.model import Ising, BMM, FHMM
from SIP.sampler import *

class MainExp(object):
    def __init__(self, seed=cmd_args.seed):
        self.model = None
        self.rng = np.random.RandomState(seed)
        self.color = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink',
                      'tab:grey', 'tab:olive', 'tab:cyan']

    def init_grid_ising(self, p=100, mu=2, sigma=3, lamda=1, seed=0, device=torch.device("cpu")):
        self.model = Ising(p=p, mu=mu, sigma=sigma, lamda=lamda, seed=seed, device=device)
        self.model_info = f"ising/p-{p}_mu-{mu}_sigma-{sigma}_lamda-{lamda}"


    def init_bmm(self, p=100, m=10, seed=0, device=torch.device('cpu')):
        self.model = BMM(p=p, m=m, seed=seed, device=device)
        self.model_info = f"bmm/p-{p}_m-{m}"

    def init_fhmm(self, L=1000, K=10, seed=0, device=torch.device('cpu')):
        self.model = FHMM(L=L, K=K, seed=seed, device=device)
        self.model_info = f"fhmm"

    def _run(self, method='gibbs_r-1', T=1000, *args, **kwargs):
        config = method.split(sep='_')
        if config[0] == 'RW':
            sampler = RWSampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        elif config[0] == 'Gibbs':
            sampler = GibbsSampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        elif config[0] == 'LB':
            sampler = LBSampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        elif config[0] == 'MT':
            sampler = MTSampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        elif config[0] == 'GWG':
            sampler = GWGSampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        elif config[0] == 'MSA':
            sampler = MSASampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        elif config[0] == 'MSF':
            sampler = MSFSampler(R=int(config[1].split(sep='-')[1]))
            return sampler.sample(self.model, T=T, method=method)
        else:
            raise NotImplementedError

    def eva(self, methods, T=1000, device=torch.device("cpu")):
        log = ''
        if not os.path.isdir(f'results/{self.model_info}'):
            os.makedirs(f'results/{self.model_info}')
        if not os.path.isfile(f'results/{self.model_info}/data.pkl'):
            with open(f'results/{self.model_info}/data.pkl', 'wb') as handle:
                pickle.dump(self.model, handle)
        for i, method in enumerate(methods):
            ESS_p = []
            ESS_d = []
            TIME = []
            SUCC = []
            res = []
            for j in range(cmd_args.N):
                logp, trace, elapse, succ = self._run(method, T=T)
                auto_cor_p = acf(logp[T // 2:], nlags=T // 4, fft=True)
                auto_cor_d = acf(trace[T // 2:], nlags=T // 4, fft=True)
                ess_p = self.ess(auto_cor_p)
                ess_d = self.ess(auto_cor_d)
                ESS_p.append(ess_p)
                ESS_d.append(ess_d)
                TIME.append(elapse)
                SUCC.append(succ)
                res.append([logp, trace, elapse, succ])
            log += f"{method}: \tESS_p: {np.mean(ESS_p):.1f}, {np.std(ESS_p):.1f}, " \
                   f"ESS_d: {np.mean(ESS_d):.1f}, {np.std(ESS_d):.1f} " \
                   f"TIME: {np.mean(TIME):.0f}, succ: {np.mean(SUCC) / T:.4f}\n"
            with open(f'results/{self.model_info}/{method}.pkl', 'wb') as handle:
                pickle.dump(res, handle)
        with open(f'results/{self.model_info}/res.txt', 'a') as handle:
            handle.write(log)

    def ess(self, auto_cor):
        rho = 0
        for i in range(len(auto_cor)):
            if auto_cor[i] < 0:
                break
            rho += auto_cor[i]
        return len(auto_cor) / (1 + 2 * rho)

    def run(self, methods, T=1000, burn_in=2000, interval=2000, max_lag=100):
        if not os.path.isdir(f'figs/{self.model_info}'):
            os.makedirs(f'figs/{self.model_info}')
        fig, ax = plt.subplots(3, 3, figsize=(24, 18))
        ax[0, 0].set_title(f"AutoCorrelation: log_p", fontsize=16)
        ax[0, 1].set_title(f"Trace Plot: log_p", fontsize=16)
        ax[0, 2].set_title(f"ESS: log_p", fontsize=16)
        ax[1, 0].set_title(f"AutoCorrelation: trace", fontsize=16)
        ax[1, 1].set_title(f"Trace Plot: trace", fontsize=16)
        ax[1, 2].set_title(f"ESS: trace", fontsize=16)
        ax[2, 0].set_title(f"Burn in: log_p", fontsize=16)
        ax[2, 1].set_title(f"Reconstruction Error", fontsize=16)
        ax[2, 2].set_title(f"Accept Rate", fontsize=16)
        Succ, Ess_p, Ess_d, Time = [], [], [], []
        for i, method in enumerate(methods):
            logp, trace, elapse, succ = self._run(method, T=T)
            auto_cor_p = acf(logp[T // 2:], nlags=min([T // 2 - 1, max_lag]), fft=True)
            auto_cor_d = acf(trace[T // 2:], nlags=min([T // 2 - 1, max_lag]), fft=True)
            ess_p = self.ess(auto_cor_p)
            ess_d = self.ess(auto_cor_d)
            ax[0, 0].plot(np.arange(interval), auto_cor_p[:interval], label=method, color=self.color[i], alpha=0.5)
            ax[0, 1].plot(np.arange(interval), logp[-interval:], label=method, color=self.color[i], alpha=0.5)
            ax[1, 0].plot(np.arange(interval), auto_cor_d[:interval], label=method, color=self.color[i], alpha=0.5)
            ax[1, 1].plot(np.arange(interval), trace[-interval:], label=method, color=self.color[i], alpha=0.5)
            ax[2, 0].plot(np.arange(burn_in), logp[:burn_in], label=method, color=self.color[i], alpha=0.5)
            ax[2, 1].plot(np.arange(burn_in), trace[:burn_in], label=method, color=self.color[i], alpha=0.5)
            Succ.append(succ / len(trace))
            Ess_p.append(ess_p)
            Ess_d.append(ess_d)
            Time.append(elapse)
        energy = self.model.energy(self.model.X).item()
        ax[2, 0].plot(np.arange(burn_in), [energy] * burn_in, label='GT', color=self.color[i+1], alpha=0.5)
        for i in range(len(methods)):
            ax[2, 2].bar(i, Succ[i], color=self.color[i])
            ax[0, 2].bar(i, Ess_p[i], color=self.color[i])
            ax[1, 2].bar(i, Ess_d[i], color=self.color[i])
        for idx in range(3):
            ax[0, idx].legend()
            ax[2, idx].legend()
        fig.suptitle(self.model_info)
        plt.show()


if __name__ ==  "__main__":
    device = torch.device(f"cuda:{cmd_args.device}" if torch.cuda.is_available() else "cpu")
    myExp = MainExp()
    if cmd_args.model == 'ising':
        myExp.init_grid_ising(p=cmd_args.p, mu=cmd_args.mu, sigma=cmd_args.sigma, lamda=cmd_args.lamda,
                              seed=cmd_args.seed, device=device)
    elif cmd_args.model == 'bmm':
        myExp.init_bmm(p=cmd_args.p, m=cmd_args.m, seed=cmd_args.seed, device=device)
    elif cmd_args.model == 'fhmm':
        myExp.init_fhmm(L=cmd_args.L, K=cmd_args.K, seed=cmd_args.seed, device=device)
    else:
        raise NotImplementedError

    myExp.run(['GWG_R-1', 'GWG_R-3', 'MSF_R-3',
               # 'Gibbs_R-1', 'Gibbs_R-3', 'RW_R-1', 'RW_R-5'
               ], T=10000, burn_in=5000, interval=1000, max_lag=2000)
    # myExp.eva([cmd_args.method], T=cmd_args.T, device=device)


