import torch, argparse, os, pickle
from posteriors import recover_docs
import scipy.io
import numpy as np
from torch.utils.data import DataLoader
from models_datasets import AttnCTMDataset, generate_attn_batch, AttentionModel

BATCH_SIZE = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def top_k_ovelap(y_true, y_pred, k):
    assert y_true.shape == y_pred.shape
    ret = 0
    N, C = y_true.shape
    if k > C:
        print('error')
        return -1
    for i in range(N):
        overlap = len(set(np.argsort(y_pred[i])[-k:])& set(np.argsort(y_true[i])[-k:]))/k
        ret += overlap
    return ret/N

if __name__ == '__main__':
    arg = argparse.ArgumentParser()
    arg.add_argument('alpha', type=int, help='alpha')
    arg.add_argument('filename', type=str, help='model weight file path')

    input_args= arg.parse_args()
    weights_file = input_args.filename
    data_path = 'data/alpha%d.0_60wordsdoc'%input_args.alpha

    with open(os.path.join(data_path, "topics.pkl"), 'rb') as fname:
        A = pickle.load(fname)

    with open(os.path.join(data_path, 'MCMC_Posterior_200.pkl'), 'rb') as f:
        L = pickle.load(f)

    model = torch.load('savedmodels/%s'%weights_file)
    #model = AttentionModel(5000, 512, 0.0, 3)

    # get test documents
    token_file = os.path.join(data_path, 'bow_test_tokens_200')
    count_file = os.path.join(data_path, 'bow_test_counts_200')

    tokens = scipy.io.loadmat(token_file)['tokens'].squeeze()
    counts = scipy.io.loadmat(count_file)['counts'].squeeze()

    test_documents = recover_docs(tokens, counts)

    # get predictions (N, V)
    test_dataset = AttnCTMDataset(test_documents, [0] * len(test_documents))
    test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, num_workers=12, pin_memory=True,
                                 shuffle=False,
                                 collate_fn=generate_attn_batch)

    with torch.no_grad():
        all_preds = []
        model.to(device)
        model.eval()
        print("Making predictions")
        for text, _ in test_dataloader:
            predictions = model.get_word_probability(text.to(device)).squeeze(1).float()
            all_preds.append(predictions.cpu())

        print("Tensorizing result")
        output = torch.cat(all_preds)
        output = output.view(len(test_dataset), -1)

    # Estimate E(posterior)
    Eta = np.transpose(np.dot(np.linalg.pinv(np.transpose(A)),
                              np.transpose(output)))
    row_sums = Eta.sum(axis=1)
    Eta = Eta / row_sums[:, np.newaxis]

    print('average TV:', np.sum(np.abs(Eta-L)) / (2 * len(test_documents)))
    print('top 2 overlap:', top_k_ovelap(L, Eta, 2))
    print('top 4 overlap:', top_k_ovelap(L, Eta, 4))
    print('top 6 overlap:', top_k_ovelap(L, Eta, 6))

