import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import os
from pathlib import Path
import matplotlib.pyplot as plt
import os.path
import sys
import h5py
import math
import gc
import numpy as np
import matplotlib.pyplot as plt
import argparse
from pathlib import Path
from sklearn.cluster import KMeans
import pandas as pd

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--train_type', type=str, help='baseline or active')
    parser.add_argument('--sampling', type=str, help='random or uncertainty')
    parser.add_argument('--xType', type=str, help='number of ciphertext')
    parser.add_argument('--start_trace', type=int, help='start trace')
    parser.add_argument('--end_trace', type=int)
    parser.add_argument('--batch_size', type=int, help='batch_size', default=256)
    parser.add_argument('--num_epoch', type=int, help='batch_size', default=256)
    parser.add_argument('--trained_model_path', type=str)
    parser.add_argument('--eval_path', type=str)
    parser.add_argument('--all_ids', type=str)
    parser.add_argument('--num_iteration', type=int, help='iteration_num', default=5)
    parser.add_argument('--num_sample', type=int, default=100)
    parser.add_argument('--schedule_iteration', type=int, nargs='+', help='when to train with sampling data')
    parser.add_argument('--resume_it', type=int)
    parser.add_argument('--eval_interval', type=int, help='iteration_num', default=1)
    parser.add_argument('--update_sampling', type=str, help='iteration_num')
    parser.add_argument('--subtrain_interval', type=int, help='iteration_num', default=1)
    parser.add_argument('--name', type=str, help='experiment name', default='test')
    parser.add_argument('--medoids_path', type=str)
    parser.add_argument('--num_trace', type=int, help='iteration_num', default=5)
    parser.add_argument('--start_key', type=int, help='iteration_num', default=0)
    parser.add_argument('--end_key', type=int, help='iteration_num', default=300)
    parser.add_argument('--sim_metric', type=str)
    parser.add_argument('--sampling_file', type=str)
    parser.add_argument('--update_type', type=str, default ='transfer', help = 'Choose transfer, muzv, ada, mmd')
    parser.add_argument('--union_type', type=str, default ='none')
    parser.add_argument('--loss_type', type=str, default ='none')
    parser.add_argument('--loss_alpha', type=float, default =0.7)
    parser.add_argument('--normalize', type=int, default = 0)
    parser.add_argument('--transformer', type=int, default = 0)
    parser.add_argument('--cummulative', type=int, default = 0)
    parser.add_argument('--seed', type=int, default = 2024)
    parser.add_argument('--add_num', type=int, default = 0)

    return parser

parser = parse_arguments()
args = parser.parse_args()

def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    #random.seed(seed)
    np.random.seed(seed)

set_seeds(args.seed)

bp_range = [0, 3329]
skpv_range = [0, 3328]
fqmul_range = [-1828, 1664]
tracelen = 600
NumFQMULclasses = fqmul_range[1] - fqmul_range[0] + 1;  # number of classes for fqmul(skpv, bp)
NumSKPVclasses = skpv_range[1] - skpv_range[0] + 1;     # number of classes for skpv
NumBPinput = bp_range[1] - bp_range[0] + 1;             # number of input for bp (ciphertext)
noClasses = NumSKPVclasses
noHypoKeys = NumSKPVclasses
sKeyNo = 0  # Note: sKeyNo is in range 0 to 3 and which subkeys are they are decided by code in m4 (NOT by code in PC)
work = 'train' #'train'  #'attack'
training_file_list = ['Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data000000to099999_600samples.h5',\
'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data100000to199999_600samples.h5']
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data200000to299999_600samples.h5']#,\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data300000to399999_600samples.h5',\
#'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1_100kDatax5_h5_data400000to499999_600samples.h5']

trained_model_path = args.trained_model_path
data_path = 'data.npz'
nruns_default = 10
maxtrc_default = 200
testPortion = 1
attack_byModel_epNo = 232


# training parameters
train_batch_size = args.batch_size#100#150#200#250#500#640 #80 for mars45 #170 for mars56
period = 8 #8
maxEpochs = args.num_epoch#3072#2048#1536#1280#1024#512#256 #1536
attack_byModel_fileNo = int(attack_byModel_epNo/period)
N_TRACE = 20000
Threshold_Save = 200

#model hyper-parameters
noConv1Dbranch = 1
noLayers = 6    # if newly train
noClassificationLayer = 1
GPU_clear = True    # False

# training data type
xType = args.xType  #'wave' #'wavebp0' #'wavebp1' #'wavebp01' #'wavebp01next0' #'wavebp01next01'
yType = 'skpv'    #'fqmul0' #'fqmul1' #'skpv' 
trainPortion = 0.8

# Database and logs for model and training progress (epochs)
attackModel = 'Kyber512_indcpa_dec_poly_frombytes_mul_skpv0_1_bp0_1'
device = 'm4_CWLite'
attackModel_dev = attackModel + '_' + device
attackModel_dev_folder = '../' + attackModel_dev + '/'

MLmodelStruct = '4C4FC_2BP4FC4FC_J4FCSM'
#MLmodel_detail = '3C[512_128_64]_2BP4FC[1024_512_256_128]4FC[1024_512_256_128]_J4FC[1024_512_256_128]SM'
MLmodel_detail = '4C/512_256_128_64/_2BP4FC/1024_512_256_128/4FC/1024_512_256_128/_J4FC/1024_512_256_128/SM'

hyper_ver = 'hy0001010101_skpv0'    #hyper-parameter contains 5 groups: Conv1D, FC for Conv1D, BP0, BP1, FC for joined BPs
#dataFile_train = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
dataFile_train_folder = '100kDatax5_train'#'skvp0_0_700points100kDatax5train' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'    #'20kDatax25'
dataFile_attack = '100kDatax1_test'#'skvp0_0_700points100kDatax1attack' #'skvp0_0_100kDatax5' #'skvp0_0_100kDatax1_attack'  #'20kDatax25'
model_input_type = '_in[[][]]_tf2' #'in[]_tf2' #'[[][]]_tf2'
#data_type = dataFile_train + model_input_type
data_type = '100kDataxN' + str(len(training_file_list)) + model_input_type
#database_folder_train = attackModel_dev_folder + attackModel + '_' + dataFile_train_folder + '_h5/'
save_path = 'Phase_3_{}_{}_{}_{}_{}_{}_{}_{}_key_{}_{}_trace_{}_{}_{}_loss_{}'.format(args.name ,args.train_type, xType, args.loss_type , args.start_trace, args.end_trace, args.update_sampling, args.num_sample, args.start_key, args.end_key, args.num_trace, args.union_type, args.update_type, args.loss_alpha)
print(save_path)
#database_folder_train = os.path.join('trained_models', save_path)
database_folder_train = os.path.join('multi_attack_trained_models', save_path)
Path(database_folder_train).mkdir(parents=True, exist_ok=True)
database_folder_attack = attackModel_dev_folder + attackModel + '_' + dataFile_attack + '_h5/'
logFilename = MLmodelStruct + '_' + hyper_ver
DLmodel_name = logFilename
#DLmodel_folder = attackModel_dev_folder + logFilename + '_' + data_type + '/'
DLmodel_folder = 'models/'
modelLogFolder = DLmodel_folder + 'log' + DLmodel_name + '/'
logTrainedModel_byFile_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byDataFile/'
#logTrainedModel_byEp_folder = DLmodel_folder + 'trained' + DLmodel_name + '_byEpoch/'
logTrainedModel_byEp_folder = logTrainedModel_byFile_folder
attackLogFolder = DLmodel_folder + 'log' + DLmodel_name + '_attack/'
if os.path.isdir(DLmodel_folder) == False:
    os.mkdir(DLmodel_folder)
if os.path.isdir(modelLogFolder) == False:
    os.mkdir(modelLogFolder)
if os.path.isdir(logTrainedModel_byFile_folder) == False:
    os.mkdir(logTrainedModel_byFile_folder)
if os.path.isdir(logTrainedModel_byEp_folder) == False:
    os.mkdir(logTrainedModel_byEp_folder)
print('DLmodel_folder =', DLmodel_folder)
print('modelLogFolder =', modelLogFolder)
print('logTrainedModel_byFile_folder =', logTrainedModel_byFile_folder)
print('logTrainedModel_byEp_folder =', logTrainedModel_byEp_folder)


################################################################################################
####################################### MODELS STRUCTURE #######################################
################################################################################################
# Input BatchNormalization for each PoI size
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_inputBNorms =   [   1,      0,      0,      0,      0,      0]
###################### MULTI CONVOLUTIONAL-SIZE CONVOLUTION ######################
# Convolutional nodes
# matrix showing number of nodes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_NoConvNodes =   [   512,    256,    128,    64,     0,      0]
# Convolutional filter sizes
# matrix showing filter sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convKernelSizes = [   3,      3,      3,      3,      0,      0] # subModel0

###############################################
# Pooling size in convolutional layers
# matrix showing MaxPooling sizes in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolSizes = [   2,      2,      2,      2,      0,      0] # subModel0

# Pooling stride in convolutional layers
# matrix showing MaxPooling strike in each convolutional layer in each PoI length
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_convPoolStrides = [   3,      3,      3,      3,      0,      0] # subModel0

# BatchNormalization in convolutional layers
# matrix showing BatchNormalization condition in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convBNorms = [   1,      1,      1,      1,      0,      0] # subModel0
# Dropout in convolutional layers
# matrix showing Dropout value in each convolutional layer in each PoI length
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_convDrops = [   0,      0,      0,      0,      0,      0] # subModel0

###################### MULTI CONVOLUTIONAL-SIZE FULLY-CONNECTED ######################
# Flatten Convolutional feature map before Fully connected
#                           subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_convFeatFlat = [    1,      0,      0,      0,      0,      0]
# Fully-connected for convolutional value before adding Plaintext
# matrix showing fully-connected condition before adding Plaintext
#                   layer0  layer1  layer2  layer3  layer4  layer5
subMods_FCs = [   1024,   512,    256,    128,    0,      0] # subModel0
# BatchNormalization for fully-connected of convolutional value before adding Plaintext
# matrix showing BatchNormalization for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_BNorms = [   1,      1,      1,      1,      0,      0] # subModel0

# Dropout for fully-connected of convolutional value before adding Plaintext
# matrix showing Dropout for fully-connected condition before adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_FC_Drops = [   0.2,    0,      0.2,    0,      0,      0] # subModel0

###################### MULTI_CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENSION ######################
# Plaintext adding here
if xType == 'wave':
    noBPbranch = 0
    #                       sub0    sub1    sub2    sub3    sub4    sub5
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp0':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0] # conv1D branch 0
elif xType == 'wavebp1':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01':
    noBPbranch = 2
    subMods_Pext =  [  1,      1,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01next0':
    subMods_Pext =  [  0,      0,      0,      0,      0,      0]  # conv1D branch 0
elif xType == 'wavebp01next01':
    noBPbranch = 4
    subMods_Pext =  [  1,      1,      1,      1,      0,      0]  # conv1D branch 0


###################### (MULTI CONVOLUTIONAL-SIZE + PLAINTEXT-EXTENDED) FULLY-CONNECTED ######################
# Fully-connected for convolutional value after adding Plaintext
# matrix showing fully-connected condition after adding Plaintext
#                       layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FCs = [   [   1024,   1024,   512,    256,    128,    0], # subModel0
                        [   1024,   1024,   512,    256,    128,    0], # subModel1
                        [   1024,   1024,   512,    256,    128,    0], # subModel2
                        [   1024,   1024,   512,    256,    128,    0], # subModel3
                        [   0,      0,      0,      0,      0,      0], # subModel4
                        [   0,      0,      0,      0,      0,      0]]    # subModel5
# BatchNormalization for fully-connected of convolutional value after adding Plaintext
# matrix showing BatchNormalization for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_BNorms = [ [   1,      1,      1,      1,      1,      0], # subModel0
                            [   1,      1,      1,      1,      1,      0], # subModel1
                            [   1,      1,      1,      1,      1,      0], # subModel2
                            [   1,      1,      1,      1,      1,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]    # subModel5
# Dropout for fully-connected of convolutional value after adding Plaintext
# matrix showing Dropout for fully-connected condition after adding Plaintext
#                               layer0  layer1  layer2  layer3  layer4  layer5
subMods_Pext_FC_Drops = [  [   0.2,    0.2,    0,      0.1,    0,      0], # subModel0
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel1
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel2
                            [   0.2,    0.2,    0,      0.1,    0,      0], # subModel3
                            [   0,      0,      0,      0,      0,      0], # subModel4
                            [   0,      0,      0,      0,      0,      0]]    # subModel5

# Softmax for each sub-model if available
#                               subMod0 subMod1 subMod2 subMod3 subMod4 subMod5
subMods_classification =    [  0,      0,      0,      0,      0,      0]

if xType == 'wave':
    subMods_join =  [   0]
else:
    subMods_join =  [   1]  

subMods_join_FCs =  [  1024,   1024,   512,    256,    128,    0]
# BatchNormalization for fully-connected of convolutional value after joining PoIs
# matrix showing BatchNormalization for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_BNorms =    [  1,      1,      1,      1,      1,      0]
# Dropout for fully-connected of convolutional value after joining PoIs
# matrix showing Dropout for fully-connected condition after joining PoIs
#                           layer0  layer1  layer2  layer3  layer4  layer5
subMods_join_FC_Drops = [  0.2,        0.2,        0,      0.1,        0,      0]

# Softmax for joined-model if available
subMods_join_classification =   [   1]

class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3, padding='same')
        self.pool1 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn1 = nn.BatchNorm1d(512)
        
        self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding='same')
        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.conv3 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding='same')
        self.pool3 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn3 = nn.BatchNorm1d(128)
        
        self.conv4 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, padding='same')
        self.pool4 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn4 = nn.BatchNorm1d(64)
        
        # Fully connected layers
        self.fc1 = nn.Linear(448, 1024)
        self.bn5 = nn.BatchNorm1d(1024)
        self.dropout1 = nn.Dropout(0.2)
        
        self.fc2 = nn.Linear(1024, 512)
        self.bn6 = nn.BatchNorm1d(512)
        
        self.fc3 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(0.2)
        
        self.fc4 = nn.Linear(256, 128)
        self.bn8 = nn.BatchNorm1d(128)
        
        self.fc5 = nn.Linear(128, 1024)
        self.bn9 = nn.BatchNorm1d(1024)
        self.dropout3 = nn.Dropout(0.2)
        
        self.fc6 = nn.Linear(1024, 1024)
        self.bn10 = nn.BatchNorm1d(1024)
        self.dropout4 = nn.Dropout(0.2)
        
        self.fc7 = nn.Linear(1024, 512)
        self.bn11 = nn.BatchNorm1d(512)
        
        self.fc8 = nn.Linear(512, 256)
        self.bn12 = nn.BatchNorm1d(256)
        self.dropout5 = nn.Dropout(0.2)
        
        self.fc9 = nn.Linear(256, 128)
        self.bn13 = nn.BatchNorm1d(128)
        
        self.fc10 = nn.Linear(128, 3329)
    
    def forward(self, x):
        # Add channel dimension
        x = x.unsqueeze(1)
        
        # Convolutional layers
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.bn1(x)
        
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.bn2(x)
        
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = self.bn3(x)
        
        x = F.relu(self.conv4(x))
        x = self.pool4(x)
        x = self.bn4(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.bn5(x)
        x = self.dropout1(x)
        
        x = F.relu(self.fc2(x))
        x = self.bn6(x)
        
        x = F.relu(self.fc3(x))
        x = self.bn7(x)
        x = self.dropout2(x)
        
        x = F.relu(self.fc4(x))
        x = self.bn8(x)
        
        x = F.relu(self.fc5(x))
        x = self.bn9(x)
        x = self.dropout3(x)
        
        x = F.relu(self.fc6(x))
        x = self.bn10(x)
        x = self.dropout4(x)
        
        x = F.relu(self.fc7(x))
        x = self.bn11(x)
        
        x = F.relu(self.fc8(x))
        x = self.bn12(x)
        x = self.dropout5(x)
        
        x = F.relu(self.fc9(x))
        x = self.bn13(x)
        
        x = self.fc10(x)
        
        return x

#MMD Loss

class RBF(nn.Module):

    def __init__(self, n_kernels=5, mul_factor=2.0, bandwidth=None):
        super().__init__()
        self.bandwidth_multipliers = mul_factor ** (torch.arange(n_kernels) - n_kernels // 2)
        self.bandwidth_multipliers = self.bandwidth_multipliers.to('cuda')
        self.bandwidth = bandwidth

    def get_bandwidth(self, L2_distances):
        if self.bandwidth is None:
            n_samples = L2_distances.shape[0]
            return L2_distances.data.sum() / (n_samples ** 2 - n_samples)

        return self.bandwidth

    def forward(self, X):
        L2_distances = torch.cdist(X, X) ** 2
       
        return torch.exp(-L2_distances[None, ...] / (self.get_bandwidth(L2_distances) * self.bandwidth_multipliers)[:, None, None]).sum(dim=0)


class MMDLoss(nn.Module):

    def __init__(self, kernel=RBF()):
        super().__init__()
        self.kernel = kernel

    def forward(self, X, Y):
        K = self.kernel(torch.vstack([X, Y]))

        X_size = X.shape[0]
        XX = K[:X_size, :X_size].mean()
        XY = K[:X_size, X_size:].mean()
        YY = K[X_size:, X_size:].mean()
        return XX - 2 * XY + YY

# Training loop (example)
def train_transfer(args, save_folder, model, train_loader, test_loader, optimizer, criterion, epochs=10):
    start_time = time.time()
    # Assume `model` is your pre-trained model
    for param in model.parameters():
        param.requires_grad = False  # Freeze all layers

    # Unfreeze the last layer
    for param in model.fc10.parameters():  # Change 'fc' to the last layer name of your model
        param.requires_grad = True  # Unfreeze last layer
    model.train()
    losses = []
    for epoch in range(epochs):
        train_loss = []
        val_loss = []
        for batch_idx, (trace_data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            target = target.long().to(device)
            trace_data = trace_data.to(device)
            output = model(trace_data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Train Loss: {loss.item()}')
            train_loss.append(loss.item())

        for batch_idx, (trace_data, target) in enumerate(test_loader):
            target = target.long().to(device)
            trace_data = trace_data.to(device)
            output = model(trace_data)
            loss = criterion(output, target)
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Val Loss: {loss.item()}')
            val_loss.append(loss.item())

        if epoch % args.eval_interval == 0:
            save_path = os.path.join(save_folder, 'model_{}.pt'.format(epoch))
            torch.save(model.state_dict(), save_path)

        losses.append({"Epoch": epoch + 1, "Train Loss": np.mean(train_loss), "Validation Loss": np.mean(val_loss)})
    save_path = os.path.join(save_folder, 'model_end.pt'.format(epoch))
    torch.save(model.state_dict(), save_path)

    df = pd.DataFrame(losses)
    df.to_csv(os.path.join(save_folder, "losses.csv"), index=False)
    print("---Training done in %s seconds ---" % (time.time() - start_time))

    return model

def train_mmd(args, save_folder, model, train_loader, target_loader, test_loader, optimizer, criterion, mmd_loss, lambda_, epochs=10):
    start_time = time.time()
    save_path = os.path.join(save_folder, 'model_start.pt')
    torch.save(model.state_dict(), save_path)
    model.train()
    losses = []
    mmd_loss.to(device)
    for epoch in range(epochs):
        train_loss = []
        val_loss = []
        for batch_idx, (trace_data, target) in enumerate(train_loader):
            #Get trace from target device
            target_data, _ = next(iter(target_loader))
            target_data = target_data.to(device)
            optimizer.zero_grad()
            target = target.long().to(device)
            trace_data = trace_data.to(device)
            output = model(trace_data)
            output_target = model(target_data)
            #print(output.shape)
            #print(target.shape)
            #Domain adaptation loss based on MMD
            loss = criterion(output, target) + lambda_*mmd_loss(output, output_target)
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Train Loss: {loss.item()}')
            train_loss.append(loss.item())

        '''
        for batch_idx, (trace_data, target) in enumerate(test_loader):
            target = target.long().to(device)
            trace_data = trace_data.to(device)
            output = model(trace_data)
            loss = criterion(output, target)
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Val Loss: {loss.item()}')
            val_loss.append(loss.item())

        #if epoch % args.eval_interval == 0:
            #save_path = os.path.join(save_folder, 'model_{}.pt'.format(epoch))
            #torch.save(model.state_dict(), save_path)
        '''
        #losses.append({"Epoch": epoch + 1, "Train Loss": np.mean(train_loss), "Validation Loss": np.mean(val_loss)})
    save_path = os.path.join(save_folder, 'model_end.pt')
    torch.save(model.state_dict(), save_path)

    #df = pd.DataFrame(losses)
    #df.to_csv(os.path.join(save_folder, "losses.csv"), index=False)
    print("---Training done in %s seconds ---" % (time.time() - start_time))

    return model

def train_otf(args, save_folder, model, train_loader, test_loader, target_loader, optimizer, criterion, mmd_loss, lambda_, epochs=10):
    start_time = time.time()
    save_path = os.path.join(save_folder, 'model_start.pt')
    torch.save(model.state_dict(), save_path)
    model.train()
    losses = []
    mmd_loss.to(device)
    for epoch in range(epochs):
        train_loss = []
        val_loss = []
        for batch_idx, (trace_data, target) in enumerate(target_loader):
            optimizer.zero_grad()
            target = target.long().to(device)
            print(target)
            print(target.shape)
            exit()
            trace_data = trace_data.to(device)
            output = model(trace_data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Train Loss: {loss.item()}')
            train_loss.append(loss.item())

        '''
        for batch_idx, (trace_data, target) in enumerate(test_loader):
            target = target.long().to(device)
            trace_data = trace_data.to(device)
            output = model(trace_data)
            loss = criterion(output, target)
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Val Loss: {loss.item()}')
            val_loss.append(loss.item())

        #if epoch % args.eval_interval == 0:
            #save_path = os.path.join(save_folder, 'model_{}.pt'.format(epoch))
            #torch.save(model.state_dict(), save_path)
        '''
        #losses.append({"Epoch": epoch + 1, "Train Loss": np.mean(train_loss), "Validation Loss": np.mean(val_loss)})
    save_path = os.path.join(save_folder, 'model_end.pt')
    torch.save(model.state_dict(), save_path)

    #df = pd.DataFrame(losses)
    #df.to_csv(os.path.join(save_folder, "losses.csv"), index=False)
    print("---Training done in %s seconds ---" % (time.time() - start_time))

    return model

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd import Variable
import math
import pdb


pad_size = 5


def init_weights_normal(m):
    inition_func = torch.nn.init.xavier_normal
    if isinstance(m, nn.Linear):
        inition_func(m.weight)
    if isinstance(m, nn.Conv1d):
        inition_func(m.weight)


def init_weights_uniform(m):
    inition_func = torch.nn.init.xavier_uniform
    if isinstance(m, nn.Linear):
        inition_func(m.weight)
        #m.bias.data.fill_(0.01)
    if isinstance(m, nn.Conv1d):
        inition_func(m.weight)
        #m.bias.data.fill_(0.01)


class ReverseLayerF(Function):
    ''' Reverse layer functions '''
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = alpha
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        output = grad_output.neg() * ctx.alpha
        return output, None

class RevGrad(nn.Module):
    def __init__(self):
        super(RevGrad, self).__init__()
        act_func = nn.ReLU
        # Convolutional layers
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=512, kernel_size=3, padding='same')
        self.pool1 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn1 = nn.BatchNorm1d(512)
        
        self.conv2 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding='same')
        self.pool2 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn2 = nn.BatchNorm1d(256)
        
        self.conv3 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding='same')
        self.pool3 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn3 = nn.BatchNorm1d(128)
        
        self.conv4 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, padding='same')
        self.pool4 = nn.MaxPool1d(kernel_size=3, stride=3)
        self.bn4 = nn.BatchNorm1d(64)
        
        # Fully connected layers
        self.fc1 = nn.Linear(448, 1024)
        self.bn5 = nn.BatchNorm1d(1024)
        self.dropout1 = nn.Dropout(0.2)
        
        self.fc2 = nn.Linear(1024, 512)
        self.bn6 = nn.BatchNorm1d(512)
        
        self.fc3 = nn.Linear(512, 256)
        self.bn7 = nn.BatchNorm1d(256)
        self.dropout2 = nn.Dropout(0.2)
        
        self.fc4 = nn.Linear(256, 128)
        self.bn8 = nn.BatchNorm1d(128)
        
        self.fc5 = nn.Linear(128, 1024)
        self.bn9 = nn.BatchNorm1d(1024)
        self.dropout3 = nn.Dropout(0.2)
        
        self.fc6 = nn.Linear(1024, 1024)
        self.bn10 = nn.BatchNorm1d(1024)
        self.dropout4 = nn.Dropout(0.2)
        
        self.fc7 = nn.Linear(1024, 512)
        self.bn11 = nn.BatchNorm1d(512)
        
        self.fc8 = nn.Linear(512, 256)
        self.bn12 = nn.BatchNorm1d(256)
        self.dropout5 = nn.Dropout(0.2)
        
        self.fc9 = nn.Linear(256, 128)
        self.bn13 = nn.BatchNorm1d(128)
        
        self.fc10 = nn.Linear(128, 3329)


        # source domain classifier block
        self.class_classifier = nn.Sequential()
        self.class_classifier.add_module('c_fc1', nn.Linear(256, 128))
        self.class_classifier.add_module('c_relu', act_func())
        self.class_classifier.add_module('c_out', nn.Linear(128, 3329))

        # domain discriminator block
        self.domain_classifier = nn.Sequential()
        self.domain_classifier.add_module('d_fc1', nn.Linear(256, 128))
        self.domain_classifier.add_module('d_relu1', act_func())
        self.domain_classifier.add_module('d_fc2', nn.Linear(128, 3329))


        # Weight initialization
        self.class_classifier.apply(init_weights_normal)
        self.domain_classifier.apply(init_weights_normal)
    
    def forward(self, x, alpha):
        # Add channel di,mension
        x = x.unsqueeze(1)
        
        # Convolutional layers
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.bn1(x)
        
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.bn2(x)
        
        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = self.bn3(x)
        
        x = F.relu(self.conv4(x))
        x = self.pool4(x)
        x = self.bn4(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = F.relu(self.fc1(x))
        x = self.bn5(x)
        x = self.dropout1(x)
        
        x = F.relu(self.fc2(x))
        x = self.bn6(x)
        
        x = F.relu(self.fc3(x))
        x = self.bn7(x)
        x = self.dropout2(x)
        
        x = F.relu(self.fc4(x))
        x = self.bn8(x)
        
        x = F.relu(self.fc5(x))
        x = self.bn9(x)
        x = self.dropout3(x)
        
        x = F.relu(self.fc6(x))
        x = self.bn10(x)
        x = self.dropout4(x)
        
        x = F.relu(self.fc7(x))
        x = self.bn11(x)
        
        x = F.relu(self.fc8(x))
        x = self.bn12(x)
        x = self.dropout5(x)
        
        #x = F.relu(self.fc9(x))
        #x = self.bn13(x)
        
        #x = self.fc10(x)
        
        reverse_feature = ReverseLayerF.apply(x, alpha)
        class_output = self.class_classifier(x)
        domain_output = self.domain_classifier(reverse_feature)

        return x, class_output, domain_output

def train_ada(args, save_folder, model, train_loader, target_loader, test_loader, optimizer, criterion, mmd_loss, lambda_, epochs=10):
    start_time = time.time()
    save_path = os.path.join(save_folder, 'model_start.pt')
    torch.save(model.state_dict(), save_path)
    model.train()
    losses = []
    loss_class = torch.nn.CrossEntropyLoss()
    loss_domain = torch.nn.CrossEntropyLoss()

    len_source_loader = len(train_loader)
    len_target_loader = len(target_loader)
    len_source_dataset = len(train_loader.dataset)

    

    for epoch in range(epochs):
        train_loss = []
        val_loss = []
        for batch_idx, (trace_data, target) in enumerate(train_loader):
            #print(trace_data.shape[0])
            #exit()
            
            #Get trace from target device
            # the parameter for reversing gradients
            p = float(batch_idx + epoch * len_source_loader) / epochs / len_source_loader
            alpha = 2. / (1. + np.exp(-10 * p)) - 1
            target_data, _ = next(iter(target_loader))
            target_data = target_data.to(device)
            optimizer.zero_grad()
            target = target.long().to(device)
            trace_data = trace_data.to(device)
            #output = model(trace_data)
            dlabel_src = Variable(torch.ones(trace_data.shape[0]).long().cuda())
            dlabel_tgt = Variable(torch.zeros(target_data.shape[0]).long().cuda())

            #Source
            _, clabel_src, dlabel_pred_src = model(trace_data, alpha=alpha)
            label_loss = loss_class(clabel_src, target)
            domain_loss_src = loss_domain(dlabel_pred_src, dlabel_src)

            #target
            _, clabel_tgt, dlabel_pred_tgt = model(target_data, alpha=alpha)
            domain_loss_tgt = loss_domain(dlabel_pred_tgt, dlabel_tgt)

            domain_loss_total = domain_loss_src + domain_loss_tgt
            loss_total = label_loss + domain_loss_total

            #print(output.shape)
            #print(target.shape)
            #Domain adaptation loss based on MMD
            loss_total.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Train Loss: {loss_total.item()}')
            train_loss.append(loss_total.item())

        save_path = os.path.join(save_folder, 'model_end.pt')
    torch.save(model.state_dict(), save_path)

    #df = pd.DataFrame(losses)
    #df.to_csv(os.path.join(save_folder, "losses.csv"), index=False)
    print("---Training done in %s seconds ---" % (time.time() - start_time))

    return model

def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def load_meta_trace_files(data_path, start_trace, end_trace):
    data = np.load(data_path)
    trace_profiling = data['data']
    bp_profiling = data['bp']
    skpv_profiling = data['label']

    return (trace_profiling[start_trace:end_trace], bp_profiling[:,start_trace:end_trace], skpv_profiling[start_trace:end_trace])#, fqmul_profiling)

from sklearn.preprocessing import StandardScaler
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = normalize(data[i])

    return data

def create_training_data_optimize(args, data_path, sKeyNo, trainPortion, xType, yType, is_test, start_trace, end_trace):
    val_num = 300000
    val_ids = np.load('val_ids.npz')['ids']
    print(len(val_ids))
    print(val_ids[:10])
    end_val_trace = end_trace + val_num
    if is_test:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_file_from_test(data_path, sKeyNo)
    else:
        (trace_profiling, bp_profiling, skpv_profiling) = load_meta_trace_files(data_path, start_trace, end_val_trace)

    #train_dataset = tf.data.Dataset.from_tensor_slices((trace_profiling[:150000], tf.one_hot(skpv_profiling[:150000], 3329) ))
    #exit()

    #Reshaped_trace_profiling= trace_profiling.reshape((trace_profiling.shape[0], trace_profiling.shape[1], 1))
    print(trace_profiling.shape)


    dataSize = end_trace - start_trace
    trainSize = math.floor(dataSize * trainPortion)
    print(skpv_profiling[:10])
    
    #y_train_skpv = to_categorical(skpv_profiling, num_classes=NumSKPVclasses)
    y_train_skpv = skpv_profiling
    if args.normalize == 1:
        trace_profiling = normalize_data_per_trace(trace_profiling)
    xTrain = trace_profiling[:dataSize]
    print(xTrain[0][:10])

    if is_test:
        xVal = trace_profiling[dataSize:end_val_trace]
        yVal = y_train_skpv[dataSize:end_val_trace]
        yVal_value = skpv_profiling[dataSize:end_val_trace]
    else:
        xVal = trace_profiling[val_ids]
        yVal = y_train_skpv[val_ids]
        yVal_value = skpv_profiling[val_ids]

    yTrain = y_train_skpv[:dataSize]
    yTrain_value = skpv_profiling[:dataSize]
    

    #del trace_profiling, bp_profiling, skpv_profiling
    return xTrain, yTrain, xVal, yVal, yTrain_value, yVal_value

# class to represent dataset
class SCADataset():
  
    def __init__(self, data):
        
        self.x = data[0].astype(np.float32)
        self.y = data[1]
        self.n_samples = data[0].shape[0] 
      
    # support indexing such that dataset[i] can 
    # be used to get i-th sample
    def __getitem__(self, index):
        return self.x[index], self.y[index]
        
    # we can call len(dataset) to return the size
    def __len__(self):
        return self.n_samples

def random_sampling(data, num_sample):
    #print(len(data))
    #print(data.shape)
    np.random.seed(args.seed)
    if args.xType == 'wavebp01':
        data = data[0]
    rand_ids = np.random.choice(len(data), num_sample, replace=False)
    print(len(rand_ids))
    print('---')
    return rand_ids


import time
#MAIN

xTrain_original, yTrain_original, xVal, yVal, yTrain_value, yVal_value = create_training_data_optimize(args,data_path,sKeyNo, trainPortion, xType, yType,False, args.start_trace, args.end_trace)
print(xTrain_original.shape)
print(yTrain_original.shape)
print(xVal.shape)
print(yVal.shape)

# Example usage
xType = 'wavebp'
noConv1Dbranch = 1
noBPbranch = 2
noLayers = 6
tracelen = 600
NumBPinput = 3330
classes = 3329
database_folder_train = os.path.join('multi_attack_trained_models', save_path)
Path(database_folder_train).mkdir(parents=True, exist_ok=True)
# Define optimizer and loss function
criterion = nn.CrossEntropyLoss()
lambda_ = 0.1 #Penalty coeficicent for MMD
mmd_loss = MMDLoss()
# Example usage
trace_profiling, bp_profiling, skpv_profiling = load_meta_trace_files(data_path, args.start_trace, args.end_trace)
num_sample = args.num_sample
all_ids = np.load(args.all_ids)
train_data = [xTrain_original[all_ids], yTrain_original[all_ids]]
print('Ids length', len(all_ids))
test_data = [xVal, yVal]
SCAdataset = SCADataset(train_data)
SCAdataset_val = SCADataset(test_data)
train_loader = DataLoader(SCAdataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(SCAdataset_val, batch_size=args.batch_size, shuffle=False)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

xTest_multi, yTest_multi = load_multi_attack(args.eval_path)
xTest_multi, yTest_multi = xTest_multi[args.start_key:args.end_key,:args.num_trace,:], yTest_multi[args.start_key:args.end_key]
xTest_multi = normalize_data_per_trace(xTest_multi)
#print(xTest_multi.shape, yTest_multi.shape)
#exit()
#Training with different keys
for it in range(len(xTest_multi)):
    print('Iteration: ', it)    #Load model
    all_active_ids = np.load(args.all_ids) #Load existing samples
    target_labels = np.ones(xTest_multi.shape[1])*yTest_multi[it]
    target_data = [xTest_multi[it,:,:], target_labels]
    SCA_target_dataset = SCADataset(target_data)
    target_loader = DataLoader(SCA_target_dataset, batch_size=4, shuffle=False)
    model = CNNModel()
    optimizer = optim.RMSprop(model.parameters(), lr=0.00001)
    model.load_state_dict(torch.load(args.trained_model_path, weights_only=True))
    model = model.to(device)
    database_folder_train_it = os.path.join(database_folder_train, 'it_'+str(it))
    Path(database_folder_train_it).mkdir(parents=True, exist_ok=True)
    save_model_name = (database_folder_train_it+'/model_best')
    save_ids_name = database_folder_train_it+'/all_ids.npy'
    if args.update_type == 'transfer':
        disjoint_sampled_ids = np.load(args.sampling_file)
        disjoint_sampled_ids = disjoint_sampled_ids[args.start_key:args.end_key,:]
        print(disjoint_sampled_ids.shape)
        sampled_ids = disjoint_sampled_ids[it].astype(int)
        train_data = [xTrain_original[sampled_ids], yTrain_original[sampled_ids]]
        SCA_target_dataset = SCADataset(train_data)
        target_loader = DataLoader(SCA_target_dataset, batch_size=args.batch_size, shuffle=True)
        model = train_transfer(args, database_folder_train_it, model, target_loader, target_loader, optimizer, criterion, epochs=args.num_epoch)
    elif args.update_type == 'otf':
        model = train_otf(args, database_folder_train_it, model, train_loader, val_loader, target_loader, optimizer, criterion, mmd_loss, lambda_, epochs=args.num_epoch)
    elif args.update_type == 'ada':
        new_model = RevGrad()
        new_model.to(device)
        # this works, but could be dangerous, if you are not careful
        new_model.load_state_dict(model.state_dict(), strict=False)
        #named_layers = dict(model.named_modules())
        #print(named_layers)
        #print('------------------')
        #newnamed_layers = dict(new_model.named_modules())
        #print(newnamed_layers)
        #Check
        #print((new_model.conv1.weight == model.conv1.weight).all())
        model = train_ada(args, database_folder_train_it, new_model, train_loader, val_loader, target_loader, optimizer, criterion, mmd_loss, lambda_, epochs=args.num_epoch)
    elif args.update_type == 'mmd':
        model = train_mmd(args, database_folder_train_it, model, train_loader, val_loader, target_loader, optimizer, criterion, mmd_loss, lambda_, epochs=args.num_epoch)
    else:
        print('WRONG TYPE')
    np.save(save_ids_name, all_active_ids)
