import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm as tqdm
import pywt
from pywt import wavedec
import h5py

def load_meta_trace_file_from_test(database_file, sKeyNo, load_metadata=False):
    print('\nLoad database_file =', database_file)
    
    in_file  = h5py.File(database_file, "r")
    # Load profiling traces
    #trace_profiling = np.array(in_file['wave'], dtype=np.float)
    trace_profiling = np.array(in_file['wave'], dtype=float)
    #skpv_a_vec0_evenCoeff0 = np.array(in_file['sca_tmp_skpv'][:,sKeyNo])
    skpv_a_vec0_evenCoeff0 = np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo].astype(int))
    skpv_profiling = skpv_a_vec0_evenCoeff0
    bp_b_vec0_evenCoeff0 = np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo].astype(int))
    bp_b_vec0_oddCoeff1 = np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo].astype(int))
    bp_b_vec0_evenCoeff0_next_sKeyNo = np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo+1].astype(int))
    bp_b_vec0_oddCoeff1_next_sKeyNo = np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo+1].astype(int))
    
    sca_bp_in = np.array(in_file['sca_bp_in'])
    bp_profiling = [bp_b_vec0_evenCoeff0, bp_b_vec0_oddCoeff1, bp_b_vec0_evenCoeff0_next_sKeyNo, bp_b_vec0_oddCoeff1_next_sKeyNo]

    return (trace_profiling, bp_profiling, skpv_profiling)

def get_all_coeffs(input_signal, num_level):
    coeffs = wavedec(input_signal, 'db1', level=num_level) #Multiply by sqrt(2)
    all_data = {}
    all_data['cA{}'.format(num_level)] = coeffs[0]
    for i in range(num_level):
        all_data['cD{}'.format(i+1)] = coeffs[-i-1]
    for i in range(num_level, 1, -1):
        #print('-------------------------')
        #print(all_data['cA{}'.format(i)])
        #print(coeffs[-i])
        #print(pywt.idwt(all_data['cA{}'.format(i)], coeffs[-num_level], 'db1'))
        all_data['cA{}'.format(i-1)] = pywt.idwt(all_data['cA{}'.format(i)], coeffs[-i], 'db1')
    
    return all_data

#infile = np.load('data_wavelet.npz')
#data = infile['cA2']
#print(data.shape)
#exit()

infile = np.load('data.npz')
NUM_SAMPLES = 500000
data = infile['data'][:NUM_SAMPLES]
bp = infile['bp'][:NUM_SAMPLES]
labels = infile['label'][:NUM_SAMPLES]

'''
data, bp, labels = load_meta_trace_file_from_test('KYBER51.H5', sKeyNo = 0, load_metadata=False)
'''

all_wavelet = {'cA3':[],
              'cA2':[],
              'cA1':[],
              'cD3':[],
              'cD2':[],
              'cD1':[],}


for sample in tqdm(data):
    output = get_all_coeffs(sample, 3)
    for k,v in all_wavelet.items():
        all_wavelet[k].append(output[k])

np.savez('data_wavelet.npz', label=labels, bp = bp,
                            cA3 = np.array(all_wavelet['cA3']), 
                            cA2 = np.array(all_wavelet['cA2']),
                            cA1 = np.array(all_wavelet['cA1']),
                            cD3 = np.array(all_wavelet['cD3']), 
                            cD2 = np.array(all_wavelet['cD2']),
                            cD1 = np.array(all_wavelet['cD1']))