import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
import os.path
import sys
import h5py
import math
import gc
import time
import numpy as np
#from numba import cuda

import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Activation, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model   #, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model

import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
from natsort import natsorted

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--train_folder', type=str, help='eval folder')
    parser.add_argument('--eval_path', type=str, help='eval folder')
    parser.add_argument('--num_traces', type=int, default=100)
    parser.add_argument('--num_key', type=int, default=300)
    parser.add_argument('--if_infer', type=int, default=0)
    parser.add_argument('--normalize', type=int)
    parser.add_argument('--dwt', type=int)
    parser.add_argument('--cummulative', type=int)
    return parser

parser = parse_arguments()
args = parser.parse_args()
basepath = 'trained_models/'

data = np.load('data.npz')
labels = data['label']

xType = 'wave'
maxtrc = args.num_traces
bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses

def get_meanrank(xTest_multi, yTest_vals, model, maxtrc):
    start_time = time.time()
    multi_rank = []
    for i in range(len(yTest_vals)): #Iterate each key
        nruns = 1 #data for multi-label is limited, so we did 1 run only
        batches = np.zeros((nruns, maxtrc), 'int')
        batches[0] = np.arange(maxtrc)
        #for i in range(nruns_default):
        #    batches[i,:] = np.random.choice(len(self.xTest[0][0]), maxtrc_default, False)
        test_rank = eval_model(model, nruns, maxtrc, batches, [[xTest_multi[i]]], [yTest_vals[i]], noHypoKeys, noClasses)
        multi_rank.append(test_rank)
    multi_rank = np.array(multi_rank)
    print('Eval took {}'.format(time.time()-start_time))
    return multi_rank


def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    realkey = int(yTest_value[0])
    rankmat_byKey = np.tile(0, (nruns, maxtrc))
    rankmat_byClass = np.tile(0, (nruns, maxtrc))
    ps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    lpsums_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    print("-------------------------------------------")
    print(len(xTest[0][0]))
    #print('%s  is running' % (model.__name__))
    for krun in range(nruns):
        #print('%s  run %d of %d' % (model.__name__, krun+1, nruns))
        #if (krun % nruns) == 0:
        #    print('%s  run %d of %d' % (model.name, krun+1, nruns))
        samp = batches[krun,:]
        #ps = model.predict(U[samp,:])

        if xType == 'wave':
            ps = model.predict([xTest[0][0][samp,:,:]])
        elif xType == 'wavebp0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:]])
        elif xType == 'wavebp1':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01next0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:]])
        elif xType == 'wavebp01next01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:], xTest[1][3][samp,:]])

        lps = np.log(ps)
        lpsums = np.zeros(noHypoKeys)
        #lpsAllHypoKeys = np.zeros((maxtrc, noHypoKeys))
        for i in range(maxtrc):
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = AES_Sbox[P[samp[i]] ^ range(noHypoKeys)]
            realClass = realkey#S[realkey]
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = P[samp[i]] ^ range(noHypoKeys)
            #realClass = HWcompute(S[realkey])
            lpsAllHypoKeys = lps
            #for hypoKey in range(noHypoKeys):
            #   lpsAllHypoKeys[i, hypoKey] = lps[i, S[hypoKey]]
            #print('lpsums.shape =', lpsums.shape, ';   lps.shape =', lps.shape)
            lpsums += lps[i]#, S]
            lpsums_AllHypoKeys_Nruns[i,:,krun] = lpsums
            #print('realkey =', realkey)
            rnk_byKey = sum(lpsums > lpsums[realkey])
            rankmat_byKey[krun, i] = rnk_byKey
            rnk_byClass = sum(lps[i, :] > lps[i, realClass])
            rankmat_byClass[krun, i] = rnk_byClass
        ps_AllClasses_Nruns[:,:,krun] = ps
        lps_AllClasses_Nruns[:,:,krun] = lps
        lps_AllHypoKeys_Nruns[:,:,krun] = lpsAllHypoKeys
    return rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns

def eval_model(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses)
    plot_data = ['model_type', rankmat_byKey]
    mr = np.mean(rankmat_byKey, 0)

    return mr[-1]

from sklearn.preprocessing import StandardScaler
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def z_norm(timeseries):
    '''
    timeseries = np.expand_dims(timeseries, 1)
    scaler = StandardScaler()
    scaler.fit(timeseries)
    res = scaler.transform(timeseries)[:,0]
    '''

    return stats.zscore(timeseries)

def normalize_data_per_trace(data):
    print(data.shape)
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = normalize(data[j][i])

    return data

def z_normalize_data_per_trace(data):
    print(data.shape)
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = z_norm(data[j][i])

    return data

def dwt_reconstruct(data):
    (cA1, cD1) = pywt.dwt(data, 'db1')
    cD1 = np.zeros(cA1.shape)
    trace_profiling = pywt.idwt(cA1, cD1, 'db1')
    return trace_profiling

def data_dwt(data):
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = dwt_reconstruct(data[j][i])

    return data

def cummulative_transform(data):
    print(data.shape)
    for i in range(len(data)):
        for j in range(len(data[i])):
                data[i,j] = np.cumsum(data[i,j], dtype=float)
    return data

xTest_multi, yTest_multi = load_multi_attack(args.eval_path)
xTest_multi, yTest_multi = xTest_multi[:args.num_key,:args.num_traces,:], yTest_multi[:args.num_key]

if args.normalize == 1:
    xTest_multi = normalize_data_per_trace(xTest_multi)

xTest_multi = np.expand_dims(xTest_multi, axis = 3)
#test_best = get_meanrank(xTest_multi, yTest_multi, best_model, maxtrc)
print(xTest_multi.shape)
'''
end_model_path = os.path.join('multi_attack_trained_models/test_1_baseline_none_wave_0_200000_30000_balance_alpha_0.5_50000_index.npy/model_best.keras')
end_model = load_model(end_model_path)
test_end = get_meanrank(xTest_multi, yTest_multi, end_model, maxtrc)
print(test_end)
'''
data = data['data'][:200000]
baseline_mr_best = 76.37
baseline_240k_best = 23.1
baseline_160k_best = 36.78

test_key = 1733

folder_path = args.train_folder

#folder_path = 'multi_attack_trained_models/test_baseline_active_uncertainty_wave_10_0_200000_2000_ndex.npy_None_16'
random_path = 'multi_attack_trained_models/test_baseline_active_random_wave_10_0_200000_2000_ndex.npy_None_16'
#random_path = 'multi_attack_trained_models/test_1_baseline_none_wave_0_200000_30000_balance_alpha_0.5_50000_index.npy'
#random_path = 'multi_attack_trained_models/test_retrain_active_uncertainty_wave_3_0_200000_2000_ndex.npy_None_768'
un_label_path = 'multi_attack_trained_models/test_update_active_uncertainty_balance_label_wave_1_0_200000_2000_ndex.npy_None_128'
un_balance_path_notrain = 'multi_attack_trained_models/test_baseline_fix3_active_uncertainty_balance_wave_10_0_200000_2000_ndex.npy_None_16'
#folder_path = 'multi_attack_trained_models/test_update_active_uncertainty_balance_wave_10_0_200000_2000'
un_balance_update = 'multi_attack_trained_models/'
def get_info(folder_path, infer = False):
    #all_path = sorted(os.listdir(folder_path))
    all_path = natsorted(next(os.walk(folder_path))[1])
    print(all_path)
    all_labels = []
    all_num_samples = []
    all_multi = []
    all_rank = []
    all_std = []
    all_occ = []
    idx = 0
    rank_dict = {}
    prev_ids = []
    all_diff = []
    all_prev = []
    all_curr = []
    fname = os.path.join(folder_path, all_path[0])
    label_ids = np.load(os.path.join(fname, 'all_ids.npy'))
    core_ids = label_ids[:50000]
    for i in range(1, len(all_path)):
        all_counts = np.zeros(3329)
        fname = os.path.join(folder_path, all_path[i])
        prev_fname = os.path.join(folder_path, all_path[i-1])
        #prev_ids = np.load(os.path.join(prev_fname, 'all_ids.npy'))[-2000:]
        
        label_ids = np.load(os.path.join(fname, 'all_ids.npy'))
        print(len(label_ids))
        print(data.shape)
        subset_ids = label_ids[-2000:]
        print(label_ids[-1])
        print(subset_ids[-1])
        #Core evaluation
        subset_ids = core_ids
        subset_data = [data[subset_ids]]
        subset_gt = labels[subset_ids]
        #print(subset_data.shape)
        #exit()
        end_model_path = os.path.join(fname, 'model_best_end.keras')
        #end_model_path = os.path.join(fname, 'model_best.keras')
        end_model = load_model(end_model_path)
        prev_model_path = os.path.join(prev_fname, 'model_best_end.keras')
        prev_model = load_model(prev_model_path)
        #Prob towards correct key
        #test_all = get_meanrank(xTest_multi, yTest_multi, end_model, maxtrc)
        pred = end_model.predict(subset_data)
        prev_pred = prev_model.predict(subset_data)
        print(pred.shape)
        pred_prob = []
        prev_prob = []
        pred_rank = []
        prev_rank = []
        for j in range(len(subset_ids)):
            pred_prob.append(pred[j, subset_gt[j]])
            prev_prob.append(prev_pred[j, subset_gt[j]])

            pred_order = pred[j].argsort()
            pred_ranks = pred_order.argsort()
            pred_rank.append(pred_ranks[subset_gt[j]])

            prev_order = prev_pred[j].argsort()
            prev_ranks = prev_order.argsort()
            prev_rank.append(prev_ranks[subset_gt[j]])

        pred_prob = np.array(pred_prob)
        prev_prob = np.array(prev_prob)
        print(pred_prob.shape)
        all_prev.append(prev_rank)
        all_curr.append(pred_rank)
    return all_path, all_diff, all_curr, all_prev


def get_info_cummulative(folder_path, infer = False):
    #all_path = sorted(os.listdir(folder_path))
    all_path = natsorted(next(os.walk(folder_path))[1])
    print(all_path)
    all_labels = []
    all_num_samples = []
    all_multi = []
    all_rank = []
    all_std = []
    all_occ = []
    idx = 0
    rank_dict = {}
    prev_ids = []
    all_diff = []
    all_prev = []
    all_curr = []
    fname = os.path.join(folder_path, all_path[0])
    label_ids = np.load(os.path.join(fname, 'all_ids.npy'))
    cummulative_ids = []
    cummulative_ids_subset = []
    fname = os.path.join(folder_path, all_path[0])
    label_ids = np.load(os.path.join(fname, 'all_ids.npy'))
    core_ids = label_ids[:50000]
    num_curr = []
    num_medoids = []
    curr_ratios = []
    medoids_ratios = []
    all_full = []
    all_ss = []
    all_full_ratio = []
    all_ss_ratio = []
    for i in range(len(all_path)):
        all_counts = np.zeros(3329)
        fname = os.path.join(folder_path, all_path[i])
        #prev_fname = os.path.join(folder_path, all_path[i-1])
        #prev_ids = np.load(os.path.join(prev_fname, 'all_ids.npy'))[-2000:]
        
        #label_ids = np.load(os.path.join(fname, 'all_ids.npy'))[-2000:]
        label_ids = np.load(os.path.join(fname, 'all_ids.npy'))
        medoids_ids = np.load(os.path.join(fname, 'medoids_ids.npy'))
        
        cummulative_ids.append(medoids_ids)
        curr_ids = np.concatenate(cummulative_ids)
        unique_samples = np.unique(curr_ids)

        cummulative_ids_subset.append(label_ids)
        curr_ids_subset = np.concatenate(cummulative_ids_subset)
        unique_samples_subset = np.unique(curr_ids_subset)
        print(len(curr_ids))
        print(len(curr_ids_subset))
        print(len(unique_samples))
        print(len(unique_samples_subset))
        print('---')
        all_full.append(len(unique_samples))
        all_ss.append(len(unique_samples_subset ))
        '''
        cummulative_ids.append(medoids_ids)
        curr_ids = labels[np.concatenate(cummulative_ids)]
        unique_samples, full_count = np.unique(curr_ids, return_counts=True)

        cummulative_ids_subset.append(label_ids)
        curr_ids_subset = labels[np.concatenate(cummulative_ids_subset)]
        unique_samples_subset, ss_count = np.unique(curr_ids_subset, return_counts=True)
        all_full.append(len(unique_samples))
        all_ss.append(len(unique_samples_subset))
        all_full_ratio.append(np.max(full_count)/np.min(full_count))
        all_ss_ratio.append(np.max(ss_count)/np.min(ss_count))
        '''
        curr_labels = labels[label_ids]
        medoids_labels = labels[medoids_ids]
        unique_curr, curr_count = np.unique(curr_labels, return_counts=True)
        unique_medoids, medoids_count = np.unique(medoids_labels, return_counts=True)
        curr_ratio = np.max(curr_count)/np.min(curr_count)
        medoids_ratio = np.max(medoids_count)/np.min(medoids_count)
        num_curr.append(len(unique_curr))
        num_medoids.append(len(unique_medoids))
        curr_ratios.append(curr_ratio)
        medoids_ratios.append(medoids_ratio)

    res_dict = {}
    '''
    res_dict['num_curr'] = num_curr
    res_dict['num_medoids'] = num_medoids
    res_dict['curr_ratios'] = curr_ratios
    res_dict['medoids_ratios'] = medoids_ratios
    '''
    res_dict['all_full'] = all_full
    res_dict['all_ss'] = all_ss
    '''
    res_dict['all_full_ratio'] = all_full_ratio
    res_dict['all_ss_ratio'] = all_ss_ratio
    '''
    df = pd.DataFrame(res_dict)
    #df.to_csv(os.path.join(folder_path, 'analysis.csv'.format(args.eval_path[:-4])))
    #df.to_csv(os.path.join(folder_path, 'analysis_label_distibution.csv'.format(args.eval_path[:-4])))
    df.to_csv(os.path.join(folder_path, 'analysis_num.csv'.format(args.eval_path[:-4])))

    return all_path, all_diff, all_curr, all_prev

all_path, all_diff, all_curr, all_prev = get_info_cummulative(folder_path, infer=1)
exit()

all_path, all_diff, all_curr, all_prev = get_info(folder_path, infer=1)

res_dict = {}
for i in range(len(all_prev)):
    res_dict['prev_{}'.format(i)] = all_prev[i]
    res_dict['curr_{}'.format(i)] = all_curr[i]
df = pd.DataFrame(res_dict)
#df.to_csv(os.path.join(folder_path, 'analysis.csv'.format(args.eval_path[:-4])))
df.to_csv(os.path.join(folder_path, 'analysis_mean_rank_coreset.csv'.format(args.eval_path[:-4])))
#np.savez('res_random.npz', rd_multi)
#np.savez('res_uncertain.npz', all_multi)
#np.savez('res_bal.npz', bal_multi)
#np.savez('res_bal_update.npz', bal_multi_notrain)
#np.savez('res_bal_best.npz', bal_multi)
#np.savez('res_bal_update_best.npz', bal_multi_notrain)