import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
import os.path
import sys
import h5py
import math
import gc
import time
import numpy as np
#from numba import cuda
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
from pathlib import Path
import matplotlib.pyplot as plt
import os.path
import sys
import h5py
import math
import gc
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
from sklearn.cluster import KMeans
import torch.nn.functional as F

import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
#from sklearn_extra.cluster import KMedoid
from natsort import natsorted
import pywt
from natsort import natsorted
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from layers.Embed import DataEmbedding
from layers.Conv_Blocks import Inception_Block_V1


def FFT_for_Period(x, k=2):
    # [B, T, C]
    xf = torch.fft.rfft(x, dim=1)
    # find period by amplitudes
    frequency_list = abs(xf).mean(0).mean(-1)
    frequency_list[0] = 0
    _, top_list = torch.topk(frequency_list, k)
    top_list = top_list.detach().cpu().numpy()
    period = x.shape[1] // top_list
    return period, abs(xf).mean(-1)[:, top_list]


class TimesBlock(nn.Module):
    def __init__(self, configs):
        super(TimesBlock, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.k = configs.top_k
        # parameter-efficient design
        self.conv = nn.Sequential(
            Inception_Block_V1(configs.d_model, configs.d_ff,
                               num_kernels=configs.num_kernels),
            nn.GELU(),
            Inception_Block_V1(configs.d_ff, configs.d_model,
                               num_kernels=configs.num_kernels)
        )

    def forward(self, x):
        B, T, N = x.size()
        period_list, period_weight = FFT_for_Period(x, self.k)

        res = []
        for i in range(self.k):
            period = period_list[i]
            # padding
            if (self.seq_len + self.pred_len) % period != 0:
                length = (
                                 ((self.seq_len + self.pred_len) // period) + 1) * period
                padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
                out = torch.cat([x, padding], dim=1)
            else:
                length = (self.seq_len + self.pred_len)
                out = x
            # reshape
            out = out.reshape(B, length // period, period,
                              N).permute(0, 3, 1, 2).contiguous()
            # 2D conv: from 1d Variation to 2d Variation
            out = self.conv(out)
            # reshape back
            out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
            res.append(out[:, :(self.seq_len + self.pred_len), :])
        res = torch.stack(res, dim=-1)
        # adaptive aggregation
        period_weight = F.softmax(period_weight, dim=1)
        period_weight = period_weight.unsqueeze(
            1).unsqueeze(1).repeat(1, T, N, 1)
        res = torch.sum(res * period_weight, -1)
        # residual connection
        res = res + x
        return res


class TimesnetModel(nn.Module):
    """
    Paper link: https://openreview.net/pdf?id=ju_Uqw384Oq
    """

    def __init__(self, configs):
        super(TimesnetModel, self).__init__()
        self.configs = configs
        self.seq_len = configs.seq_len
        self.label_len = configs.label_len
        self.pred_len = configs.pred_len
        self.model = nn.ModuleList([TimesBlock(configs)
                                    for _ in range(configs.e_layers)])
        self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
                                           configs.dropout)
        self.layer = configs.e_layers
        self.layer_norm = nn.LayerNorm(configs.d_model)
        
        self.act = F.gelu
        self.dropout = nn.Dropout(configs.dropout)
        self.projection = nn.Linear(
            configs.d_model * configs.seq_len, configs.num_class)

    def classification(self, x_enc):
        # embedding
        enc_out = self.enc_embedding(x_enc, None)  # [B,T,C]
        # TimesNet
        for i in range(self.layer):
            enc_out = self.layer_norm(self.model[i](enc_out))

        # Output
        # the output transformer encoder/decoder embeddings don't include non-linearity
        output = self.act(enc_out)
        output = self.dropout(output)
        # zero-out padding embeddings
        #output = output * x_mark_enc.unsqueeze(-1)
        # (batch_size, seq_length * d_model)
        output = output.reshape(output.shape[0], -1)
        output = self.projection(output)  # (batch_size, num_classes)
        return output

    def forward(self, x_enc, mask=None):
        dec_out = self.classification(x_enc)
        return dec_out  # [B, N]

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--seq_len', type=int, default=600, help='input sequence length')
    parser.add_argument('--label_len', type=int, default=48, help='start token length')
    parser.add_argument('--pred_len', type=int, default=0, help='prediction sequence length')
    parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='subset for M4')
    parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False)

    parser.add_argument('--features', type=str, default='S',
                        help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
    parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')

    # model define
    parser.add_argument('--expand', type=int, default=2, help='expansion factor for Mamba')
    parser.add_argument('--d_conv', type=int, default=4, help='conv kernel size for Mamba')
    parser.add_argument('--top_k', type=int, default=5, help='for TimesBlock')
    parser.add_argument('--num_kernels', type=int, default=6, help='for Inception')
    parser.add_argument('--enc_in', type=int, default=1, help='encoder input size')
    parser.add_argument('--dec_in', type=int, default=1, help='decoder input size')
    parser.add_argument('--c_out', type=int, default=7, help='output size')
    parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
    parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
    parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
    parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
    parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
    parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
    parser.add_argument('--factor', type=int, default=1, help='attn factor')
    parser.add_argument('--distil', action='store_false',
                        help='whether to use distilling in encoder, using this argument means not using distilling',
                        default=True)
    parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--activation', type=str, default='gelu', help='activation')
    parser.add_argument('--channel_independence', type=int, default=1,
                        help='0: channel dependence 1: channel independence for FreTS model')
    parser.add_argument('--decomp_method', type=str, default='moving_avg',
                        help='method of series decompsition, only support moving_avg or dft_decomp')
    parser.add_argument('--use_norm', type=int, default=1, help='whether to use normalize; True 1 False 0')
    parser.add_argument('--down_sampling_layers', type=int, default=0, help='num of down sampling layers')
    parser.add_argument('--down_sampling_window', type=int, default=1, help='down sampling window size')
    parser.add_argument('--down_sampling_method', type=str, default=None,
                        help='down sampling method, only support avg, max, conv')
    parser.add_argument('--seg_len', type=int, default=96,
                        help='the length of segmen-wise iteration of SegRNN')

    # metrics (dtw)
    parser.add_argument('--use_dtw', type=bool, default=False,
                        help='the controller of using dtw metric (dtw is time consuming, not suggested unless necessary)')

    # TimeXer
    parser.add_argument('--patch_len', type=int, default=16, help='patch length')
    parser.add_argument('--model', type=str, default='CNN')

    parser.add_argument('--train_folder', type=str, help='eval folder')
    parser.add_argument('--eval_path', type=str, help='eval path')
    parser.add_argument('--xType', type=str, help='type, wavebp01 mean using ciphertext')
    parser.add_argument('--num_traces', type=int, default=100)
    parser.add_argument('--num_key', type=int, default=300)
    parser.add_argument('--normalize', type=int)
    parser.add_argument('--dwt', type=int)
    parser.add_argument('--cummulative', type=int)
    parser.add_argument('--all_ids', type=str, default = 'None')
    parser.add_argument('--all_path', type=int, default=0)
    parser.add_argument('--zmuv', type=int, default=0)
    return parser

parser = parse_arguments()
args = parser.parse_args()
basepath = 'trained_models/'

data = np.load('data.npz')
labels = data['label']

xType = args.xType
maxtrc = args.num_traces
bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses
sKeyNo = 0  # 
trainPortion = 1.0
def get_meanrank(xTest_multi, yTest_vals, model, maxtrc):
    start_time = time.time()
    multi_rank = []
    for i in range(len(yTest_vals)): #Iterate each key
        nruns = 1 #data for multi-label is limited, so we did 1 run only
        batches = np.zeros((nruns, maxtrc), 'int')
        batches[0] = np.arange(maxtrc)
        #for i in range(nruns_default):
        #    batches[i,:] = np.random.choice(len(self.xTest[0][0]), maxtrc_default, False)
        if args.xType == 'wavebp01':
            test_rank = eval_model(model, nruns, maxtrc, batches, [xTest_multi[0][i], xTest_multi[1][i]], [yTest_vals[i]], noHypoKeys, noClasses)
        else:
            test_rank = eval_model(model, nruns, maxtrc, batches, [[xTest_multi[i]]], [yTest_vals[i]], noHypoKeys, noClasses)
        multi_rank.append(test_rank)
    multi_rank = np.array(multi_rank)
    print('Eval took {}'.format(time.time()-start_time))
    return multi_rank


def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    if args.xType == 'wavebp01':
        data = [data, infile['bp']]
    labels = infile['label']

    return data, labels

def mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):

    realkey = int(yTest_value[0])
    rankmat_byKey = np.tile(0, (nruns, maxtrc))
    rankmat_byClass = np.tile(0, (nruns, maxtrc))
    ps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllClasses_Nruns = np.zeros((maxtrc, noClasses, nruns))
    lps_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    lpsums_AllHypoKeys_Nruns = np.zeros((maxtrc, noHypoKeys, nruns))
    print("-------------------------------------------")
    print(len(xTest[0][0]))
    #print('%s  is running' % (model.__name__))
    for krun in range(nruns):
        #print('%s  run %d of %d' % (model.__name__, krun+1, nruns))
        #if (krun % nruns) == 0:
        #    print('%s  run %d of %d' % (model.name, krun+1, nruns))
        samp = batches[krun,:]
        #ps = model.predict(U[samp,:])

        print(xTest[0][0][samp,:,0].shape)
        model = model.to('cuda')
        pred_batch = torch.from_numpy(xTest[0][0][samp,:,0]).to('cuda')
        pred_batch = pred_batch.float()
        if args.model == 'Timesnet':
            pred_batch = pred_batch.unsqueeze(-1)
        pred_batch = pred_batch.unsqueeze(1)
        #print(pred_batch.shape)
        logits = model(pred_batch)
        ps = F.softmax(logits, dim=1)  
        ps = ps.cpu().detach().numpy()
        #print(ps)
        #print(ps.shape)
        #exit()
        lps = np.log(ps)
        lpsums = np.zeros(noHypoKeys)
        #lpsAllHypoKeys = np.zeros((maxtrc, noHypoKeys))
        for i in range(maxtrc):
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = AES_Sbox[P[samp[i]] ^ range(noHypoKeys)]
            realClass = realkey#S[realkey]
            #S = AES_Sbox[P[samp[i]] ^ range(0x100)]
            #S = P[samp[i]] ^ range(noHypoKeys)
            #realClass = HWcompute(S[realkey])
            lpsAllHypoKeys = lps
            #for hypoKey in range(noHypoKeys):
            #   lpsAllHypoKeys[i, hypoKey] = lps[i, S[hypoKey]]
            #print('lpsums.shape =', lpsums.shape, ';   lps.shape =', lps.shape)
            lpsums += lps[i]#, S]
            lpsums_AllHypoKeys_Nruns[i,:,krun] = lpsums
            #print('realkey =', realkey)
            rnk_byKey = sum(lpsums > lpsums[realkey])
            rankmat_byKey[krun, i] = rnk_byKey
            rnk_byClass = sum(lps[i, :] > lps[i, realClass])
            rankmat_byClass[krun, i] = rnk_byClass
        ps_AllClasses_Nruns[:,:,krun] = ps
        lps_AllClasses_Nruns[:,:,krun] = lps
        lps_AllHypoKeys_Nruns[:,:,krun] = lpsAllHypoKeys
    return rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns

from sklearn.preprocessing import StandardScaler
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def z_norm(timeseries):
    '''
    timeseries = np.expand_dims(timeseries, 1)
    scaler = StandardScaler()
    scaler.fit(timeseries)
    res = scaler.transform(timeseries)[:,0]
    '''

    return stats.zscore(timeseries)

def normalize_data_per_trace(data):
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = normalize(data[j][i])

    return data

def z_normalize_data_per_trace(data):
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = z_norm(data[j][i])

    return data

def dwt_reconstruct(data):
    (cA1, cD1) = pywt.dwt(data, 'db1')
    cD1 = np.zeros(cA1.shape)
    trace_profiling = pywt.idwt(cA1, cD1, 'db1')
    return trace_profiling

def data_dwt(data):
    for j in range(len(data)):
        for i in range(len(data[j])):
            data[j][i] = dwt_reconstruct(data[j][i])

    return data

def cummulative_transform(data):
    print(data.shape)
    for i in range(len(data)):
        for j in range(len(data[i])):
                data[i,j] = np.cumsum(data[i,j], dtype=float)
    return data

def eval_model(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses):
    rankmat_byKey, rankmat_byClass, ps_AllClasses_Nruns, lps_AllClasses_Nruns, lps_AllHypoKeys_Nruns, lpsums_AllHypoKeys_Nruns = mk_rankmat(model, nruns, maxtrc, batches, xTest, yTest_value, noHypoKeys, noClasses)
    plot_data = ['model_type', rankmat_byKey]
    mr = np.mean(rankmat_byKey, 0)

    return mr[-1]
'''
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3, padding='same')
        self.pool1 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn1 = nn.BatchNorm1d(512)
        
        self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding='same')
        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.conv3 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding='same')
        self.pool3 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn3 = nn.BatchNorm1d(128)
        
        self.conv4 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, padding='same')
        self.pool4 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn4 = nn.BatchNorm1d(64)
        
        # Fully connected layers
        self.fc1 = nn.Linear(448, 1024)
        self.bn5 = nn.BatchNorm1d(1024)
        self.dropout1 = nn.Dropout(0.2)
        
        self.fc2 = nn.Linear(1024, 512)
        self.bn6 = nn.BatchNorm1d(512)
        
        self.fc3 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(0.2)
        
        self.fc4 = nn.Linear(256, 128)
        self.bn8 = nn.BatchNorm1d(128)
        
        self.fc5 = nn.Linear(128, 1024)
        self.bn9 = nn.BatchNorm1d(1024)
        self.dropout3 = nn.Dropout(0.2)
        
        self.fc6 = nn.Linear(1024, 1024)
        self.bn10 = nn.BatchNorm1d(1024)
        self.dropout4 = nn.Dropout(0.2)
        
        self.fc7 = nn.Linear(1024, 512)
        self.bn11 = nn.BatchNorm1d(512)
        
        self.fc8 = nn.Linear(512, 256)
        self.bn12 = nn.BatchNorm1d(256)
        self.dropout5 = nn.Dropout(0.2)
        
        self.fc9 = nn.Linear(256, 128)
        self.bn13 = nn.BatchNorm1d(128)
        
        self.fc10 = nn.Linear(128, 3329)
    
    def forward(self, x):
        # Add channel dimension
        x = x.unsqueeze(1)
        
        # Convolutional layers
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.bn1(x)
        
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.bn2(x)
        
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = self.bn3(x)
        
        x = F.relu(self.conv4(x))
        x = self.pool4(x)
        x = self.bn4(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.bn5(x)
        x = self.dropout1(x)
        
        x = F.relu(self.fc2(x))
        x = self.bn6(x)
        
        x = F.relu(self.fc3(x))
        x = self.bn7(x)
        x = self.dropout2(x)
        
        x = F.relu(self.fc4(x))
        x = self.bn8(x)
        
        x = F.relu(self.fc5(x))
        x = self.bn9(x)
        x = self.dropout3(x)
        
        x = F.relu(self.fc6(x))
        x = self.bn10(x)
        x = self.dropout4(x)
        
        x = F.relu(self.fc7(x))
        x = self.bn11(x)
        
        x = F.relu(self.fc8(x))
        x = self.bn12(x)
        x = self.dropout5(x)
        
        x = F.relu(self.fc9(x))
        x = self.bn13(x)
        
        x = self.fc10(x)
        
        return x
'''

from torch import nn
import torch.nn.functional as F
 
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
       
        # Convolutional layers
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3, padding='same')
        self.pool1 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn1 = nn.BatchNorm1d(512)
       
        self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding='same')
        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn2 = nn.BatchNorm1d(256)
       
        self.conv3 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding='same')
        self.pool3 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn3 = nn.BatchNorm1d(128)
       
        self.conv4 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, padding='same')
        self.pool4 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn4 = nn.BatchNorm1d(64)
       
        # Fully connected layers
        self.fc1 = nn.Linear(448, 1024)
        self.bn5 = nn.BatchNorm1d(1024)
        self.dropout1 = nn.Dropout(0.2)
       
        self.fc2 = nn.Linear(1024, 512)
        self.bn6 = nn.BatchNorm1d(512)
       
        self.fc3 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(0.2)
       
        self.fc4 = nn.Linear(256, 128)
        self.bn8 = nn.BatchNorm1d(128)
       
        self.fc5 = nn.Linear(128, 1024)
        self.bn9 = nn.BatchNorm1d(1024)
        self.dropout3 = nn.Dropout(0.2)
       
        self.fc6 = nn.Linear(1024, 1024)
        self.bn10 = nn.BatchNorm1d(1024)
        self.dropout4 = nn.Dropout(0.2)
       
        self.fc7 = nn.Linear(1024, 512)
        self.bn11 = nn.BatchNorm1d(512)
       
        self.fc8 = nn.Linear(512, 256)
        self.bn12 = nn.BatchNorm1d(256)
        self.dropout5 = nn.Dropout(0.2)
       
        self.fc9 = nn.Linear(256, 128)
        self.bn13 = nn.BatchNorm1d(128)
       
        self.fc10 = nn.Linear(128, 3329)
   
    def forward(self, x):
        # Add channel dimension
        # x = x.unsqueeze(1)
       
        # Convolutional layers
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.bn1(x)
       
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.bn2(x)
       
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = self.bn3(x)
       
        x = F.relu(self.conv4(x))
        x = self.pool4(x)
        x = self.bn4(x)
       
        # Flatten
        x = x.view(x.size(0), -1)
       
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.bn5(x)
        x = self.dropout1(x)
       
        x = F.relu(self.fc2(x))
        x = self.bn6(x)
       
        x = F.relu(self.fc3(x))
        x = self.bn7(x)
        x = self.dropout2(x)
       
        x = F.relu(self.fc4(x))
        x = self.bn8(x)
       
        x = F.relu(self.fc5(x))
        x = self.bn9(x)
        x = self.dropout3(x)
       
        x = F.relu(self.fc6(x))
        x = self.bn10(x)
        x = self.dropout4(x)
       
        x = F.relu(self.fc7(x))
        x = self.bn11(x)
       
        x = F.relu(self.fc8(x))
        x = self.bn12(x)
        x = self.dropout5(x)
       
        x = F.relu(self.fc9(x))
        x = self.bn13(x)
       
        x = self.fc10(x)
       
        return x

def ZMUV(power_trace,power_trace_target,n):    
    print('[LOG---- APPLYING ZMUVN.............]')
    indx=min(len(power_trace), len(power_trace_target))
    trainMean = np.mean(power_trace, axis = 0)
    testMean = np.mean(power_trace_target, axis = 0)
    print(np.shape(trainMean))
    print(np.shape(testMean))
    trainStd = np.std(power_trace, axis = 0)
    testStd = np.std(power_trace_target, axis = 0)
    print(np.shape(trainStd))
    print(np.shape(testStd))
    coeff=(trainStd/testStd)
    modified_trace=[]
    for i in range(n):
        t=[]
        for j in range(len(power_trace_target[i])):
            t.append((((power_trace_target[i][j]-testMean[j])*coeff[j])+trainMean[j]))
        modified_trace.append(t)
    print('[LOG----ZMUVN COMPLETE.............]')

    return modified_trace

def load_meta_trace_files(data_path, start_trace, end_trace):
    data = np.load(data_path)
    trace_profiling = data['data']
    bp_profiling = data['bp']
    skpv_profiling = data['label']

    return (trace_profiling[start_trace:end_trace], bp_profiling[:,start_trace:end_trace], skpv_profiling[start_trace:end_trace])#, fqmul_profiling)

def create_training_data_optimize(args, data_path, sKeyNo, trainPortion, xType, yType, is_test, start_trace, end_trace):
    val_num = 300000
    val_ids = np.load('val_ids.npz')['ids']
    print(len(val_ids))
    print(val_ids[:10])
    end_val_trace = end_trace + val_num
    if is_test:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file_from_test(data_path, sKeyNo)
    else:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(data_path, start_trace, end_val_trace)

    #train_dataset = tf.data.Dataset.from_tensor_slices((trace_profiling[:150000], tf.one_hot(skpv_profiling[:150000], 3329) ))
    #exit()

    #Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    print(trace_profiling.shape)


    dataSize = end_trace - start_trace
    trainSize = math.floor(dataSize * trainPortion)
    print(skpv_profiling[:10])
    
    #y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)
    y_train_skpv = skpv_profiling
    if args.normalize == 1:
        print(trace_profiling.shape)
        trace_profiling = normalize_data_per_trace([trace_profiling])[0]
    xTrain = trace_profiling[:dataSize]
    print(xTrain[0][:10])

    if is_test:
        xVal = trace_profiling[dataSize:end_val_trace]
        yVal = y_train_skpv[dataSize:end_val_trace]
        yVal_value = skpv_profiling[dataSize:end_val_trace]
    else:
        xVal = trace_profiling[val_ids]
        yVal = y_train_skpv[val_ids]
        yVal_value = skpv_profiling[val_ids]

    yTrain = y_train_skpv[:dataSize]
    yTrain_value = skpv_profiling[:dataSize]
    

    #del trace_profiling, bp_profiling, skpv_profiling
    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value

yType = 'skpv'
xTest_multi, yTest_multi = load_multi_attack(args.eval_path)
xTrain_original, yTrain_original, xVal, yVal, yTrain_value, yVal_value = create_training_data_optimize(args,'data.npz',sKeyNo, trainPortion, xType, yType,False, 0, 200000)

print(xTest_multi[1].shape)
if args.xType == 'wavebp01':
    xTest_multi, yTest_multi = [xTest_multi[0][:args.num_key,:args.num_traces,:], xTest_multi[1][:args.num_key,:2,:args.num_traces,]], yTest_multi[:args.num_key]
else:
    xTest_multi, yTest_multi = xTest_multi[:args.num_key,:args.num_traces,:], yTest_multi[:args.num_key]

print(xTest_multi[1].shape)


#xTest_multi[1] = to_categorical(xTest_multi[1], num_classes=3330)

if args.dwt == 1:
    xTest_multi = data_dwt(xTest_multi)


if args.normalize == 1:
    if args.cummulative == 1:
        xTest_multi = normalize_data_per_trace(cummulative_transform(xTest_multi))
    elif args.cummulative == 2:
        xTest_multi = cummulative_transform(normalize_data_per_trace(xTest_multi))
    else:
        xTest_multi = normalize_data_per_trace(xTest_multi)
elif args.normalize == 2:
    if args.cummulative == 1:
        xTest_multi = z_normalize_data_per_trace(cummulative_transform(xTest_multi))
    elif args.cummulative == 2:
        xTest_multi = cummulative_transform(z_normalize_data_per_trace(xTest_multi))
    else:
        xTest_multi = z_normalize_data_per_trace(xTest_multi)
    #print(xTrain.shape)
if args.xType != 'wavebp01':
    xTest_multi = np.expand_dims(xTest_multi, axis = 3)
#test_best = get_meanrank(xTest_multi, yTest_multi, best_model, maxtrc)
print(yTest_multi)
'''
end_model_path = os.path.join('multi_attack_trained_models/test_1_baseline_none_wave_0_200000_30000_balance_alpha_0.5_50000_index.npy/model_best.keras')
end_model = load_model(end_model_path)
test_end = get_meanrank(xTest_multi, yTest_multi, end_model, maxtrc)
print(test_end)
'''
baseline_mr_best = 76.37
baseline_240k_best = 23.1
baseline_160k_best = 36.78

test_key = 1733

folder_path = args.train_folder
print(xTest_multi.shape)

def get_info(args, folder_path, xTrain_original, xTest_multi):
    
    all_counts = np.zeros(3329)
    fname = folder_path
    args.num_class = 3329
    end_model_path = os.path.join(fname, 'model.pth')
    if args.model == 'Timesnet':
        model =TimesnetModel(args)
    else:
        model = CNNModel()
    #model.load_state_dict(torch.load(end_model_path, weights_only=True))
    model = torch.load(end_model_path, weights_only=False)
    model.eval()
    if args.all_ids != 'None':
        all_ids = np.load(args.all_ids)
        train_data = xTrain_original[all_ids]
    else:
        train_data = xTrain_original
    if args.zmuv == 1:
        print(train_data.shape)
        print(xTest_multi.shape)
        xTest_res = []
        for i in range(xTest_multi.shape[0]):
            xTest_perkey = ZMUV(train_data, xTest_multi[i,:,:,0], args.num_traces)
            xTest_perkey = np.array(xTest_perkey)
            print(xTest_perkey.shape)
            xTest_res.append(xTest_perkey)
            
        xTest_multi = np.expand_dims(np.array(xTest_res), axis=3)
        print(xTest_multi.shape)

    all_test = get_meanrank(xTest_multi, yTest_multi, model, maxtrc)
    df = pd.DataFrame({
        'Mean_Rank': all_test
    })
    
    return df

df = get_info(args, folder_path, xTrain_original, xTest_multi)
df.to_csv(os.path.join(folder_path, 'mean_rank_full_{}_{}_zmuv_{}.csv'.format(args.eval_path[:-4], args.num_traces, args.zmuv)))
#np.savez('res_random.npz', rd_multi)
#np.savez('res_uncertain.npz', all_multi)
#np.savez('res_bal.npz', bal_multi)
#np.savez('res_bal_update.npz', bal_multi_notrain)
#np.savez('res_bal_best.npz', bal_multi)
#np.savez('res_bal_update_best.npz', bal_multi_notrain)
