import os.path
import sys
import h5py
import math
import gc
import time
import numpy as np
#from numba import cuda
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Activation, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model   #, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model
import tensorflow as tf
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
import pandas as pd

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--trained_model_path', type=str)
    parser.add_argument('--num_iteration', type=int, help='iteration_num', default=5)

    return parser

def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def load_meta_trace_file(database_file, sKeyNo, load_metadata=False):
    print('\nLoad database_file =', database_file)
    # Open the Kyber database HDF5 for reading
    try:
        in_file  = h5py.File(database_file, "r")
    except:
        print("Error: can't open HDF5 file '%s' for reading (it might be malformed) ..." % database_file)
        sys.exit(-1)
    # Load profiling traces
    #trace_profiling = np.array(in_file['wave'], dtype=np.float)
    info = list(in_file.keys())

    trace_profiling = np.array(in_file['wave'], dtype=float)
    #skpv_a_vec0_evenCoeff0 = np.array(in_file['sca_tmp_skpv'][:,sKeyNo])
    skpv_a_vec0_evenCoeff0 = np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo].astype(int))
    skpv_profiling = skpv_a_vec0_evenCoeff0
    

    return trace_profiling, skpv_profiling

    

def mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    realkey = int(yTest_value[0])
    rankmat_byKey = np.tile(0, (nruns, maxtrc))
    rankmat_byClass = np.tile(0, (nruns, maxtrc))
    ps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    lpsums_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    print("-------------------------------------------")
    print(len(xTest[0][0]))
    #print('%s  is running' % (model.__name__))
    for krun in range(nruns):
        #print('%s  run %d of %d' % (model.__name__, krun+1, nruns))
        #if (krun % nruns) == 0:
        #    print('%s  run %d of %d' % (model.name, krun+1, nruns))
        samp = batches[krun,:]
        #ps = model.predict(U[samp,:])

        if xType == 'wave':
            ps = model.predict([xTest[0][0][samp,:,:]])
        elif xType == 'wavebp0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:]])
        elif xType == 'wavebp1':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01next0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:]])
        elif xType == 'wavebp01next01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:], xTest[1][3][samp,:]])

        lps = np.log(ps)
        lpsums = np.zeros(noHypoKeys)
        #lpsAllHypoKeys = np.zeros((maxtrc, noHypoKeys))
        for i in range(maxtrc):
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = AES_Sbox[P[samp[i]] ^ range(noHypoKeys)]
            realClass = realkey#S[realkey]
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = P[samp[i]] ^ range(noHypoKeys)
            #realClass = HWcompute(S[realkey])
            lpsAllHypoKeys = lps
            #for hypoKey in range(noHypoKeys):
            #   lpsAllHypoKeys[i, hypoKey] = lps[i, S[hypoKey]]
            #print('lpsums.shape =', lpsums.shape, ';   lps.shape =', lps.shape)
            lpsums += lps[i]#, S]
            lpsums_AllHypoKeys_Nruns[i,:,krun] = lpsums
            #print('realkey =', realkey)
            rnk_byKey = sum(lpsums > lpsums[realkey])
            rankmat_byKey[krun, i] = rnk_byKey
            rnk_byClass = sum(lps[i, :] > lps[i, realClass])
            rankmat_byClass[krun, i] = rnk_byClass
        ps_AllClasses_Nruns[:,:,krun] = ps
        lps_AllClasses_Nruns[:,:,krun] = lps
        lps_AllHypoKeys_Nruns[:,:,krun] = lpsAllHypoKeys
    return rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns

def eval_model(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses)
    plot_data = ['model_type', rankmat_byKey]
    mr = np.mean(rankmat_byKey, 0)

    return rankmat_byKey

parser = parse_arguments()
args = parser.parse_args()

#best_model_path = os.path.join(args.trained_model_path, 'model_best.keras')
end_model_path = os.path.join(args.trained_model_path, 'model_best_end.keras')
#xTest_multi, yTest_multi = load_multi_attack('attack_multi_data_300key.npz')
xTest, yTest = load_meta_trace_file('KYBER51.H5', 0, load_metadata=False)
print(xTest.shape)
print(yTest.shape)
print(yTest[0])

xTest_multi, yTest_multi = load_multi_attack('attack_multi_data_300key.npz')
print('---------------------')
print(xTest_multi.shape)
print(yTest_multi.shape)

xType = 'wave'
maxtrc = 200
nruns = 20
bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses

def get_meanrank_multi(xTest_multi, yTest_vals, model, maxtrc):
    start_time = time.time()
    multi_rank = []
    for i in range(len(yTest_vals)): #Iterate each key
        nruns = 1 #data for multi-label is limited, so we did 1 run only
        batches = np.zeros((nruns, maxtrc), 'int')
        batches[0] = np.arange(maxtrc)
        #for i in range(nruns_default):
        #    batches[i,:] = np.random.choice(len(self.xTest[0][0]), maxtrc_default, False)
        test_rank = eval_model(model, nruns, maxtrc, batches, [[xTest_multi[i]]], [yTest_vals[i]], noHypoKeys, noClasses)
        multi_rank.append(test_rank)
    multi_rank = np.array(multi_rank)
    print(multi_rank.shape)
    print('Eval took {}'.format(time.time()-start_time))
    return np.mean(multi_rank)

#best_model = load_model(best_model_path)
end_model = load_model(end_model_path)
#xTest_multi = np.expand_dims(xTest_multi, axis = 3)
#test_best = get_meanrank(xTest_multi, yTest_multi, best_model, maxtrc)
#test_end = get_meanrank(xTest_multi, yTest_multi, end_model, maxtrc)
xTest = np.expand_dims(xTest, axis = 2)
batches = np.zeros((nruns, maxtrc), 'int')
batches[0] = np.arange(maxtrc)
test_rank = eval_model(end_model, nruns, maxtrc, batches, [[xTest]], yTest, noHypoKeys, noClasses)
print(test_rank.shape)
print(test_rank)
mean_rank = np.mean(test_rank,0)
print(mean_rank)
df = pd.DataFrame({
    'Mean_Rank': mean_rank,
    })
df.to_csv(os.path.join(args.trained_model_path, 'mean_rank_target.csv'))
#np.savez('res_random.npz', rd_multi)
#print(test_best)
#print(test_end)