import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
import os.path
import sys
import h5py
import math
import gc
import time
import numpy as np
#from numba import cuda

import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Activation, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model   #, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model

import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
from natsort import natsorted
import pywt
from natsort import natsorted

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--train_folder', type=str, help='eval folder')
    parser.add_argument('--eval_path', type=str, help='eval path')
    parser.add_argument('--xType', type=str, help='type, wavebp01 mean using ciphertext')
    parser.add_argument('--num_traces', type=int, default=100)
    parser.add_argument('--num_key', type=int, default=300)
    parser.add_argument('--normalize', type=int)
    parser.add_argument('--dwt', type=int)
    parser.add_argument('--cummulative', type=int)
    parser.add_argument('--all_ids', type=str)
    parser.add_argument('--all_path', type=int, default=0)
    parser.add_argument('--zmuv', type=int, default=0)
    return parser

parser = parse_arguments()
args = parser.parse_args()
basepath = 'trained_models/'

data = np.load('data.npz')
labels = data['label']

xType = args.xType
#maxtrc = args.num_traces
maxtrc = 10
bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses
sKeyNo = 0

def get_meanrank(xTest_multi, yTest_vals, model, maxtrc):
    start_time = time.time()
    multi_rank = []
    multi_rank_GE = []
    for i in range(len(yTest_vals)): #Iterate each key
        curr_rank = 0
        curr_rank_GE = 0
        nruns = 5 #data for multi-label is limited, so we did 1 run only
        
        #batches[0] = np.arange(maxtrc)
        #print(batches.shape)
        for i in range(nruns):
            print('{}/nrus'.format(i))
            batches = np.zeros((nruns, maxtrc), 'int')
            for j in range(nruns):
                batches[j,:] = np.random.choice(len(xTest_multi[0]), maxtrc, False)
            test_rank, test_rank_GE = eval_model(model, nruns, maxtrc, batches, [[xTest_multi[i]]], [yTest_vals[i]], noHypoKeys, noClasses)
            print(batches.shape)
            print(test_rank, test_rank_GE)
            curr_rank += test_rank
            curr_rank_GE += test_rank_GE
        #exit()
        '''
        exit()
        if args.xType == 'wavebp01':
            test_rank, test_rank_GE = eval_model(model, nruns, maxtrc, batches, [xTest_multi[0][i], xTest_multi[1][i]], [yTest_vals[i]], noHypoKeys, noClasses)
        else:
            test_rank, test_rank_GE = eval_model(model, nruns, maxtrc, batches, [[xTest_multi[i]]], [yTest_vals[i]], noHypoKeys, noClasses)
        '''
        multi_rank.append(curr_rank/nruns)
        multi_rank_GE.append(curr_rank_GE/nruns)
    multi_rank = np.array(multi_rank)
    multi_rank_GE = np.array(multi_rank_GE)
    print('Eval took {}'.format(time.time()-start_time))
    return multi_rank, multi_rank_GE


def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    if args.xType == 'wavebp01':
        data = [data, infile['bp']]
    labels = infile['label']

    return data, labels

def mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):

    realkey = int(yTest_value[0])
    rankmat_byKey = np.tile(0, (nruns, maxtrc))
    rankmat_byClass = np.tile(0, (nruns, maxtrc))
    ps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    lpsums_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    print("-------------------------------------------")
    print(len(xTest[0][0]))
    #print('%s  is running' % (model.__name__))
    for krun in range(nruns):
        #print('%s  run %d of %d' % (model.__name__, krun+1, nruns))
        #if (krun % nruns) == 0:
        #    print('%s  run %d of %d' % (model.name, krun+1, nruns))
        samp = batches[krun,:]
        #ps = model.predict(U[samp,:])

        if xType == 'wave':
            ps = model.predict([xTest[0][0][samp,:,:]])
        elif xType == 'wavebp0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:]])
        elif xType == 'wavebp1':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01':
            ps = model.predict([[xTest[0][samp,:]], [xTest[1][0][samp,:]], [xTest[1][1][samp,:]]])
        elif xType == 'wavebp01next0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:]])
        elif xType == 'wavebp01next01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:], xTest[1][3][samp,:]])

        lps = np.log(ps)
        lpsums = np.zeros(noHypoKeys)
        #lpsAllHypoKeys = np.zeros((maxtrc, noHypoKeys))
        for i in range(maxtrc):
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = AES_Sbox[P[samp[i]] ^ range(noHypoKeys)]
            realClass = realkey#S[realkey]
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = P[samp[i]] ^ range(noHypoKeys)
            #realClass = HWcompute(S[realkey])
            lpsAllHypoKeys = lps
            #for hypoKey in range(noHypoKeys):
            #   lpsAllHypoKeys[i, hypoKey] = lps[i, S[hypoKey]]
            #print('lpsums.shape =', lpsums.shape, ';   lps.shape =', lps.shape)
            lpsums += lps[i]#, S]
            lpsums_AllHypoKeys_Nruns[i,:,krun] = lpsums
            #print('realkey =', realkey)
            rnk_byKey = sum(lpsums > lpsums[realkey])
            rankmat_byKey[krun, i] = rnk_byKey
            rnk_byClass = sum(lps[i, :] > lps[i, realClass])
            rankmat_byClass[krun, i] = rnk_byClass
        ps_AllClasses_Nruns[:,:,krun] = ps
        lps_AllClasses_Nruns[:,:,krun] = lps
        lps_AllHypoKeys_Nruns[:,:,krun] = lpsAllHypoKeys
    return rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns

def mk_rankmat_GE(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):

    realkey = int(yTest_value[0])
    rankmat_byKey = np.tile(0, (nruns, maxtrc))
    rankmat_byClass = np.tile(0, (nruns, maxtrc))
    ps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    lpsums_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    print("-------------------------------------------")
    print(len(xTest[0][0]))
    #print('%s  is running' % (model.__name__))
    for krun in range(nruns):
        #print('%s  run %d of %d' % (model.__name__, krun+1, nruns))
        #if (krun % nruns) == 0:
        #    print('%s  run %d of %d' % (model.name, krun+1, nruns))
        samp = batches[krun,:]
        #ps = model.predict(U[samp,:])

        if xType == 'wave':
            ps = model.predict([xTest[0][0][samp,:,:]])
        elif xType == 'wavebp0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:]])
        elif xType == 'wavebp1':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01':
            ps = model.predict([[xTest[0][samp,:]], [xTest[1][0][samp,:]], [xTest[1][1][samp,:]]])
        elif xType == 'wavebp01next0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:]])
        elif xType == 'wavebp01next01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:], xTest[1][3][samp,:]])

        lps = np.log(ps)
        lpsums = np.zeros(noHypoKeys)
        #lpsAllHypoKeys = np.zeros((maxtrc, noHypoKeys))
        #guessing_entropy = np.zeros(maxtrc)
        for i in range(maxtrc):
            realClass = realkey
            log_likelihood = np.sum(lps[:i, :], axis=0)
            ranked = np.argsort(log_likelihood)[::-1]
            #guessing_entropy[i] = list(ranked).index(real_key)
            rankmat_byKey[krun, i] = list(ranked).index(realkey)
        '''
        for i in range(0, trace_num_max):
            log_likelihood = np.sum(score_mat[0:i+1, :], axis=0)
            ranked = np.argsort(log_likelihood)[::-1]
            guessing_entropy[time, i] = list(ranked).index(real_key)
            if list(ranked).index(real_key) == 0:
                success_flag[time, i] = 1
        '''

    return rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns

def ranking_curve(preds, key, plaintext, target_byte, rank_root, leakage_model='HW', trace_num_max=50):
    """
    - preds : the probability for each class (n*256 for a byte, n*9 for Hamming weight)
    - real_key : the key of the target device
    - device_id : id of the target device
    - model_flag : a string for naming GE result
    # max trace num for attack
    """
    # GE/SR is averaged over 100 attacks 
    num_averaged = 5
    guessing_entropy = np.zeros((num_averaged, trace_num_max))
    success_flag = np.zeros((num_averaged, trace_num_max))

    real_key = key[target_byte]
    plaintext = plaintext[:, target_byte]
    print('-------------')
    print(real_key)
    # attack multiples times for average
    for time in range(num_averaged):
        # select the attack traces randomly
        random_index = list(range(plaintext.shape[0]))

        #         ## customized by HL
        #         print(f"random_index shape {len(random_index)}, max value {max(random_index)}, min value {min(random_index)}")

        random.shuffle(random_index)
        random_index = random_index[0:trace_num_max]

        #         ## customized by HL
        #         print(f"random_index shape after slicing {len(random_index)}, max value {max(random_index)}, min value {min(random_index)}")

        # initialize score matrix
        score_mat = np.zeros((trace_num_max, 256))
        for key_guess in range(0, 256):
            for i in range(0, trace_num_max):
                initialState = int(plaintext[random_index[i]]) ^ key_guess
                sout = Sbox[initialState]
                if leakage_model == 'ID':
                    label = sout
                elif leakage_model == 'HW':
                    label = HW_byte[sout]
                try:
                    score_mat[i, key_guess] = preds[random_index[i], label]
                except Exception as e:
                    #pdb.set_trace()
                    print(e.message)
                    exit()
        score_mat = np.log(score_mat + 1e-40)

        #         ## customized by HL
        #         print(f"score_mat {score_mat}")

        for i in range(0, trace_num_max):
            log_likelihood = np.sum(score_mat[0:i+1, :], axis=0)
            ranked = np.argsort(log_likelihood)[::-1]
            guessing_entropy[time, i] = list(ranked).index(real_key)
            if list(ranked).index(real_key) == 0:
                success_flag[time, i] = 1

    guessing_entropy = np.mean(guessing_entropy, axis=0)

    return guessing_entropy

from sklearn.preprocessing import StandardScaler
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def z_norm(timeseries):
    '''
    timeseries = np.expand_dims(timeseries, 1)
    scaler = StandardScaler()
    scaler.fit(timeseries)
    res = scaler.transform(timeseries)[:,0]
    '''

    return stats.zscore(timeseries)

def normalize_data_per_trace(data):
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = normalize(data[j][i])

    return data

def z_normalize_data_per_trace(data):
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = z_norm(data[j][i])

    return data

def dwt_reconstruct(data):
    (cA1, cD1) = pywt.dwt(data, 'db1')
    cD1 = np.zeros(cA1.shape)
    trace_profiling = pywt.idwt(cA1, cD1, 'db1')
    return trace_profiling

def data_dwt(data):
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = dwt_reconstruct(data[j][i])

    return data

def cummulative_transform(data):
    print(data.shape)
    for i in range(len(data)):
        for j in range(len(data[i])):
                data[i,j] = np.cumsum(data[i,j], dtype=float)
    return data

def eval_model(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    rankmat_byKey_GE, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat_GE(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses)
    rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses)
    plot_data = ['model_type', rankmat_byKey]
    mr = np.mean(rankmat_byKey, 0)
    mr_GE = np.mean(rankmat_byKey_GE, 0)
    print('------------')
    print(mr)
    print(mr_GE)
    print('--------------')
    return mr[-1], mr_GE[-1]

def ZMUV(power_trace,power_trace_target,n):    
    print('[LOG---- APPLYING ZMUVN.............]')
    indx=min(len(power_trace), len(power_trace_target))
    trainMean = np.mean(power_trace, axis = 0)
    testMean = np.mean(power_trace_target, axis = 0)
    print(np.shape(trainMean))
    print(np.shape(testMean))
    trainStd = np.std(power_trace, axis = 0)
    testStd = np.std(power_trace_target, axis = 0)
    print(np.shape(trainStd))
    print(np.shape(testStd))
    ep = 1e-15
    trainStd += ep
    testStd += ep
    coeff=(trainStd/testStd)
    modified_trace=[]
    for i in range(n):
        t=[]
        for j in range(len(power_trace_target[i])):
            t.append((((power_trace_target[i][j]-testMean[j])*coeff[j])+trainMean[j]))
        modified_trace.append(t)
    print('[LOG----ZMUVN COMPLETE.............]')

    return modified_trace

def load_meta_trace_files(data_path, start_trace, end_trace):
    data = np.load(data_path)
    trace_profiling = data['data']
    bp_profiling = data['bp']
    skpv_profiling = data['label']

    return (trace_profiling[start_trace:end_trace], bp_profiling[:,start_trace:end_trace], skpv_profiling[start_trace:end_trace])#, fqmul_profiling)

def create_training_data_optimize(args, data_path, sKeyNo, trainPortion, xType, yType, is_test, start_trace, end_trace):
    val_num = 300000
    val_ids = np.load('val_ids.npz')['ids']
    print(len(val_ids))
    print(val_ids[:10])
    end_val_trace = end_trace + val_num
    if is_test:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file_from_test(data_path, sKeyNo)
    else:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(data_path, start_trace, end_val_trace)

    #train_dataset = tf.data.Dataset.from_tensor_slices((trace_profiling[:150000], tf.one_hot(skpv_profiling[:150000], 3329) ))
    #exit()

    #Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    print(trace_profiling.shape)


    dataSize = end_trace - start_trace
    trainSize = math.floor(dataSize * trainPortion)
    print(skpv_profiling[:10])
    
    #y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)
    y_train_skpv = skpv_profiling
    if args.normalize == 1:
        print(trace_profiling.shape)
        trace_profiling = normalize_data_per_trace([trace_profiling])[0]
    xTrain = trace_profiling[:dataSize]
    print(xTrain[0][:10])

    if is_test:
        xVal = trace_profiling[dataSize:end_val_trace]
        yVal = y_train_skpv[dataSize:end_val_trace]
        yVal_value = skpv_profiling[dataSize:end_val_trace]
    else:
        xVal = trace_profiling[val_ids]
        yVal = y_train_skpv[val_ids]
        yVal_value = skpv_profiling[val_ids]

    yTrain = y_train_skpv[:dataSize]
    yTrain_value = skpv_profiling[:dataSize]
    

    #del trace_profiling, bp_profiling, skpv_profiling
    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value

yType = 'skpv'
xTest_multi, yTest_multi = load_multi_attack(args.eval_path)
xTrain_original, yTrain_original, xVal, yVal, yTrain_value, yVal_value = create_training_data_optimize(args,'data.npz',sKeyNo, 1.0, xType, yType,False, 0, 200000)
all_ids = np.load(args.all_ids)
train_data = xTrain_original[all_ids]

xTest_multi, yTest_multi = load_multi_attack(args.eval_path)
print(xTest_multi[1].shape)
if args.xType == 'wavebp01':
    xTest_multi, yTest_multi = [xTest_multi[0][:args.num_key,:args.num_traces,:], xTest_multi[1][:args.num_key,:2,:args.num_traces,]], yTest_multi[:args.num_key]
else:
    xTest_multi, yTest_multi = xTest_multi[:args.num_key,:args.num_traces,:], yTest_multi[:args.num_key]

print(xTest_multi[1].shape)


#xTest_multi[1] = to_categorical(xTest_multi[1], num_classes=3330)

if args.dwt == 1:
    xTest_multi = data_dwt(xTest_multi)


if args.normalize == 1:
    if args.cummulative == 1:
        xTest_multi = normalize_data_per_trace(cummulative_transform(xTest_multi))
    elif args.cummulative == 2:
        xTest_multi = cummulative_transform(normalize_data_per_trace(xTest_multi))
    else:
        xTest_multi = normalize_data_per_trace(xTest_multi)
elif args.normalize == 2:
    if args.cummulative == 1:
        xTest_multi = z_normalize_data_per_trace(cummulative_transform(xTest_multi))
    elif args.cummulative == 2:
        xTest_multi = cummulative_transform(z_normalize_data_per_trace(xTest_multi))
    else:
        xTest_multi = z_normalize_data_per_trace(xTest_multi)
    #print(xTrain.shape)

if args.zmuv == 1:
    print(train_data.shape)
    print(xTest_multi.shape)
    xTest_res = []
    for i in range(xTest_multi.shape[0]):
        xTest_perkey = ZMUV(train_data, xTest_multi[i], args.num_traces)
        xTest_perkey = np.array(xTest_perkey)
        xTest_res.append(xTest_perkey)
        
    xTest_multi = np.array(xTest_res)
    print(xTest_multi.shape)


if args.xType != 'wavebp01':
    xTest_multi = np.expand_dims(xTest_multi, axis = 3)
#test_best = get_meanrank(xTest_multi, yTest_multi, best_model, maxtrc)
print(yTest_multi)
'''
end_model_path = os.path.join('multi_attack_trained_models/test_1_baseline_none_wave_0_200000_30000_balance_alpha_0.5_50000_index.npy/model_best.keras')
end_model = load_model(end_model_path)
test_end = get_meanrank(xTest_multi, yTest_multi, end_model, maxtrc)
print(test_end)
'''
baseline_mr_best = 76.37
baseline_240k_best = 23.1
baseline_160k_best = 36.78

test_key = 1733

folder_path = args.train_folder

def get_info(args, folder_path):
    all_counts = np.zeros(3329)
    fname = folder_path
    print(folder_path)
    all_path = natsorted(os.listdir(folder_path))
    all_path = [i for i in all_path if 'model_best' in i]
    print(all_path)
    if args.all_path == 0:
        #end_model_path = os.path.join(fname, 'model_best800.keras')
        end_model_path = os.path.join(fname, 'model_best_end.keras')
        end_model = load_model(end_model_path, custom_objects={"tf": tf})
        all_test, all_test_GE = get_meanrank(xTest_multi, yTest_multi, end_model, maxtrc)
        df = pd.DataFrame({
            'Mean_Rank': all_test,
            'GE': all_test_GE
        })
    else:
        rank_dict = {}
        for fpath in all_path:
            end_model_path = os.path.join(fname, fpath)
            end_model = load_model(end_model_path)
            rank_test = get_meanrank(xTest_multi, yTest_multi, end_model, maxtrc)
            rank_dict[fpath] = rank_test
        print(rank_dict)
        df = pd.DataFrame(rank_dict)
    return df

df = get_info(args, folder_path)
print(df)
df.to_csv(os.path.join(folder_path, 'mean_rank_full_{}_{}_zmuv_{}.csv'.format(args.eval_path[:-4], args.num_traces, args.zmuv)))
#np.savez('res_random.npz', rd_multi)
#np.savez('res_uncertain.npz', all_multi)
#np.savez('res_bal.npz', bal_multi)
#np.savez('res_bal_update.npz', bal_multi_notrain)
#np.savez('res_bal_best.npz', bal_multi)
#np.savez('res_bal_update_best.npz', bal_multi_notrain)
