import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
import os.path
import sys
import h5py
import math
import gc
import time
import numpy as np
#from numba import cuda

import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Activation, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model   #, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model

import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid


basepath = 'trained_models/'

data = np.load('data.npz')
labels = data['label']

xType = 'wave'
maxtrc = 100
bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses

def get_meanrank(xTest_multi, yTest_vals, model, maxtrc):
    start_time = time.time()
    multi_rank = []
    for i in range(len(yTest_vals)): #Iterate each key
        nruns = 1 #data for multi-label is limited, so we did 1 run only
        batches = np.zeros((nruns, maxtrc), 'int')
        batches[0] = np.arange(maxtrc)
        #for i in range(nruns_default):
        #    batches[i,:] = np.random.choice(len(self.xTest[0][0]), maxtrc_default, False)
        test_rank = eval_model(model, nruns, maxtrc, batches, [[xTest_multi[i]]], [yTest_vals[i]], noHypoKeys, noClasses)
        multi_rank.append(test_rank)
    multi_rank = np.array(multi_rank)
    print('Eval took {}'.format(time.time()-start_time))
    return np.mean(multi_rank)


def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    realkey = int(yTest_value[0])
    rankmat_byKey = np.tile(0, (nruns, maxtrc))
    rankmat_byClass = np.tile(0, (nruns, maxtrc))
    ps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    lpsums_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    print("-------------------------------------------")
    print(len(xTest[0][0]))
    #print('%s  is running' % (model.__name__))
    for krun in range(nruns):
        #print('%s  run %d of %d' % (model.__name__, krun+1, nruns))
        #if (krun % nruns) == 0:
        #    print('%s  run %d of %d' % (model.name, krun+1, nruns))
        samp = batches[krun,:]
        #ps = model.predict(U[samp,:])

        if xType == 'wave':
            ps = model.predict([xTest[0][0][samp,:,:]])
        elif xType == 'wavebp0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:]])
        elif xType == 'wavebp1':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01next0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:]])
        elif xType == 'wavebp01next01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:], xTest[1][3][samp,:]])

        lps = np.log(ps)
        lpsums = np.zeros(noHypoKeys)
        #lpsAllHypoKeys = np.zeros((maxtrc, noHypoKeys))
        for i in range(maxtrc):
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = AES_Sbox[P[samp[i]] ^ range(noHypoKeys)]
            realClass = realkey#S[realkey]
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = P[samp[i]] ^ range(noHypoKeys)
            #realClass = HWcompute(S[realkey])
            lpsAllHypoKeys = lps
            #for hypoKey in range(noHypoKeys):
            #   lpsAllHypoKeys[i, hypoKey] = lps[i, S[hypoKey]]
            #print('lpsums.shape =', lpsums.shape, ';   lps.shape =', lps.shape)
            lpsums += lps[i]#, S]
            lpsums_AllHypoKeys_Nruns[i,:,krun] = lpsums
            #print('realkey =', realkey)
            rnk_byKey = sum(lpsums > lpsums[realkey])
            rankmat_byKey[krun, i] = rnk_byKey
            rnk_byClass = sum(lps[i, :] > lps[i, realClass])
            rankmat_byClass[krun, i] = rnk_byClass
        ps_AllClasses_Nruns[:,:,krun] = ps
        lps_AllClasses_Nruns[:,:,krun] = lps
        lps_AllHypoKeys_Nruns[:,:,krun] = lpsAllHypoKeys
    return rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns

def eval_model(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses)
    plot_data = ['model_type', rankmat_byKey]
    mr = np.mean(rankmat_byKey, 0)

    return mr[-1]

xTest_multi, yTest_multi = load_multi_attack('attack_multi_data_300key.npz')
xTest_multi = np.expand_dims(xTest_multi, axis = 3)
#test_best = get_meanrank(xTest_multi, yTest_multi, best_model, maxtrc)
print(yTest_multi)
'''
end_model_path = os.path.join('multi_attack_trained_models/test_1_baseline_none_wave_0_200000_30000_balance_alpha_0.5_50000_index.npy/model_best.keras')
end_model = load_model(end_model_path)
test_end = get_meanrank(xTest_multi, yTest_multi, end_model, maxtrc)
print(test_end)
'''
baseline_mr_best = 76.37
baseline_240k_best = 23.1
baseline_160k_best = 36.78

test_key = 1733

folder_path = 'multi_attack_trained_models/test_baseline_active_uncertainty_wave_10_0_200000_2000_ndex.npy_None_16'
random_path = 'multi_attack_trained_models/test_baseline_active_random_wave_10_0_200000_2000_ndex.npy_None_16'
#random_path = 'multi_attack_trained_models/test_1_baseline_none_wave_0_200000_30000_balance_alpha_0.5_50000_index.npy'
#random_path = 'multi_attack_trained_models/test_retrain_active_uncertainty_wave_3_0_200000_2000_ndex.npy_None_768'
un_label_path = 'multi_attack_trained_models/test_update_active_uncertainty_balance_label_wave_1_0_200000_2000_ndex.npy_None_128'
un_balance_path_notrain = 'multi_attack_trained_models/test_baseline_fix3_active_uncertainty_balance_wave_10_0_200000_2000_ndex.npy_None_16'
un_balance_path = 'multi_attack_trained_models/test_update_active_uncertainty_balance_wave_10_0_200000_2000'
un_balance_update = 'multi_attack_trained_models/'
def get_info(folder_path, infer = False):
    #all_path = sorted(os.listdir(folder_path))
    all_path = sorted(next(os.walk(folder_path))[1])
    #print(all_path)
    #exit()
    all_labels = []
    all_num_samples = []
    all_multi = []
    all_rank = []
    all_std = []
    all_occ = []
    for fpath in all_path:
        all_counts = np.zeros(3329)
        fname = os.path.join(folder_path, fpath)
        result_multi = np.mean(pd.read_csv(os.path.join(fname, 'attack_rank_multi.csv'), index_col=0).to_numpy(),  axis = 0)
        result_rank = np.min(pd.read_csv(os.path.join(fname, 'attack_rank.csv'), index_col=0).to_numpy(), axis = 0)

        #all_multi.append(np.min(result_multi))
        #print(result_multi)
        #print(np.min(result_multi))
        all_rank.append(result_rank)
        label_ids = np.load(os.path.join(fname, 'all_ids.npy'))
        all_num_samples.append(len(label_ids))
        trained_labels = labels[label_ids]
        unique, counts = np.unique(trained_labels, return_counts=True)
        for value, number in zip(unique,counts):
            all_counts[value] = number
        #print(np.std(all_counts))
        all_occ.append(all_counts[1733])
        all_std.append(np.std(all_counts))
        all_labels.append(all_counts)
        if infer:
            end_model_path = os.path.join(fname, 'model_best_end.keras')
            #end_model_path = os.path.join(fname, 'model_best.keras')
            end_model = load_model(end_model_path)
            test_end = get_meanrank(xTest_multi, yTest_multi, end_model, maxtrc)
            all_multi.append(test_end)
        else:
            all_multi.append(np.min(result_multi))
        

    print(len(all_labels))

    all_labels = np.array(all_labels)
    all_rank = np.array(all_rank)
    all_multi = np.array(all_multi)
    print(all_labels.shape)
    #Plot

    return all_labels, all_multi, all_rank, all_std, all_occ


'''
for i in range(len(yTest_multi)):
#for i in range(2):
    plt.figure(figsize=(10,7))
    plt.grid()
    print('-------------')
    ##print(yTest_multi[i])
    #print(all_num_samples)
    plot_vals = all_labels[:,yTest_multi[i]] / np.mean(all_labels, axis = 1) #/ all_num_samples
    print(plot_vals)
    plot_multi= 1 - all_multi[:,i] / 3329 #/ np.mean(all_multi, axis = 1) #/ all_num_samples
    print(plot_multi)
    plt.plot(plot_vals, label='ratio of label'+str(i)+' count / mean counts')
    plt.plot(plot_multi, label = 'ranking score')
    plt.xlabel('Num. Iteration')
    plt.ylabel('Label counts')
    plt.legend()
    plt.savefig('Label_Analysis_{}.png'.format(i))
    plt.clf()

    #plt.plot(np.mean(all_labels, axis = 1), label='mean' + str(i))

#plt.plot(np.min(all_labels, axis = 1), label ='Min count')
#plt.plot(np.max(all_labels, axis = 1), label ='Max count')

for i in range(len(yTest_multi)):
    print('-------------')
    ##print(yTest_multi[i])
    #print(all_num_samples)
    plot_vals = all_multi[:,i] #/ all_num_samples
    print(plot_vals)
    plt.plot(plot_vals, label=str(i))
'''
all_labels, all_multi, all_rank, all_std, all_occ = get_info(folder_path)
rd_labels, rd_multi, rd_rank, rd_std, rd_occ = get_info(random_path)
un_labels, un_multi, un_rank, un_std, un_occ = get_info(un_label_path)
bal_labels, bal_multi, bal_rank, bal_std, bal_occ = get_info(un_balance_path, infer=False)
bal_labels_notrain, bal_multi_notrain, bal_rank_notrain, bal_std_notrain, bal_occ_notrain = get_info(un_balance_path_notrain, infer=False)
#np.savez('res_random.npz', rd_multi)
#np.savez('res_uncertain.npz', all_multi)
#np.savez('res_bal.npz', bal_multi)
#np.savez('res_bal_update.npz', bal_multi_notrain)
#np.savez('res_bal_best.npz', bal_multi)
#np.savez('res_bal_update_best.npz', bal_multi_notrain)
#exit()
all_multi = np.load('res_uncertain.npz')['arr_0']
rd_multi = np.load('res_random.npz')['arr_0']
bal_multi = np.load('res_bal.npz')['arr_0'][:10]
bal_multi_notrain = np.load('res_bal_update.npz')['arr_0']

bal_multi_best = np.load('res_bal_best.npz')['arr_0']
bal_multi_notrain_best = np.load('res_bal_update_best.npz')['arr_0']

baseline_mr = 65.15
baseline_240k = 6.8
baseline_160k = 13.6
all_multi = np.hstack((baseline_mr,all_multi))
rd_multi = np.hstack((baseline_mr,rd_multi))
bal_multi = np.hstack((baseline_mr,bal_multi))
un_multi = np.hstack((baseline_mr,un_multi))
bal_multi_notrain = np.hstack((baseline_mr,bal_multi_notrain))
print(bal_multi_notrain)
print('----------------')
print(bal_multi)
print(bal_multi.shape)
print(all_rank.shape)
'''
label_ids = np.load(os.path.join(random_path, 'all_ids.npy'))
all_counts = np.zeros(3329)
trained_labels = labels[label_ids]
unique, counts = np.unique(trained_labels, return_counts=True)
for value, number in zip(unique,counts):
    all_counts[value] = number

plt.figure(figsize=(10,7))
plt.grid()
plt.plot(all_labels[0], label = 'uncertainty')
plt.plot(all_counts, label = 'pre-train baseline')
plt.ylabel('Standard Deviation of Label Counts')
plt.xlabel('Num. Iteration')
plt.legend()
plt.savefig('Label_Plot.png')
plt.clf()
exit()
'''
plt.figure(figsize=(10,7))
plt.grid()
plt.plot(all_std, label = 'uncertainty')
plt.plot(rd_std, label = 'random')
plt.plot(bal_std, label = 'uncertainty balance', c='r')
plt.ylabel('Standard Deviation of Label Counts')
plt.xlabel('Num. Iteration')
plt.legend()
plt.savefig('Label_STD.png')
plt.clf()

plt.figure(figsize=(10,7))
plt.grid()
plt.ylabel('Attack Mean Rank')
plt.plot(all_rank, label = 'uncertainty')
plt.plot(rd_rank, label = 'random')
#plt.plot(un_rank, label = 'uncertainty label')
plt.plot(bal_rank, label = 'uncertainty balance', c='r')
#plt.axhline(y=17.4, color='r', linestyle='--', label = '6k samples, minmax')
#plt.axhline(y=155.3, linestyle='--', label = '8k samples, minmax')
plt.axhline(y=20.1, linestyle='--', label = '50k samples, random')
#plt.axhline(y=15.7, linestyle='--', label = '30k samples, initial model')
plt.axhline(y=baseline_240k, linestyle='--', label = '160k samples, baseline')
plt.xlabel('Num. Iteration x 250 samples')
plt.legend()
plt.savefig('Label_rank_4_strat.png')
plt.clf()


#'baseline_minmax_wave_0_200000_cA2__16000_mm'
#for num_sample in num_iteration:

# Plot min rank
plt.figure(figsize=(10,7))
plt.grid()
#plt.plot(all_multi, label = 'uncertainty')
#plt.plot(rd_multi, label = 'random')
#plt.plot(un_multi, label = 'uncertainty label')
plt.plot(bal_multi, label = 'uncertainty balance retrain', c='r')
plt.plot(bal_multi_notrain, label = 'uncertainty balance update schedule')
#plt.axhline(y=17.4, color='r', linestyle='--', label = '6k samples, minmax')
#plt.axhline(y=155.3, linestyle='--', label = '8k samples, minmax')
#plt.axhline(y=0.3, linestyle='--', label = '8k samples, random')
#plt.axhline(y=7, linestyle='--', label = '200k samples')
#plt.axhline(y=20.1, linestyle='--', label = '50k samples, baseline')
#plt.axhline(y=15.7, linestyle='--', label = '30k samples, initial model')
plt.axhline(y=baseline_160k, linestyle='--', label = '160k samples, baseline', c='r')
plt.axhline(y=baseline_240k, linestyle='--', label = '240k samples, baseline', c='g')
plt.xlabel('Num. Iteration')
plt.legend()
plt.savefig('Label_rank_uncertain_multi_update.png')
plt.clf()

# Plot min rank
plt.figure(figsize=(10,7))
plt.grid()
#plt.plot(all_multi, label = 'uncertainty')
#plt.plot(rd_multi, label = 'random')
#plt.plot(un_multi, label = 'uncertainty label')
plt.plot(bal_multi_best, label = 'uncertainty balance retrain', c='r')
plt.plot(bal_multi_notrain_best, label = 'uncertainty balance update schedule')
#plt.axhline(y=17.4, color='r', linestyle='--', label = '6k samples, minmax')
#plt.axhline(y=155.3, linestyle='--', label = '8k samples, minmax')
#plt.axhline(y=0.3, linestyle='--', label = '8k samples, random')
#plt.axhline(y=7, linestyle='--', label = '200k samples')
#plt.axhline(y=20.1, linestyle='--', label = '50k samples, baseline')
#plt.axhline(y=15.7, linestyle='--', label = '30k samples, initial model')
plt.axhline(y=baseline_160k_best, linestyle='--', label = '160k samples, baseline', c='r')
plt.axhline(y=baseline_240k_best, linestyle='--', label = '240k samples, baseline', c='g')
plt.xlabel('Num. Iteration')
plt.legend()
plt.savefig('Label_rank_uncertain_multi_update_best.png')
plt.clf()

retrain_index = [2,4,6,8]
# Plot min rank
plt.figure(figsize=(10,7))
plt.grid()
plt.ylim(0,200)
plt.plot(all_multi, label = 'uncertainty')
plt.plot(rd_multi, label = 'random')
#plt.plot(un_multi, label = 'uncertainty label w/ balance', c='b')
plt.plot(bal_multi, label = 'uncertainty balance', c='r')
plt.plot(bal_multi[retrain_index], label = 'uncertainty balance, retrain epoch', c='r')
#plt.plot(bal_multi_notrain, label = 'uncertainty balance update schedule')
#plt.axhline(y=17.4, color='r', linestyle='--', label = '6k samples, minmax')
#plt.axhline(y=155.3, linestyle='--', label = '8k samples, minmax')
#plt.axhline(y=0.3, linestyle='--', label = '8k samples, random')
#plt.axhline(y=7, linestyle='--', label = '200k samples')
#plt.axhline(y=319.6, linestyle='--', label = '15k samples, initial model')
#plt.axhline(y=20.1, linestyle='--', label = '50k samples, baseline', c ='r')
#plt.axhline(y=15.7, linestyle='--', label = '30k samples, initial model')
plt.axhline(y=baseline_160k, linestyle='--', label = '160k samples, baseline', c='r')
plt.axhline(y=baseline_240k, linestyle='--', label = '240k samples, baseline', c='g')
plt.xlabel('Num. Iteration')
plt.legend()
plt.savefig('Label_rank_uncertain_multi_zoom.png')
plt.clf()


plt.figure(figsize=(10,7))
plt.grid()
plt.plot(rd_occ, label = 'Random')
#plt.plot(mean_mm, '--', label = 'cA2average num.samples per label')
plt.plot(all_occ, label = 'Uncertainty')
#plt.plot(mean_rnd, '--', label = 'Raw signal average num.samples per label')
#plt.xticks(x, num_samples)
plt.xlabel('Num. Samples')
plt.ylabel('Number of sample within label')
plt.legend()
plt.savefig('Label_Distribution_uncertainty_minmax.png')
