#!/usr/bin/env python
# -*- coding: utf-8 -*-



import os
import sys
import random
import pdb
import numpy as np
import tensorflow as tf
import cPickle as pkl
import six.moves.cPickle as pickle
from collections import OrderedDict
#from peentree import *
import tensorflow.contrib.eager as tfe
import pandas as pd

from BBB_IMDB import prepare_data, get_minibatches_idx, set_seed, load_models, create_config, \
    LSTM_Model, accuracy


def load_imdb_data_test(n_words=5000, sort_by_len=True):
    test_set = pkl.load(open("data/imdb_test.pkl", 'rb'))

    def remove_unk(x):
        return [[1 if w >= n_words else w for w in sen] for sen in x]

    test_set_x, test_set_y = test_set
    test_set_x = remove_unk(test_set_x)

    def len_argsort(seq):
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))

    if sort_by_len:
        sorted_index = len_argsort(test_set_x)
        test_set_x = [test_set_x[i] for i in sorted_index]
        test_set_y = [test_set_y[i] for i in sorted_index]

    test = (test_set_x, test_set_y)
    return test
    
    
def run_translation(predict, model, logger, batchsize, type='test'):
    test_batches = get_minibatches_idx(len(predict[0]), batchsize)
    accuracies = []
    all_logits=[]
    for test_batch in test_batches:
        y_v = [predict[1][i] for i in test_batch]
        x_v = [predict[0][i] for i in test_batch]
        x_, mask_, y_ = prepare_data(x_v, y_v, maxlen=None)
        test_logits,_,_ = model(x_, mask_)
        for test_logit in test_logits:
            all_logits.append(test_logit.numpy())
        accuracy_ = accuracy(test_logits, y_)
        accuracies.append(accuracy_)
    averaged_accuracy = np.mean(accuracies)
    print(type+' error : {} '.format(1 - averaged_accuracy))
    logger.write(type+' error : {} '.format(1 - averaged_accuracy))
    logger.write(type+' acc : {} '.format(averaged_accuracy))
    logger.flush()
    return all_logits


def predict(batch_size=8, X=10):
    for x in range(X):
        nr = x+1
        log_path = "logs/bbb_imdb_predict."+str(nr)+".log"
        logger = open(log_path, "a")
        print('\nSetting seed')
        #set_seed(42)

        print('...Loading Data...')
        test = load_imdb_data_test()
        model_path = 'models/BBB_IMDB.pkl'
        config = create_config()
        trained_param = load_models(model_path)
        print('trained_param: %s' % trained_param)
        assert trained_param is not None
        LSTM_model = LSTM_Model(config, trained_param)
        print('...Done...')

        
        all_test_logits = run_translation(test, LSTM_model, logger, batch_size)

        logits_labels= OrderedDict()
        logits_labels['test_logits']=all_test_logits
        logits_labels['test_labels']=test[1]
        output = open("results/"+"bbb_imdb_test_logits_labels."+str(nr)+".pkl", 'wb')
        pickle.dump(logits_labels, output)
        output.close()


if __name__ == '__main__':
    predict()
    #build_train_model()

