import pymc3 as pm
import os, pickle, scipy.io, torch
import numpy as np
from torch.utils.data import DataLoader
from models_datasets import AttnCTMDataset, generate_attn_batch
from scipy.special import softmax
from scipy.optimize import minimize, LinearConstraint

BATCH_SIZE = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def mcmc_posterior(A, document):
    K, V = A.shape
    K_s = int(K/2)
    m = pm.Model()
    with m:
        s = pm.Dirichlet('s', a=(1 / K_s) * np.ones(K_s))
        w = pm.Dirichlet('w', a=np.ones(2) * 30)

        s_t = []
        for j in range(K_s):
            s_t.append(s[j] * w[0])
            s_t.append(s[j] * w[1])

        Y_obs = pm.Categorical("Y_obs", p=pm.math.dot(s_t, A), observed=document)
        idata = pm.sample(3000, tune=2000, chains=1, return_inferencedata=True, target_accept=0.95)

    ret = np.zeros(K)
    for i in range(len(idata.posterior['s'].sel(chain=0).data)):
        post_i = []
        for j in range(K_s):
            post_i.append(idata.posterior['s'].sel(chain=0).data[i][j] * idata.posterior['w'].sel(chain=0).data[i][0])
            post_i.append(idata.posterior['s'].sel(chain=0).data[i][j] * idata.posterior['w'].sel(chain=0).data[i][1])
        ret += np.array(post_i)

    ret = ret / len(idata.posterior['s'].sel(chain=0).data)
    return ret

def recover_docs(tokens, counts):
    ret = []
    for this_doc_tokens, this_doc_counts in zip(tokens, counts):
        this_doc = np.concatenate([np.repeat(t, c) for t, c in zip(this_doc_tokens, this_doc_counts)])
        ret.append(this_doc)
    return ret

def construct_posteriors(data_path, model, N_test, constrained_opt=False):
    with open(os.path.join(data_path, "topics.pkl"), 'rb') as fname:
        A = pickle.load(fname)

    K, V = A.shape

    token_file = os.path.join(data_path, 'bow_test_tokens_'+str(N_test))
    count_file = os.path.join(data_path, 'bow_test_counts_'+str(N_test))

    tokens = scipy.io.loadmat(token_file)['tokens'].squeeze()
    counts = scipy.io.loadmat(count_file)['counts'].squeeze()

    assert N_test == len(tokens)
    test_documents = recover_docs(tokens, counts)

    posterior_file_name = 'MCMC_Posterior_%d.pkl'%N_test
    ## If MCMC posterior has not been calculated
    if not (posterior_file_name in os.listdir(data_path)):
        ## Use MCMC to calculate posterior
        L = np.zeros((N_test, K)) # size=num_docs*K
        print('Running MCMC to calculate E(w|x) for %d test documents...'%N_test)
        for i in range(N_test):
            L[i,:] = mcmc_posterior(A, test_documents[i])

        row_sums = L.sum(axis=1)
        L = L / row_sums[:, np.newaxis]

        with open(os.path.join(data_path, posterior_file_name), 'wb') as f:
            pickle.dump(L, f)

    else:
        print('Using existing MCMC posterior as true posterior')
        with open(os.path.join(data_path, posterior_file_name), 'rb') as f:
            L = pickle.load(f)

    ## Calculate posterior based on ML model's prediction
    test_dataset = AttnCTMDataset(test_documents, [0]*N_test)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                 shuffle=False,
                                 collate_fn=generate_attn_batch)

    with torch.no_grad():
        all_preds = []
        model.to(device)
        model.eval()
        print("Making predictions")
        for text, _ in test_dataloader:
            predictions = model.get_word_probability(text.to(device)).squeeze(1).float()
            all_preds.append(predictions.cpu())

        print("Tensorizing result")
        output = torch.cat(all_preds)
        output = output.view(len(test_dataset), -1)

    # Estimate E(posterior)
    Eta = np.transpose(np.dot(np.linalg.pinv(np.transpose(A)),
                              np.transpose(output)))

    if constrained_opt:
        output = np.array(output)
        Eta = optimize_posterior(output, A, Eta)
    
    row_sums = Eta.sum(axis=1)
    Eta = Eta / row_sums[:, np.newaxis]

    TV = np.sum(np.abs(Eta - L)) / (2.0 * N_test)

    predicted_MAP = np.argmax(Eta, axis=1)
    real_MAP = np.argmax(L, axis=1)
    accuracy = np.sum(predicted_MAP == real_MAP)/N_test

    return Eta, L, TV, accuracy

def optimize_posterior(predictions, A, Eta):
    K = A.shape[0]
    A = np.transpose(A)
    for i in range(len(predictions)):
        b = predictions[i]
        def loss(x):
            return np.sum(np.abs(np.dot(A,x)-b))

        constraint = LinearConstraint(np.ones(K), lb=1, ub=1)
        bounds = [(0, 1) for _ in range(K)]
        x0 = Eta[i]
        result = minimize(fun=loss, x0=x0, bounds=bounds, constraints=constraint)

        if loss(result.x) < loss(x0):
            Eta[i] = result.x
    return Eta

