## This file is the code for some utility functions, eg. computing the predictive loglikelihood 
from gibbs_approx_parallel_efox import sample_lik_params
import numpy as np
from numpy.random import choice, normal, dirichlet, beta, gamma, multinomial, exponential, binomial
import scipy.stats as ss
import scipy.special as ssp
from multiprocessing import Pool
from functools import partial
from scipy.special import comb
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal, multivariate_t
import os, sys
home_dir = os.path.expanduser('~')
utils_path = os.path.join(home_dir, 'hierarchical-ts-modelling')
sys.path.append(utils_path)
from utils import evaluate_dist, plot_state_predictions_test_samples

def nbinom(k, n, p): # replaced all instances of ss.nbinom.pmf() with nbinom()
    return comb(k+n-1, n-1) * p**n * (1-p)**k

def multivariate_t_distribution(X,mu,Sigma,df):
    '''
    Multivariate t-student density:
    output:
        the density of the given element
    input:
        X = parameter (d dimensional numpy array or scalar)
        mu = mean (d dimensional numpy array or scalar)
        Sigma = scale matrix (dxd numpy array)
        df = degrees of freedom
    '''
    d = X.shape[0]; # num features
    Xm = X-mu;
    #print('Xm: ', Xm)
    #df = 3
    #print('df: ', df)
    #Sigma = np.where(Sigma<0.01, np.random.uniform(low=0.01, high=0.02), Sigma)
    #print('Sigma: ', Sigma)
    V = df*Sigma; # original
    #V = Sigma
    V_inv = np.linalg.inv(V);
    #print('V_inv: ', V_inv)
    #Sigma_inv = np.linalg.inv(Sigma)
    (sign, logdet) = np.linalg.slogdet(np.pi*V);# original
    #(sign, logdet) = np.linalg.slogdet(V)
    # ORIGINAL
    logz = ssp.loggamma(df/2) + (0.5*logdet) - ssp.loggamma((df+d)/2);
    logp = -0.5*(df+d)*np.log(1+np.sum(np.matmul(V_inv, Xm)*Xm,axis=-1));
    logp = logp - logz;

    #logz = ssp.loggamma(df/2) + (0.5*logdet) + (d/2)*np.log(df) + (d/2)*np.log(np.pi)- ssp.loggamma((df+d)/2);
    #logp = -0.5*(df+d)*np.log(1+(1/df)*np.sum(np.matmul(V_inv, Xm)*Xm,axis=-1));
    #logp = logp - logz;
    
   
    #if np.any(p>1):
    #    print('p is greater than 1!')
    #    print('p: ', p)
    return np.exp(logp)

def multivariate_gaussian(X, mu, Sigma, df):
    d = X.shape[0]; # num features
    Xm = X-mu
    Sigma_inv = np.linalg.inv(Sigma)

    _, logdet = np.linalg.slogdet(2*np.pi*Sigma)
    logp = (-0.5) * logdet
    logp += (-0.5) * np.dot(Xm, np.matmul(Sigma_inv, Xm))

    if np.exp(logp) > 1:
        print('p is greater than 1!')
        print('p: ', np.exp(logp))

    return np.exp(logp)






def compute_beta_param(mean_val, var_val):
    sum_val = (mean_val*(1-mean_val)/var_val) - 1;
    alpha_val = mean_val*sum_val;
    beta_val = (1-mean_val)*sum_val;
    return alpha_val, beta_val

def compute_real_transition_mat(zt_real, wt_real):
    n_mat_real = np.zeros((max(zt_real)+1, max(zt_real)+1));
    for t in range(1, len(zt_real)):
        if wt_real[t] == 0:
            n_mat_real[zt_real[t-1], zt_real[t]] += 1;
    return n_mat_real

def compute_confusion_mat(zt, zt_real):
    conf_mat = np.zeros((len(np.unique(zt_real)), max(zt)+1));
    for t in range(len(zt)):
        conf_mat[zt_real[t], zt[t]] += 1;
    return conf_mat

def estimate_y(zt, yt):
    yest = np.zeros(yt.shape);
    states = np.unique(zt);
    for j in states:
        ind = (zt==j);
        yest[ind] = (yt[ind].mean());
    return yest

## sample pi|z, w, kappa, alpha, beta
def sample_pi_our(K, alpha0, beta_vec, beta_new, n_mat, kappa_vec, kappa_new):
    pi_mat = np.zeros((K+1,K+1));
    for j in range(K):
        prob_vec = np.hstack((alpha0*beta_vec+n_mat[j], alpha0*beta_new));
        prob_vec[prob_vec<0.01] = 0.01; ## clip step
        pi_mat[j] = dirichlet(prob_vec, size=1)[0];
    prob_vec = np.hstack((alpha0*beta_vec, alpha0*beta_new));
    prob_vec[prob_vec<0.01] = 0.01; ## clip step
    pi_mat[-1] = dirichlet(prob_vec, size=1)[0];
    
    ## compute transition probability
    kappa_all = np.hstack((kappa_vec, kappa_new));
    prob_mat = pi_mat*np.expand_dims(1-kappa_all, axis=1) + np.diag(kappa_all);
    return prob_mat

## sample pi|z, alpha, beta, rho0
def sample_pi_efox(K, alpha0, beta_vec, beta_new, n_mat, rho0):
    pi_mat = np.zeros((K+1,K+1));
    for j in range(K):
        prob_vec = np.hstack((alpha0*beta_vec+n_mat[j], alpha0*beta_new));
        prob_vec[j] += rho0;
        prob_vec[prob_vec<0.01] = 0.01; ## clip step
        pi_mat[j] = dirichlet(prob_vec, size=1)[0];
    prob_vec = np.hstack((alpha0*beta_vec, alpha0*beta_new+rho0));
    prob_vec[prob_vec<0.01] = 0.01; ## clip step
    pi_mat[-1] = dirichlet(prob_vec, size=1)[0];
    return pi_mat

## compute log marginal likelihood p(y|pi,alpha,beta,sigma,z,w,kappa,yold)
#forward - backward algorithm
#p(y1, ... yt, zt=k)
#p(y1, ... yt, yt+1, zt+1=j) = sum_k(p(y1, ... yt, zt=k)*pi(k,j))*yt+1

def compute_log_marginal_lik_gaussian(K, yt, zt, prob_mat, mu0, sigma0, sigma0_pri, ysum, ycnt):
    ## if zt is -1, then yt is a brand new sequence starting with state 0
    ## if zt is not -1, then it's the state of time point before the first time point of yt
    
    T = len(yt);
    a_mat = np.zeros((T+1, K+1));
    c_vec = np.zeros(T);
    if zt != -1:
        a_mat[0,zt] = 1; #np.log(ss.norm.pdf(yt[0],0,sigma0));
    
    ## compute mu sigma posterior
    varn = 1/(1/(sigma0_pri**2) + ycnt/(sigma0**2));
    mun = ((mu0/(sigma0_pri**2)) + (ysum/(sigma0**2)))*varn;
    
    varn = np.hstack((np.sqrt((sigma0**2)+varn), np.sqrt((sigma0**2)+(sigma0_pri**2))));
    mun = np.hstack((mun, mu0));
    
    for t in range(T):
        if t==0 and zt==-1:
            j = 0;
            a_mat[t+1, j] = ss.norm.pdf(yt[t], mun[j], varn[j]);
        else:
            for j in range(K+1):
                a_mat[t+1, j] = sum(a_mat[t,:]*prob_mat[:,j])*ss.norm.pdf(yt[t], mun[j], varn[j]);
        c_vec[t] = sum(a_mat[t+1,:])
        a_mat[t+1,:] /= c_vec[t];
    
    log_marginal_lik = sum(np.log(c_vec));
    return a_mat, log_marginal_lik

## compute log marginal likelihood p(y|pi,alpha,beta,sigma,z,w,kappa,yold)
def compute_log_marginal_lik_multinomial(K, yt, zt, prob_mat, dir0, ysum): 
    ## if zt is -1, then yt is a brand new sequence starting with state 0
    ## if zt is not -1, then it's the state of time point before the first time point of yt
    
    T = len(yt);
    a_mat = np.zeros((T+1, K+1));
    c_vec = np.zeros(T);
    n_multi = sum(yt[0]);
    dir0_sum = sum(dir0);
    if zt != -1:
        a_mat[0,zt] = 1; #np.log(ss.norm.pdf(yt[0],0,sigma0));
    
    ## compute mu sigma posterior
    yt_dist=(ssp.loggamma(dir0_sum+ysum.sum(axis=1))-ssp.loggamma(dir0_sum+ysum.sum(axis=1)+n_multi))-np.sum(ssp.loggamma(dir0+ysum),axis=1);
    yt_knew_dist = ssp.loggamma(dir0_sum)-ssp.loggamma(dir0_sum+n_multi)-np.sum(ssp.loggamma(dir0));
    yt_dist = np.hstack((yt_dist, yt_knew_dist))+ssp.loggamma(n_multi);
    yt_dist = np.real(yt_dist);
    
    single_term = np.vstack((dir0+ysum, dir0));
            
    for t in range(T):
        if t==0 and zt==-1:
            j = 0;
            a_mat[t+1,j] = np.exp(yt_dist[j]+np.real(np.sum(ssp.loggamma(single_term[j]+yt[t])-ssp.loggamma(1+yt[t]))));
        else:
            for j in range(K+1):
                a_mat[t+1,j] = sum(a_mat[t,:]*prob_mat[:,j])*np.exp(yt_dist[j]+np.real(np.sum(ssp.loggamma(single_term[j]+yt[t])-ssp.loggamma(1+yt[t]))));
            
        c_vec[t] = sum(a_mat[t+1,:])
        a_mat[t+1,:] /= c_vec[t];
    
    log_marginal_lik = sum(np.log(c_vec));
    return a_mat, log_marginal_lik

## compute log marginal likelihood p(y|pi,alpha,beta,sigma,z,w,kappa,yold)
def compute_log_marginal_lik_poisson(K, yt, zt, prob_mat, lam_a_pri, lam_b_pri, ysum, ycnt):
    T = len(yt);
    m_multi = len(yt[0]);
    a_mat = np.zeros((T+1, K+1)); # the tj'th element is the likelihood p(x1, ..., xt, zt=j)
    c_vec = np.zeros(T);
    if zt != -1:
        a_mat[0,zt] = 1; #np.log(ss.norm.pdf(yt[0],0,sigma0));
    
    ## compute lambda posterior
    
    lam_a_post = lam_a_pri + ysum;
    lam_b_post = lam_b_pri + ycnt;
    
    for t in range(T):
        if t==0 and zt==-1:
            j = 0;
            a_mat[t+1, j] = np.exp(np.sum(np.log(nbinom(yt[t], lam_a_post[j], lam_b_post[j]/(lam_b_post[j]+1)))));
        else:
            for j in range(K+1):
                yt_dist = np.sum(np.log(nbinom(yt[t], lam_a_post, lam_b_post/(lam_b_post+1))), axis=1); # prob of seeing yt[t] given each dif state
                yt_knew_dist = np.sum(np.log(nbinom(yt[t], lam_a_pri, lam_b_pri/(lam_b_pri+1))));
                yt_dist = np.exp(np.hstack((yt_dist, yt_knew_dist)));
                a_mat[t+1, j] = sum(a_mat[t,:]*prob_mat[:,j])*yt_dist[j];
        c_vec[t] = sum(a_mat[t+1,:])
        a_mat[t+1,:] /= c_vec[t];
    
    log_marginal_lik = sum(np.log(c_vec))/T; # Modified to divide by T, so we are getting the average ll over time steps.
    return a_mat, log_marginal_lik

def compute_log_marginal_lik_ar(K,yt,zt,prob_mat,M0,V0,S0,n0,s_ybar_ybar_inv,s_y_y_plus_s0,s_y_ybar,s_y_cond_ybar_plus_s0,dff):
    ## if zt is -1, then yt is a brand new sequence starting with state 0
    ## if zt is not -1, then it's the state of time point before the first time point of yt
    
    T = len(yt);
    m_multi = yt.shape[1];
    a_mat = np.zeros((T, K+1));
    c_vec = np.zeros(T-1);
    if zt != -1:
        a_mat[0,zt] = 1; #np.log(ss.norm.pdf(yt[0],0,sigma0));
    else:
        a_mat[0,0] = 1;
        
    for t in range(1,T):
        ## compute y marginal likelihood
        tmp = np.matmul(s_ybar_ybar_inv, yt[t-1]); #K by m_multi mat
        nun = 1/np.matmul(tmp, yt[t-1]); # length K vector
        mun = np.array([np.matmul(s_y_ybar[ik], tmp[ik]) for ik in range(K)]); # K by m_multi mat
        yt_dist = np.array([multivariate_t_distribution(yt[t], mun[ik], s_y_cond_ybar_plus_s0[ik]*(1+(1/nun[ik]))/dff[ik], dff[ik]) for ik in range(K)]);
        
        mu_new = np.matmul(M0, yt[t-1]);
        nu_new = 1/np.sum(np.matmul(V0, yt[t-1])*yt[t-1]);
        yt_knew_dist = multivariate_t_distribution(yt[t], mu_new, S0*(1+(1/nu_new))/(n0+1-m_multi), (n0+1-m_multi));
        
        yt_dist = np.hstack((yt_dist, yt_knew_dist));
        for j in range(K+1):
            a_mat[t, j] = sum(a_mat[t-1,:]*prob_mat[:,j])*yt_dist[j];
        c_vec[t-1] = sum(a_mat[t,:]);
        a_mat[t,:] /= c_vec[t-1];
    
    log_marginal_lik = sum(np.log(c_vec));
    return a_mat, log_marginal_lik

def compute_log_marginal_lik_ar_all(K,yt_ls,prob_mat,M0,V0,S0,n0,s_ybar_ybar_inv,s_y_y_plus_s0,s_y_ybar,s_y_cond_ybar_plus_s0,dff):
    log_liks = []
    for yt in yt_ls:
        _, loglike = compute_log_marginal_lik_ar(K,yt,-1,prob_mat,M0,V0,S0,n0,s_ybar_ybar_inv,s_y_y_plus_s0,s_y_ybar,s_y_cond_ybar_plus_s0,dff)
        log_liks.append(loglike)
    return np.array(log_liks)

def compute_log_marginal_lik_poisson_approx(L, yt, zt, prob_mat, suff_stats):
    T = len(yt);
    m_multi = len(yt[0]);
    a_mat = np.zeros((T+1, L));
    c_vec = np.zeros(T);
    if zt != -1:
        a_mat[0,zt] = 1; #np.log(ss.norm.pdf(yt[0],0,sigma0));
    
    ## compute lambda posterior
    
    lam_a_post = suff_stats['ysum'];
    lam_b_post = suff_stats['ycnt'];
    
    for t in range(T):
        if t==0 and zt==-1:
            j = 0;
            a_mat[t+1, j] = np.exp(np.sum(np.log(nbinom(yt[t], lam_a_post[j], lam_b_post[j]/(lam_b_post[j]+1)))));
        else:
            for j in range(L):
                yt_dist = np.sum(np.log(nbinom(yt[t], lam_a_post, lam_b_post/(lam_b_post+1))), axis=1);
                yt_dist = np.exp(yt_dist);
                a_mat[t+1, j] = sum(a_mat[t,:]*prob_mat[:,j])*yt_dist[j];
        c_vec[t] = sum(a_mat[t+1,:])
        a_mat[t+1,:] /= c_vec[t];
        
    
    log_marginal_lik = sum(np.log(c_vec));
    return a_mat, log_marginal_lik

def compute_log_marginal_lik_ar_approx(L, yt, zt, prob_mat, suff_stats):
    ## if zt is -1, then yt is a brand new sequence starting with state 0
    ## if zt is not -1, then it's the state of time point before the first time point of yt
    
    T = len(yt);
    m_multi = yt.shape[1];
    a_mat = np.zeros((T, L));
    c_vec = np.zeros(T-1);
    if zt != -1:
        a_mat[0,zt] = 1; #np.log(ss.norm.pdf(yt[0],0,sigma0));
    else:
        a_mat[0,0] = 1;
        
    for t in range(1,T):
        ## compute y marginal likelihood
        tmp = np.matmul(suff_stats['s_ybar_ybar_inv'], yt[t-1]); #K by m_multi mat
        nun = 1/np.matmul(tmp, yt[t-1]); # length K vector
        mun = np.array([np.matmul(suff_stats['s_y_ybar'][ik], tmp[ik]) for ik in range(L)]); # K by m_multi mat
        dff = suff_stats['dff']+1-m_multi;
        
        yt_dist = np.array([multivariate_t_distribution(yt[t],
                                                        mun[ik],
                                                        suff_stats['s_y_cond_ybar_plus_s0'][ik]*(1+(1/nun[ik]))/dff[ik],
                                                        dff[ik]) for ik in range(L)]);
        
        for j in range(L):
            a_mat[t, j] = sum(a_mat[t-1,:]*prob_mat[:,j])*yt_dist[j];
        c_vec[t-1] = sum(a_mat[t,:]);
        a_mat[t,:] /= c_vec[t-1];
    
    log_marginal_lik = sum(np.log(c_vec));
    return a_mat, log_marginal_lik

def compute_log_marginal_lik_ar_approx_parallel(L,verbose, prob_mat, pi_init, suff_stats, yt_ls, old_way=True):
    if verbose:
        print('Entered compute log marginal lik ar apx parallel')
    # i.e. the forward algorithm. Finds p(y1, ..., yt)
    ## if zt is -1, then yt is a brand new sequence starting with state 0
    ## if zt is not -1, then it's the state of time point before the first time point of yt
    log_marginal_lik_ls = [];
    counter = 0
    Ts = []
    lik_params = sample_lik_params(suff_stats, mode='ar')
    
    for ind, yt in enumerate(yt_ls):
        T = len(yt);
        Ts.append(T)
        m_multi = yt.shape[1];
        a_mat = np.zeros((T, L)); # tl'th element is p(zt=l, y1, ..., yt)
        c_vec = np.zeros(T);
        for t in range(0,T):
            ## compute y marginal likelihood
            # sequence of L n_feat x n_feat identity matrices matmul'd with a (n_feat,) vector.
            tmp = np.matmul(suff_stats['s_ybar_ybar_inv'], yt[t-1]); # result is shape (L, n_feat)
            nun = 1/np.matmul(tmp, yt[t-1]); # length L vector
            mun = np.array([np.matmul(suff_stats['s_y_ybar'][ik], tmp[ik]) for ik in range(L)]); # L by n_feat mat
            dff = suff_stats['dff']+1-m_multi;     
             
            if old_way:
                # original
                yt_dist = np.array([multivariate_t_distribution(yt[t],
                                                                mun[ik],
                                                                suff_stats['s_y_cond_ybar_plus_s0'][ik]*(1+(1/nun[ik]))/dff[ik],
                                                                dff[ik]) for ik in range(L)]);
            else:
                # TODO: Do with multivariate normal and A/Sigma only, see what numbers look like
                A, Sigma = lik_params['a_mat_post'], lik_params['sigma_mat_post']
                if t==0:
                    yt_dist = np.array(
                        [multivariate_normal(mean=np.matmul(A[ik], np.zeros_like(yt[t])), cov=Sigma[ik]).pdf(yt[t]) for ik in
                        range(L)])
                else:
                    #print('means for multivariate normal: \n', [np.matmul(A[ik], yt[t - 1]) for ik in range(L)])
                    yt_dist = np.array(
                        [multivariate_normal(mean=np.matmul(A[ik], yt[t - 1]), cov=Sigma[ik]).pdf(yt[t]) for ik in
                        range(L)])
            
            
            #if ind<5:
            #    print('yt_dist for ind %d: '%ind, yt_dist)
            if t == 0:
                a_mat[t] = pi_init*yt_dist; # When there's no previous state, just compute p(z1)*p(y1|z1) = p(z1, y1) = a_mat[1]
            else:
                # The matmul dot products a_mat[t-1] which is p(Z_{t-1}, Y_{1:t-1}=y_{1:t-1}) with each col of prob_mat. 
                # The i'th col of prob_mat is p(Z_t=i|Z_{t-1})
                # The result of the matmul is a size L vector, where the l'th element is p(Z_t=l, Y_{1:t-1}=y_{1:t-1}). So the vector represents P(Z_t, Y_{1:t-1}=y_{1:t-1})
                # This result is then multiplied with yt_dist = P(Y_t=y_t|Z_t), so after multiplying we have p(Z_t, Y_{1:t}=y_{1:t}) which is precisely a_mat[t]
                a_mat[t] = np.matmul(a_mat[t-1].reshape(1,-1), prob_mat).reshape(-1)*yt_dist;
            
            c_vec[t-1] = sum(a_mat[t]);
            a_mat[t] /= c_vec[t-1]; # For numerical stability of a_mat
        
        log_marginal_lik_ls.append(sum(np.log(c_vec)))
        counter +=1
    
    log_marginal_lik_ls = np.array(log_marginal_lik_ls);
    print('=-=-=-')
    print('log_marginal_lik_ls: ', log_marginal_lik_ls)
    print('Ts: ', Ts)
    return log_marginal_lik_ls

def compute_log_marginal_lik_ar_fmp(L,prob_mat, pi_init, suff_stats,yt_all_ls,n_cores, verbose, return_mean=True, old_way=True):
    chunks = np.array_split(yt_all_ls, n_cores);
    #chunks = [yt_all_ls[i::n_cores] for i in range(n_cores)];
    func = partial(compute_log_marginal_lik_ar_approx_parallel, L, verbose, prob_mat, pi_init, suff_stats, old_way=old_way)
    print('len yt_all_ls: ', len(yt_all_ls))
    print('n_cores: ', n_cores)
    pool = Pool(processes=n_cores);
    results = pool.map(func, chunks);
    pool.close();
    pool.join();
    log_lik = np.concatenate(results, axis=0); # of shape (n_samples,)
    if return_mean:
        return log_lik.mean() # avg log likelihood of all observations in all samples
    else:
        return log_lik

def compute_log_marginal_lik_poisson_approx_parallel(L, prob_mat, pi_init, suff_stats, yt_ls):
    ## if zt is -1, then yt is a brand new sequence starting with state 0
    ## if zt is not -1, then it's the state of time point before the first time point of yt
    log_marginal_lik_ls = [];
    lam_a_post = suff_stats['ysum'];
    lam_b_post = suff_stats['ycnt'];
    
    for yt in yt_ls:
        T = len(yt);
        m_multi = yt.shape[1];
        a_mat = np.zeros((T, L));
        c_vec = np.zeros(T-1);
        
        for t in range(T):
            ## compute y marginal likelihood
            yt_dist = np.sum(np.log(nbinom(yt[t], lam_a_post, lam_b_post/(lam_b_post+1))), axis=1);
            yt_dist = np.exp(yt_dist);
            
            if t == 0:
                a_mat[t] = pi_init*yt_dist;
            else:
                a_mat[t] = np.matmul(a_mat[t-1].reshape(1,-1), prob_mat).reshape(-1)*yt_dist;
            c_vec[t-1] = sum(a_mat[t]);
            a_mat[t] /= c_vec[t-1];
        log_marginal_lik_ls.append(sum(np.log(c_vec)));
    
    log_marginal_lik_ls = np.array(log_marginal_lik_ls);
    return log_marginal_lik_ls

def compute_log_marginal_lik_poisson_fmp(L,prob_mat, pi_init, suff_stats,yt_all_ls,n_cores):
    chunks = np.array_split(yt_all_ls, n_cores);
    #chunks = [yt_all_ls[i::n_cores] for i in range(n_cores)];
    func = partial(compute_log_marginal_lik_poisson_approx_parallel,L, prob_mat,pi_init,suff_stats);
    
    pool = Pool(processes=n_cores);
    results = pool.map(func, chunks);
    pool.close();
    pool.join();
    
    log_lik = np.concatenate(results, axis=0);
    return log_lik.sum()


def compute_log_obs_lik_ar_approx_parallel(L, prob_mat, pi_init, lik_params, yt_ls):
    ## if zt is -1, then yt is a brand new sequence starting with state 0
    ## if zt is not -1, then it's the state of time point before the first time point of yt
    log_obs_lik_ls = [];
    
    for yt in yt_ls:
        T = len(yt);
        m_multi = yt.shape[1];
        a_mat = np.zeros((T, L));
        c_vec = np.zeros(T-1);
        mun = np.matmul(lik_params['a_mat_post'],yt[:-1].T); # L-d-(T-1) mat
        mun = np.transpose(mun, [0,2,1]); #L-(T-1)-d mat
        L = mun.shape[0];
        T = mun.shape[1]+1;
        lik_mat = np.zeros((T,L));
        for ik in range(L):
            for it in range(T-1):
                lik_mat[it+1,ik] = ss.multivariate_normal.pdf(yt[it+1], mean=mun[ik,it], cov=lik_params['sigma_mat_post'][ik]);

        for t in range(1,T):
            if t == 1:
                a_mat[t] = pi_init*lik_mat[t];
            else:
                a_mat[t] = np.matmul(a_mat[t-1].reshape(1,-1), prob_mat).reshape(-1)*lik_mat[t];
            c_vec[t-1] = sum(a_mat[t]);
            a_mat[t] /= c_vec[t-1];
        log_obs_lik_ls.append(sum(np.log(c_vec)));
    
    log_obs_lik_ls = np.array(log_obs_lik_ls);
    return log_obs_lik_ls

def compute_log_obs_lik_ar_fmp(L,prob_mat, pi_init, lik_params,yt_all_ls,n_cores):
    chunks = np.array_split(yt_all_ls, n_cores);
    #chunks = [yt_all_ls[i::n_cores] for i in range(n_cores)];
    func = partial(compute_log_obs_lik_ar_approx_parallel,L, prob_mat,pi_init,lik_params);
    
    pool = Pool(processes=n_cores);
    results = pool.map(func, chunks);
    pool.close();
    pool.join();
    
    log_lik = np.concatenate(results, axis=0);
    return log_lik.sum()


def plot_generated_sample(L, A_mats, Sigma_mats, beta, save_path, x_train_min, x_train_max, x_train_max_len, colors=None):
    ordered_states = list(np.argsort(beta)[::-1])
    states = []
    sample = []
    stds = []
    for i in ordered_states:
        for j in range(1, int(x_train_max_len*beta[i]) + 1):
                states.append(i)
                point = []
                if len(sample) == 0:
                    for _ in range(40):
                        pt = np.random.multivariate_normal(np.zeros((A_mats[0].shape[0],)), np.identity(A_mats[0].shape[0]))
                        point.append(pt)
                    sample.append(np.mean(np.stack(point), axis=0))
                    stds.append(np.std(np.stack(point), axis=0))
                    point = []
                
                sigma_mat = Sigma_mats[i]
                for _ in range(40):
                    point.append(np.random.multivariate_normal(A_mats[i].dot(sample[j-1]), sigma_mat))
                sample.append(np.mean(np.stack(point), axis=0))
                stds.append(np.std(np.stack(point), axis=0))

    states = np.stack(states)
    sample = np.stack(sample)
    stds = np.stack(stds)

    if colors is None:
        colors = list(sns.color_palette('Set3', L+1))

    fig, ax = plt.subplots(1, 1, figsize=(8, 2))
    for i in range(sample.shape[-1]):
        ax.plot(sample[:, i], label="Feature %d" % i)
        ax.fill_between(np.arange(sample.shape[0]), sample[:, i] - stds[:, i], sample[:, i] + stds[:, i], color='b', alpha=0.4)
    
    for t in range(sample.shape[0] - 1):
        ax.axvspan(t, t + 1, alpha=0.4, color=colors[states[t]])
    plt.ylim(x_train_min-2, x_train_max+2)
    plt.title("Generated Time series sample from the posterior", fontsize=16, fontweight='bold')

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()


def post_training_eval(args, name, cv, z_train, train_lens, L, x_test, z_test, test_lens, x_train):
    print('Entering eval mode!')
    print('Loading checkpoint..')
    checkpoint = np.load('./' + args.data + '/' + name + '/checkpoint_' + str(cv) + '.npz', allow_pickle=True)
    
    x_train_reshaped = np.concatenate([x_train[i] for i in range(len(x_train))])


    new_colors = [color for color in np.load('../../%s_color_codes.pkl'%args.data, allow_pickle=True)]
    if not os.path.exists('./' + args.data + '/' + name + '/eval/'):
        os.makedirs('./' + args.data + '/' + name + '/eval/')
    lik_params = checkpoint['lik_params'][-1]
    beta = checkpoint['beta_vecs'][-1]
    zt_train = checkpoint['zt'][-1]
    _, state_mapper = evaluate_dist(z_true=z_train, z_pred=zt_train, z_lens=train_lens, k_max=L)

    colors_pre = list(sns.color_palette('Set3', L+1))
    colors = colors_pre.copy()
    ordered_states = np.argsort(beta)[::-1]
    for st_ind, st in enumerate(ordered_states):
        colors[st] = colors_pre[st_ind]

        # Use state mapper instead of below
        # search first 10 for repeat. Keep track of ones I'm throwing away, so can swap back in later if repeat happens
        # Plot the 10 test samples with new colors 
    z_train_long = np.concatenate([z_train[i] for i in range(len(z_train))])
    unique_elements, counts = np.unique(z_train_long, return_counts=True)
    sorted_indices = np.argsort(counts)
    sorted_elements = unique_elements[sorted_indices][::-1]

    thrown_away = []
    print('state mapper: ', state_mapper)
    print('sorted_elements: ', sorted_elements)
    print('new_colors: ', new_colors)
    for i, color in enumerate(new_colors):
        for key, value in state_mapper.items():
            if key == sorted_elements[i]:
                thrown_away.append(colors[value])
                colors[value] = color
                break
    
    # Plot generated data
    plot_generated_sample(L=L, A_mats=lik_params['a_mat_post'], Sigma_mats=lik_params['sigma_mat_post'], beta=beta, save_path='./%s/%s/eval/gen_sample.pdf'%(args.data, name), x_train_min=x_train_reshaped.min(), x_train_max=x_train_reshaped.max(), x_train_max_len=max([len(x_i) for x_i in x_train]), colors=colors)
    

    # Pie plot
    
    gt_state_count = [np.count_nonzero(np.concatenate([z_train[i] for i in range(len(z_train))])==z_s) for z_s in range(len(np.unique(np.concatenate([z_train[i] for i in range(len(z_train))]))))]
    fig, axs = plt.subplots(1, 2, figsize=(10, 5))
    axs[0].pie(np.sort(beta)[::-1], colors=[colors[ii_state] for ii_state in ordered_states])
    axs[0].set_title("Predicted class distribution")
    axs[1].pie(np.sort(gt_state_count)[::-1],colors=[colors[state_mapper[ii_state]] for (ii_state) in np.argsort(gt_state_count)[::-1]])
    axs[1].set_title("Ground truth class distribution mapped with hungarian")


    plt.tight_layout()
    plt.savefig("./%s/%s/eval/pie_chart.pdf" % (args.data, name))
    plt.close()

    # Plot test samples
    zt_test = checkpoint['last_zt_test']
    plot_state_predictions_test_samples(n_test_samples=min(10, len(x_test)), test_x=x_test, test_z=z_test, test_lens=test_lens, test_preds=zt_test, state_mapper=state_mapper, colors=colors, save_path="./%s/%s/eval/IND_test_sample_inference.pdf" % (args.data, name))