import pymc3 as pm
import os, pickle, scipy.io, torch
import numpy as np
from torch.utils.data import DataLoader
from models_datasets import AttnCTMDataset, generate_attn_batch
from scipy.special import softmax
from scipy.optimize import minimize, LinearConstraint

BATCH_SIZE = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def true_posterior(topics, documents, two_targets=False):
    K, _ = topics.shape

    ## Expand documents
    p_mat = np.ones((len(documents), K))
    for i, doc in enumerate(documents):
        for word in doc:
            p_mat[i, :] = p_mat[i, :] * topics[:, word]
        p_mat[i, :] = p_mat[i, :]/(np.sum(p_mat[i, :]))

    ret = p_mat
    if two_targets:
        ret = np.zeros((len(documents), K**2))
        for i in range(p_mat.shape[0]):
            ret[i, :] = np.kron(p_mat[i], p_mat[i]).squeeze()

    return ret

def mcmc_posterior_wp(prior_type, A, document):
    assert prior_type in ['lda', 'ctm', 'pam'], f'Invalid argument prior_type={str(prior_type)}'

    K, V = A.shape
    m = pm.Model()

    if prior_type == 'lda':
        with m:
            w = pm.Dirichlet('w', a=np.ones(K)/K)  # topic distribution vector
            Y_obs = pm.Categorical('Y_obs', p=w.dot(A), observed=document)
            idata = pm.sample(draws=2000, tune=3000, chains=1, return_inferencedata=True, target_accept=0.9)
        return idata.posterior['w'].sel(chain=0).mean(axis=0).data

    elif prior_type == 'ctm':
        sigma = np.load('cov_matrix.npy')
        with m:
            w = pm.MvNormal('w', mu=[0]*K, cov=sigma, shape=(K,))  # topic distribution vector
            Y_obs = pm.Categorical('Y_obs', p=(pm.math.exp(w)/pm.math.sum(pm.math.exp(w), axis=0)).dot(A), observed=document)
            idata = pm.sample(draws=2000, tune=3000, chains=1, return_inferencedata=True, target_accept=0.9)
        return np.mean(softmax(idata.posterior['w'].sel(chain=0).data, axis=1), axis=0)

    elif prior_type == 'pam':
        K_s = K // 2
        with m:
            s = pm.Dirichlet('s', a=(1 / K_s) * np.ones(K_s))
            w = pm.Dirichlet('w', a=np.ones(2) * 30)

            s_t = []
            for j in range(K_s):
                s_t.append(s[j] * w[0])
                s_t.append(s[j] * w[1])

            Y_obs = pm.Categorical("Y_obs", p=pm.math.dot(s_t, A), observed=document)
            idata = pm.sample(2000, tune=3000, chains=1, return_inferencedata=True, target_accept=0.95)

        ret = np.zeros(K)
        for i in range(len(idata.posterior['s'].sel(chain=0).data)):
            post_i = []
            for j in range(K_s):
                post_i.append(idata.posterior['s'].sel(chain=0).data[i][j] * idata.posterior['w'].sel(chain=0).data[i][0])
                post_i.append(idata.posterior['s'].sel(chain=0).data[i][j] * idata.posterior['w'].sel(chain=0).data[i][1])
            ret += np.array(post_i)

        ret = ret / len(idata.posterior['s'].sel(chain=0).data)
        return ret

    else:
        return None


def vi_posterior(vi_type, A, document, two_targets = False):
    assert vi_type in ['lda', 'ctm', 'pam'], f'Invalid argument prior_type={str(prior_type)}'
    
    if vi_type == 'lda':
        K, V = A.shape
        m = pm.Model()
        with m:
            w = pm.Dirichlet('w', a=np.ones(K) / K)  # topic distribution vector
            Y_obs = pm.Categorical('Y_obs', p=w.dot(A), observed=document)
            mean_field = pm.fit(method='advi')
        draws = softmax(mean_field.sample(2000)['w'], axis=1)
        if two_targets:
            ret = np.zeros(K*K)
            for i in range(draws.shape[0]):
                ret = ret + np.kron(draws[i], draws[i])
            return ret/draws.shape[0]
        return draws.mean(axis=0)

    elif vi_type == 'ctm':
        sigma = np.load('cov_matrix.npy')
        K, V = A.shape
        m = pm.Model()
        with m:
            w = pm.MvNormal('w', mu=[0] * K, cov=sigma, shape=(K,))  # topic distribution vector
            Y_obs = pm.Categorical('Y_obs', p=(pm.math.exp(w) / pm.math.sum(pm.math.exp(w), axis=0)).dot(A),observed=document)
            mean_field = pm.fit(method='advi')
        draws = softmax(mean_field.sample(2000)['w'], axis=1)
        if two_targets:
            ret = np.zeros(K * K)
            for i in range(draws.shape[0]):
                ret = ret + np.kron(draws[i], draws[i])
            return ret / draws.shape[0]
        return draws.mean(axis=0)

    elif vi_type == 'pam':
        K, V = A.shape
        K_s = int(K / 2)
        m = pm.Model()
        with m:
            s = pm.Dirichlet('s', a=(1 / K_s) * np.ones(K_s))
            w = pm.Dirichlet('w', a=np.ones(2) * 30)

            s_t = []
            for j in range(K_s):
                s_t.append(s[j] * w[0])
                s_t.append(s[j] * w[1])

            Y_obs = pm.Categorical("Y_obs", p=pm.math.dot(s_t, A), observed=document)
            mean_field = pm.fit(method='advi')

        if two_targets:
            estimates = mean_field.sample(2000)
            ret = np.zeros(K * K)

            for i in range(len(estimates['s'])):
                post_i = []
                for j in range(K_s):
                    post_i.append(estimates['s'][i][j] * estimates['w'][i][0])
                    post_i.append(estimates['s'][i][j] * estimates['w'][i][1])
                draw = np.array(post_i)
                ret = ret + np.kron(draw, draw)

            ret = ret / len(estimates['s'])
            return ret
        else:
            estimates = mean_field.sample(2000)
            ret = np.zeros(K)

            for i in range(len(estimates['s'])):
                post_i = []
                for j in range(K_s):
                    post_i.append(estimates['s'][i][j] * estimates['w'][i][0])
                    post_i.append(estimates['s'][i][j] * estimates['w'][i][1])
                ret += np.array(post_i)

            ret = ret / len(estimates['s'])
            return ret


def recover_docs(tokens, counts):
    ret = []
    for this_doc_tokens, this_doc_counts in zip(tokens, counts):
        this_doc = np.concatenate([np.repeat(t, c) for t, c in zip(this_doc_tokens, this_doc_counts)])
        ret.append(this_doc)
    return ret


def construct_posteriors(data_path, model, N_test, constrained_opt=False, two_targets=False, re_calculate=False):
    with open(os.path.join(data_path, "topics.pkl"), 'rb') as fname:
        A = pickle.load(fname)

    K, V = A.shape

    token_file = os.path.join(data_path, 'bow_test_tokens_'+str(N_test))
    count_file = os.path.join(data_path, 'bow_test_counts_'+str(N_test))

    tokens = scipy.io.loadmat(token_file)['tokens'].squeeze()
    counts = scipy.io.loadmat(count_file)['counts'].squeeze()

    assert N_test == len(tokens)
    test_documents = recover_docs(tokens, counts)

    if not two_targets:
        # Calculate posterior by Variational Inference
        for vi_type in ['lda', 'ctm', 'pam']:
            vi_posterior_filename = '%s_prior_VI_Posterior_%d.pkl'%(vi_type, N_test)
            if two_targets:
                vi_posterior_filename = 'Two_targets_'+vi_posterior_filename
            if re_calculate or not (vi_posterior_filename in os.listdir(data_path)):
                test_documents = recover_docs(tokens,counts)
                L = np.zeros((N_test, K))
                if two_targets:
                    L = np.zeros((N_test, K*K))

                print(f'Running Variational Inference assuming {vi_type} to calculate E(w|x) for {N_test} test documents...')
                for i in range(N_test):
                    L[i,:] = vi_posterior(vi_type, A, test_documents[i], two_targets)
                row_sums = L.sum(axis=1)
                L = L / row_sums[:, np.newaxis]
                with open(os.path.join(data_path, vi_posterior_filename), 'wb') as f:
                    pickle.dump(L, f)

        # Calculate posterior by MCMC assuming some generative model
        for prior_type in ['lda', 'ctm', 'pam']:
            mcmc_posterior_filename = '%s_prior_MCMC_Posterior_%d.pkl'%(prior_type, N_test)
            if two_targets:
                mcmc_posterior_filename = 'Two_targets_'+mcmc_posterior_filename
            if re_calculate or not (mcmc_posterior_filename in os.listdir(data_path)):
                test_documents = recover_docs(tokens,counts)
                L = np.zeros((N_test, K))
                if two_targets:
                    L = np.zeros((N_test, K*K))

                print(f'Running MCMC assuming {prior_type} prior to calculate E(w|x) for {N_test} test documents...')
                for i in range(N_test):
                    L[i,:] = mcmc_posterior_wp(prior_type, A, test_documents[i])
                row_sums = L.sum(axis=1)
                L = L / row_sums[:, np.newaxis]
                with open(os.path.join(data_path, mcmc_posterior_filename), 'wb') as f:
                    pickle.dump(L, f)

    # Get true topic posterior
    true_posterior_file_name = 'True_Posterior_%d.pkl'%N_test
    if two_targets:
        true_posterior_file_name = 'Two_targets_'+true_posterior_file_name

    ## If true topic posterior has not been calculated
    if re_calculate or not (true_posterior_file_name in os.listdir(data_path)):
        ## Use MCMC to calculate posterior
        L = true_posterior(A, test_documents, two_targets=two_targets)

        with open(os.path.join(data_path, true_posterior_file_name), 'wb') as f:
            pickle.dump(L, f)

    else:
        print('Using existing posterior as true posterior')
        with open(os.path.join(data_path, true_posterior_file_name), 'rb') as f:
            L = pickle.load(f)

    ## Calculate posterior based on ML model's prediction
    test_dataset = AttnCTMDataset(test_documents, [0]*N_test)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                 shuffle=False,
                                 collate_fn=generate_attn_batch)

    with torch.no_grad():
        all_preds = []
        model.to(device)
        model.eval()
        print("Making predictions")
        for text, _ in test_dataloader:
            predictions = model.get_word_probability(text.to(device)).squeeze(1).float()
            all_preds.append(predictions.cpu())

        print("Tensorizing result")
        output = torch.cat(all_preds)
        output = output.view(len(test_dataset), -1)

    # Estimate E(posterior)
    if two_targets:
        A = np.kron(A,A)
    Eta = np.transpose(np.dot(np.linalg.pinv(np.transpose(A)),
                              np.transpose(output)))

    if constrained_opt:
        output = np.array(output)
        Eta = optimize_posterior(output, A, Eta)
    Eta = Eta * (Eta > 0)
    row_sums = Eta.sum(axis=1)
    Eta = Eta / row_sums[:, np.newaxis]

    TV = np.sum(np.abs(Eta - L)) / (2.0 * N_test)

    predicted_MAP = np.argmax(Eta, axis=1)
    real_MAP = np.argmax(L, axis=1)
    accuracy = np.sum(predicted_MAP == real_MAP)/N_test

    return Eta, L, TV, accuracy


def optimize_posterior(predictions, A, Eta):
    K = A.shape[0]
    A = np.transpose(A)
    for i in range(len(predictions)):
        b = predictions[i]
        def loss(x):
            return np.sum(np.abs(np.dot(A,x)-b))

        constraint = LinearConstraint(np.ones(K), lb=1, ub=1)
        bounds = [(0, 1) for _ in range(K)]
        x0 = Eta[i]
        result = minimize(fun=loss, x0=x0, bounds=bounds, constraints=constraint)

        if loss(result.x) < loss(x0):
            Eta[i] = result.x
    return Eta
    
