import numpy as np
import math, os, pickle, scipy.io
import torch
from torch.utils.data import DataLoader
from models_datasets import PredictWordDataset, generate_batch
import sklearn.metrics, scipy.stats


BATCH_SIZE = 2048
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def expand(tokens, counts):
    tokens = map(lambda x: x.squeeze(), tokens)
    tokens = map(lambda x: [x.tolist()] if x.ndim == 0 else x.tolist(), tokens)
    counts = map(lambda x: x.squeeze(), counts)
    counts = map(lambda x: [x.tolist()] if x.ndim == 0 else x.tolist(), counts)
    return (map(
        lambda x: np.random.permutation(np.concatenate(list(map(lambda z: np.repeat(z[0], z[1]), zip(x[0], x[1]))))),
        zip(tokens, counts)))


def translate(doc, id2word):
    return ([int(id2word[x]) for x in doc])


def likelihood_matrix(topics, id2word, tokens, counts):
    K, _ = topics.shape

    ## Expand documents
    documents = list(map(lambda x: translate(x, id2word), expand(tokens, counts)))

    p_mat = np.ones((len(documents), K))
    for i, doc in enumerate(documents):
        for ind in doc:
            p_mat[i, :] = p_mat[i, :] * topics[:, ind]
        p_mat[i, :] = p_mat[i, :] * (K / (np.sum(p_mat[i, :])))

    return (p_mat)


def likelihood_vector(tok, cou, id2word, topics):
    K, _ = topics.shape
    p_vec = np.ones(K)
    ## Expand the document
    doc = list(map(lambda x: translate(x, id2word), expand([tok], [cou])))[0]
    for ind in doc:
        p_vec = p_vec * topics[:, ind]
    return (p_vec / np.sum(p_vec))


def construct_posteriors(data_path, model):
    with open(os.path.join(data_path, "id2word.pkl"), 'rb') as fname:
        id2word = pickle.load(fname)

    with open(os.path.join(data_path, "topics.pkl"), 'rb') as fname:
        topics = pickle.load(fname) # topic matrix

    with open(os.path.join(data_path, "train_topics.pkl"), 'rb') as fname:
        doc_topics = pickle.load(fname) # test document's generating topic (index)


    token_file = os.path.join(data_path, 'bow_train_tokens')
    count_file = os.path.join(data_path, 'bow_train_counts')

    tokens = scipy.io.loadmat(token_file)['tokens'].squeeze()
    counts = scipy.io.loadmat(count_file)['counts'].squeeze()

    L_data = likelihood_matrix(topics, id2word, tokens, counts)

    # Get prediction
    test_dataset = PredictWordDataset(tokens, counts, doc_topics)
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True, shuffle=False,
                             collate_fn=generate_batch)
    with torch.no_grad():
        all_preds = []
        model.to(device)
        model.eval()
        print("Making predictions")
        for text, offset, _ in test_dataloader:
            predictions = model.get_word_probability(text.to(device), offset.to(device)).squeeze(1).float()
            all_preds.append(predictions.cpu())

        print("Tensorizing result")
        output = torch.cat(all_preds)
        output = output.view(len(test_dataset), -1)

    # Estimate posterior
    topic_matrix = np.transpose(topics)
    Eta = np.transpose(np.dot(np.linalg.pinv(topic_matrix),
np.transpose(output)))
    row_sums = Eta.sum(axis=1)
    Eta = Eta / row_sums[:, np.newaxis]

    #topic_distances = sklearn.metrics.pairwise_distances(topics, metric='l1')
    #K = topic_distances.shape[0]
    #avg_l1 = np.sum(topic_distances) / (K * (K - 1))
    #topic_distances = sklearn.metrics.pairwise_distances(topics, metric='l2')
    #avg_l2 = np.sum(topic_distances) / (K * (K - 1))

    #row_sums = Eta.sum(axis=1)
    #Eta = Eta / row_sums[:, np.newaxis]

    row_sums = L_data.sum(axis=1)
    L_data = L_data / row_sums[:, np.newaxis]

    N = len(tokens)
    predicted_classes = np.argmax(Eta, axis=1)
    accuracy = np.sum(predicted_classes == doc_topics) / N

    #avg_ent = sum(map(lambda i: scipy.stats.entropy(L_data[i, :]), range(N))) / N
    TV = np.sum(np.abs(Eta - L_data)) / (2.0 * N)

    #return (Eta, L_ref, L_data, X, doc_topics, avg_l1, avg_l2, avg_ent, TV, accuracy)
    return Eta, L_data, TV, accuracy
