import h5py
import numpy as np

def load_meta_trace_files(database_file, sKeyNo = 0, load_metadata=False):
    trace = []
    skpv_a_vec0_evenCoeff0 = []
    bp_b_vec0_evenCoeff = []
    bp_b_vec0_evenCoeff_next_sKeyNo = []
    bp_b_vec0_oddCoeff = []
    bp_b_vec0_oddCoeff_next_sKeyNo = []
    fileNo = 0

    in_file  = h5py.File(database_file, "r")
    trace.append(np.array(in_file['wave'], dtype=float))
    skpv_a_vec0_evenCoeff0.append(np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo].astype(int)))
    bp_b_vec0_evenCoeff.append(np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo].astype(int)))
    bp_b_vec0_evenCoeff_next_sKeyNo.append(np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo+1].astype(int)))
    bp_b_vec0_oddCoeff.append(np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo].astype(int)))
    bp_b_vec0_oddCoeff_next_sKeyNo.append(np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo+1].astype(int)))
        
    trace_profiling = np.concatenate(trace)
    skpv_profiling = np.concatenate(skpv_a_vec0_evenCoeff0)

    return trace_profiling, skpv_profiling

#Merge data

traces_21k, cross_labels_21k = load_meta_trace_files('raw_data_21000.h5')
traces_30k, cross_labels_30k = load_meta_trace_files('raw_data_30000.h5')
print(traces_21k.shape)
print(traces_30k.shape)
#traces = np.concatenate((traces_21k, traces_30k[1:, :]))
#cross_labels = np.concatenate((cross_labels_21k, cross_labels_30k[1:]))
traces, cross_labels = load_meta_trace_files('KYBER50.h5')
#unique, counts = np.unique(all_labels, return_counts=True)

KEY_NUM = 0
SAMP_NUM = 10

original_data = []
cross_data = []
#traces, cross_labels = load_meta_trace_files('data19000.h5')

print(cross_labels[:10])
data = np.load('attack_multi_data_300key.npz')
labels = data['label']
data = data['data']
data_ids = np.load('attack_300key_ids.npz')['ids']
sorted_ids = np.sort(np.hstack(data_ids))
print(data_ids.shape)
print(sorted_ids.shape)
print(traces.shape)
print(cross_labels.shape)

original_data = data[KEY_NUM][SAMP_NUM]
print(original_data.shape)
all_data = []
all_data_ids = []
#Init list 2d:

for i in range(len(labels)):
    all_data.append([])
    all_data_ids.append([])

print(labels.shape)
print(cross_labels.shape)

#print(labels)
print('----------------')
for i in range(len(traces)):
    #print(cross_labels[i])
    curr_index = np.where(labels == cross_labels[i]) #[0][0]
    all_data[curr_index].append(traces[i])
    all_data_ids[curr_index].append(sorted_ids[i])
    #print(curr_index)


all_data_ids = np.array(all_data_ids)
print(all_data_ids.shape)
print(data_ids.shape)
print((all_data_ids==data_ids).all())
print(all_data_ids[0][:10])
print(data_ids[0][:10])
#exit()
#exit()
sum_len = 0
for i in range(len(labels)):
    print(len(all_data[i]))
    sum_len += len(all_data[i])

def compare_sum(a, b):
    return np.abs(np.sum(a) - np.sum(b))

from scipy.spatial import distance

all_comp = []
for i in range(len(all_data)):
    for j in range(len(all_data[i])):
        comp = distance.euclidean(all_data[i][j], data[i][j])
        all_comp.append(comp)

print('Mean euclidean distance, cross device and original, same key')
print(np.mean(all_comp))

'''
all_comp = []

for i in range(len(all_data)):
    #for each key
    for j in range(len(all_data[i])):
        #For each sample in key, calculate distance to other trace in other key
        for k in range(len(all_data)):
            if k != j:
                comp = distance.euclidean(all_data[i][j], data[k][j]) #for a[i,j]#a[k,l]
                all_comp.append(comp)

print('Mean euclidean distance, cross device and original, different key')
print(np.mean(all_comp))
'''

#--------------------------------
'''
data = np.load('attack_multi_data_300key_noise_0.4.npz')
data = data['data']
all_comp = []
for i in range(len(all_data)):
    for j in range(len(all_data[i])):
        comp = compare_sum(all_data[i][j], data[i][j])
        all_comp.append(comp)

print('Mean variation, synthetic noise')
print(np.mean(all_comp))
'''
#Threshold for number of data
NUM_TRACES = 100

save_data = []
for curr_data in all_data:
    print(len(curr_data))
    save_data.append(curr_data[:NUM_TRACES])

save_data = np.array(save_data)
print(save_data.shape)
print(labels.shape)
np.savez('attack_300key_100_cross_Device4.npz', data=save_data, label=labels)

cross_data = all_data[KEY_NUM][SAMP_NUM]
print(cross_data.shape)

import pandas as pd

cross_df = pd.DataFrame(data=cross_data)
original_df = pd.DataFrame(data=original_data)

cross_df.to_csv('cross_data_key_{}.csv'.format(KEY_NUM))
original_df.to_csv('original_data_key_{}.csv'.format(KEY_NUM))
import matplotlib.pyplot as plt

plt.plot(original_data, label ='original')
plt.plot(cross_data, label ='cross device')
plt.legend()
plt.savefig('Test_Compare.png')