import pymc3 as pm
import os, pickle, scipy.io, torch
import numpy as np
from torch.utils.data import DataLoader
from models_datasets import PredictWordDataset, generate_batch

BATCH_SIZE = 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def mcmc_posterior(A, document):
    K, V = A.shape
    m = pm.Model()
    with m:
        w = pm.Dirichlet('w', a=np.ones(K)/K)  # topic distribution vector
        Y_obs = pm.Categorical('Y_obs', p=w.dot(A), observed=document)
        idata = pm.sample(draws=2000, tune=3000, chains=1, return_inferencedata=True, target_accept=0.9)
    return idata.posterior['w'].sel(chain=0).mean(axis=0).data

def recover_docs(tokens, counts):
    ret = []
    for this_doc_tokens, this_doc_counts in zip(tokens, counts):
        this_doc = np.concatenate([np.repeat(t, c) for t, c in zip(this_doc_tokens, this_doc_counts)])
        ret.append(this_doc)
    return ret

def construct_posteriors(data_path, model, N_test):
    with open(os.path.join(data_path, "topics.pkl"), 'rb') as fname:
        A = pickle.load(fname)

    K, V = A.shape

    token_file = os.path.join(data_path, 'bow_test_tokens_'+str(N_test))
    count_file = os.path.join(data_path, 'bow_test_counts_'+str(N_test))

    tokens = scipy.io.loadmat(token_file)['tokens'].squeeze()
    counts = scipy.io.loadmat(count_file)['counts'].squeeze()

    assert N_test == len(tokens)

    posterior_file_name = 'MCMC_Posterior_%d.pkl'%N_test
    ## If MCMC posterior has not been calculated
    if not (posterior_file_name in os.listdir(data_path)):
        ## Use MCMC to calculate posterior
        test_documents = recover_docs(tokens,counts)
        L = np.zeros((N_test, K)) # size=num_docs*K
        print('Running MCMC to calculate E(w|x) for %d test documents...'%N_test)
        for i in range(N_test):
            L[i,:] = mcmc_posterior(A, test_documents[i])

        row_sums = L.sum(axis=1)
        L = L / row_sums[:, np.newaxis]

        with open(os.path.join(data_path, posterior_file_name), 'wb') as f:
            pickle.dump(L, f)

    else:
        print('Using existing MCMC posterior as true posterior')
        with open(os.path.join(data_path, posterior_file_name), 'rb') as f:
            L = pickle.load(f)

    ## Calculate posterior based on ML model's prediction
    test_dataset = PredictWordDataset(tokens, counts, [0]*N_test)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                 shuffle=False,
                                 collate_fn=generate_batch)

    with torch.no_grad():
        all_preds = []
        model.to(device)
        model.eval()
        print("Making predictions")
        for text, offset, _ in test_dataloader:
            predictions = model.get_word_probability(text.to(device), offset.to(device)).squeeze(1).float()
            all_preds.append(predictions.cpu())

        print("Tensorizing result")
        output = torch.cat(all_preds)
        output = output.view(len(test_dataset), -1)

    # Estimate E(posterior)
    Eta = np.transpose(np.dot(np.linalg.pinv(np.transpose(A)),
                              np.transpose(output)))
    row_sums = Eta.sum(axis=1)
    Eta = Eta / row_sums[:, np.newaxis]

    TV = np.sum(np.abs(Eta - L)) / (2.0 * N_test)

    predicted_MAP = np.argmax(Eta, axis=1)
    real_MAP = np.argmax(L, axis=1)
    accuracy = np.sum(predicted_MAP == real_MAP)/N_test

    return Eta, L, TV, accuracy
