#!/usr/bin/env python
# -*- coding: utf-8 -*-


import os
import sys
import logging
import random
import pdb
import numpy as np
import tensorflow as tf
import cPickle as pkl
import six.moves.cPickle as pickle
from collections import OrderedDict


from LSTM_IMDB import prepare_data, get_minibatches_idx, set_seed, load_models, create_config, \
    LSTM_Model, accuracy
    
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
LOGGER = logging.getLogger(__name__)

def load_imdb_data_test(n_words=5000, sort_by_len=True):
    test_set = pkl.load(open("data/imdb_test.pkl", 'rb'))

    def remove_unk(x):
        return [[1 if w >= n_words else w for w in sen] for sen in x]

    test_set_x, test_set_y = test_set
    test_set_x = remove_unk(test_set_x)

    def len_argsort(seq):
        return sorted(range(len(seq)), key=lambda x: len(seq[x]))

    if sort_by_len:
        sorted_index = len_argsort(test_set_x)
        test_set_x = [test_set_x[i] for i in sorted_index]
        test_set_y = [test_set_y[i] for i in sorted_index]

    test = (test_set_x, test_set_y)
    return test
    
    



def run_translation(predict, model, logger, batchsize, type='test'):
    test_batches = get_minibatches_idx(len(predict[0]), batchsize)
    accuracies = []
    all_logits=[]
    for test_batch in test_batches:
        y_v = [predict[1][i] for i in test_batch]
        x_v = [predict[0][i] for i in test_batch]
        x_, mask_, y_ = prepare_data(x_v, y_v, maxlen=None)
        test_logits = model(x_, mask_)
        for test_logit in test_logits:
            all_logits.append(test_logit.numpy())
        accuracy_ = accuracy(test_logits, y_)
        accuracies.append(accuracy_)
    averaged_accuracy = np.mean(accuracies)
    print(type+' error : {} '.format(1 - averaged_accuracy))
    logger.write(type+' error : {} '.format(1 - averaged_accuracy))
    logger.write(type+' acc : {} '.format(averaged_accuracy))
    logger.flush()
    return all_logits


def predict(batch_size=8, X=10):
    for x in range(X):
        nr = x+1
        log_path = "logs/lstm_imdb_predict."+str(nr)+".log"
        logger = open(log_path, "a")
        print('\nSetting seed')
        #set_seed(42)

        print('...Loading Data...')
        test = load_imdb_data_test()
        pdb.set_trace()
        model_path = 'models/LSTM_IMDB.pkl'
        config = create_config()
        trained_param = load_models(model_path)
        print('trained_param: %s' % trained_param)
        assert trained_param is not None
        LSTM_model = LSTM_Model(config, trained_param)
        print('...Done...')

        
        all_test_logits = run_translation(test, LSTM_model, logger, batch_size)

        logits_labels= OrderedDict()
        logits_labels['test_logits']=all_test_logits
        logits_labels['test_labels']=test[1]
        output = open("results/"+"lstm_test_logits_labels."+str(nr)+".pkl", 'wb')
        pickle.dump(logits_labels, output)
        output.close()


if __name__ == '__main__':
    predict()
