import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from layers.Embed import DataEmbedding
from layers.Conv_Blocks import Inception_Block_V1


def FFT_for_Period(x, k=2):
    # [B, T, C]
    xf = torch.fft.rfft(x, dim=1)
    # find period by amplitudes
    frequency_list = abs(xf).mean(0).mean(-1)
    frequency_list[0] = 0
    _, top_list = torch.topk(frequency_list, k)
    top_list = top_list.detach().cpu().numpy()
    period = x.shape[1] // top_list
    return period, abs(xf).mean(-1)[:, top_list]


class TimesBlock(nn.Module):
    def __init__(self, configs):
        super(TimesBlock, self).__init__()
        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.k = configs.top_k
        # parameter-efficient design
        self.conv = nn.Sequential(
            Inception_Block_V1(configs.d_model, configs.d_ff,
                               num_kernels=configs.num_kernels),
            nn.GELU(),
            Inception_Block_V1(configs.d_ff, configs.d_model,
                               num_kernels=configs.num_kernels)
        )

    def forward(self, x):
        B, T, N = x.size()
        period_list, period_weight = FFT_for_Period(x, self.k)

        res = []
        for i in range(self.k):
            period = period_list[i]
            # padding
            if (self.seq_len + self.pred_len) % period != 0:
                length = (
                                 ((self.seq_len + self.pred_len) // period) + 1) * period
                padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
                out = torch.cat([x, padding], dim=1)
            else:
                length = (self.seq_len + self.pred_len)
                out = x
            # reshape
            out = out.reshape(B, length // period, period,
                              N).permute(0, 3, 1, 2).contiguous()
            # 2D conv: from 1d Variation to 2d Variation
            out = self.conv(out)
            # reshape back
            out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
            res.append(out[:, :(self.seq_len + self.pred_len), :])
        res = torch.stack(res, dim=-1)
        # adaptive aggregation
        period_weight = F.softmax(period_weight, dim=1)
        period_weight = period_weight.unsqueeze(
            1).unsqueeze(1).repeat(1, T, N, 1)
        res = torch.sum(res * period_weight, -1)
        # residual connection
        res = res + x
        return res


class TimesnetModel(nn.Module):
    """
    Paper link: https://openreview.net/pdf?id=ju_Uqw384Oq
    """

    def __init__(self, configs):
        super(TimesnetModel, self).__init__()
        self.configs = configs
        self.seq_len = configs.seq_len
        self.label_len = configs.label_len
        self.pred_len = configs.pred_len
        self.model = nn.ModuleList([TimesBlock(configs)
                                    for _ in range(configs.e_layers)])
        self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
                                           configs.dropout)
        self.layer = configs.e_layers
        self.layer_norm = nn.LayerNorm(configs.d_model)
        
        self.act = F.gelu
        self.dropout = nn.Dropout(configs.dropout)
        self.projection = nn.Linear(
            configs.d_model * configs.seq_len, configs.num_class)

    def classification(self, x_enc):
        # embedding
        enc_out = self.enc_embedding(x_enc, None)  # [B,T,C]
        # TimesNet
        for i in range(self.layer):
            enc_out = self.layer_norm(self.model[i](enc_out))

        # Output
        # the output transformer encoder/decoder embeddings don't include non-linearity
        output = self.act(enc_out)
        output = self.dropout(output)
        # zero-out padding embeddings
        #output = output * x_mark_enc.unsqueeze(-1)
        # (batch_size, seq_length * d_model)
        output = output.reshape(output.shape[0], -1)
        output = self.projection(output)  # (batch_size, num_classes)
        return output

    def forward(self, x_enc, mask=None):
        dec_out = self.classification(x_enc)
        return dec_out  # [B, N]


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
from pathlib import Path
import matplotlib.pyplot as plt
import os.path
import sys
import h5py
import math
import gc
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
from sklearn.cluster import KMeans
import pandas as pd


def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    # basic config
    '''
    parser.add_argument('--task_name', type=str, required=True, default='long_term_forecast',
                        help='task name, options:[long_term_forecast, short_term_forecast, imputation, classification, anomaly_detection]')
    parser.add_argument('--is_training', type=int, required=True, default=1, help='status')
    parser.add_argument('--model_id', type=str, required=True, default='test', help='model id')
    parser.add_argument('--model', type=str, required=True, default='Autoformer',
                        help='model name, options: [Autoformer, Transformer, TimesNet]')
    '''
    parser.add_argument('--seq_len', type=int, default=600, help='input sequence length')
    parser.add_argument('--label_len', type=int, default=48, help='start token length')
    parser.add_argument('--pred_len', type=int, default=0, help='prediction sequence length')
    parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='subset for M4')
    parser.add_argument('--inverse', action='store_true', help='inverse output data', default=False)

    parser.add_argument('--features', type=str, default='S',
                        help='forecasting task, options:[M, S, MS]; M:multivariate predict multivariate, S:univariate predict univariate, MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], you can also use more detailed freq like 15min or 3h')
    parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='location of model checkpoints')

    # model define
    parser.add_argument('--expand', type=int, default=2, help='expansion factor for Mamba')
    parser.add_argument('--d_conv', type=int, default=4, help='conv kernel size for Mamba')
    parser.add_argument('--top_k', type=int, default=5, help='for TimesBlock')
    parser.add_argument('--num_kernels', type=int, default=6, help='for Inception')
    parser.add_argument('--enc_in', type=int, default=1, help='encoder input size')
    parser.add_argument('--dec_in', type=int, default=1, help='decoder input size')
    parser.add_argument('--c_out', type=int, default=7, help='output size')
    parser.add_argument('--d_model', type=int, default=512, help='dimension of model')
    parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
    parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
    parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
    parser.add_argument('--d_ff', type=int, default=2048, help='dimension of fcn')
    parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
    parser.add_argument('--factor', type=int, default=1, help='attn factor')
    parser.add_argument('--distil', action='store_false',
                        help='whether to use distilling in encoder, using this argument means not using distilling',
                        default=True)
    parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--activation', type=str, default='gelu', help='activation')
    parser.add_argument('--channel_independence', type=int, default=1,
                        help='0: channel dependence 1: channel independence for FreTS model')
    parser.add_argument('--decomp_method', type=str, default='moving_avg',
                        help='method of series decompsition, only support moving_avg or dft_decomp')
    parser.add_argument('--use_norm', type=int, default=1, help='whether to use normalize; True 1 False 0')
    parser.add_argument('--down_sampling_layers', type=int, default=0, help='num of down sampling layers')
    parser.add_argument('--down_sampling_window', type=int, default=1, help='down sampling window size')
    parser.add_argument('--down_sampling_method', type=str, default=None,
                        help='down sampling method, only support avg, max, conv')
    parser.add_argument('--seg_len', type=int, default=96,
                        help='the length of segmen-wise iteration of SegRNN')

    # metrics (dtw)
    parser.add_argument('--use_dtw', type=bool, default=False,
                        help='the controller of using dtw metric (dtw is time consuming, not suggested unless necessary)')

    # TimeXer
    parser.add_argument('--patch_len', type=int, default=16, help='patch length')

    
    parser.add_argument('--train_type', type=str, help='baseline or active')
    parser.add_argument('--sampling', type=str, help='random, minmax or uncertainty')
    parser.add_argument('--name', type=str, help='experiment name', default='test')
    parser.add_argument('--xType', type=str, help='number of ciphertext')
    parser.add_argument('--start_trace', type=int, help='start trace')
    parser.add_argument('--end_trace', type=int)
    parser.add_argument('--batch_size', type=int, help='batch_size', default=256)
    parser.add_argument('--num_epoch', type=int, help='batch_size', default=256)
    parser.add_argument('--trained_model_path', type=str)
    parser.add_argument('--num_iteration', type=int, help='iteration_num', default=5)
    parser.add_argument('--all_ids', type=str)
    parser.add_argument('--eval_path', type=str)
    parser.add_argument('--num_sample', type=int, help='iteration_num', default=5)
    parser.add_argument('--eval_interval', type=int, help='iteration_num', default=1)
    parser.add_argument('--seed', type=int, default = 0)
    parser.add_argument('--normalize', type=int, default = 0)
    parser.add_argument('--transformer', type=int, default = 0)
    parser.add_argument('--mlp', type=int, default = 0)
    parser.add_argument('--cummulative', type=int, default = 0)
    parser.add_argument('--use_bp', type=int, default = 0)
    parser.add_argument('--lr', type=float, default = 0.00001)

    parser.add_argument('--node', type=int, default=256)
    parser.add_argument('--n_layers', type=int, default=6)
    parser.add_argument('--batch_norm', type=int, default=1)
    parser.add_argument('--num_class', type=int, default=3329)
    parser.add_argument('--dropout_rate', type=float, default=0.2)

    return parser

parser = parse_arguments()
args = parser.parse_args()



def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    #random.seed(seed)
    np.random.seed(seed)

set_seeds(args.seed)

bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses
sKeyNo = 0  # Note: sKeyNo is in range 0 to 3 and which subkeys are they are decided by code in m4 (NOT by code in PC)
work = 'train' #'train'  #'attack'
training_file_list = ['Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data000000to099999_600samples.h5',\
'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data100000to199999_600samples.h5']
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data200000to299999_600samples.h5']#,\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data300000to399999_600samples.h5',\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data400000to499999_600samples.h5']

trained_model_path = args.trained_model_path
data_path = 'data.npz'
nruns_default = 10
maxtrc_default = 200
testPortion = 1
attack_byModel_epNo = 232


# training parameters
train_batch_size = args.batch_size#100#150#200#250#500#640 #80 for mars45 #170 for mars56
period = 8 #8
maxEpochs = args.num_epoch#3072#2048#1536#1280#1024#512#256 #1536
attack_byModel_fileNo = int(attack_byModel_epNo/period)
N_TRACE = 20000
Threshold_Save = 200

#model hyper-parameters
noConv1Dbranch = 1
noLayers = 6    # if newly train
noClassificationLayer = 1
GPU_clear = True    # False

# training data type
xType = args.xType  #'wave' #'wavebp0' #'wavebp1' #'wavebp01' #'wavebp01next0' #'wavebp01next01'
yType = 'skpv'    #'fqmul0' #'fqmul1' #'skpv' 
trainPortion = 0.8

# Database and logs for model and training progress (epochs)
attackModel = 'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1'
device = 'm4_CWLite'
attackModel_dev = attackModel + '_' + device
attackModel_dev_folder = '../' + attackModel_dev + '/'

MLmodelStruct = '4C4FC_2BP4FC4FC_J4FCSM'
#MLmodel_detail = '3C[512_128_64]_2BP4FC[1024_512_256_128]4FC[1024_512_256_128]_J4FC[1024_512_256_128]SM'
MLmodel_detail = '4C/512_256_128_64/_2BP4FC/1024_512_256_128/4FC/1024_512_256_128/_J4FC/1024_512_256_128/SM'

hyper_ver = 'hy0001010101_skpv0'    #hyper-parameter contains 5 groups: Conv1D, FC for Conv1D, BP0, BP1, FC for joined BPs
#dataFile_train = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
dataFile_train_folder = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'    #'20kDatax25'
dataFile_attack = '100kDatax1_test'#'skvp0_0_700points100kDatax1attack' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
model_input_type = '_in[[][]]_tf2' #'in[]_tf2' #'[[][]]_tf2'
#data_type = dataFile_train + model_input_type
data_type = '100kDataxN' + str(len(training_file_list)) + model_input_type
#database_folder_train = attackModel_dev_folder + attackModel + '_' + dataFile_train_folder + '_h5/'
save_path = '{}_{}_{}_{}_{}_{}_{}_{}_{}_{}_lr_{}_mlp_{}'.format(args.name ,args.train_type, args.sampling ,xType, args.start_trace, args.end_trace, args.num_sample, args.seed, args.normalize, args.cummulative, args.lr, args.mlp)
print(save_path)
#database_folder_train = os.path.join('trained_models', save_path)
database_folder_train = os.path.join('multi_attack_trained_models', save_path)
Path(database_folder_train).mkdir(parents=True, exist_ok=True)
database_folder_attack = attackModel_dev_folder + attackModel + '_' + dataFile_attack + '_h5/'
logFilename = MLmodelStruct + '_' + hyper_ver
DLmodel_name = logFilename
#DLmodel_folder = attackModel_dev_folder + logFilename + '_' + data_type + '/'
DLmodel_folder = 'models/'
modelLogFolder = DLmodel_folder + 'log' + DLmodel_name + '/'
logTrainedModel_byFile_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byDataFile/'
#logTrainedModel_byEp_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byEpoch/'
logTrainedModel_byEp_folder = logTrainedModel_byFile_folder
attackLogFolder = DLmodel_folder + 'log' + DLmodel_name + '_attack/'
if os.path.isdir(DLmodel_folder) == False:
    os.mkdir(DLmodel_folder)
if os.path.isdir(modelLogFolder) == False:
    os.mkdir(modelLogFolder)
if os.path.isdir(logTrainedModel_byFile_folder) == False:
    os.mkdir(logTrainedModel_byFile_folder)
if os.path.isdir(logTrainedModel_byEp_folder) == False:
    os.mkdir(logTrainedModel_byEp_folder)
print('DLmodel_folder =', DLmodel_folder)
print('modelLogFolder =', modelLogFolder)
print('logTrainedModel_byFile_folder =', logTrainedModel_byFile_folder)
print('logTrainedModel_byEp_folder =', logTrainedModel_byEp_folder)


################################################################################################
####################################### MODELS STRUCTURE #######################################
################################################################################################
# Input BatchNormalization for each PoI size
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_inputBNorms =   [   1,      0,      0,      0,      0,      0]
###################### MULTI CONVOLUTIONAL-SIZE CONVOLUTION ######################
# Convolutional nodes
# matrix showing number of nodes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_NoConvNodes =   [   512,    256,    128,    64,     0,      0]
# Convolutional filter sizes
# matrix showing filter sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convKernelSizes = [   3,      3,      3,      3,      0,      0] # subModel0

###############################################
# Pooling size in convolutional layers
# matrix showing MaxPooling sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolSizes = [   2,      2,      2,      2,      0,      0] # subModel0

# Pooling stride in convolutional layers
# matrix showing MaxPooling strike in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolStrides = [   3,      3,      3,      3,      0,      0] # subModel0

# BatchNormalization in convolutional layers
# matrix showing BatchNormalization condition in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convBNorms = [   1,      1,      1,      1,      0,      0] # subModel0
# Dropout in convolutional layers
# matrix showing Dropout value in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convDrops = [   0,      0,      0,      0,      0,      0] # subModel0

###################### MULTI CONVOLUTIONAL-SIZE FULLY-CONNECTED ######################
# Flatten Convolutional feature map before Fully connected
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_convFeatFlat = [    1,      0,      0,      0,      0,      0]
# Fully-connected for convolutional value before adding Plaintext
# matrix showing fully-connected condition before adding Plaintext
#                   layer0  layer1  layer2  layer3  layer4  layer5
subMods_FCs = [   1024,   512,    256,    128,    0,      0] # subModel0
# BatchNormalization for fully-connected of convolutional value before adding Plaintext
# matrix showing BatchNormalization for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_BNorms = [   1,      1,      1,      1,      0,      0] # subModel0

# Dropout for fully-connected of convolutional value before adding Plaintext
# matrix showing Dropout for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_Drops = [   0.2,    0,      0.2,    0,      0,      0] # subModel0

###################### MULTI_CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENSION ######################
# Plaintext adding here
if xType == 'wave':
    noBPbranch = 0
    #                       sub0    sub1    sub2    sub3    sub4    sub5
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp0':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0] # conv1D branch 0
elif xType == 'wavebp1':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01':
    noBPbranch = 2
    subMods_Pext =  [  1,      1,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01next0':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01next01':
    noBPbranch = 4
    subMods_Pext =  [  1,      1,      1,      1,      0,      0]  # conv1D branch 0


###################### (MULTI CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENDED) FULLY-CONNECTED ######################
# Fully-connected for convolutional value after adding Plaintext
# matrix showing fully-connected condition after adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FCs = [   [   1024,   1024,   512,    256,    128,    0], # subModel0
                        [   1024,   1024,   512,    256,    128,    0], # subModel1
                        [   1024,   1024,   512,    256,    128,    0], # subModel2
                        [   1024,   1024,   512,    256,    128,    0], # subModel3
                        [   0,      0,      0,      0,      0,      0], # subModel4
                        [   0,      0,      0,      0,      0,      0]]    # subModel5
# BatchNormalization for fully-connected of convolutional value after adding Plaintext
# matrix showing BatchNormalization for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_BNorms = [ [   1,      1,      1,      1,      1,      0], # subModel0
                            [   1,      1,      1,      1,      1,      0], # subModel1
                            [   1,      1,      1,      1,      1,      0], # subModel2
                            [   1,      1,      1,      1,      1,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]    # subModel5
# Dropout for fully-connected of convolutional value after adding Plaintext
# matrix showing Dropout for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_Drops = [  [   0.2,    0.2,    0,      0.1,    0,      0], # subModel0
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel1
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel2
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]    # subModel5

# Softmax for each sub-model if available
#                               subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_classification =    [  0,      0,      0,      0,      0,      0]

if xType == 'wave':
    subMods_join =  [   0]
else:
    subMods_join =  [   1]  

subMods_join_FCs =  [  1024,   1024,   512,    256,    128,    0]
# BatchNormalization for fully-connected of convolutional value after joining PoIs
# matrix showing BatchNormalization for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_BNorms =    [  1,      1,      1,      1,      1,      0]
# Dropout for fully-connected of convolutional value after joining PoIs
# matrix showing Dropout for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_Drops = [  0.2,        0.2,        0,      0.1,        0,      0]

# Softmax for joined-model if available
subMods_join_classification =   [   1]

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3, padding='same')
        self.pool1 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn1 = nn.BatchNorm1d(512)
        
        self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding='same')
        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.conv3 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding='same')
        self.pool3 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn3 = nn.BatchNorm1d(128)
        
        self.conv4 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, padding='same')
        self.pool4 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn4 = nn.BatchNorm1d(64)
        
        # Fully connected layers
        self.fc1 = nn.Linear(448, 1024)
        self.bn5 = nn.BatchNorm1d(1024)
        self.dropout1 = nn.Dropout(0.2)
        
        self.fc2 = nn.Linear(1024, 512)
        self.bn6 = nn.BatchNorm1d(512)
        
        self.fc3 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(0.2)
        
        self.fc4 = nn.Linear(256, 128)
        self.bn8 = nn.BatchNorm1d(128)
        
        self.fc5 = nn.Linear(128, 1024)
        self.bn9 = nn.BatchNorm1d(1024)
        self.dropout3 = nn.Dropout(0.2)
        
        self.fc6 = nn.Linear(1024, 1024)
        self.bn10 = nn.BatchNorm1d(1024)
        self.dropout4 = nn.Dropout(0.2)
        
        self.fc7 = nn.Linear(1024, 512)
        self.bn11 = nn.BatchNorm1d(512)
        
        self.fc8 = nn.Linear(512, 256)
        self.bn12 = nn.BatchNorm1d(256)
        self.dropout5 = nn.Dropout(0.2)
        
        self.fc9 = nn.Linear(256, 128)
        self.bn13 = nn.BatchNorm1d(128)
        
        self.fc10 = nn.Linear(128, 3329)
    
    def forward(self, x):
        # Add channel dimension
        x = x.unsqueeze(1)
        
        # Convolutional layers
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.bn1(x)
        
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.bn2(x)
        
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = self.bn3(x)
        
        x = F.relu(self.conv4(x))
        x = self.pool4(x)
        x = self.bn4(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.bn5(x)
        x = self.dropout1(x)
        
        x = F.relu(self.fc2(x))
        x = self.bn6(x)
        
        x = F.relu(self.fc3(x))
        x = self.bn7(x)
        x = self.dropout2(x)
        
        x = F.relu(self.fc4(x))
        x = self.bn8(x)
        
        x = F.relu(self.fc5(x))
        x = self.bn9(x)
        x = self.dropout3(x)
        
        x = F.relu(self.fc6(x))
        x = self.bn10(x)
        x = self.dropout4(x)
        
        x = F.relu(self.fc7(x))
        x = self.bn11(x)
        
        x = F.relu(self.fc8(x))
        x = self.bn12(x)
        x = self.dropout5(x)
        
        x = F.relu(self.fc9(x))
        x = self.bn13(x)
        
        x = self.fc10(x)
        
        return x

import time

# Training loop (example)
def train(args, save_folder, model, train_loader, test_loader, optimizer, criterion, epochs=10):
    start_time = time.time()

    model.train()
    losses = []
    for epoch in range(epochs):
        train_loss = []
        val_loss = []
        for batch_idx, (trace_data, target) in enumerate(train_loader):
            print(trace_data.shape)
            padding_mask = torch.ones(trace_data.shape)
            optimizer.zero_grad()
            target = target.long().to(device)
            trace_data = trace_data.to(device).unsqueeze(-1)
            output = model(trace_data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Train Loss: {loss.item()}')
            train_loss.append(loss.item())
        '''
        for batch_idx, (trace_data, target) in enumerate(test_loader):
            target = target.long().to(device)
            trace_data = trace_data.to(device)
            output = model(trace_data)
            loss = criterion(output, target)
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Val Loss: {loss.item()}')
            val_loss.append(loss.item())
        '''
        if epoch % args.eval_interval == 0:
            save_path = os.path.join(save_folder, 'model_{}.pt'.format(epoch))
            torch.save(model.state_dict(), save_path)

        losses.append({"Epoch": epoch + 1, "Train Loss": np.mean(train_loss), "Validation Loss": np.mean(val_loss)})

    df = pd.DataFrame(losses)
    df.to_csv(os.path.join(save_folder, "losses.csv"), index=False)
    print("---Training done in %s seconds ---" % (time.time() - start_time))

    return model

def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def load_meta_trace_files(data_path, start_trace, end_trace):
    data = np.load(data_path)
    trace_profiling = data['data']
    bp_profiling = data['bp']
    skpv_profiling = data['label']

    return (trace_profiling[start_trace:end_trace], bp_profiling[:,start_trace:end_trace], skpv_profiling[start_trace:end_trace])#, fqmul_profiling)

def random_sampling(data, num_sample):
    #print(len(data))
    #print(data.shape)
    np.random.seed(args.seed)
    if args.xType == 'wavebp01':
        data = data[0]
    rand_ids = np.random.choice(len(data), num_sample, replace=False)
    print(len(rand_ids))
    print('---')
    return rand_ids

from sklearn.preprocessing import StandardScaler
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = normalize(data[i])

    return data

def create_training_data_optimize(args, data_path, sKeyNo, trainPortion, xType, yType, is_test, start_trace, end_trace):
    val_num = 300000
    val_ids = np.load('val_ids.npz')['ids']
    print(len(val_ids))
    print(val_ids[:10])
    end_val_trace = end_trace + val_num
    if is_test:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file_from_test(data_path, sKeyNo)
    else:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(data_path, start_trace, end_val_trace)

    #train_dataset = tf.data.Dataset.from_tensor_slices((trace_profiling[:150000], tf.one_hot(skpv_profiling[:150000], 3329) ))
    #exit()

    #Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    print(trace_profiling.shape)


    dataSize = end_trace - start_trace
    trainSize = math.floor(dataSize * trainPortion)
    print(skpv_profiling[:10])
    
    #y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)
    y_train_skpv = skpv_profiling
    if args.normalize == 1:
        trace_profiling = normalize_data_per_trace(trace_profiling)
    xTrain = trace_profiling[:dataSize]
    print(xTrain[0][:10])

    if is_test:
        xVal = trace_profiling[dataSize:end_val_trace]
        yVal = y_train_skpv[dataSize:end_val_trace]
        yVal_value = skpv_profiling[dataSize:end_val_trace]
    else:
        xVal = trace_profiling[val_ids]
        yVal = y_train_skpv[val_ids]
        yVal_value = skpv_profiling[val_ids]

    yTrain = y_train_skpv[:dataSize]
    yTrain_value = skpv_profiling[:dataSize]
    

    #del trace_profiling, bp_profiling, skpv_profiling
    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value

# class to represent dataset
class SCADataset():
  
    def __init__(self, data):
        
        self.x = data[0].astype(np.float32)
        self.y = data[1]
        self.n_samples = data[0].shape[0] 
      
    # support indexing such that dataset[i] can 
    # be used to get i-th sample
    def __getitem__(self, index):
        return self.x[index], self.y[index]
        
    # we can call len(dataset) to return the size
    def __len__(self):
        return self.n_samples

#MAIN

xTrain_original, yTrain_original, xVal, yVal, yTrain_value, yVal_value = create_training_data_optimize(args,data_path,sKeyNo, trainPortion, xType, yType,False, args.start_trace, args.end_trace)
print(xTrain_original.shape)
print(yTrain_original.shape)
print(xVal.shape)
print(yVal.shape)

# Example usage
xType = 'wavebp'
noConv1Dbranch = 1
noBPbranch = 2
noLayers = 6
tracelen = 600
NumBPinput = 3330
classes = 3329
database_folder_train = os.path.join('multi_attack_trained_models', save_path)
Path(database_folder_train).mkdir(parents=True, exist_ok=True)
# Example usage:
#model = CNNModel()
model = TimesnetModel(args)
# Define optimizer and loss function
optimizer = optim.RMSprop(model.parameters(), lr=0.00001)
criterion = nn.CrossEntropyLoss()
# Example usage
trace_profiling, bp_profiling, skpv_profiling = load_meta_trace_files(data_path, args.start_trace, args.end_trace)
num_sample = args.num_sample
all_ids = random_sampling(xTrain_original, num_sample = args.num_sample)
np.save(os.path.join(database_folder_train,'all_ids.npy'), all_ids)
train_data = [xTrain_original[all_ids], yTrain_original[all_ids]]
test_data = [xVal, yVal]
SCAdataset = SCADataset(train_data)
SCAdataset_val = SCADataset(test_data)
train_loader = DataLoader(SCAdataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(SCAdataset_val, batch_size=args.batch_size, shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
print(device)
model = model.to(device)
model = train(args, database_folder_train, model, train_loader, val_loader, optimizer, criterion, epochs=args.num_epoch)