## how to run this file
## nohup python run_full_bayesian_gibbs_poisson.py 1 20 i01_maze15_2d_data_100ms_sample_trials ./ &
## see the comments below for the meaning of these command line parameters (sys.argv)
## note that when computing the predictive liklihood on test-data, the start point is assumed to be given. see more details in ../../code/util.py

## load packages
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as ss
import sys
sys.path.append('../../code/')
from gibbs_poisson import *
from util import *
import argparse

sys.path.append('../../../../')
from utils import evaluate_acc

def main(args):
    seed_vec = [111,222,333,444,555,666,777,888,999,1000];

    seed = int(args.seed-1)%10; ## random seed
    np.random.seed(seed_vec[seed]) ## fix randomness

    

    iters = args.iters ## number of iterations 

    if args.data=="sim_hard":
        ## set params
        p = 3;
        v0_range=[0.01, 0.99];
        v1_range=[0.01, 2]; ## [0.001, 10] if p=2
        alpha0_a_pri=1;
        alpha0_b_pri=0.01;
        gamma0_a_pri=2;
        gamma0_b_pri=1;
        v0_num_grid=30;
        v1_num_grid=30;

        x_all = np.load('../../../.././data/sim_hard_x.npy')
        z_all = np.load('../../../.././data/sim_hard_z.npy')
        x_test = x_all[500:]
        z_test = z_all[500:]
        x_train = x_all[:500]
        z_train = z_all[:500]
        
        feature_mean = x_train.reshape((-1, x_train.shape[-1])).mean(0)
        feature_std = x_train.reshape((-1, x_train.shape[-1])).std(0)
        x_train = (x_train-feature_mean)/feature_std
        feature_mean = x_test.reshape((-1, x_test.shape[-1])).mean(0)
        feature_std = x_test.reshape((-1, x_test.shape[-1])).std(0)
        x_test = (x_test-feature_mean)/feature_std
        
        x_train += 2
        x_test += 2

        x_train[np.where(x_train < 0)] = 0
        x_test[np.where(x_test < 0)] = 0

    elif args.data=="sim_easy":
        ## set params
        p = 3;
        v0_range=[0.01, 0.99];
        v1_range=[0.01, 2]; ## [0.001, 10] if p=2
        alpha0_a_pri=1;
        alpha0_b_pri=0.01;
        gamma0_a_pri=2;
        gamma0_b_pri=1;
        v0_num_grid=30;
        v1_num_grid=30;

        x_all = np.load('../../../.././data/sim_easy_x.npy')
        z_all = np.load('../../../.././data/sim_easy_z.npy')
        x_test = x_all[500:]
        z_test = z_all[500:]
        x_train = x_all[:500]
        z_train = z_all[:500]
        
        feature_mean = x_train.reshape((-1, x_train.shape[-1])).mean(0)
        feature_std = x_train.reshape((-1, x_train.shape[-1])).std(0)
        x_train = (x_train-feature_mean)/feature_std
        feature_mean = x_test.reshape((-1, x_test.shape[-1])).mean(0)
        feature_std = x_test.reshape((-1, x_test.shape[-1])).std(0)
        x_test = (x_test-feature_mean)/feature_std
        
        x_train += 2
        x_test += 2

        x_train[np.where(x_train < 0)] = 0
        x_test[np.where(x_test < 0)] = 0

    elif args.data == "har":
        assert 2==1 # Do Har new!
        ## set params
        p = 3;
        v0_range=[0.01, 0.99];
        v1_range=[0.01, 2]; ## [0.001, 10] if p=2
        alpha0_a_pri=1;
        alpha0_b_pri=0.01;
        gamma0_a_pri=2;
        gamma0_b_pri=1;
        v0_num_grid=30;
        v1_num_grid=30;

        n_train = 300
        k_max = 15
        n_features = 6
        lr_init = 1e-4
        decay = 0.01
        n_epochs = 4
        gen_model = HDPHMM(alpha=30, gma=5, kappa=[30, 1], k_min=k_max, n_features=n_features, theta_size=60,
                           device=device)
        if args.train:
            x_all = np.load("./data/HAR_train_data.npy", allow_pickle=True)
            z_all = np.load("./data/HAR_train_labels.npy", allow_pickle=True)
        x_all_test = np.load("./data/HAR_test_data.npy", allow_pickle=True)
        z_all_test = np.load("./data/HAR_test_labels.npy", allow_pickle=True)
        test_lens = [len(xx) for xx in x_all_test]
        x_all_test = pad_sequence([torch.Tensor(xx)[:min(3000, len(xx))] for xx in x_all_test], batch_first=True)
        z_all_test = pad_sequence([torch.Tensor(xx)[:min(3000, len(xx))] for xx in z_all_test], batch_first=True)

    elif args.data == "har_new":
        ## set params
        p = 3;
        v0_range=[0.01, 0.99];
        v1_range=[0.01, 2]; ## [0.001, 10] if p=2
        alpha0_a_pri=1;
        alpha0_b_pri=0.01;
        gamma0_a_pri=2;
        gamma0_b_pri=1;
        v0_num_grid=30;
        v1_num_grid=30;

        n_train = 300
        k_max = 14
        n_features = 6
        lr_init = 1e-4
        decay = 0.01
        n_epochs = 4
        gen_model = HDPHMM(alpha=30, gma=5, kappa=[30, 1], k_min=k_max, n_features=n_features, theta_size=60,
                           device=device)
        if args.train:
            x_all = np.load("./data/train_data.npy", allow_pickle=True)
            z_all = np.load("./data/train_labels.npy", allow_pickle=True)
            z_all = z_all + 1
        x_all_test = np.load("./data/test_data.npy", allow_pickle=True)
        z_all_test = np.load("./data/test_labels.npy", allow_pickle=True)
        z_all_test = z_all_test + 1
        test_lens = [len(xx) for xx in x_all_test]
        x_all_test = pad_sequence([torch.Tensor(xx)[:min(3000, len(xx))] for xx in x_all_test], batch_first=True)
        z_all_test = pad_sequence([torch.Tensor(xx)[:min(3000, len(xx))] for xx in z_all_test], batch_first=True)

        x_all_test += 1 # shift by 1 since raw data is bounded between [-1, 1]
        x_all_train += 1
    
    elif args.data == "har_processed":
        ## set params
        p = 3;
        v0_range=[0.01, 0.99];
        v1_range=[0.01, 2]; ## [0.001, 10] if p=2
        alpha0_a_pri=1;
        alpha0_b_pri=0.01;
        gamma0_a_pri=2;
        gamma0_b_pri=1;
        v0_num_grid=30;
        v1_num_grid=30;

        n_train = 300
        k_max = 8
        n_features = 10
        lr_init = 1e-4
        decay = 0.01
        n_epochs = 40
        gen_model = HDPHMM(alpha=30, gma=5, kappa=[8, 1], k_min=k_max, n_features=n_features, theta_size=150,
                           device=device)

        x_all = np.load("./data/HAR_processed/x_train.pkl", allow_pickle=True).transpose(0,2,1)
        z_all = np.load("./data/HAR_processed/state_train.pkl", allow_pickle=True)
        pca = PCA(n_components=10, svd_solver='arpack')
        x_all = pca.fit_transform(x_all.reshape(-1, x_all.shape[-1])).reshape((z_all.shape[0], z_all.shape[1],10))
        x_all_test = np.load("./data/HAR_processed/x_test.pkl", allow_pickle=True).transpose(0,2,1)
        z_all_test = np.load("./data/HAR_processed/state_test.pkl", allow_pickle=True)
        x_all_test = pca.transform(x_all_test.reshape(-1, x_all_test.shape[-1])).reshape((z_all_test.shape[0], z_all_test.shape[1],10))
        test_lens = [len(xx) for xx in x_all_test]
        feature_mean = x_all.reshape((-1, x_all.shape[-1])).mean(0)
        feature_std = x_all.reshape((-1, x_all.shape[-1])).std(0)
        x_all = (x_all - feature_mean) / feature_std
        x_all_test = (x_all_test - feature_mean) / feature_std
        print(x_all.shape, x_all_test.shape)
        # x_all_test = pad_sequence([torch.Tensor(xx)[:min(3000, len(xx))] for xx in x_all_test], batch_first=True)
        # z_all_test = pad_sequence([torch.Tensor(xx)[:min(3000, len(xx))] for xx in z_all_test], batch_first=True)
    else:
        raise ValueError("Dataset doesn't exist")

    accs = []
    pred_vecs = [] # predicted zt vectors for each sample
    real_vecs = [] # real zt vectors for each sample
    log_likes = {}
    test_log_likes = []
    for i in range(len(x_train)):
        yt_real = x_train[i]
        z_real = z_train[i]
        T = len(yt_real);

        ## gamma prior parameters for poisson observations firing rate, can adjust it based on data
        lam_a_pri = np.ones(len(yt_real[0]));
        lam_b_hyper_pri_shape = 1;
        lam_b_hyper_pri_rate = 1;

        ### start gibbs

        zt_sample = [];
        wt_sample = [];
        kappa_vec_sample = [];
        hyperparam_sample = [];
        post_sample = [];
        
        pi_mat_sample = [];

        for it in range(iters):
            if it == 0:
                '''
                Summary of variables: 
                rho0, rho1:  params on prior for dist of kappas. There's a prior on rho0 and rho1: rho0 = v0 * 1/(v1^p) and rho1 = (1-v0) * 1/(v1^p) 
                where v0,v1 are sampled from uniforms of the v0 / v1 ranges
                
                alpha0: pulled from gamma(alpha0_a_pri, 1/alpha0_b_pri). larger alpha0_a_pri shifts the prob mass right, larger alpha0_b_pri flattens / spreads the dist
                gamma0: pulled from gamma(gamma0_a_pri, 1/gamma0_b_pri)
                lam_a_pri: Initialized above.
                lam_b_pri: (n_features,) shaped vector where each value is iid pulled from gamma(lam_b_hyper_pri_shape, 1/(lam_b_hyper_pri_rate)). larger lam_b_hyper_pri_shape 
                shifts the prob mass right, larger lam_b_hyper_pri_rate flattens / spreads the dist
                
                K: Number of states discovered so far. Initialized to 1
                zt:  (T,) shaped vector, contains state predicted for each time step. Initialized to zeros.
                wt:  (T,) shaped binary vector, contains 1 if there's a self transition at that point (based on kappa), 0 else.
                beta_vec:  Initialized as a scalar value in [0, 1] (it was pulled from a dirichlet, parameterized by gamma0_a_pri, and gamma0_b_pri). Later on, its a (K,) vector, 
                its the prior dist on states
                
                beta_new:  Intialized as a scalar, its the other value pulled from the above dirichlet. i.e. its 1-beta_vec on initialization
                kappa_vec:  (K,) sized vector where the i'th element is the kappa self transition probability for state i
                kappa_new:  Scalar, the kappa value for the next state added to the set of states (i.e. when K increases, this will be kappa_K)
                n_mat: # Seems to be a square matrix, where the ij'th element counts how many times we went from state i to j, excluding self transitions. i.e. counting state changes
                ysum: # (K, n_features) shaped vector, the ij'th element is the sum of the j'th feature for all time steps that were in state i
                ycnt: # (K, n_features) shaped vector. The i'th row is just np.ones(n_features) * (number of time steps assigned to state i)
                '''
                rho0, rho1, alpha0, gamma0, lam_a_pri, lam_b_pri, K, zt, wt, beta_vec, beta_new, kappa_vec, kappa_new, n_mat, ysum, ycnt = init_gibbs_full_bayesian(p, v0_range, v1_range, alpha0_a_pri, alpha0_b_pri, gamma0_a_pri, gamma0_b_pri, lam_a_pri, lam_b_hyper_pri_shape, lam_b_hyper_pri_rate, T, yt_real);
            else:
                zt, wt, n_mat, ysum, ycnt, beta_vec, kappa_vec, beta_new, kappa_new, K = sample_zw(zt, wt, yt_real, n_mat, ysum, ycnt, beta_vec, beta_new, kappa_vec, kappa_new, alpha0, gamma0, lam_a_pri, lam_b_pri, rho0, rho1, K);

            # updates K to be len(set(zt)), updates the other counting variables / matrices accordingly
            zt, n_mat, ysum, ycnt, beta_vec, K = decre_K(zt, n_mat, ysum, ycnt, beta_vec);

            # num_1_vec: (K,) sized vector, jth value is the number of self transitions for state j
            # num_0_vec: (K,) sized vector, jth value is the number of non self transitions for state j
            kappa_vec, kappa_new, num_1_vec, num_0_vec = sample_kappa(zt, wt, rho0, rho1, K);
            
            # m_mat is a (K, K) matrix.. unclear what its measuring. Transition matrix?
            m_mat = sample_m(n_mat, beta_vec, alpha0, K);
            beta_vec, beta_new = sample_beta(m_mat, gamma0);
            
            ## sample hyperparams
            alpha0 = sample_alpha(m_mat, n_mat, alpha0, alpha0_a_pri, alpha0_b_pri);
            gamma0 = sample_gamma(K, m_mat, gamma0, gamma0_a_pri, gamma0_b_pri);
            rho0, rho1, posterior_grid = sample_rho(v0_range, v1_range, v0_num_grid, v1_num_grid, K, num_1_vec, num_0_vec,p);
            lam_mat = sample_lam_mat(lam_a_pri, lam_b_pri, ysum, ycnt);
            lam_b_pri = sample_lam_b_pri(lam_b_hyper_pri_shape, lam_b_hyper_pri_rate, lam_a_pri, lam_mat, K);
            
            ## compute loglik
            if it%10 == 0:
                pi_mat = sample_pi_our(K, alpha0, beta_vec, beta_new, n_mat, kappa_vec, kappa_new);
                
                # Get avg log likelihood over all train samples
                avg_train_ll = []
                for train_ind in range(len(x_train)):
                    yt_train = x_train[train_ind]
                    _, loglik_test = compute_log_marginal_lik_poisson(K, yt_train, zt[-1], pi_mat, lam_a_pri, lam_b_pri, ysum, ycnt);
                    avg_train_ll.append(loglik_test)
                avg_train_ll = np.mean(avg_train_ll)
                print('LL on train data after training on sample %d after %d iterations: '%(i, it), avg_train_ll)
                if it in log_likes:
                    log_likes[it].append(avg_train_ll)
                else:
                    log_likes[it] = [avg_train_ll]

                post_sample.append(posterior_grid);
                pi_mat_sample.append(pi_mat);
            
                zt_sample.append(zt.copy());
                wt_sample.append(wt.copy());
                kappa_vec_sample.append(kappa_vec.copy());
                hyperparam_sample.append(np.hstack((np.array([alpha0, gamma0, rho0, rho1]), lam_b_pri)));
            
            #if it%100 == 0:
            #    np.savez(rlt_path+file_name+'_full_bayesian_rlt_'+str(seed) +'.npz', zt=zt_sample, wt=wt_sample, kappa=kappa_vec_sample, hyper=hyperparam_sample, loglik=loglik_test_sample, post=post_sample, pi_mat=pi_mat_sample);
                
        
        #np.savez(rlt_path+file_name+'_full_bayesian_rlt_'+str(seed) +'.npz', zt=zt_sample, wt=wt_sample, kappa=kappa_vec_sample, hyper=hyperparam_sample, loglik=loglik_test_sample, post=post_sample, pi_mat=pi_mat_sample);

        acc = evaluate_acc(z_true=[z_real], z_pred=[zt])
        accs.append(acc)
        
        pred_vecs.append(zt)
        real_vecs.append(z_real)

        # Now compute avg log likelihood of the test data after being trained on this particular train sample
        avg_test_ll = []
        for test_ind in range(len(x_test)):
            yt_test = x_test[test_ind]
            _, loglik_test = compute_log_marginal_lik_poisson(K, yt_test, zt[-1], pi_mat, lam_a_pri, lam_b_pri, ysum, ycnt);
            avg_test_ll.append(loglik_test)
        avg_test_ll = np.mean(avg_test_ll)
        test_log_likes.append(avg_test_ll)
        
        
    print('log_likes keys: ', log_likes.keys())
    print('final train accuracy mean: ', np.mean(accs), ' and std: ', np.std(accs))

    train_log_likes_over_iters = [np.mean(log_likes[k]) for k in log_likes]
    print('Train Log likelihood over iterations: ', train_log_likes_over_iters)
    print('Avg Test Log Likelihood on last iteration: ', np.mean(test_log_likes))

    plt.figure()
    plt.plot(np.array(train_log_likes_over_iters))
    plt.savefig("./training_log_like_%s.pdf" % args.data)
    plt.close()
    
    print('Train Accuracy over all samples: ')
    lens = np.array([len(a) for a in real_vecs])
    if np.any(lens!=lens[0]):
        print(evaluate_acc(z_true=np.array(real_vecs, dtype='object'), z_pred=np.array(pred_vecs, dtype='object')))
    else:
        print(evaluate_acc(z_true=np.array(real_vecs), z_pred=np.array(pred_vecs)))
    
            
        
    
        




if __name__=="__main__":
    parser = argparse.ArgumentParser(description='Run DS HDP HMM Poisson')
    parser.add_argument('--data', type=str, default='sim_hard')
    parser.add_argument('--seed', type=int)
    parser.add_argument('--iters', type=int)
    args = parser.parse_args()
    main(args)