import os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model   #, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model

def load_meta_trace_files(database_folder_train, sKeyNo, work, load_metadata=False):
    trace = []
    skpv_a_vec0_evenCoeff0 = []
    bp_b_vec0_evenCoeff = []
    bp_b_vec0_evenCoeff_next_sKeyNo = []
    bp_b_vec0_oddCoeff = []
    bp_b_vec0_oddCoeff_next_sKeyNo = []
    fileNo = 0
    if work == 'train':
        NoFiles = len(training_file_list)
    else:
        NoFiles = 1
    for fileNo in range(0, NoFiles):
        if work == 'train':
            database_file = database_folder_train + training_file_list[fileNo]
        else:
            database_file = database_folder_train
        print('\nLoad database_file =', database_file)
        check_file_exists(database_file)
        # Open the Kyber database HDF5 for reading
        try:
            in_file  = h5py.File(database_file, "r")
        except:
            print("Error: can't open HDF5 file '%s' for reading (it might be malformed) ..." % database_file)
            sys.exit(-1)
        trace.append(np.array(in_file['wave'], dtype=float))
        skpv_a_vec0_evenCoeff0.append(np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo].astype(int)))
        bp_b_vec0_evenCoeff.append(np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo].astype(int)))
        bp_b_vec0_evenCoeff_next_sKeyNo.append(np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo+1].astype(int)))
        bp_b_vec0_oddCoeff.append(np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo].astype(int)))
        bp_b_vec0_oddCoeff_next_sKeyNo.append(np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo+1].astype(int)))
        
    trace_profiling = np.concatenate(trace)
    skpv_profiling = np.concatenate(skpv_a_vec0_evenCoeff0)

    return trace_profiling, skpv_profiling

def check_file_exists(file_path):
    if os.path.exists(file_path) == False:
        print("Error: provided file path '%s' does not exist!" % file_path)
        sys.exit(-1)
    return

sKeyNo = 0
test_trace, test_skpv = load_meta_trace_files('KYBER51.H5', sKeyNo, work = 'test', load_metadata=False)

print(test_skpv[:10])

data = np.load('data.npz')
trace_profiling = data['data']
bp_profiling = data['bp']
skpv_profiling = data['label']

print(len(skpv_profiling))
print(skpv_profiling[:10])