import os
import os.path
import sys
import h5py
import numpy as np
import matplotlib.pyplot as plt
import ast

import os.path
import sys
import h5py
import math
import gc
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import argparse
#from tensorflow.python.keras.layers import Lambda
#from sklearn.model_selection import train_test_split
#K-center: https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py
# Trace and metadata parameters
from pathlib import Path
from sklearn.cluster import KMeans
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

def set_seeds(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)

set_seeds(2025)

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--batch_size', type=int, help='batch_size', default=256)
    parser.add_argument('--num_epoch', type=int, help='batch_size', default=256)
    parser.add_argument('--num_sample', type=int, help='batch_size', default=256)
    parser.add_argument('--eval_interval', type=int, help='batch_size', default=10)
    parser.add_argument('--sampling', type=str, default='None')
    parser.add_argument('--name', type=str, help='experiment name', default='test')

    return parser   

def check_file_exists(file_path):
    file_path = os.path.normpath(file_path)
    if os.path.exists(file_path) == False:
        print("Error: provided file path '%s' does not exist!" % file_path)
        sys.exit(-1)
    return

def load_ascad(ascad_database_file, load_metadata=False):
    check_file_exists(ascad_database_file)
    # Open the ASCAD database HDF5 for reading
    try:
        in_file  = h5py.File(ascad_database_file, "r")
    except:
        print("Error: can't open HDF5 file '%s' for reading (it might be malformed) ..." % ascad_database_file)
        sys.exit(-1)
    # Load profiling traces
    X_profiling = np.array(in_file['Profiling_traces/traces'], dtype=np.int8)
    # Load profiling labels
    Y_profiling = np.array(in_file['Profiling_traces/labels'])
    # Load attacking traces
    X_attack = np.array(in_file['Attack_traces/traces'], dtype=np.int8)
    # Load attacking labels
    Y_attack = np.array(in_file['Attack_traces/labels'])
    if load_metadata == False:
        return (X_profiling, Y_profiling), (X_attack, Y_attack)
    else:
        return (X_profiling, Y_profiling), (X_attack, Y_attack), (in_file['Profiling_traces/metadata'], in_file['Attack_traces/metadata'])


def load_multi_attack(data_path):
    infile = np.load(data_path)
    data = infile['data']
    labels = infile['label']

    return data, labels

def random_sampling(data, num_sample):
    #print(len(data))
    #print(data.shape)
    np.random.seed(2025)
    rand_ids = np.random.choice(len(data), num_sample, replace=False)
    print(len(rand_ids))
    print('---')
    return rand_ids

def train(args, save_folder, model, train_loader, test_loader, optimizer, criterion, epochs=10):
    start_time = time.time()

    model.train()
    losses = []
    for epoch in range(epochs):
        train_loss = []
        val_loss = []
        for batch_idx, (trace_data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            target = target.long().to(device)
            trace_data = trace_data.float().unsqueeze(1).to(device)
            #print(trace_data.shape)
            output = model(trace_data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Train Loss: {loss.item()}')
            train_loss.append(loss.item())

        for batch_idx, (trace_data, target) in enumerate(test_loader):
            target = target.long().to(device)
            trace_data = trace_data.float().unsqueeze(1).to(device)
            output = model(trace_data)
            loss = criterion(output, target)
            if batch_idx % 100 == 0:
                print(f'Epoch {epoch+1}/{epochs}, Batch {batch_idx}, Val Loss: {loss.item()}')
            val_loss.append(loss.item())

        if epoch % args.eval_interval == 0:
            save_path = os.path.join(save_folder, 'model_{}.pt'.format(epoch))
            torch.save(model.state_dict(), save_path)

        losses.append({"Epoch": epoch + 1, "Train Loss": np.mean(train_loss), "Validation Loss": np.mean(val_loss)})

    save_path = os.path.join(save_folder, 'model.pt'.format(epoch))
    torch.save(model.state_dict(), save_path)
    df = pd.DataFrame(losses)
    df.to_csv(os.path.join(save_folder, "losses.csv"), index=False)
    print("---Training done in %s seconds ---" % (time.time() - start_time))

    return model
# class to represent dataset
class SCADataset():
  
    def __init__(self, data):
        
        self.x = data[0]
        self.y = data[1]
        self.n_samples = data[0].shape[0] 
      
    # support indexing such that dataset[i] can 
    # be used to get i-th sample
    def __getitem__(self, index):
        return self.x[index], self.y[index]
        
    # we can call len(dataset) to return the size
    def __len__(self):
        return self.n_samples


import torch
import torch.nn as nn
import torch.nn.functional as F

import torch
import torch.nn as nn

class EFCNN(nn.Module):
    def __init__(self, input_dim=1400, num_classes=256, num_blocks=3, filters=128):
        super(EFCNN, self).__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.filters = filters

        # Feature Encoder (Encode2 block) with BatchNorm
        self.encoder = nn.Sequential(
            nn.Conv1d(1, filters, kernel_size=5, padding=2),
            nn.BatchNorm1d(filters),
            nn.SELU(),
            nn.Conv1d(filters, filters, kernel_size=5, padding=2),
            nn.BatchNorm1d(filters),
            nn.SELU(),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )

        # CNN Backbone Blocks with BatchNorm
        self.blocks = nn.Sequential(
            *[self._make_block(filters) for _ in range(num_blocks - 1)]
        )

        # Final classifier
        reduced_dim = input_dim // (2 ** num_blocks)  # halved by each pooling
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(filters * reduced_dim, num_classes)  # logits only, no softmax
        )

    def _make_block(self, filters):
        return nn.Sequential(
            nn.Conv1d(filters, filters, kernel_size=5, padding=2),
            nn.BatchNorm1d(filters),
            nn.SELU(),
            nn.Conv1d(filters, filters, kernel_size=5, padding=2),
            nn.BatchNorm1d(filters),
            nn.SELU(),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.blocks(x)
        x = self.classifier(x)
        return x



'''
# Example usage
if __name__ == "__main__":
    model = InceptionNet1D()
    dummy_input = torch.randn(16, 1, 700)  # batch size = 16
    output = model(dummy_input)
    print(output.shape)  # Should be [16, 256]
'''

#MAIN
parser = parse_arguments()
args = parser.parse_args()

fpath = 'ASCAD_variable.h5'
(X_profiling, Y_profiling), (X_attack, Y_attack), (Metadata_profiling, Metadata_attack) = load_ascad(fpath, load_metadata=True)

print('X_profiling: ' , X_profiling.shape)
print('Y_profiling: ' , Y_profiling.shape)
print('X_attack: ' , X_attack.shape)
print('Y_attack: ' , Y_attack.shape)
print(np.unique(Y_profiling, return_counts=False))
print(np.unique(Y_attack, return_counts=False))

save_path = '{}'.format(args.name)
print(save_path)
database_folder_train = os.path.join('multi_attack_trained_models', save_path)
Path(database_folder_train).mkdir(parents=True, exist_ok=True)

if args.sampling == 'random':
    sample_ids = random_sampling(X_profiling, args.num_sample)
    np.save(os.path.join(database_folder_train,'all_ids.npy'), sample_ids)
    X_profiling = X_profiling[sample_ids]
    Y_profiling = Y_profiling[sample_ids]

train_data = [X_profiling[:args.num_sample], Y_profiling[:args.num_sample]]
test_data = [X_profiling[:10000], Y_profiling[:10000]]
SCAdataset = SCADataset(train_data)
SCAdataset_val = SCADataset(test_data)
train_loader = DataLoader(SCAdataset, batch_size=args.batch_size, shuffle=True)
val_loader = DataLoader(SCAdataset_val, batch_size=args.batch_size, shuffle=False)

#model = MLPBest()
#model = InceptionNet1D_BN()
model = EFCNN(input_dim=1400, num_classes=256, num_blocks=3, filters=128)
# Define optimizer and loss function
#optimizer = optim.RMSprop(model.parameters(), lr=0.00001)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
model = model.to(device)
model = train(args, database_folder_train, model, train_loader, val_loader, optimizer, criterion, epochs=args.num_epoch)
