#!/usr/bin/env python
# -*- coding: utf-8 -*-


import os
import sys
import random
import pdb
import numpy as np
import tensorflow as tf
import cPickle as pkl
import six.moves.cPickle as pickle
from collections import OrderedDict

from LSTM_BIH import load_mit_bih_a_testdata, \
    prepare_data, get_minibatches_idx, set_seed, load_models, create_config, \
    LSTM_Model, load_mit_bih_a_data, accuracy

def run_translation(predict, model, logger, batchsize, type='test'):
    test_batches = get_minibatches_idx(len(predict[0]), batchsize)
    accuracies = []
    all_logits=[]
    for test_batch in test_batches:
        y_v = [predict[1][i] for i in test_batch]
        x_v = [predict[0][i] for i in test_batch]
        x_, mask_, y_ = prepare_data(x_v, y_v, maxlen=None)
        x_ = x_.reshape((x_.shape[0], x_.shape[1], 1))
        test_logits = model(x_, mask_)
        for test_logit in test_logits:
            all_logits.append(test_logit.numpy())
        accuracy_ = accuracy(test_logits, y_)
        accuracies.append(accuracy_)
    averaged_accuracy = np.mean(accuracies)
    print(type+' error : {} '.format(1 - averaged_accuracy))
    logger.write(type+' error : {} '.format(1 - averaged_accuracy))
    logger.write(type+' acc : {} '.format(averaged_accuracy))
    logger.flush()
    return all_logits

def predict(batch_size=256, X=10):
    for x in range(X):
        nr = x + 1
        log_path = "logs/lstm_bih.predict."+str(nr)+".log"
        logger = open(log_path, "a")
        print('\nSetting seed')
        #set_seed(12345)

        print('...Loading Data...')
        test = load_mit_bih_a_testdata()
        model_path = 'models/LSTM_BIH.pkl'
        config = create_config()  # returns dictionary
        print('...Loading Model...')
        trained_param = load_models(model_path)
        LSTM_model = LSTM_Model(config, trained_param)
        print('...Done...')

        all_test_logits = run_translation(test, LSTM_model, logger, batch_size)

        logits_labels= OrderedDict()
        logits_labels['test_logits']=all_test_logits
        logits_labels['test_labels']=test[1]
        output = open("bih_results/"+"lstm_test_logits_labels."+str(nr)+".pkl", 'wb')
        pickle.dump(logits_labels, output)
        output.close()


if __name__ == '__main__':
    predict()
