import numpy as np
import pymc3 as pm
import os, pickle, argparse
import scipy.io
from scipy.special import softmax

def recover_docs(tokens, counts):
    ret = []
    for this_doc_tokens, this_doc_counts in zip(tokens, counts):
        this_doc = np.concatenate([np.repeat(t, c) for t, c in zip(this_doc_tokens, this_doc_counts)])
        ret.append(this_doc)
    return ret


def mcmc_PURE(topics, documents):
    K, _ = topics.shape
    p_mat = np.ones((200, K))
    i = 0
    for doc, _ in documents:
        for ind in doc:
            p_mat[i, :] = p_mat[i, :] * topics[:, ind]
        p_mat[i, :] = p_mat[i, :] * (K / (np.sum(p_mat[i, :])+0.00001))
        i += 1

    return (p_mat)

def mcmc_LDA(A, document):
    K, V = A.shape
    m = pm.Model()
    with m:
        w = pm.Dirichlet('w', a=np.ones(K)/K)  # topic distribution vector
        Y_obs = pm.Categorical('Y_obs', p=w.dot(A), observed=document)
        idata = pm.sample(draws=2000, tune=3000, chains=1, return_inferencedata=True, target_accept=0.9)
    return idata.posterior['w'].sel(chain=0).mean(axis=0).data


sigma = np.load('cov_matrix.npy')

def mcmc_CTM(A, document):
    K, V = A.shape
    m = pm.Model()
    with m:
        w = pm.MvNormal('w', mu=[0]*20, cov=sigma, shape=(20,))  # topic distribution vector
        Y_obs = pm.Categorical('Y_obs', p=(pm.math.exp(w)/pm.math.sum(pm.math.exp(w), axis=0)).dot(A), observed=document)
        idata = pm.sample(draws=2000, tune=3000, chains=1, return_inferencedata=True, target_accept=0.9)
    return np.mean(softmax(idata.posterior['w'].sel(chain=0).data, axis=1), axis=0)

def mcmc_PAM(A, document):
    K, V = A.shape
    K_s = int(K/2)
    m = pm.Model()
    with m:
        s = pm.Dirichlet('s', a=(1 / K_s) * np.ones(K_s))
        w = pm.Dirichlet('w', a=np.ones(2) * 30)

        s_t = []
        for j in range(K_s):
            s_t.append(s[j] * w[0])
            s_t.append(s[j] * w[1])

        Y_obs = pm.Categorical("Y_obs", p=pm.math.dot(s_t, A), observed=document)
        idata = pm.sample(3000, tune=2000, chains=1, return_inferencedata=True, target_accept=0.95)

    ret = np.zeros(K)
    for i in range(len(idata.posterior['s'].sel(chain=0).data)):
        post_i = []
        for j in range(K_s):
            post_i.append(idata.posterior['s'].sel(chain=0).data[i][j] * idata.posterior['w'].sel(chain=0).data[i][0])
            post_i.append(idata.posterior['s'].sel(chain=0).data[i][j] * idata.posterior['w'].sel(chain=0).data[i][1])
        ret += np.array(post_i)

    ret = ret / len(idata.posterior['s'].sel(chain=0).data)
    return ret

A_ctm = np.load('../src_CTM/TopicMatrices.npy')
A_pam = np.load('../src_PAM/TopicMatrices.npy')

with open('../src_LDA/TopicsMatrix.pkl', 'rb') as f:
    A_lda = pickle.load(f)

with open('../src_pure/TopicsMatrix.pkl', 'rb') as f:
    A_pure = pickle.load(f)


def get_documents(doc_type, alpha):
    data_path = 'testdocs/' + doc_type + '/alpha%d' % alpha

    token_file = os.path.join(data_path, 'bow_test_tokens')
    count_file = os.path.join(data_path, 'bow_test_counts')

    tokens = scipy.io.loadmat(token_file)['tokens'].squeeze()
    counts = scipy.io.loadmat(count_file)['counts'].squeeze()

    documents = recover_docs(tokens, counts)[:200]

    with open(os.path.join(data_path, 'MCMC_Posterior_200.pkl'), 'rb') as f:
        L = pickle.load(f)

    L = list(L)

    return zip(documents, L)


# input argument = pure, lda, ctm, pam
def infer(doc_type):
    document_list = []
    for alpha in range(1, 10, 2):
        document_list.append(get_documents(doc_type, alpha))

    if doc_type == 'pure':
        for i in range(5):
            all_lda_posterior = np.zeros((200, 20))
            all_ctm_posterior = np.zeros((200, 20))
            all_pam_posterior = np.zeros((200, 20))

            alpha = 2 * i + 1
            doc_results = []

            thing = get_documents(doc_type, alpha)

            ind = 0
            for doc, true_post in thing:
                new_record = {}

                # Assuming LDA prior
                lda_posterior = mcmc_LDA(A_pure[2 * i], doc)
                all_lda_posterior[ind, :] = lda_posterior
                new_record['lda'] = np.sum(np.abs(lda_posterior - true_post)) / 2

                # Assuming CTM prior
                ctm_posterior = mcmc_CTM(A_pure[2 * i], doc)
                all_ctm_posterior[ind, :] = ctm_posterior
                new_record['ctm'] = np.sum(np.abs(ctm_posterior - true_post)) / 2

                # Assuming PAM prior
                pam_posterior = mcmc_PAM(A_pure[2 * i], doc)
                all_pam_posterior[ind, :] = pam_posterior
                new_record['pam'] = np.sum(np.abs(pam_posterior - true_post)) / 2

                doc_results.append(new_record)
                ind += 1

            np.save('output/pure_lda_%d.npy'%alpha, all_lda_posterior)
            np.save('output/pure_ctm_%d.npy' %alpha, all_ctm_posterior)
            np.save('output/pure_pam_%d.npy' %alpha, all_pam_posterior)

            with open('output/pure_output_%d.pkl' % (2 * i + 1), 'wb') as f:
                pickle.dump(doc_results, f)

    elif doc_type == 'lda':
        for i in range(5):
            # Assuming pure prior
            all_pure_posterior = mcmc_PURE(A_lda[2 * i], document_list[i])

            # valid_posterior = pure_posterior[~np.isnan(pure_posterior).any(axis=1)]

            with open('testdocs/pure/alpha%d/MCMC_Posterior_200.pkl' % (2 * i + 1), 'rb') as f:
                valid_truth = pickle.load(f)

            avg_pure_tv = np.sum(np.abs(all_pure_posterior - valid_truth)) / 400
            # valid_truth = valid_truth[~np.isnan(pure_posterior).any(axis=1)]
            # avg_pure_tv = np.sum(np.abs(valid_posterior - valid_truth))/(2*200)

            alpha = 2 * i + 1
            doc_results = []

            all_ctm_posterior = np.zeros((200, 20))
            all_pam_posterior = np.zeros((200, 20))

            thing = get_documents(doc_type, alpha)

            ind = 0
            for doc, true_post in thing:
                new_record = {}
                new_record['pure'] = avg_pure_tv

                # Assuming CTM prior
                ctm_posterior = mcmc_CTM(A_lda[2 * i], doc)
                all_ctm_posterior[ind, :] = ctm_posterior
                new_record['ctm'] = np.sum(np.abs(ctm_posterior - true_post)) / 2

                # Assuming PAM prior
                pam_posterior = mcmc_PAM(A_lda[2 * i], doc)
                all_pam_posterior[ind, :] = pam_posterior
                new_record['pam'] = np.sum(np.abs(pam_posterior - true_post)) / 2

                doc_results.append(new_record)
                ind += 1

            # Save
            np.save('output/lda_pure_%d.npy'%alpha, all_pure_posterior)
            np.save('output/lda_ctm_%d.npy'%alpha, all_ctm_posterior)
            np.save('output/lda_pam_%d.npy'%alpha, all_pam_posterior)

            with open('output/lda_output_%d.pkl' % (2 * i + 1), 'wb') as f:
                pickle.dump(doc_results, f)


    elif doc_type == 'ctm':
        for i in range(5):
            all_lda_posterior = np.zeros((200, 20))
            all_pure_posterior = np.zeros((200, 20))
            all_pam_posterior = np.zeros((200, 20))

            # Assuming pure prior
            all_pure_posterior = mcmc_PURE(A_ctm[2 * i], document_list[i])
            #valid_posterior = pure_posterior[~np.isnan(pure_posterior).any(axis=1)]

            with open('testdocs/pure/alpha%d/MCMC_Posterior_200.pkl' % (2 * i + 1), 'rb') as f:
                valid_truth = pickle.load(f)

            #valid_truth = valid_truth[~np.isnan(pure_posterior).any(axis=1)]
            avg_pure_tv = np.sum(np.abs(all_pure_posterior - valid_truth)) / (2 * 200)

            alpha = 2 * i + 1
            doc_results = []

            thing = get_documents(doc_type, alpha)
            ind = 0
            for doc, true_post in thing:
                new_record = {}
                new_record['pure'] = avg_pure_tv

                # Assuming LDA prior
                lda_posterior = mcmc_LDA(A_ctm[2 * i], doc)
                all_lda_posterior[ind, :] = lda_posterior
                new_record['lda'] = np.sum(np.abs(lda_posterior - true_post)) / 2

                # Assuming PAM prior
                pam_posterior = mcmc_PAM(A_ctm[2 * i], doc)
                all_pam_posterior[ind, :] = pam_posterior
                new_record['pam'] = np.sum(np.abs(pam_posterior - true_post)) / 2

                doc_results.append(new_record)
                ind += 1

            # save
            np.save('output/ctm_pure_alpha%d.npy'%alpha, all_pure_posterior)
            np.save('output/ctm_lda_alpha%d.npy'%alpha, all_lda_posterior)
            np.save('output/ctm_pam_alpha%d.npy'%alpha, all_pam_posterior)

            with open('output/ctm_output_%d.pkl' % (2 * i + 1), 'wb') as f:
                pickle.dump(doc_results, f)

    elif doc_type == 'pam':
        for i in range(5):
            all_lda_posterior = np.zeros((200, 20))
            all_pure_posterior = np.zeros((200, 20))
            all_ctm_posterior = np.zeros((200, 20))

            # Assuming pure prior
            pure_posterior = mcmc_PURE(A_pam[2 * i], document_list[i])
            valid_posterior = pure_posterior[~np.isnan(pure_posterior).any(axis=1)]

            with open('testdocs/pure/alpha%d/MCMC_Posterior_200.pkl' % (2 * i + 1), 'rb') as f:
                valid_truth = pickle.load(f)

            valid_truth = valid_truth[~np.isnan(pure_posterior).any(axis=1)]
            avg_pure_tv = np.sum(np.abs(valid_posterior - valid_truth)) / (2 * 200)

            alpha = 2 * i + 1
            doc_results = []

            thing = get_documents(doc_type, alpha)
            ind = 0
            for doc, true_post in thing:
                new_record = {}
                new_record['pure'] = avg_pure_tv

                # Assuming CTM prior
                ctm_posterior = mcmc_CTM(A_pam[2 * i], doc)
                all_ctm_posterior[ind, :] = ctm_posterior
                new_record['ctm'] = np.sum(np.abs(ctm_posterior - true_post)) / 2

                # Assuming LDA prior
                lda_posterior = mcmc_LDA(A_pam[2 * i], doc)
                all_lda_posterior[ind, :] = lda_posterior
                new_record['lda'] = np.sum(np.abs(lda_posterior - true_post)) / 2

                doc_results.append(new_record)
                ind += 1

            # save
            np.save('output/pam_pure_alpha%d.npy'%alpha, all_pure_posterior)
            np.save('output/pam_lda_alpha%d.npy'%alpha, all_lda_posterior)
            np.save('output/pam_ctm_alpha%d.npy'%alpha, all_ctm_posterior)

            with open('output/PAM_output_%d.pkl' % (2 * i + 1), 'wb') as f:
                pickle.dump(doc_results, f)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('model_type', type=int, help='one of pure, lda, ctm, pam')

    input_args = parser.parse_args()
    model_type = input_args.model_type

    infer(model_type)