import os
import os.path
import sys
import h5py
import numpy as np
import matplotlib.pyplot as plt
import ast
import argparse
import tensorflow as tf
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import Flatten, Dense, Input, Conv1D, MaxPooling1D, GlobalAveragePooling1D, GlobalMaxPooling1D, AveragePooling1D, BatchNormalization, Activation, Add, add
from tensorflow.keras import backend as K
from tensorflow.keras.applications.imagenet_utils import decode_predictions
from tensorflow.keras.applications.imagenet_utils import preprocess_input
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import load_model
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import copy
import time
from tqdm import tqdm as tqdm
from scipy.spatial import distance
from sklearn_extra.cluster import CLARA
from pathlib import Path
import os
import argparse
from scipy.spatial import distance
from scipy import stats

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--save_folder', type=str)
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--num_sample', type=int)
    parser.add_argument('--num_cluster', type=int)
    parser.add_argument('--normalize', type=int)
    return parser 

parser = parse_arguments()
args = parser.parse_args()


def check_file_exists(file_path):
    file_path = os.path.normpath(file_path)
    if os.path.exists(file_path) == False:
        print("Error: provided file path '%s' does not exist!" % file_path)
        sys.exit(-1)
    return

def load_sca_model(model_file):
    check_file_exists(model_file)
    try:
        model = load_model(model_file)
    except:
        print("Error: can't load Keras model file '%s'" % model_file)
        sys.exit(-1)
    return model

def load_ascad(ascad_database_file, load_metadata=False):
    check_file_exists(ascad_database_file)
    # Open the ASCAD database HDF5 for reading
    try:
        in_file  = h5py.File(ascad_database_file, "r")
    except:
        print("Error: can't open HDF5 file '%s' for reading (it might be malformed) ..." % ascad_database_file)
        sys.exit(-1)
    # Load profiling traces
    X_profiling = np.array(in_file['Profiling_traces/traces'], dtype=np.int8)
    # Load profiling labels
    Y_profiling = np.array(in_file['Profiling_traces/labels'])
    # Load attacking traces
    X_attack = np.array(in_file['Attack_traces/traces'], dtype=np.int8)
    # Load attacking labels
    Y_attack = np.array(in_file['Attack_traces/labels'])
    if load_metadata == False:
        return (X_profiling, Y_profiling), (X_attack, Y_attack)
    else:
        return (X_profiling, Y_profiling), (X_attack, Y_attack), (in_file['Profiling_traces/metadata'], in_file['Attack_traces/metadata'])


import matplotlib.pyplot as plt

num_traces=2000
target_byte=2
multilabel=0
simulated_key=0
save_file=""
(X_profiling, Y_profiling), (X_attack, Y_attack), (Metadata_profiling, Metadata_attack) = load_ascad(args.data_path, load_metadata=True)

data = X_profiling

print('Clustering in progress...')
dist_metric = distance.euclidean
start_time = time.time()
#clara = CLARA(n_clusters=args.num_cluster, random_state=0, metric=dist_metric).fit(data)
clara = CLARA(n_clusters=args.num_cluster, random_state=0).fit(data)
print('------------Took: {}s --------------'.format(time.time() - start_time))
#data_indexes = fast_min_max_sampling(data, args, num_sample=args.num_sample)
Path(args.save_folder).mkdir(parents=True, exist_ok=True)

np.save(os.path.join(args.save_folder, 'clara_cluster_centers.npy'),clara.cluster_centers_)
np.save(os.path.join(args.save_folder, 'clara_medoid_indices.npy'),clara.medoid_indices_)
np.save(os.path.join(args.save_folder, 'clara_labels.npy'),clara.labels_)
np.save(os.path.join(args.save_folder, 'inertia.npy'),clara.inertia_)