## how to run this file
## nohup python run_full_bayesian_approx_parallel_gibbs_ar.py 1 30 2 'prior' bee_seq_data ./ 40 &
## see the comments below for the meaning of these command line parameters (sys.argv)

## load packages
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as ss
import sys
sys.path.append('../../code/')
from gibbs_approx_parallel import *
from sklearn.metrics import hamming_loss
from util import *
import argparse
import seaborn as sns
import os
import datetime
from multiprocessing import Pool, cpu_count
import ast 
sns.set()

sys.path.append('../../../')
from utils import evaluate_dist, evaluate_dist_beta, line_plot, plot_state_predictions_test_samples, load_data

def gibbs_sample(x_train, pi_bar, kappa_vec, pi_init, L, lik_params, mode, init_suff_stats, n_cores,
                 beta_vec, alpha0, alpha0_init, gamma0, alpha0_a_pri, alpha0_b_pri, gamma0_a_pri, gamma0_b_pri,
                 rho0, rho1, v0_range, v1_range, v0_num_grid, v1_num_grid, p):
    
    # Sampling zt, wt, n_mat, num_1_vec, num_0_vec, n_ft, K, uniq, suff_stats
    zt, wt, n_mat, num_1_vec, num_0_vec, n_ft, K, uniq, suff_stats = sample_zw_fmp(x_train, pi_bar, kappa_vec, pi_init, L, lik_params, mode, init_suff_stats, n_cores)
    
    # Sampling HDP prior parameters
    kappa_vec = sample_kappa(num_1_vec, num_0_vec, rho0, rho1)
    m_mat, m_init = sample_m(n_mat, n_ft, beta_vec, alpha0, alpha0_init)
    beta_vec = sample_beta(m_mat, m_init, gamma0)
    pi_bar, pi_init = sample_pi(n_mat, n_ft, alpha0, alpha0_init, beta_vec)

    # Sampling observation likelihood parameters
    lik_params = sample_lik_params(suff_stats, mode)
    
    # Sampling hyperparameters
    alpha0, alpha0_init = sample_alpha(m_mat, n_mat, alpha0, m_init, n_ft, alpha0_init, alpha0_a_pri, alpha0_b_pri)
    gamma0 = sample_gamma(K, m_mat, m_init, gamma0, gamma0_a_pri, gamma0_b_pri)
    rho0, rho1, posterior_grid = sample_rho(v0_range, v1_range, v0_num_grid, v1_num_grid, num_1_vec, num_0_vec, p)

    return zt, wt, n_mat, num_1_vec, num_0_vec, n_ft, K, uniq, suff_stats, kappa_vec, m_mat, m_init, beta_vec, pi_bar, pi_init, lik_params, alpha0, alpha0_init, gamma0, rho0, rho1, posterior_grid


def main(args, data_load_config, name, cv):
    start_t = datetime.datetime.now()

    
    iters = args.iters; ## number of iterations
    n_cores = 1
    init_way = args.init_way; ## initialization way: from 'prior' or from 'hmm' (parametric hmm result)
    L = int(args.k_max); ## in general, set it to be twice the number of states is already good enough
    args_dict = vars(args) # converts to dict
    print('args dict: ', args_dict)


    
    x_train, z_train, train_lens, x_valid, z_valid, valid_lens, x_test, z_test, test_lens = load_data(args.data, data_load_config, normalize=False, pad_ragged=False, path_to_data='../../../data/')
    if args.eval:
        post_training_eval(args, name, cv, z_train, train_lens, L, x_test, z_test, test_lens, x_train)
        with open(f'./{args.data}/{name}/output_{cv}.txt', 'r') as file:
            original_args = ast.literal_eval(file.readline().strip("args dict: ").strip())
        p = original_args['p']
        mode = 'ar'
        v0_range=(original_args['v0_range0'], original_args['v0_range1'])
        v1_range=(original_args['v1_range0'], original_args['v1_range1'])
        alpha0_a_pri, alpha0_b_pri=(float(original_args['alpha0_a_pri']), float(original_args['alpha0_b_pri']))
        gamma0_a_pri, gamma0_b_pri=(float(original_args['gamma0_a_pri']), float(original_args['gamma0_b_pri']))
        v0_num_grid=4
        v1_num_grid=4
        D = x_train[0].shape[-1];
        prior_params = {}
        for mat in ['M0', 'V0', 'S0']:
            mat_type = original_args[mat + '_type']
            mat_factor = original_args[mat + '_factor']
            if mat_type == 'zeros':
                prior_params[mat] = np.zeros((D, D)) * mat_factor
            elif mat_type == 'identity':
                prior_params[mat] = np.identity(D) * mat_factor
            elif mat_type == 'emperical_cov':
                prior_params[mat] = np.cov(np.concatenate(x_train, axis=0).T) * mat_factor
        prior_params['n0'] = D+2

        rho0, rho1, alpha0, alpha0_init, gamma0, init_suff_stats, L, beta_vec, kappa_vec, pi_bar, pi_init, lik_params = init_gibbs_full_bayesian(p, v0_range, v1_range, alpha0_a_pri, alpha0_b_pri, gamma0_a_pri, gamma0_b_pri,prior_params, L,mode);

        checkpoint = np.load('./%s/%s/checkpoint_'%(args.data, name) + str(cv) + '.npz', allow_pickle=True)
        
        beta_vec, pi_bar, kappa_vec, pi_init, lik_params, pi_all = checkpoint['beta_vecs'][-1], checkpoint['pi_bars'][-1], checkpoint['kappa_vec'][-1], checkpoint['pi_init'][-1], checkpoint['lik_params'][-1], checkpoint['pi_mat'][-1]
        alpha0, gamma0, rho0, rho1, alpha0_init = checkpoint['hyper'][-1]
        mode = 'ar'
        test_hammings = []
        test_logliks = []
        for i in range(3):
            print('i: ', i)
            zt, wt, n_mat, num_1_vec, num_0_vec, n_ft, K, uniq, suff_stats, kappa_vec, m_mat, m_init, beta_vec, pi_bar, pi_init, lik_params, alpha0, alpha0_init, gamma0, rho0, rho1, posterior_grid = gibbs_sample(x_train, pi_bar, kappa_vec, pi_init, L, lik_params, mode, init_suff_stats, n_cores,
                 beta_vec, alpha0, alpha0_init, gamma0, alpha0_a_pri, alpha0_b_pri, gamma0_a_pri, gamma0_b_pri,
                 rho0, rho1, v0_range, v1_range, v0_num_grid, v1_num_grid, p)
            pi_all = (np.diag(kappa_vec)+np.matmul(np.diag(1-kappa_vec),pi_bar));
            print('len x_test: ', len(x_test))
            loglik_test = compute_log_marginal_lik_ar_fmp(L, pi_all, pi_init, suff_stats,x_test,n_cores, False, return_mean=False);
            _, state_mapper = evaluate_dist(z_true=z_train, z_pred=zt, z_lens=train_lens, k_max=L)

            zt_test, _, _, _, _, _, _, _, _ = sample_zw_fmp(x_test, pi_bar, kappa_vec, pi_init, L, lik_params, mode, init_suff_stats, n_cores);
            
            test_hamming, _ = evaluate_dist(z_true=z_test, z_pred=zt_test, z_lens=test_lens, k_max=L, mapping=state_mapper)
            test_hammings.append(test_hamming)
            test_logliks.append(loglik_test)
        test_hammings = np.concatenate(test_hammings)
        test_logliks = np.concatenate(test_logliks)
        
        print('Test Hamming: ', round(np.mean(test_hammings), 3), ' $\pm$ ', round(np.std(test_hammings), 2))
        print('Test Loglik: ', round(np.mean(test_logliks), 3), ' $\pm$ ', round(np.std(test_logliks), 2))
        return
    seed = [111, 222, 333, 444, 555, 666, 777, 888, 999][cv]
    np.random.seed(seed)
    
    p = args.p
    v0_range=(args.v0_range0, args.v0_range1)
    v1_range=(args.v1_range0, args.v1_range1)
    alpha0_a_pri, alpha0_b_pri=(float(args.alpha0_a_pri), float(args.alpha0_b_pri))
    gamma0_a_pri, gamma0_b_pri=(float(args.gamma0_a_pri), float(args.gamma0_b_pri))
    v0_num_grid=4
    v1_num_grid=4

    D = x_train[0].shape[-1];
    prior_params = {}
    for mat in ['M0', 'V0', 'S0']:
        mat_type = args_dict[mat + '_type']
        mat_factor = args_dict[mat + '_factor']
        if mat_type == 'zeros':
            prior_params[mat] = np.zeros((D, D)) * mat_factor
        elif mat_type == 'identity':
            prior_params[mat] = np.identity(D) * mat_factor
        elif mat_type == 'emperical_cov':
            prior_params[mat] = np.cov(np.concatenate(x_train, axis=0).T) * mat_factor
    prior_params['n0'] = D+2

    print('prior_params: ', prior_params)
    

    print('Min value in train data: ', min([vec.min() for vec in x_train]))
    print('Min value in test data: ', min([vec.min() for vec in x_test]))
    print('Max value in train data: ', max([vec.max() for vec in x_train]))
    print('Max value in test data: ', max([vec.max() for vec in x_test]))

    mode = 'ar';
    
    ### start gibbs
    zt_sample = [];
    hyperparam_sample = [];
    pi_mat_sample = [];
    pi_init_sample = [];
    uniq_sample = [];
    lik_params_sample = [];
    beta_vecs = []
    
    train_hammings = []
    train_logps = []
    validation_hammings = []
    validation_logps = []
    test_logps = [];
    test_hammings = []

    #wt_sample = [];
    kappa_vec_sample = [];
    post_sample = [];
    pi_bars = []

    train_hammings_beta_matched = []
    validation_hammings_beta_matched = []



    if init_way == 'prior':
        rho0, rho1, alpha0, alpha0_init, gamma0, init_suff_stats, L, beta_vec, kappa_vec, pi_bar, pi_init, lik_params = init_gibbs_full_bayesian(p, v0_range, v1_range, alpha0_a_pri, alpha0_b_pri, gamma0_a_pri, gamma0_b_pri,prior_params, L,mode);

    
    if os.path.isfile('./' + name + '/checkpoint_' + str(cv) + '.npz'): # if a checkpoint exists
        print('Checkpoint found! Loading now..')
        f = np.load('./' + name + '/checkpoint_' + str(cv) + '.npz', allow_pickle=True)
        start_it = f['it']
        pi_bar = f['pi_bars'][-1]
        kappa_vec = f['kappa_vec'][-1]
        pi_init = f['pi_init'][-1]
        lik_params = f['lik_params'][-1]
        alpha0, gamma0, rho0, rho1, alpha0_init = f['hyper'][-1]

        train_hammings = f['train_hammings']
        train_logps = f['train_loglik']
        validation_hammings = f['validation_hammings']
        validation_logps = f['val_loglik']
        test_hammings = f['test_hammings']
        test_logps = f['loglik']
        
        print('Succesfully loaded checkpoint!')
    else:
        start_it = 0

    for it in range(start_it, iters):
        print("Iteration: ", it)
        curr_t = datetime.datetime.now()
        print('Hours since start: ', round((curr_t-start_t).days*24 + (curr_t-start_t).seconds/3600, 2))
        if (curr_t-start_t).days >= 1: # if the job has been running for >= 24 hrs
            print('Stopping Early! 24 HR Max Reached!')
            break
        print('Right before call to sample_zw_fmp')
        zt, wt, n_mat, num_1_vec, num_0_vec, n_ft, K, uniq, suff_stats = sample_zw_fmp(x_train, pi_bar, kappa_vec, pi_init, L, lik_params, mode, init_suff_stats, n_cores);
        
        ## sample hdp prior parameters
        kappa_vec = sample_kappa(num_1_vec, num_0_vec, rho0, rho1);
        m_mat, m_init = sample_m(n_mat, n_ft, beta_vec, alpha0, alpha0_init);
        beta_vec = sample_beta(m_mat, m_init, gamma0);
        pi_bar, pi_init = sample_pi(n_mat, n_ft, alpha0, alpha0_init, beta_vec);

        ## sample observation likelihood parameters
        lik_params = sample_lik_params(suff_stats, mode);
        
        ## sample hyperparams
        alpha0, alpha0_init = sample_alpha(m_mat, n_mat, alpha0, m_init, n_ft, alpha0_init, alpha0_a_pri, alpha0_b_pri);
        gamma0 = sample_gamma(K, m_mat, m_init, gamma0, gamma0_a_pri, gamma0_b_pri);
        rho0, rho1, posterior_grid = sample_rho(v0_range, v1_range, v0_num_grid, v1_num_grid, num_1_vec, num_0_vec,p);
        
        ## compute loglik
        if it%10 == 0:
            pi_all = (np.diag(kappa_vec)+np.matmul(np.diag(1-kappa_vec),pi_bar));

            v = False
            loglik_test_arr = compute_log_marginal_lik_ar_fmp(L, pi_all, pi_init, suff_stats, x_test, n_cores, v, return_mean=False, old_way=True);
            loglik_test_new_arr = compute_log_marginal_lik_ar_fmp(L, pi_all, pi_init, suff_stats, x_test, n_cores, v, return_mean=False, old_way=False);
            loglik_train = compute_log_marginal_lik_ar_fmp(L, pi_all, pi_init, suff_stats, x_train, n_cores, v, return_mean=True, old_way=True);
            loglik_val = compute_log_marginal_lik_ar_fmp(L, pi_all, pi_init, suff_stats, x_valid, n_cores, v, return_mean=True, old_way=True);
            loglik_val_new = compute_log_marginal_lik_ar_fmp(L, pi_all, pi_init, suff_stats, x_valid, n_cores, v, return_mean=True, old_way=False);
            print('loglik_test_arr mean, std, min, max: ', loglik_test_arr.mean(), loglik_test_arr.std(), loglik_test_arr.min(), loglik_test_arr.max())
            print('loglik_train: ', loglik_train)
            print('loglik_val: ', loglik_val)
            print('loglik_test_new_arr mean, std, min, max: ', loglik_test_new_arr.mean(), loglik_test_new_arr.std(), loglik_test_new_arr.min(), loglik_test_new_arr.max())
            print('loglik_val_new: ', loglik_val_new)

            print('='*20)
            print('Learned sigma mats')
            for s in range(L):
                print('State: ', s)
                print(lik_params['sigma_mat_post'][s])
                print()
            print('='*20)
            
            loglik_test = loglik_test_arr.mean()



            test_logps.append(loglik_test);
            train_logps.append(loglik_train)
            validation_logps.append(loglik_val)

            print('loglik_test: ', loglik_test)
            print('loglik_train: ', loglik_train)
            print('loglik_val: ', loglik_val)
            z_train = np.array(z_train)
            train_hamming, state_mapper = evaluate_dist(z_true=z_train, z_pred=zt, z_lens=train_lens, k_max=L)
            #train_hamming_beta_matched, state_mapper_beta_matched = evaluate_dist_beta(z_true=z_train, z_pred=zt, beta=beta_vec)
            train_hammings.append(train_hamming)
            #train_hammings_beta_matched.append(train_hamming_beta_matched)
            #z_train_mapped = np.vectorize(state_mapper.get)(z_train) # map labels
            print('Iteration: ', it, ' Train LL: ', loglik_train, ' Train hamming: ', train_hamming)
        
            post_sample.append(posterior_grid);
            pi_mat_sample.append(copy.deepcopy(pi_all));
            pi_init_sample.append(copy.deepcopy(pi_init));
            uniq_sample.append(copy.deepcopy(uniq));
            beta_vecs.append(copy.deepcopy(beta_vec))
            
            zt_sample.append(copy.deepcopy(zt));
            #wt_sample.append(copy.deepcopy(wt));
            kappa_vec_sample.append(kappa_vec.copy());
            hyperparam_sample.append(np.array([alpha0, gamma0, rho0, rho1, alpha0_init]));
            lik_params_sample.append(copy.deepcopy(lik_params));
            pi_bars.append(pi_bar)

            # Now for test data predictions:
            zt_test, _, _, _, _, _, _, _, _ = sample_zw_fmp(x_test, pi_bar, kappa_vec, pi_init, L, lik_params, mode, init_suff_stats, n_cores);
            test_hamming, _ = evaluate_dist(z_true=z_test, z_pred=zt_test, z_lens=test_lens, k_max=L, mapping=state_mapper)
            test_hamming = np.mean(test_hamming)
            test_hammings.append(test_hamming)
            print('hamming on test set: ', test_hamming)

            # Now for val data predictions:
            zt_val, _, _, _, _, _, _, _, _ = sample_zw_fmp(x_valid, pi_bar, kappa_vec, pi_init, L, lik_params, mode, init_suff_stats, n_cores);
            val_hamming, _ = evaluate_dist(z_true=z_valid, z_pred=zt_val, z_lens=valid_lens, k_max=L, mapping=state_mapper)
            val_hamming = np.mean(val_hamming)
            print('hamming on validation set: ', val_hamming)
            validation_hammings.append(val_hamming)
            #validation_hammings_beta_matched.append(validation_hamming_beta_matched)

            
        if it%50 == 0 or (it%10==0 and args.data=="har"):
            np.savez('./%s/%s/checkpoint_'%(args.data, name) + str(cv) + '.npz', kappa_vec=kappa_vec_sample, zt=zt_sample, pi_bars=pi_bars, hyper=hyperparam_sample, 
            loglik=test_logps, pi_mat=pi_mat_sample, pi_init=pi_init_sample,uniq=uniq_sample,lik_params=lik_params_sample, it=it, train_loglik=train_logps, 
            train_hammings=train_hammings, val_loglik=validation_logps, validation_hammings=validation_hammings, test_hammings=test_hammings, last_zt_test=zt_test, beta_vecs=beta_vecs);
            
            print('Checkpoint saved!')

            colors = list(sns.color_palette('Set3', L + 1))
            plot_state_predictions_test_samples(n_test_samples=min(10, len(x_test)), test_x=x_test, test_z=z_test, test_lens=test_lens, test_preds=zt_test, state_mapper=state_mapper, colors=colors, save_path="./%s/%s/IND_ts_sample_inference_%d.pdf" % (args.data, name, cv))
            
            
            # create pie plot directory
            ordered_states = np.argsort(beta_vec)[::-1]
            if not os.path.exists("./%s/%s/plots/"%(args.data, name)):
                os.makedirs("./%s/%s/plots/"%(args.data, name))

            gt_state_count = [np.count_nonzero(np.concatenate([z_train[i] for i in range(len(z_train))])==z_s) for z_s in range(len(np.unique(np.concatenate([z_train[i] for i in range(len(z_train))]))))]
            fig, axs = plt.subplots(1, 3, figsize=(15, 5))
            axs[0].pie(np.sort(beta_vec)[::-1], colors=[colors[ii_state] for ii_state in ordered_states])
            axs[0].set_title("Predicted class distribution")

            axs[1].pie(np.sort(gt_state_count)[::-1],colors=[colors[state_mapper[ii_state]] for (ii_state) in np.argsort(gt_state_count)[::-1]])
            axs[1].set_title("Ground truth class distribution mapped with hungarian")

            #axs[2].pie(np.sort(gt_state_count)[::-1],colors=[colors[state_mapper_beta_matched[ii_state]] for (ii_state) in np.argsort(gt_state_count)[::-1]])
            #axs[2].set_title("Ground truth class distribution mapped with beta")

            plt.tight_layout()
            plt.savefig("./%s/%s/plots/pie_chart_iter%d_cv%d.pdf" % (args.data, name, it, cv))
            plt.close()

            plt.figure()
            sns.heatmap(pi_bar)
            plt.savefig("./%s/%s/plots/transition_heatmap_iter%d_cv%d.pdf" % (args.data, name, it, cv))


            #states = list(np.concatenate([samp for samp in z_train_mapped]))

           
            line_plot(train_logps, "./%s/%s/train_log_like_%d.pdf" % (args.data, name, cv))
            print('Final Test Logp: ', test_logps[-1])

            line_plot(validation_logps, "./%s/%s/val_log_like_%d.pdf" % (args.data, name, cv))
            line_plot(test_logps, "./%s/%s/test_log_like_%d.pdf" % (args.data, name, cv))
            line_plot(train_hammings, "./%s/%s/train_hamming_%d.pdf" % (args.data, name, cv))
            line_plot(validation_hammings, "./%s/%s/val_hamming_%d.pdf" % (args.data, name, cv))
            #line_plot(train_hammings_beta_matched, "./%s/%s/train_hamming_betamatched_%d.pdf" % (args.data, name, cv))
            #line_plot(validation_hammings_beta_matched, "./%s/%s/val_hamming_betamatched_%d.pdf" % (args.data, name, cv))

            if not os.path.exists('./%s/%s/generated_samples_cv%d/'%(args.data, name, cv)):
                os.makedirs('./%s/%s/generated_samples_cv%d/'%(args.data, name, cv))
            plot_generated_sample(L, lik_params['a_mat_post'], lik_params['sigma_mat_post'], beta_vec, './%s/%s/generated_samples_cv%d/gen_sample_%d.pdf'%(args.data, name, cv, it), x_train_min=min([xx.min() for xx in x_train]), x_train_max=max([xx.max() for xx in x_train]), x_train_max_len=max([len(x_i) for x_i in x_train]), colors=colors)
            
            
    ## permute result

    ## save results
    #seed = int((int(sys.argv[1])-1)%10);
    np.savez('./%s/%s/checkpoint_'%(args.data, name) + str(cv) + '.npz', kappa_vec=kappa_vec_sample, zt=zt_sample, pi_bars=pi_bars, hyper=hyperparam_sample, 
            loglik=test_logps, pi_mat=pi_mat_sample, pi_init=pi_init_sample,uniq=uniq_sample,lik_params=lik_params_sample, it=it, train_loglik=train_logps, 
            train_hammings=train_hammings, val_loglik=validation_logps, validation_hammings=validation_hammings, test_hammings=test_hammings, last_zt_test=zt_test, beta_vecs=beta_vecs);

    print('Final Train hamming: ', train_hammings[-1])
    
    
    


if __name__=="__main__":
    parser = argparse.ArgumentParser(description='Run S HDP HMM')
    parser.add_argument('--data', type=str, default='sim_hard')
    parser.add_argument('--ds_factor', type=int)
    parser.add_argument('--iters', type=int)
    parser.add_argument('--init_way', type=str)
    parser.add_argument('--k_max', type=int)
    
    # Hyper params
    parser.add_argument('--alpha0_a_pri', type=str)
    parser.add_argument('--alpha0_b_pri', type=str)
    parser.add_argument('--gamma0_a_pri', type=str)
    parser.add_argument('--gamma0_b_pri', type=str)

    parser.add_argument('--S0_type', type=str)
    parser.add_argument('--S0_factor', type=float)

    parser.add_argument('--M0_type', type=str)
    parser.add_argument('--M0_factor', type=float)

    parser.add_argument('--V0_type', type=str)
    parser.add_argument('--V0_factor', type=float)

    parser.add_argument('--p', type=int)
    parser.add_argument('--v0_range0', type=float)
    parser.add_argument('--v0_range1', type=float)
    parser.add_argument('--v1_range0', type=float)
    parser.add_argument('--v1_range1', type=float)

    parser.add_argument('--short_name', type=str)
    parser.add_argument('--cv', type=int)
    parser.add_argument('--eval', action='store_true')
    args = parser.parse_args()
    
    data_load_config = {'sim_easy': {'n_train': 100, 'n_valid': 50, 'n_test': 50}, 
                        'sim_hard': {'n_train': 100, 'n_valid': 50, 'n_test': 50},
                        'sim_semi_markov': {'n_train': 60, 'n_valid': 20, 'n_test': 20},
                        'har': {'ds_factor': args.ds_factor},
                        'har_70': {'ds_factor': args.ds_factor},
                       }
    
    main(args, data_load_config, args.short_name, args.cv)
