## how to run this file
## nohup python run_full_bayesian_approx_parallel_gibbs_ar_efox.py 1 30 2 'prior' bee_seq_data ./ 40 &
## see the comments below for the meaning of these command line parameters (sys.argv)

## load packages
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as ss
from sklearn.metrics import hamming_loss
import sys
sys.path.append('../../code/')
from gibbs_approx_parallel_efox import *
from util import *
import argparse
sys.path.append('../../../')
from utils import evaluate_dist, evaluate_dist_beta, line_plot, plot_state_predictions_test_samples, load_data
import seaborn as sns
import datetime
import os
from multiprocessing import Pool
import wandb
import ast
sns.set()
def gibbs_sample(x_train, pi_bar, pi_init, L, lik_params, mode, init_suff_stats, n_cores,
                 beta_vec, alpha0, alpha0_init, rho0, gamma0, alpha0_a_pri, alpha0_b_pri, gamma0_a_pri, gamma0_b_pri, c_pri, d_pri):
    
    # Sampling zt, n_mat, n_ft, K, uniq, suff_stats
    zt, n_mat, n_ft, K, uniq, suff_stats = sample_zw_fmp(x_train, pi_bar, pi_init, L, lik_params, mode, init_suff_stats, n_cores)
    
    # Sampling HDP prior parameters
    m_mat, m_init, w_vec, m_mat_bar = sample_m_w_mbar(n_mat, n_ft, beta_vec, alpha0, alpha0_init, rho0)
    beta_vec = sample_beta(m_mat_bar, m_init, gamma0)
    pi_bar, pi_init = sample_pi(n_mat, n_ft, alpha0, alpha0_init, rho0, beta_vec)
    
    # Sampling observation likelihood parameters
    lik_params = sample_lik_params(suff_stats, mode)
    
    # Sampling hyperparameters
    concentration, alpha0_init = sample_concentration(m_mat, n_mat, alpha0, rho0, m_init, n_ft, alpha0_init, alpha0_a_pri, alpha0_b_pri)
    gamma0 = sample_gamma(K, m_mat_bar, m_init, gamma0, gamma0_a_pri, gamma0_b_pri)
    stick_ratio = sample_stick_ratio(w_vec, m_mat, c_pri, d_pri)
    rho0, alpha0 = transform(concentration, stick_ratio)  # rho = concentration*stick_ratio, alpha0 = concentration - rho0

    return zt, n_mat, n_ft, K, uniq, suff_stats, m_mat, m_init, w_vec, m_mat_bar, beta_vec, pi_bar, pi_init, lik_params, concentration, alpha0_init, gamma0, stick_ratio, rho0, alpha0


def is_possemi_def(x):
    return np.all(np.linalg.eigvals(x) >= 0)

def main(args, data_load_config, name, cv):
    iters = args.iters; ## number of iterations
    n_cores = 1
    init_way = 'prior' ## initialization way: from 'prior' or from 'hmm' (parametric hmm result)
    L = int(args.k_max); ## in general, set it to be twice the number of states is already good enough

    args_dict = vars(args) # converts to dict
    print('args dict: ', args_dict)
    
    x_train, z_train, train_lens, x_valid, z_valid, valid_lens, x_test, z_test, test_lens = load_data(args.data, data_load_config, normalize=False, pad_ragged=False, path_to_data='../../../data/')
    unq = []
    for z in z_train:
        unq.append(np.unique(z))
    unq = np.unique(np.concatenate(unq))
    print('unique z_train: ', unq)

    unq = []
    for z in z_test:
        unq.append(np.unique(z))
    unq = np.unique(np.concatenate(unq))
    print('unique z_test: ', unq)
    if args.eval:
        post_training_eval(args, name, cv, z_train, train_lens, L, x_test, z_test, test_lens, x_train)
        assert False, 'Finished plotting'
        '''
        np.savez('./%s/%s/checkpoint_'%(args.data, name) + str(cv) + '.npz', zt=zt_sample, hyper=hyperparam_sample, loglik=test_logps, val_loglik=validation_logps, 
            train_loglik=train_logps, pi_mat=pi_mat_sample, pi_init=pi_init_sample,uniq=uniq_sample,lik_params=lik_params_sample, it=it, train_hammings=train_hammings,
            validation_hammings=validation_hammings, test_hammings=test_hammings, last_zt_test=zt_test, beta_vecs=beta_vecs);
        '''
        with open(f'./{args.data}/{name}/output_{cv}.txt', 'r') as file:
            original_args = ast.literal_eval(file.readline().strip("args dict: ").strip())
        c_pri, d_pri, alpha0_a_pri, alpha0_b_pri, gamma0_a_pri, gamma0_b_pri = original_args['c_pri'], original_args['d_pri'], original_args['alpha0_a_pri'], original_args['alpha0_b_pri'], original_args['gamma0_a_pri'], original_args['gamma0_b_pri']
        mode = 'ar'
        prior_params = {}
        D = x_train[0].shape[-1];
        for mat in ['M0', 'V0', 'S0']:
            mat_type = original_args[mat + '_type']
            mat_factor = original_args[mat + '_factor']
            if mat_type == 'zeros':
                prior_params[mat] = np.zeros((D, D)) * mat_factor
            elif mat_type == 'identity':
                prior_params[mat] = np.identity(D) * mat_factor
            elif mat_type == 'emperical_cov':
                prior_params[mat] = np.cov(np.concatenate([x for x in x_train], axis=0).T) * mat_factor
        
        prior_params['n0'] = D+2
        rho0, alpha0, alpha0_init, gamma0, init_suff_stats, L, beta_vec, pi_bar, pi_init, lik_params = init_gibbs_full_bayesian(alpha0_a_pri, alpha0_b_pri, gamma0_a_pri, gamma0_b_pri, c_pri, d_pri,prior_params, L, mode);
        s_ybar_ybar_inv,s_y_y_plus_s0,s_y_ybar,s_y_cond_ybar_plus_s0,dff = init_suff_stats['s_ybar_ybar_inv'],init_suff_stats['s_y_y_plus_s0'],init_suff_stats['s_y_ybar'],init_suff_stats['s_y_cond_ybar_plus_s0'],init_suff_stats['dff']
        checkpoint = np.load('./' + args.data + '/' + name + '/checkpoint_' + str(cv) + '.npz', allow_pickle=True)
        alpha0, gamma0, rho0, alpha0_init = checkpoint['hyper'][-1]
        beta_vec, pi_bar, pi_init, lik_params = checkpoint['beta_vecs'][-1], checkpoint['pi_mat'][-1], checkpoint['pi_init'][-1], checkpoint['lik_params'][-1]
        
        
        print('='*20)
        print('Learned mats')
        for s in range(L):
            print('State: ', s)
            print('Sigma:')
            print(lik_params['sigma_mat_post'][s])
            if not is_possemi_def(lik_params['sigma_mat_post'][s]):
                print('Not PSD!')
            print()
            print('A mat:')
            print(lik_params['a_mat_post'][s])
        print('='*20)

        test_hammings = []
        test_logliks = []
        
        for i in range(3):
            print('i: ', i)
            zt, n_mat, n_ft, K, uniq, suff_stats, m_mat, m_init, w_vec, m_mat_bar, \
            beta_vec, pi_bar, pi_init, lik_params, concentration, alpha0_init, \
            gamma0, stick_ratio, rho0, alpha0 = gibbs_sample(x_train, pi_bar, pi_init, L, lik_params, 
                                                             mode, init_suff_stats, n_cores, beta_vec, \
                                                             alpha0, alpha0_init, rho0, gamma0, \
                                                             alpha0_a_pri, alpha0_b_pri, gamma0_a_pri, \
                                                             gamma0_b_pri, c_pri, d_pri)
            loglik_testnew = compute_log_marginal_lik_ar_fmp(L, pi_bar, pi_init, suff_stats, x_test, n_cores, False, return_mean=False, old_way=False);
            print('new test log liks:')
            print(loglik_testnew)
            loglik_test = compute_log_marginal_lik_ar_fmp(L, pi_bar, pi_init, suff_stats, x_test, n_cores, False, return_mean=False);
            #loglik_test = compute_log_marginal_lik_ar_all(K=K,yt_ls=x_test,prob_mat=pi_bar,M0=prior_params['M0'],V0=prior_params['V0'],S0=prior_params['S0'],n0=prior_params['n0'],s_ybar_ybar_inv=s_ybar_ybar_inv,s_y_y_plus_s0=s_y_y_plus_s0,s_y_ybar=s_y_ybar,s_y_cond_ybar_plus_s0=s_y_cond_ybar_plus_s0,dff=dff)
            _, state_mapper = evaluate_dist(z_true=z_train, z_pred=zt, z_lens=train_lens, k_max=L)

            zt_test, _, _, _, _, _ = sample_zw_fmp(x_test, pi_bar, pi_init, L, lik_params, mode, init_suff_stats, n_cores);
            test_hamming, _ = evaluate_dist(z_true=z_test, z_pred=zt_test, z_lens=test_lens, k_max=L, mapping=state_mapper)
            test_hammings.append(test_hamming)
            test_logliks.append(loglik_testnew)

            print('='*20)
            print('Learned mats')
            for s in range(L):
                print('State: ', s)
                print('Sigma:')
                print(lik_params['sigma_mat_post'][s])
                #if not np.all(np.linalg.eigvals(lik_params['sigma_mat_post'][s]) >= 0)():
                #    print('Not PSD!')
                print()
                print('A mat:')
                print(lik_params['a_mat_post'][s])
            print('='*20)
        
        test_logliks = np.concatenate(test_logliks)
        test_hammings = np.concatenate(test_hammings)
        # Mean rounded to 3 places and std rounded to 2:
        print('Test Hamming: ', round(np.mean(test_hammings), 3), ' $\pm$ ', round(np.std(test_hammings), 2))
        print('Test Loglik: ', round(np.mean(test_logliks), 3), ' $\pm$ ', round(np.std(test_logliks), 2))
        return
    start_t = datetime.datetime.now()

    unq = []
    for z in z_test:
        unq.append(np.unique(z))
    unq = np.unique(np.concatenate(unq))
    print('unique2 z_test: ', unq)
    
    seed = [111, 222, 333, 444, 555, 666, 777, 888, 999][cv]
    np.random.seed(seed)

    c_pri, d_pri, alpha0_a_pri, alpha0_b_pri, gamma0_a_pri, gamma0_b_pri = \
    args.c_pri, args.d_pri, args.alpha0_a_pri, args.alpha0_b_pri, args.gamma0_a_pri, args.gamma0_b_pri

    prior_params = {}
    
    mode = 'ar';
    D = x_train[0].shape[-1];
    for mat in ['M0', 'V0', 'S0']:
        mat_type = args_dict[mat + '_type']
        mat_factor = args_dict[mat + '_factor']
        if mat_type == 'zeros':
            prior_params[mat] = np.zeros((D, D)) * mat_factor
        elif mat_type == 'identity':
            prior_params[mat] = np.identity(D) * mat_factor
        elif mat_type == 'emperical_cov':
            prior_params[mat] = np.cov(np.concatenate([x for x in x_train], axis=0).T) * mat_factor
    
    prior_params['n0'] = D+2

    print('prior_params: ', prior_params)
    

    ### start gibbs
    zt_sample = []
    hyperparam_sample = []
    pi_mat_sample = []
    pi_init_sample = []
    uniq_sample = []
    lik_params_sample = [];
    beta_vecs = []
    
    train_hammings = []
    train_logps = []
    validation_hammings = []
    validation_logps = []
    test_hammings = []
    test_logps = []

    #train_hammings_beta_matched = []
    #validation_hammings_beta_matched = []
    '''
    rho0: kappa (self transition parameter)
    alpha0: concentration on lower level DP
    gamma0: concentration on higher level DP
    beta: beta
    pi_init: initial state dist. Used when finding log prob, since we don't have a 'previous state' for the first time step
    pi_bar: state transition distributions
    alpha0_init: 
    L: Truncation level for # of states
    K: Number of states in use so far
    c_pri, d_pri. If none, then alpha0 = alpha0_init = gamma(alpha0_a_pri, alpha0_b_pri) and rho0=0 (ie no stickiness)
        if c_pri, d_pri are not none, then we sample concentration~gamma(alpha0_a_pri, 1/alpha0_b_pri) and stick_ratio~beta(c_pri, d_pri)
        Then we set rho0 = concentration*stick_ratio and alpha0 = concentration(1-stick_ratio). concentration controls the overall variability of transition distributions. 
        Higher concentration => the transition distributions are closer to uniform distributions.
        Lower concentration => the distributions have more mass on a smaller number of states
        stick_ratio controls how much of this variability in state transition distributions comes from the self transition value (rho0) vs the concentration used for pi, alpha0
    '''
    

    if init_way == 'prior':
        rho0, alpha0, alpha0_init, gamma0, init_suff_stats, L, beta_vec, pi_bar, pi_init, lik_params = init_gibbs_full_bayesian(alpha0_a_pri, alpha0_b_pri, gamma0_a_pri, gamma0_b_pri, c_pri, d_pri,prior_params, L, mode);
    if os.path.isfile('./' + name + '/checkpoint_' + str(cv) + '.npz'): # if a checkpoint exists
        print('Checkpoint found! Loading now..')
        f = np.load('./' + name + '/checkpoint_' + str(cv) + '.npz', allow_pickle=True)
        start_it = f['it']
        pi_bar = f['pi_mat'][-1]
        pi_init = f['pi_init'][-1]
        #print("f['lik_params'] type: ", type(f['lik_params']))
        #print("f['lik_params'] shape: ", f['lik_params'].shape)
        lik_params = f['lik_params'][-1]
        alpha0, gamma0, rho0, alpha0_init = f['hyper'][-1]
        
        train_hammings = f['train_hammings']
        train_logps = f['train_loglik']
        validation_hammings = f['validation_hammings']
        validation_logps = f['val_loglik']
        test_hammings = f['test_hammings']
        test_logps = f['loglik']

    else:
        start_it = 0
    

    for it in range(start_it,iters):
        print("Iteration: ", it)
        curr_t = datetime.datetime.now()
        print('Hours since start: ', round((curr_t-start_t).days*24 + (curr_t-start_t).seconds/3600, 2))
        if (curr_t-start_t).days >= 1: # if the job has been running for >= 24 hrs
            print('Stopping Early! 24 HR Max Reached!')
            break
        print('L: ', L)
        zt, n_mat, n_ft, K, uniq, suff_stats = sample_zw_fmp(x_train, pi_bar, pi_init, L, lik_params, mode, init_suff_stats, n_cores); 
        ## sample hdp prior parameters
        m_mat, m_init, w_vec, m_mat_bar = sample_m_w_mbar(n_mat, n_ft, beta_vec, alpha0, alpha0_init, rho0);
        beta_vec = sample_beta(m_mat_bar, m_init, gamma0);
        pi_bar, pi_init = sample_pi(n_mat, n_ft, alpha0, alpha0_init, rho0, beta_vec); 
        
        ## sample observation likelihood parameters
        lik_params = sample_lik_params(suff_stats, mode);
        
        ## sample hyperparams
        concentration, alpha0_init = sample_concentration(m_mat, n_mat, alpha0, rho0, m_init, n_ft, alpha0_init, alpha0_a_pri, alpha0_b_pri);
        gamma0 = sample_gamma(K, m_mat_bar, m_init, gamma0, gamma0_a_pri, gamma0_b_pri);
        stick_ratio = sample_stick_ratio(w_vec, m_mat, c_pri, d_pri);
        rho0, alpha0 = transform(concentration, stick_ratio); # rho = concentration*stick_ratio, alpha0=concentration-rho0
        
        ## compute loglik
        if it%10 == 0:
            #v = True if it < 500 else False
            v = False
            loglik_test_arr = compute_log_marginal_lik_ar_fmp(L, pi_bar, pi_init, suff_stats, x_test, n_cores, v, return_mean=False, old_way=True);
            loglik_test_new_arr = compute_log_marginal_lik_ar_fmp(L, pi_bar, pi_init, suff_stats, x_test, n_cores, v, return_mean=False, old_way=False);
            loglik_train = compute_log_marginal_lik_ar_fmp(L, pi_bar, pi_init, suff_stats, x_train, n_cores, v, return_mean=True, old_way=True);
            loglik_val = compute_log_marginal_lik_ar_fmp(L, pi_bar, pi_init, suff_stats, x_valid, n_cores, v, return_mean=True, old_way=True);
            loglik_val_new = compute_log_marginal_lik_ar_fmp(L, pi_bar, pi_init, suff_stats, x_valid, n_cores, v, return_mean=True, old_way=False);
            print('loglik_test_arr mean, std, min, max: ', loglik_test_arr.mean(), loglik_test_arr.std(), loglik_test_arr.min(), loglik_test_arr.max())
            print('loglik_train: ', loglik_train)
            print('loglik_val: ', loglik_val)
            print('loglik_test_new_arr mean, std, min, max: ', loglik_test_new_arr.mean(), loglik_test_new_arr.std(), loglik_test_new_arr.min(), loglik_test_new_arr.max())
            print('loglik_val_new: ', loglik_val_new)

            print('='*20)
            print('Learned mats')
            for s in range(L):
                print('State: ', s)
                print('Sigma:')
                print(lik_params['sigma_mat_post'][s])
                if not is_possemi_def(lik_params['sigma_mat_post'][s]):
                    print('Not PSD!')
                print()
                print('A mat:')
                print(lik_params['a_mat_post'][s])
            print('='*20)
            
            loglik_test = loglik_test_arr.mean()

            test_logps.append(loglik_test);
            train_logps.append(loglik_train)
            validation_logps.append(loglik_val)
            print('z_train dtype: ', z_train[0].dtype)
            print('z_train shape: ', z_train[0].shape)
            print('z_train[0]: ', z_train[0])
            print('zt dtype: ', zt[0].dtype)
            print('zt shape: ', zt[0].shape)
            print('zt[0]: ', zt[0])
            train_hamming, state_mapper = evaluate_dist(z_true=z_train, z_pred=zt, z_lens=train_lens, k_max=L)
            train_hamming = np.mean(train_hamming)
            train_hammings.append(train_hamming)
            print('Iteration: ', it, ' Train LL: ', loglik_train, ' Train hamming: ', train_hamming, ' Val LL: ', loglik_val)

            pi_mat_sample.append(copy.deepcopy(pi_bar));
            pi_init_sample.append(copy.deepcopy(pi_init));
            uniq_sample.append(copy.deepcopy(uniq));
            beta_vecs.append(copy.deepcopy(beta_vec));
            
            zt_sample.append(copy.deepcopy(zt));
            hyperparam_sample.append(np.array([alpha0, gamma0, rho0, alpha0_init]));
            lik_params_sample.append(copy.deepcopy(lik_params));

            # Now for test data predictions:
            unq = []
            for z in z_test:
                unq.append(np.unique(z))
            unq = np.unique(np.concatenate(unq))

            zt_test, _, _, _, _, _ = sample_zw_fmp(x_test, pi_bar, pi_init, L, lik_params, mode, init_suff_stats, n_cores);
            unq = []
            for z in z_test:
                unq.append(np.unique(z))
            unq = np.unique(np.concatenate(unq))

            test_hamming, _ = evaluate_dist(z_true=z_test, z_pred=zt_test, z_lens=test_lens, k_max=L, mapping=state_mapper)
            test_hamming = np.mean(test_hamming)
            test_hammings.append(test_hamming)
            print("Test Hamming: ", test_hamming)
            unq = []
            for z in z_test:
                unq.append(np.unique(z))
            unq = np.unique(np.concatenate(unq))

             # Now for validation data predictions:
            zt_val, _, _, _, _, _ = sample_zw_fmp(x_valid, pi_bar, pi_init, L, lik_params, mode, init_suff_stats, n_cores);
            val_hamming, _ = evaluate_dist(z_true=z_valid, z_pred=zt_val, z_lens=valid_lens, k_max=L, mapping=state_mapper)
            val_hamming = np.mean(val_hamming)
            validation_hammings.append(val_hamming)

        if it%50 == 0 or (it%10==0 and args.data=="har"):
            unq = []
            for z in z_test:
                unq.append(np.unique(z))
            unq = np.unique(np.concatenate(unq))
            print('unique3 z_test: ', unq)
            print('state_mapper: ', state_mapper)
            #print('state_mapper_beta_matched: ', state_mapper_beta_matched)
            np.savez('./%s/%s/checkpoint_'%(args.data, name) + str(cv) + '.npz', zt=zt_sample, hyper=hyperparam_sample, loglik=test_logps, val_loglik=validation_logps, 
            train_loglik=train_logps, pi_mat=pi_mat_sample, pi_init=pi_init_sample,uniq=uniq_sample,lik_params=lik_params_sample, it=it, train_hammings=train_hammings,
            validation_hammings=validation_hammings, test_hammings=test_hammings, last_zt_test=zt_test, beta_vecs=beta_vecs);
            print('Checkpoint saved!')

            colors = list(sns.color_palette('Set3', L + 1))
            plot_state_predictions_test_samples(n_test_samples=min(10, len(x_test)), test_x=x_test, test_z=z_test, test_lens=test_lens, test_preds=zt_test, state_mapper=state_mapper, colors=colors, save_path="./%s/%s/IND_ts_sample_inference_%d.pdf" % (args.data, name, cv))

            
            # create pie plot directory
            ordered_states = np.argsort(beta_vec)[::-1]  # np.argsort(gen_model.beta)[::-1])
            if not os.path.exists("./%s/%s/plots/"%(args.data, name)):
                os.makedirs("./%s/%s/plots/"%(args.data, name))
            gt_state_count = [np.count_nonzero(np.concatenate([z_train[i] for i in range(len(z_train))])==z_s) for z_s in range(len(np.unique(np.concatenate([z_train[i] for i in range(len(z_train))]))))]
            fig, axs = plt.subplots(1, 3, figsize=(15, 5))
            axs[0].pie(np.sort(beta_vec)[::-1], colors=[colors[ii_state] for ii_state in ordered_states])
            axs[0].set_title("Predicted class distribution")

            axs[1].pie(np.sort(gt_state_count)[::-1],colors=[colors[state_mapper[ii_state]] for (ii_state) in np.argsort(gt_state_count)[::-1]])
            axs[1].set_title("Ground truth class distribution mapped with hungarian")

            #axs[2].pie(np.sort(gt_state_count)[::-1],colors=[colors[state_mapper_beta_matched[ii_state]] for (ii_state) in np.argsort(gt_state_count)[::-1]])
            axs[2].set_title("Ground truth class distribution mapped with beta")

            plt.tight_layout()
            plt.savefig("./%s/%s/plots/pie_chart_iter%d_cv%d.pdf" % (args.data, name, it, cv))
            plt.close()

            plt.figure()
            sns.heatmap(pi_bar)
            plt.savefig("./%s/%s/plots/transition_heatmap_iter%d_cv%d.pdf" % (args.data, name, it, cv))


            #states = list(np.concatenate([samp for samp in z_train_mapped]))

           
            line_plot(train_logps, "./%s/%s/train_log_like_%d.pdf" % (args.data, name, cv))
            print('Final Test Logp: ', test_logps[-1])

            line_plot(validation_logps, "./%s/%s/val_log_like_%d.pdf" % (args.data, name, cv))
            line_plot(test_logps, "./%s/%s/test_log_like_%d.pdf" % (args.data, name, cv))
            line_plot(train_hammings, "./%s/%s/train_hamming_%d.pdf" % (args.data, name, cv))
            line_plot(validation_hammings, "./%s/%s/val_hamming_%d.pdf" % (args.data, name, cv))
            #line_plot(train_hammings_beta_matched, "./%s/%s/train_hamming_betamatched_%d.pdf" % (args.data, name, cv))
            #line_plot(validation_hammings_beta_matched, "./%s/%s/val_hamming_betamatched_%d.pdf" % (args.data, name, cv))

            if not os.path.exists('./%s/%s/generated_samples_cv%d/'%(args.data, name, cv)):
                os.makedirs('./%s/%s/generated_samples_cv%d/'%(args.data, name, cv))
            plot_generated_sample(L, lik_params['a_mat_post'], lik_params['sigma_mat_post'], beta_vec, './%s/%s/generated_samples_cv%d/gen_sample_%d.pdf'%(args.data, name, cv, it), x_train_min=min([xx.min() for xx in x_train]), x_train_max=max([xx.max() for xx in x_train]), x_train_max_len=max([len(x_i) for x_i in x_train]), colors=colors)
    ## permute result

    ## save results
    #seed = int((int(sys.argv[1])-1)%10);
    np.savez('./%s/%s/checkpoint_'%(args.data, name) + str(cv) + '.npz', zt=zt_sample, hyper=hyperparam_sample, loglik=test_logps, val_loglik=validation_logps, 
            train_loglik=train_logps, pi_mat=pi_mat_sample, pi_init=pi_init_sample,uniq=uniq_sample,lik_params=lik_params_sample, it=it, train_hammings=train_hammings,
            validation_hammings=validation_hammings, test_hammings=test_hammings, last_zt_test=zt_test, beta_vecs=beta_vecs);
    
    

    

    print('Final Train hamming: ', train_hammings[-1])

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Run S HDP HMM')
    parser.add_argument('--data', type=str, default='sim_hard')
    parser.add_argument('--ds_factor', type=int)
    parser.add_argument('--iters', type=int)
    parser.add_argument('--init_way', type=str)
    parser.add_argument('--k_max', type=int)
    
    # Hyper params
    parser.add_argument('--c_pri', type=float)
    parser.add_argument('--d_pri', type=float)
    parser.add_argument('--alpha0_a_pri', type=float)
    parser.add_argument('--alpha0_b_pri', type=float)
    parser.add_argument('--gamma0_a_pri', type=float)
    parser.add_argument('--gamma0_b_pri', type=float)

    parser.add_argument('--S0_type', type=str) # emperical cov
    parser.add_argument('--S0_factor', type=float)

    parser.add_argument('--M0_type', type=str) # zeros
    parser.add_argument('--M0_factor', type=float)

    parser.add_argument('--V0_type', type=str) # ones
    parser.add_argument('--V0_factor', type=float)
    parser.add_argument('--short_name', type=str)

    parser.add_argument('--cv', type=int)
    parser.add_argument('--eval', action='store_true')
    args = parser.parse_args()

    data_load_config = {'sim_easy': {'n_train': 100, 'n_valid': 50, 'n_test': 50}, 
                        'sim_hard': {'n_train': 100, 'n_valid': 50, 'n_test': 50},
                        'sim_semi_markov': {'n_train': 60, 'n_valid': 20, 'n_test': 20},
                        'har': {'ds_factor': args.ds_factor},
                        'har_70': {'ds_factor': args.ds_factor},
                       }
    
    main(args, data_load_config, args.short_name, args.cv)
