import numpy as np
import pandas as pd

#Test loaded file
'''
infile = np.load('attack_multi_data_300key.npz')
data = infile['data']
label = infile['label']
print(len(data[0]))
print(label[:10])
exit()
'''
infile = np.load('data.npz')
data = infile['data'][200000:]
bp = infile['bp'][:,200000:]
data_labels = infile['label'][200000:]
print(bp.shape)
unique, counts = np.unique(data_labels, return_counts=True)

#max_id = np.argmax(counts)
N_LABEL = 300 #10 #300
max_inds = np.argpartition(counts, -N_LABEL)[-N_LABEL:]
print(max_inds)
print(np.min(counts[max_inds]))
print(np.max(counts[max_inds]))
#Top 10 counts min is 116, max is 122
#Get 115 traces for attack evaluation
#exit()
N_SAMPLE = 100 #40

print(len(unique))
print(unique[:10])
attack_data = []
attack_bp = []
all_ids = []
for max_id in max_inds:
    trace_ids = np.where(data_labels == max_id)[0][:N_SAMPLE]
    #print(len(trace_ids))
    #print(data_labels[trace_ids])
    attack_data.append(data[trace_ids])
    attack_bp.append(bp[:,trace_ids])
    all_ids.append(trace_ids)
print(all_ids[0])
all_ids = np.array(all_ids)
np.savez('attack_300key_ids.npz', ids=all_ids)
print(all_ids.shape)
print(max_inds[:10])
print(all_ids[0][:10])

cap_ids = np.hstack(all_ids)
print(cap_ids.shape)
sort_ids = np.sort(cap_ids)
print(data_labels[sort_ids[:10]])

'''
attack_data = np.array(attack_data)
attack_bp = np.array(attack_bp)
print(attack_data.shape)
np.savez('attack_multi_data_300key_bp.npz', data=attack_data, label=max_inds, bp=attack_bp)
'''
exit()
infile = np.load('attack_300key_100_cross_Device3_OldDev.npz')
attack_data = infile['data']
max_inds = infile['label']
attack_bp = np.array(attack_bp)
np.savez('attack_300key_100_cross_Device3_OldDev_bp.npz', data=attack_data, label=max_inds, bp=attack_bp)