#!/usr/bin/env python
# -*- coding: utf-8 -*-


import os
import sys
import random
import pdb
import numpy as np
import tensorflow as tf
import pickle as pkl
import six.moves.cPickle as pickle
from collections import OrderedDict
import pandas as pd
#from peentree import *


from ST_BIH import load_mit_bih_a_testdata, \
    prepare_data, get_minibatches_idx, set_seed, load_models, create_config, \
    SRLSTM_Model, load_mit_bih_a_data, accuracy

def run_translation(predict, model, logger, batchsize, type='test'):
    test_batches = get_minibatches_idx(len(predict[0]), batchsize)
    accuracies = []
    all_logits=[]
    for test_batch in test_batches:
        y_v = [predict[1][i] for i in test_batch]
        x_v = [predict[0][i] for i in test_batch]
        x_, mask_, y_ = prepare_data(x_v, y_v, maxlen=None)
        x_ = x_.reshape((x_.shape[0], x_.shape[1], 1))
        test_logits = model(x_, mask_)
        for test_logit in test_logits:
            all_logits.append(test_logit.numpy())
        accuracy_ = accuracy(test_logits, y_)
        accuracies.append(accuracy_)
    averaged_accuracy = np.mean(accuracies)
    print(type+' error : {} '.format(1 - averaged_accuracy))
    logger.write(type+' error : {} '.format(1 - averaged_accuracy))
    logger.write(type+' acc : {} '.format(averaged_accuracy))
    logger.flush()
    return all_logits

def predict(batch_size=256, X=10):
    for x in range(X):
        nr = x + 1
        log_path = "logs/st_bih.predict."+str(nr)+".log"
        logger = open(log_path, "a")
        print('\nSetting seed')
        #set_seed(12345)

        print('...Loading Data...')
        test = load_mit_bih_a_testdata()
        model_path = 'models/ST-tau_BIH_c5.pkl'
        config = create_config()  # returns dictionary
        print('...Loading Model...')
        trained_param = load_models(model_path)
        LSTM_model = LSTM_Model(config, trained_param)
        print('...Done...')

        all_test_logits = run_translation(test, LSTM_model, logger, batch_size)

        logits_labels= OrderedDict()
        logits_labels['test_logits']=all_test_logits
        logits_labels['test_labels']=test[1]
        output = open("bih_results/"+"st_test_logits_labels."+str(nr)+".pkl", 'wb')
        pickle.dump(logits_labels, output)
        output.close()


if __name__ == '__main__':
    predict()

