import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run oimport os.path
import sys
import h5py
import math
import gc
import numpy as np
#from numba import cuda
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, ReLU, Dropout, Concatenate, Multiply, BatchNormalization #, AveragePooling1D, Add, GlobalAveragePooling1D, GlobalMaxPooling1D
from tensorflow.keras.utils import plot_model   #, get_source_inputs
from tensorflow.keras.utils import get_file
from tensorflow.keras import backend as K
#from keras_applications.imagenet_utils import decode_predictions
#from keras_applications.imagenet_utils import preprocess_input
#from keras_applications.imagenet_utils import _obtain_input_shape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model
import matplotlib.pyplot as plt
import argparse
import os
from pathlib import Path
import pywt

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--train_folder', type=str, help='baseline or active')
    parser.add_argument('--xType', type=str, help='number of ciphertext')
    parser.add_argument('--model_name', type=str, help='name of model')
    parser.add_argument('--baseline', type=str, help='name of model')
    parser.add_argument('--active', type=str, help='name of model')
    return parser

parser = parse_arguments()
args = parser.parse_args()


# Trace and metadata parameters
bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses
sKeyNo = 0  # Note: sKeyNo is in range 0 to 3 and which subkeys are they are decided by code in m4 (NOT by code in PC)
work = 'train' #'train'  #'attack'

attack_byModel_epNo = 232

nruns_default = 20
maxtrc_default = 200
testPortion = 1
attack_byModel_epNo = 232

# training parameters
train_batch_size = 400 #80 for mars45 #170 for mars56
period = 8
maxEpochs = 2048#3072#2048#1536#1280#1024#512#256
attack_byModel_fileNo = int(attack_byModel_epNo/period)

#model hyper-parameters
noConv1Dbranch = 1
noBPbranch = 4
noLayers = 6# if newly train
GPU_clear = True    # False

# training data type
xType = args.xType  #'wave' #'wavebp0' #'wavebp1' #'wavebp01' #'wavebp01next0' #'wavebp01next01'
yType = 'skpv'    #'fqmul0' #'fqmul1' #'skpv' 
trainPortion = 0.8
def check_file_exists(file_path):
    if os.path.exists(file_path) == False:
        print("Error: provided file path '%s' does not exist!" % file_path)
        sys.exit(-1)
    return

def load_meta_trace_file(database_file, sKeyNo, load_metadata=False):
    print('\nLoad database_file =', database_file)
    check_file_exists(database_file)
    # Open the Kyber database HDF5 for reading
    try:
        in_file  = h5py.File(database_file, "r")
    except:
        print("Error: can't open HDF5 file '%s' for reading (it might be malformed) ..." % database_file)
        sys.exit(-1)
    # Load profiling traces
    #trace_profiling = np.array(in_file['wave'], dtype=np.float)
    trace_profiling = np.array(in_file['wave'], dtype=float)
    #skpv_a_vec0_evenCoeff0 = np.array(in_file['sca_tmp_skpv'][:,sKeyNo])
    skpv_a_vec0_evenCoeff0 = np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo].astype(int))
    skpv_a_vec0_evenCoeff0_next_sKeyNo = np.array(in_file['skpv_a_vec0_evenCoeff0'][:,sKeyNo+1].astype(int))
    skpv_a_vec0_oddCoeff1 = np.array(in_file['skpv_a_vec0_oddCoeff1'][:,sKeyNo].astype(int))
    skpv_a_vec0_oddCoeff1_next_sKeyNo = np.array(in_file['skpv_a_vec0_oddCoeff1'][:,sKeyNo+1].astype(int))
    skpv_a_vec1_evenCoeff0 = np.array(in_file['skpv_a_vec1_evenCoeff0'][:,sKeyNo].astype(int))
    skpv_a_vec1_evenCoeff0_next_sKeyNo = np.array(in_file['skpv_a_vec1_evenCoeff0'][:,sKeyNo+1].astype(int))
    skpv_a_vec1_oddCoeff1 = np.array(in_file['skpv_a_vec1_oddCoeff1'][:,sKeyNo].astype(int))
    skpv_a_vec1_oddCoeff1_next_sKeyNo = np.array(in_file['skpv_a_vec1_oddCoeff1'][:,sKeyNo+1].astype(int))
    skpv_profiling = skpv_a_vec0_evenCoeff0
    bp_b_vec0_evenCoeff0 = np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo].astype(int))
    bp_b_vec0_evenCoeff0_next_sKeyNo = np.array(in_file['bp_b_vec0_evenCoeff0'][:,sKeyNo+1].astype(int))
    bp_b_vec0_oddCoeff1 = np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo].astype(int))
    bp_b_vec0_oddCoeff1_next_sKeyNo = np.array(in_file['bp_b_vec0_oddCoeff1'][:,sKeyNo+1].astype(int))
    bp_b_vec1_evenCoeff0 = np.array(in_file['bp_b_vec1_evenCoeff0'][:,sKeyNo].astype(int))
    bp_b_vec1_evenCoeff0_next_sKeyNo = np.array(in_file['bp_b_vec1_evenCoeff0'][:,sKeyNo+1].astype(int))
    bp_b_vec1_oddCoeff1 = np.array(in_file['bp_b_vec1_oddCoeff1'][:,sKeyNo].astype(int))
    bp_b_vec1_oddCoeff1_next_sKeyNo = np.array(in_file['bp_b_vec1_oddCoeff1'][:,sKeyNo+1].astype(int))
    
    sca_bp_in = np.array(in_file['sca_bp_in'])
    bp_profiling = [bp_b_vec0_evenCoeff0, bp_b_vec0_oddCoeff1, bp_b_vec0_evenCoeff0_next_sKeyNo, bp_b_vec0_oddCoeff1_next_sKeyNo]
    a_vec0_evenCoeff_by_b_vec0_evenCoeff = np.array(in_file['a_vec0_evenCoeff_by_b_vec0_evenCoeff'][:,sKeyNo])
    a_vec0_evenCoeff_by_b_vec0_oddCoeff = np.array(in_file['a_vec0_evenCoeff_by_b_vec0_oddCoeff'][:,sKeyNo])
    fqmul_profiling = [a_vec0_evenCoeff_by_b_vec0_evenCoeff, a_vec0_evenCoeff_by_b_vec0_oddCoeff]
    
    if load_metadata == False:
        return (trace_profiling, bp_profiling, skpv_profiling, fqmul_profiling)
    else:
        return (trace_profiling, bp_profiling, skpv_profiling, fqmul_profiling)

def create_training_data_form(database_folder_train_file, sKeyNo, trainPortion, xType, yType):
    (trace_profiling, bp_profiling, skpv_profiling, fqmul_profiling) = load_meta_trace_file(database_folder_train_file, sKeyNo)
    Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    dataSize = Reshaped_trace_profiling.shape[0]
    trainSize = math.floor(dataSize * trainPortion)
    valLoc = trainSize
    if valLoc == dataSize:
        valLoc = dataSize - 1

    lineNo = list(range(0, bp_profiling[0].shape[0]))
    #bp0_1hot_profiling = np.zeros((bp_profiling[0].shape[0], NumBPinput)).astype(np.int)
    bp0_1hot_profiling = np.zeros((bp_profiling[0].shape[0], NumBPinput)).astype(int)
    print('bp0_1hot_profiling.shape =', bp0_1hot_profiling.shape, '    bp_profiling[0] =', bp_profiling[0])
    bp0_1hot_profiling[lineNo,bp_profiling[0]] = 1
    Reshaped_bp0_1hot_profiling = bp0_1hot_profiling.reshape((bp0_1hot_profiling.shape[0], NumBPinput, 1))
    #print('Reshaped_bp0_1hot_profiling.shape = ', Reshaped_bp0_1hot_profiling.shape)
   
    lineNo = list(range(0, bp_profiling[1].shape[0]))
    #bp1_1hot_profiling = np.zeros((bp_profiling[1].shape[0], NumBPinput)).astype(np.int)
    bp1_1hot_profiling = np.zeros((bp_profiling[1].shape[0], NumBPinput)).astype(int)
    print('bp1_1hot_profiling.shape =', bp1_1hot_profiling.shape, '    bp_profiling[1] =', bp_profiling[1])
    #input()
    bp1_1hot_profiling[lineNo,bp_profiling[1]] = 1
    Reshaped_bp1_1hot_profiling = bp1_1hot_profiling.reshape((bp1_1hot_profiling.shape[0], NumBPinput, 1))
    
    lineNo = list(range(0, bp_profiling[2].shape[0]))
    #bp0_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[2].shape[0], NumBPinput)).astype(np.int)
    bp0_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[2].shape[0], NumBPinput)).astype(int)
    print('bp0_1hot_profiling_next_sKeyNo.shape =', bp0_1hot_profiling_next_sKeyNo.shape, '    bp_profiling[2] =', bp_profiling[2])
    bp0_1hot_profiling_next_sKeyNo[lineNo,bp_profiling[2]] = 1
    Reshaped_bp0_1hot_profiling_next_sKeyNo = bp0_1hot_profiling_next_sKeyNo.reshape((bp0_1hot_profiling_next_sKeyNo.shape[0], NumBPinput, 1))
    
    lineNo = list(range(0, bp_profiling[3].shape[0]))
    #bp1_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[3].shape[0], NumBPinput)).astype(np.int)
    bp1_1hot_profiling_next_sKeyNo = np.zeros((bp_profiling[3].shape[0], NumBPinput)).astype(int)
    print('bp1_1hot_profiling_next_sKeyNo.shape =', bp1_1hot_profiling_next_sKeyNo.shape, '    bp_profiling[3] =', bp_profiling[3])
    #input()
    bp1_1hot_profiling_next_sKeyNo[lineNo,bp_profiling[3]] = 1
    Reshaped_bp1_1hot_profiling_next_sKeyNo = bp1_1hot_profiling_next_sKeyNo.reshape((bp1_1hot_profiling_next_sKeyNo.shape[0], NumBPinput, 1))
    
    y_train_fqmul0 = to_categorical(fqmul_profiling[0], num_classes=NumFQMULclasses)
    y_train_fqmul1 = to_categorical(fqmul_profiling[1], num_classes=NumFQMULclasses)
    y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)

    xTrain_wave = [[Reshaped_trace_profiling[0:trainSize,:,:]]]
    xTrain_wavebp0 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:]]]
    xTrain_wavebp1 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp1_1hot_profiling[0:trainSize,:]]]
    xTrain_wavebp01 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]]
    xTrain_wavebp01next0 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:]]]
    xTrain_wavebp01next01 = [[Reshaped_trace_profiling[0:trainSize,:,:]], [Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[0:trainSize,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[0:trainSize,:]]]
    #xTrain_wave = Reshaped_trace_profiling[0:trainSize,:,:]
    #xTrain_wavebp0 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp0_1hot_profiling[0:trainSize,:]]
    #xTrain_wavebp1 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]
    #xTrain_wavebp01 = [Reshaped_trace_profiling[0:trainSize,:,:], Reshaped_bp0_1hot_profiling[0:trainSize,:], Reshaped_bp1_1hot_profiling[0:trainSize,:]]
    yTrain_fqmul0 = y_train_fqmul0[0:trainSize,:]
    yTrain_fqmul1 = y_train_fqmul1[0:trainSize,:]
    yTrain_skpv = y_train_skpv[0:trainSize,:]

    xVal_wave = [[Reshaped_trace_profiling[valLoc:,:,:]]]
    xVal_wavebp0 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:]]]
    xVal_wavebp1 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp1_1hot_profiling[valLoc:,:]]]
    xVal_wavebp01 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]]
    xVal_wavebp01next0 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:]]]
    xVal_wavebp01next01 = [[Reshaped_trace_profiling[valLoc:,:,:]], [Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:], Reshaped_bp0_1hot_profiling_next_sKeyNo[valLoc:,:], Reshaped_bp1_1hot_profiling_next_sKeyNo[valLoc:,:]]]
    #xVal_wave = Reshaped_trace_profiling[valLoc:,:,:]
    #xVal_wavebp0 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp0_1hot_profiling[valLoc:,:]]
    #xVal_wavebp1 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]
    #xVal_wavebp01 = [Reshaped_trace_profiling[valLoc:,:,:], Reshaped_bp0_1hot_profiling[valLoc:,:], Reshaped_bp1_1hot_profiling[valLoc:,:]]
    yVal_fqmul0 = y_train_fqmul0[valLoc:,:]
    yVal_fqmul1 = y_train_fqmul1[valLoc:,:]
    yVal_skpv = y_train_skpv[valLoc:,:]

    # Input data creation
    if xType == 'wave':
        xTrain = xTrain_wave
        xVal = xVal_wave
    elif xType == 'wavebp0':
        xTrain = xTrain_wavebp0
        xVal = xVal_wavebp0
    elif xType == 'wavebp1':
        xTrain = xTrain_wavebp1
        xVal = xVal_wavebp1
    elif xType == 'wavebp01':
        xTrain = xTrain_wavebp01
        xVal = xVal_wavebp01
    elif xType == 'wavebp01next0':
        xTrain = xTrain_wavebp01next0
        xVal = xVal_wavebp01next0
    elif xType == 'wavebp01next01':
        xTrain = xTrain_wavebp01next01
        xVal = xVal_wavebp01next01
        #print('Created xType, len(xTrain) :', xType, len(xTrain))
        #print('len(xTrain[0]) =', len(xTrain[0]), ';    len(xTrain[1]) =', len(xTrain[1]), 'PRESS ENTER TO CONTINUE')
        #print('PRESS ENTER TO CONTINUE')
        #input()
    # Category creation
    if yType == 'fqmul0':
        yTrain = yTrain_fqmul0
        yTrain_value = fqmul_profiling[0][0:trainSize]
        yVal = yVal_fqmul0
        yVal_value = fqmul_profiling[0][valLoc:]
    elif yType == 'fqmul1':
        yTrain = yTrain_fqmul1
        yTrain_value = fqmul_profiling[1][0:trainSize]
        yVal = yVal_fqmul1
        yVal_value = fqmul_profiling[1][valLoc:]
    elif yType == 'skpv':
        yTrain = yTrain_skpv
        yTrain_value = skpv_profiling[0:trainSize]
        yVal = yVal_skpv
        yVal_value = skpv_profiling[valLoc:]

    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value
    
def plot_meanrank(rankmat, maxtrc, label):
    nt = np.arange(maxtrc) + 1
    mr = np.mean(rankmat, 0)
    plt.xlabel('number of traces')
    plt.ylabel('mean rank')
    plt.plot(nt, mr, label = label)
    #print("label =", label)
    #print("nt =", nt)
    #print("mr =", mr)
    #plt.plot(nt, mr)

def mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    realkey = int(yTest_value[0])
    rankmat_byKey = np.tile(0, (nruns, maxtrc))
    rankmat_byClass = np.tile(0, (nruns, maxtrc))
    ps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    lpsums_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    print("-------------------------------------------")
    print(len(xTest[0][0]))
    #print('%s  is running' % (model.__name__))
    for krun in range(nruns):
        #print('%s  run %d of %d' % (model.__name__, krun+1, nruns))
        if (krun % nruns) == 0:
            print('%s  run %d of %d' % (model.name, krun+1, nruns))
        samp = batches[krun,:]
        #ps = model.predict(U[samp,:])

        if xType == 'wave':
            ps = model.predict([xTest[0][0][samp,:,:]])
        elif xType == 'wavebp0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:]])
        elif xType == 'wavebp1':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:]])
        elif xType == 'wavebp01next0':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:]])
        elif xType == 'wavebp01next01':
            ps = model.predict([xTest[0][0][samp,:,:], xTest[1][0][samp,:], xTest[1][1][samp,:], xTest[1][2][samp,:], xTest[1][3][samp,:]])

        lps = np.log(ps)
        lpsums = np.zeros(noHypoKeys)
        #lpsAllHypoKeys = np.zeros((maxtrc, noHypoKeys))
        for i in range(maxtrc):
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = AES_Sbox[P[samp[i]] ^ range(noHypoKeys)]
            realClass = realkey#S[realkey]
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = P[samp[i]] ^ range(noHypoKeys)
            #realClass = HWcompute(S[realkey])
            lpsAllHypoKeys = lps
            #for hypoKey in range(noHypoKeys):
            #   lpsAllHypoKeys[i, hypoKey] = lps[i, S[hypoKey]]
            #print('lpsums.shape =', lpsums.shape, ';   lps.shape =', lps.shape)
            lpsums += lps[i]#, S]
            lpsums_AllHypoKeys_Nruns[i,:,krun] = lpsums
            #print('realkey =', realkey)
            rnk_byKey = sum(lpsums > lpsums[realkey])
            rankmat_byKey[krun, i] = rnk_byKey
            rnk_byClass = sum(lps[i, :] > lps[i, realClass])
            rankmat_byClass[krun, i] = rnk_byClass
        ps_AllClasses_Nruns[:,:,krun] = ps
        lps_AllClasses_Nruns[:,:,krun] = lps
        lps_AllHypoKeys_Nruns[:,:,krun] = lpsAllHypoKeys
    return rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns

nruns_default = 20
maxtrc_default = 200
testPortion = 1
database_folder_attack_file = "KYBER51.H5"
model_index = 18

#basepath = 'trained_models/Active_MD'
basepath = args.train_folder
save_folder = os.path.join(basepath, 'log')
Path(save_folder).mkdir(parents=True, exist_ok=True)
#xType = None
# training data type
#xType = 'wavebp01next01'  #'wave' #'wavebp0' #'wavebp1' #'wavebp01' #'wavebp01next0' #'wavebp01next01'
yType = 'skpv'    #'fqmul0' #'fqmul1' #'skpv' 
#Test batch generation
'''
if '0c' in model_type:
    xType = 'wave'
if '1c' in model_type:
    xType = 'wavebp0'
if '2c' in model_type:
    xType = 'wavebp01'
if '3c' in model_type:
    xType = 'wavebp01next0'
if '4c' in model_type:
    xType = 'wavebp01next01'
'''

def wavelet_transform(data):
    print(data[0].shape)
    cA, cD = pywt.dwt(np.squeeze(data[0]), 'db1')
    print(cA.shape)
    new_data = []
    for index in range(len(data)):
        cA, cD = pywt.dwt(np.squeeze(data[index]), 'db1')
        cA = np.expand_dims(cA, axis=1)
        new_data.append(cA)
    new_data = np.array(new_data)
    print(new_data.shape)
    return new_data

import os



xTest, yTest, xVal_, yVal_, yTest_value, yVal_value = create_training_data_form(database_folder_attack_file, sKeyNo, testPortion, xType, yType)
print('len(xTest[0][0]) =', len(xTest[0][0]))
print('yTest[0][1730:1740] =', yTest[0][1730:1740])
print('yTest[1][1730:1740] =', yTest[1][1730:1740])
print('yTest[1][1733] =', yTest[1][1733])
print('yTest_value =', yTest_value)
batches = np.zeros((nruns_default, maxtrc_default), 'int')
for i in range(nruns_default):
    batches[i,:] = np.random.choice(len(xTest[0][0]), maxtrc_default, False)
print("batches.shape =", batches.shape)
print("batches =", batches)


'''
plt.figure(figsize=(10,7))
plt.rcParams.update({'font.size': 25})
plt.grid()
plt.ylim(0,50)
plt.xlim(0,50)
plot_meanrank(rankmat_byKey, maxtrc_default, model.name)
#plt.legend()   
plt.title('Model comparison using %d test runs' % nruns_default)
plt.tight_layout()  
#plt.savefig(attackLogFolder + logFilename + 'attack' + str(nruns) + 'run' + str(maxtrc) + 'trc_fileNo' + str(attack_byModel_fileNo) + '.png')
#plt.savefig(model_path + 'attack' + str(nruns) + 'run' + str(maxtrc) + '.png')
#plt.savefig(trainedFolder + saveFileName + '.pdf')
#plt.show(block = False)
print(os.path.join(save_folder, "out.png"))
plt.savefig(os.path.join(save_folder, "out.png"))

#Plot single
ep =  654
model_path = os.path.join(basepath, 'cp-{:04d}.keras'.format(ep))
model = load_model(model_path)
print('Attack start')
rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat(model, nruns_default, maxtrc_default, batches, xTest, yTest_value, noHypoKeys, noClasses)
plot_data = ['model_type', rankmat_byKey]
print('FINISH attacking')

plt.figure(figsize=(10,7))
plt.rcParams.update({'font.size': 25})
plt.grid()
plt.ylim(0,50)
plot_meanrank(rankmat_byKey, maxtrc_default, model.name)
#plt.legend()   
plt.title('Model comparison using %d test runs' % nruns_default)
plt.tight_layout()  
#plt.savefig(attackLogFolder + logFilename + 'attack' + str(nruns) + 'run' + str(maxtrc) + 'trc_fileNo' + str(attack_byModel_fileNo) + '.png')
#plt.savefig(model_path + 'attack' + str(nruns) + 'run' + str(maxtrc) + '.png')
#plt.savefig(trainedFolder + saveFileName + '.pdf')
#plt.show(block = False)
plt.savefig(os.path.join(save_folder, "out"+str(ep)+".png"))
'''

model_baseline = load_model(args.baseline)
print('Attack start')
rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat(model_baseline, nruns_default, maxtrc_default, batches, xTest, yTest_value, noHypoKeys, noClasses)
plot_data_baseline = ['model_type', rankmat_byKey]
print('FINISH attacking')

model_active= load_model(args.active)
xTest[0][0] = wavelet_transform(xTest[0][0])
print('Attack start')
rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat(model_active, nruns_default, maxtrc_default, batches, xTest, yTest_value, noHypoKeys, noClasses)
plot_data_active = ['model_type', rankmat_byKey]
print('FINISH attacking')

#out = np.load(os.path.join(basepath, 'output_200k_2c.npz'))
#model_type = str(out['model_type'])
#print(model_type)
#data = out['data']
#dt = [model_type, data]

nt = np.arange(200) + 1

plt.figure(figsize=(10,7))
plt.rcParams.update({'font.size': 25})
plt.grid()
plt.ylim(0,200)
plot_meanrank(plot_data_baseline[1], maxtrc_default, 'raw signal')
plot_meanrank(plot_data_active[1], maxtrc_default, 'wavelet')
#plt.plot(nt, np.mean(data, 0), label = "200k_original", c='r')
plt.legend()   
plt.title('Model comparison using %d test runs' % nruns_default)
plt.tight_layout()  
#plt.savefig(attackLogFolder + logFilename + 'attack' + str(nruns) + 'run' + str(maxtrc) + 'trc_fileNo' + str(attack_byModel_fileNo) + '.png')
#plt.savefig(model_path + 'attack' + str(nruns) + 'run' + str(maxtrc) + '.png')
#plt.savefig(trainedFolder + saveFileName + '.pdf')
#plt.show(block = False)
plt.savefig("compare.png")
