import os
import os.path
import sys
import h5py
import numpy as np
import matplotlib.pyplot as plt
import ast

import os.path
import sys
import h5py
import math
import gc
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)

set_seeds(2025)

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--batch_size', type=int, help='batch_size', default=256)
    parser.add_argument('--num_epoch', type=int, help='batch_size', default=256)
    parser.add_argument('--num_sample', type=int, help='batch_size', default=256)
    parser.add_argument('--eval_interval', type=int, help='batch_size', default=10)
    parser.add_argument('--sampling', type=str, default='None')
    parser.add_argument('--name', type=str, help='experiment name', default='test')

    return parser   

def check_file_exists(file_path):
    file_path = os.path.normpath(file_path)
    if os.path.exists(file_path) == False:
        print("Error: provided file path '%s' does not exist!" % file_path)
        sys.exit(-1)
    return

def load_ascad(ascad_database_file, load_metadata=False):
    check_file_exists(ascad_database_file)
    # Open the ASCAD database HDF5 for reading
    try:
        in_file  = h5py.File(ascad_database_file, "r")
    except:
        print("Error: can't open HDF5 file '%s' for reading (it might be malformed) ..." % ascad_database_file)
        sys.exit(-1)
    # Load profiling traces
    X_profiling = np.array(in_file['Profiling_traces/traces'], dtype=np.int8)
    # Load profiling labels
    Y_profiling = np.array(in_file['Profiling_traces/labels'])
    # Load attacking traces
    X_attack = np.array(in_file['Attack_traces/traces'], dtype=np.int8)
    # Load attacking labels
    Y_attack = np.array(in_file['Attack_traces/labels'])
    if load_metadata == False:
        return (X_profiling, Y_profiling), (X_attack, Y_attack)
    else:
        return (X_profiling, Y_profiling), (X_attack, Y_attack), (in_file['Profiling_traces/metadata'], in_file['Attack_traces/metadata'])


def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def random_sampling(data, num_sample):
    #print(len(data))
    #print(data.shape)
    np.random.seed(2025)
    rand_ids = np.random.choice(len(data), num_sample, replace=False)
    print(len(rand_ids))
    print('---')
    return rand_ids

def train(args, save_folder, model, train_loader, test_loader, optimizer, criterion, epochs=10):
    start_time = time.time()

    model.train()
    losses = []
    for epoch in range(epochs):
        train_loss = []
        val_loss = []
        for batch_idx, (trace_data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            target = target.long().to(device)
            trace_data = trace_data.float().unsqueeze(1).to(device)
            #print(trace_data.shape)
            output = model(trace_data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Train Loss: {loss.item()}')
            train_loss.append(loss.item())

        for batch_idx, (trace_data, target) in enumerate(test_loader):
            target = target.long().to(device)
            trace_data = trace_data.float().unsqueeze(1).to(device)
            output = model(trace_data)
            loss = criterion(output, target)
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Val Loss: {loss.item()}')
            val_loss.append(loss.item())

        if epoch % args.eval_interval == 0:
            save_path = os.path.join(save_folder, 'model_{}.pt'.format(epoch))
            torch.save(model.state_dict(), save_path)

        losses.append({"Epoch": epoch + 1, "Train Loss": np.mean(train_loss), "Validation Loss": np.mean(val_loss)})

    save_path = os.path.join(save_folder, 'model.pt'.format(epoch))
    torch.save(model.state_dict(), save_path)
    df = pd.DataFrame(losses)
    df.to_csv(os.path.join(save_folder, "losses.csv"), index=False)
    print("---Training done in %s seconds ---" % (time.time() - start_time))

# class to represent dataset
class SCADataset():
  
    def __init__(self, data):
        
        self.x = data[0]
        self.y = data[1]
        self.n_samples = data[0].shape[0] 
      
    # support indexing such that dataset[i] can 
    # be used to get i-th sample
    def __getitem__(self, index):
        return self.x[index], self.y[index]
        
    # we can call len(dataset) to return the size
    def __len__(self):
        return self.n_samples

class MLPBest(nn.Module):
    def __init__(self, node=200, layer_nb=6, input_dim=1400, num_classes=256):
        super(MLPBest, self).__init__()
        
        layers = []
        layers.append(nn.Linear(input_dim, node))
        layers.append(nn.BatchNorm1d(node))
        layers.append(nn.ReLU())

        for _ in range(layer_nb - 2):
            layers.append(nn.Linear(node, node))
            layers.append(nn.BatchNorm1d(node))
            layers.append(nn.ReLU())

        layers.append(nn.Linear(node, num_classes))  # final layer
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)

#LUNet
import torch
import torch.nn as nn

class ConvLSTMBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=2, stride=2)
        self.bn = nn.BatchNorm1d(out_channels)
        self.lstm = nn.LSTM(input_size=out_channels, hidden_size=out_channels, batch_first=True)

    def forward(self, x):
        x = self.conv(x)       # [B, C_out, L]
        x = self.bn(x)
        x = x.permute(0, 2, 1) # [B, L, C_out]
        x, _ = self.lstm(x)    # [B, L, C_out]
        x = x.permute(0, 2, 1) # [B, C_out, L]
        return x

class LUNet(nn.Module):
    def __init__(self, num_classes=256):
        super().__init__()
        self.enc1 = ConvLSTMBlock(1, 8)    # 1400 → 700
        self.enc2 = ConvLSTMBlock(8, 16)   # 700 → 350
        self.enc3 = ConvLSTMBlock(16, 32)  # 350 → 175
        self.enc4 = ConvLSTMBlock(32, 64)  # 175 → 88
        self.enc5 = ConvLSTMBlock(64, 128) # 88 → 44
        self.enc6 = ConvLSTMBlock(128, 128) # 44 → 22

        self.global_pool = nn.AdaptiveAvgPool1d(1)  # output shape: [B, 128, 1]
        self.fc = nn.Sequential(
            nn.Flatten(),                     # [B, 128]
            nn.Linear(128, 128),
            nn.SELU(),
            nn.Linear(128, num_classes)       # logits
        )

    def forward(self, x):
        x = self.enc1(x)
        x = self.enc2(x)
        x = self.enc3(x)
        x = self.enc4(x)
        x = self.enc5(x)
        x = self.enc6(x)
        x = self.global_pool(x)
        out = self.fc(x)
        return out

import torch
import torch.nn as nn
import torch.nn.functional as F

class Inception1D(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Inception1D, self).__init__()
        self.branch1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, padding=0)

        self.branch3 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        )

        self.branch5 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1),
            nn.Conv1d(out_channels, out_channels, kernel_size=5, padding=2)
        )

        self.pool_branch = nn.Sequential(
            nn.AvgPool1d(kernel_size=3, stride=1, padding=1),
            nn.Conv1d(in_channels, out_channels, kernel_size=1)
        )

    def forward(self, x):
        x1 = self.branch1(x)
        x2 = self.branch3(x)
        x3 = self.branch5(x)
        x4 = self.pool_branch(x)
        return torch.cat([x1, x2, x3, x4], dim=1)

class Inception1D_BN(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Inception1D_BN, self).__init__()
        self.branch1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, padding=0)

        self.branch3 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_channels)
        )

        self.branch5 = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=1),
            nn.Conv1d(out_channels, out_channels, kernel_size=5, padding=2),
            nn.BatchNorm1d(out_channels)
        )

        self.pool_branch = nn.Sequential(
            nn.AvgPool1d(kernel_size=3, stride=1, padding=1),
            nn.Conv1d(in_channels, out_channels, kernel_size=1),
            nn.BatchNorm1d(out_channels)
        )

    def forward(self, x):
        x1 = self.branch1(x)
        x2 = self.branch3(x)
        x3 = self.branch5(x)
        x4 = self.pool_branch(x)
        return torch.cat([x1, x2, x3, x4], dim=1)

class InceptionNet1D(nn.Module):
    def __init__(self, input_channels=1, num_classes=256):
        super(InceptionNet1D, self).__init__()
        self.incep1 = Inception1D(input_channels, 8)
        self.incep2 = Inception1D(32, 8)
        self.incep3 = Inception1D(32, 8)
        self.incep4 = Inception1D(32, 8)
        self.incep5 = Inception1D(32, 8)

        self.downsample = nn.Sequential(
            nn.MaxPool1d(kernel_size=2, stride=2),
            Inception1D(32, 16),
            nn.MaxPool1d(kernel_size=2, stride=2),
            Inception1D(64, 16),
        )

        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        # x shape: [batch, 1, 700]
        x = self.incep1(x)
        x = self.incep2(x)
        x = self.incep3(x)
        x = self.incep4(x)
        x = self.incep5(x)

        x = self.downsample(x)

        x = self.global_pool(x)  # shape: [B, C, 1]
        x = self.flatten(x)      # shape: [B, C]
        x = self.fc(x)
        return x

class InceptionNet1D_BN(nn.Module):
    def __init__(self, input_channels=1, num_classes=256):
        super(InceptionNet1D_BN, self).__init__()
        self.incep1 = Inception1D_BN(input_channels, 8)
        self.incep2 = Inception1D_BN(32, 8)
        self.incep3 = Inception1D_BN(32, 8)
        self.incep4 = Inception1D_BN(32, 8)
        self.incep5 = Inception1D_BN(32, 8)

        self.downsample = nn.Sequential(
            nn.MaxPool1d(kernel_size=2, stride=2),
            Inception1D(32, 16),
            nn.MaxPool1d(kernel_size=2, stride=2),
            Inception1D(64, 16),
        )

        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        # x shape: [batch, 1, 700]
        x = self.incep1(x)
        x = self.incep2(x)
        x = self.incep3(x)
        x = self.incep4(x)
        x = self.incep5(x)

        x = self.downsample(x)

        x = self.global_pool(x)  # shape: [B, C, 1]
        x = self.flatten(x)      # shape: [B, C]
        x = self.fc(x)
        return x

'''
# Example usage
if __name__ == "__main__":
    model = InceptionNet1D()
    dummy_input = torch.randn(16, 1, 700)  # batch size = 16
    output = model(dummy_input)
    print(output.shape)  # Should be [16, 256]
'''

#MAIN
parser = parse_arguments()
args = parser.parse_args()

fpath = 'ASCAD_variable.h5'
(X_profiling, Y_profiling), (X_attack, Y_attack), (Metadata_profiling, Metadata_attack) = load_ascad(fpath, load_metadata=True)

print('X_profiling: ' , X_profiling.shape)
print('Y_profiling: ' , Y_profiling.shape)
print('X_attack: ' , X_attack.shape)
print('Y_attack: ' , Y_attack.shape)
print(np.unique(Y_profiling, return_counts=False))
print(np.unique(Y_attack, return_counts=False))

save_path = '{}'.format(args.name)
print(save_path)
database_folder_train = os.path.join('multi_attack_trained_models', save_path)
Path(database_folder_train).mkdir(parents=True, exist_ok=True)

if args.sampling == 'random':
    sample_ids = random_sampling(X_profiling, args.num_sample)
    np.save(os.path.join(database_folder_train,'all_ids.npy'), sample_ids)
    X_profiling = X_profiling[sample_ids]
    Y_profiling = Y_profiling[sample_ids]

train_data = [X_profiling[:args.num_sample], Y_profiling[:args.num_sample]]
test_data = [X_profiling[:10000], Y_profiling[:10000]]
SCAdataset = SCADataset(train_data)
SCAdataset_val = SCADataset(test_data)
train_loader = DataLoader(SCAdataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(SCAdataset_val, batch_size=args.batch_size, shuffle=False)

#model = MLPBest()
model = InceptionNet1D_BN()
# Define optimizer and loss function
optimizer = optim.RMSprop(model.parameters(), lr=0.00001)
criterion = nn.CrossEntropyLoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
model = model.to(device)
model = train(args, database_folder_train, model, train_loader, val_loader, optimizer, criterion, epochs=args.num_epoch)