import numpy as np
import pandas as pd
import h5py
import numpy as np

def load_meta_trace_files(database_file, sKeyNo = 0, load_metadata=False):
    trace = []
    skpv_a_vec0_evenCoeff0 = []
    bp_b_vec0_evenCoeff = []
    bp_b_vec0_evenCoeff_next_sKeyNo = []
    bp_b_vec0_oddCoeff = []
    bp_b_vec0_oddCoeff_next_sKeyNo = []
    fileNo = 0

    in_file  = h5py.File(database_file, "r")
    trace.append(np.array(in_file['wave'], dtype=float))
    skpv_a_vec0_evenCoeff0.append(np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo].astype(int)))
    bp_b_vec0_evenCoeff.append(np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo].astype(int)))
    bp_b_vec0_evenCoeff_next_sKeyNo.append(np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo+1].astype(int)))
    bp_b_vec0_oddCoeff.append(np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo].astype(int)))
    bp_b_vec0_oddCoeff_next_sKeyNo.append(np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo+1].astype(int)))
        
    trace_profiling = np.concatenate(trace)
    skpv_profiling = np.concatenate(skpv_a_vec0_evenCoeff0)

    return trace_profiling, skpv_profiling

N_SAMPLE = 15
infile = np.load('attack_300key_100_cross_Device3_OldDev.npz')
max_inds = infile['label']
print(max_inds[:10])
uniq, counts = np.unique(max_inds, return_counts=True)
print(len(counts))
#np.savez('attack_300key_100_cross_Device3_OldDev_bp.npz', data=attack_data, label=max_inds, bp=attack_bp)

attack_data = []
#Load attack data
trace_profiling, data_labels = load_meta_trace_files('KYBER50.h5')
print(trace_profiling.shape)
print(len(data_labels))
data = trace_profiling
for max_id in max_inds[:50]:
    trace_ids = np.where(data_labels == max_id)[0][:N_SAMPLE]
    #print(len(trace_ids))
    #print(data_labels[trace_ids])
    print(len(trace_ids))
    attack_data.append(data[trace_ids])

attack_data = np.array(attack_data)
print(attack_data.shape)

np.savez('attack_300key_100_cross_Device4.npz', data=attack_data, label=max_inds[:50])
