#!/usr/bin/env python
# -*- coding: utf-8 -*-


import os
import sys
import random
import pdb
import numpy as np
import tensorflow as tf
import cPickle as pkl
import six.moves.cPickle as pickle
from collections import OrderedDict
import tensorflow.contrib.eager as tfe
import pandas as pd
tfe.enable_eager_execution()


from BBB_BIH import load_mit_bih_a_testdata, \
    prepare_data, get_minibatches_idx, set_seed, load_models, create_config, \
    LSTM_Model, load_mit_bih_a_data, accuracy

def run_translation(predict, model, logger, batchsize, type='test'):
    test_batches = get_minibatches_idx(len(predict[0]), batchsize)
    accuracies = []
    all_logits=[]
    for test_batch in test_batches:
        y_v = [predict[1][i] for i in test_batch]
        x_v = [predict[0][i] for i in test_batch]
        x_, mask_, y_ = prepare_data(x_v, y_v, maxlen=None)
        x_ = x_.reshape((x_.shape[0], x_.shape[1], 1))
        test_logits = model(x_, mask_)
        for test_logit in test_logits:
            all_logits.append(test_logit.numpy())
        accuracy_ = accuracy(test_logits, y_)
        accuracies.append(accuracy_)
    averaged_accuracy = np.mean(accuracies)
    print(type+' error : {} '.format(1 - averaged_accuracy))
    logger.write(type+' error : {} '.format(1 - averaged_accuracy))
    logger.write(type+' acc : {} '.format(averaged_accuracy))
    logger.flush()
    return all_logits

def predict(batch_size=256, X=10):
    for x in range(X):
        nr = x + 1
        log_path = "logs/bbb_bih.predict."+str(nr)+".log"
        logger = open(log_path, "a")
        print('\nSetting seed')

        print('...Loading Data...')
        test = load_mit_bih_a_testdata()
        model_path = 'models/BBB_BIH.pkl'
        config = create_config() 
        print('...Loading Model...')
        trained_param = load_models(model_path)
        LSTM_model = LSTM_Model(config, trained_param)
        print('...Done...')

        all_test_logits = run_translation(test, LSTM_model, logger, batch_size)

        logits_labels= OrderedDict()
        logits_labels['test_logits']=all_test_logits
        logits_labels['test_labels']=test[1]
        output = open("bih_results/"+"bbb_test_logits_labels."+str(nr)+".pkl", 'wb')
        pickle.dump(logits_labels, output)
        output.close()


if __name__ == '__main__':
    predict()

