#!/usr/bin/env python
# -*- coding: utf-8 -*-

import logging
import sys

import numpy as np
import pandas as pd
import tensorflow as tf
import six.moves.cPickle as pickle
from collections import OrderedDict

MODEL_PATH="imdb_vd"
DROPOUT_KEEP_PROB_TRAINED=0.9
DROPOUT_KEEP_PROB=0.9

from VD_IMDB import load_imdb_data, load_imdb_data_test, \
    prepare_data, get_minibatches_idx, set_seed, load_models, create_config, \
    LSTM_Model, accuracy


logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
LOGGER = logging.getLogger(__name__)


def run_translation(predict, model, logger, batchsize, type='test'):
    test_batches = get_minibatches_idx(len(predict[0]), batchsize)
    accuracies = []
    all_logits=[]
    for test_batch in test_batches:
        y_v = [predict[1][i] for i in test_batch]
        x_v = [predict[0][i] for i in test_batch]
        x_, mask_, y_ = prepare_data(x_v, y_v, maxlen=None)
        test_logits = model(x_, mask_)
        for test_logit in test_logits:
            all_logits.append(test_logit.numpy())
        accuracy_ = accuracy(test_logits, y_)
        accuracies.append(accuracy_)
    averaged_accuracy = np.mean(accuracies)
    print(type+' error : {} '.format(1 - averaged_accuracy))
    logger.write(type+' error : {} '.format(1 - averaged_accuracy))
    logger.write(type+' acc : {} '.format(averaged_accuracy))
    logger.flush()
    return all_logits


def predict(batch_size=8, X=10):
    for x in range(X):
        nr = x+1
        log_path = "logs/"+MODEL_PATH+"_"+str(DROPOUT_KEEP_PROB)+".predict."+str(nr)+".log"
        logger = open(log_path, "a")
        print('\nSetting seed')
        #set_seed(42)

        print('...Loading Data...')
        train, valid = load_imdb_data()
        test = load_imdb_data_test()
        model_path = 'new_models/VD_IMDB'#+MODEL_PATH+"_"+str(DROPOUT_KEEP_PROB_TRAINED)
        config = create_config(model_path=model_path, dropout_keep_prob=DROPOUT_KEEP_PROB)  # returns dictionary
        print('...Loading Model...')
        trained_param = load_models(config['model_path']+'.pkl')
        print('trained_param: %s' % trained_param)
        assert trained_param is not None
        LSTM_model = LSTM_Model(config, trained_param)
        print('...Done...')

        #all_valid_logits = run_translation(valid, LSTM_model, logger, batch_size, type='valid')
        all_valid_logits = []
        all_test_logits = run_translation(test, LSTM_model, logger, batch_size)

        logits_labels= OrderedDict()
        logits_labels['test_logits']=all_test_logits
        logits_labels['test_labels']=test[1]
        logits_labels['val_logits']=all_valid_logits
        logits_labels['val_labels']=valid[1]
        output = open("results/"+MODEL_PATH+"_"+str(DROPOUT_KEEP_PROB)+
                      ".test_logits_labels."+str(nr)+".pkl", 'wb')
        pickle.dump(logits_labels, output)
        output.close()


if __name__ == '__main__':
    predict()
