import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import copy
import time
from tqdm import tqdm as tqdm
from scipy.spatial import distance
from sklearn_extra.cluster import CLARA
from pathlib import Path
import os
import argparse
from scipy.spatial import distance
from scipy import stats

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def z_norm(timeseries):
    '''
    timeseries = np.expand_dims(timeseries, 1)
    scaler = StandardScaler()
    scaler.fit(timeseries)
    res = scaler.transform(timeseries)[:,0]
    '''

    return stats.zscore(timeseries)

def normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = normalize(data[i])

    return data

def z_normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = z_norm(data[i])

    return data

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--save_folder', type=str)
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--num_sample', type=int)
    parser.add_argument('--num_cluster', type=int)
    parser.add_argument('--normalize', type=int)
    
    return parser

parser = parse_arguments()
args = parser.parse_args()

data = np.load(args.data_path)
data = data['data'][:args.num_sample]
print(data.shape)
if args.normalize == 1:
    print('Norm 1')
    data = normalize_data_per_trace(data)
elif args.normalize == 2:
    print('Norm 2')
    data = z_normalize_data_per_trace(data)
print(data.shape)

print('Clustering in progress...')
dist_metric = distance.euclidean
start_time = time.time()
#clara = CLARA(n_clusters=args.num_cluster, random_state=0, metric=dist_metric).fit(data)
clara = CLARA(n_clusters=args.num_cluster, random_state=0).fit(data)
print('------------Took: {}s --------------'.format(time.time() - start_time))
#data_indexes = fast_min_max_sampling(data, args, num_sample=args.num_sample)
Path(args.save_folder).mkdir(parents=True, exist_ok=True)

np.save(os.path.join(args.save_folder, 'clara_cluster_centers.npy'),clara.cluster_centers_)
np.save(os.path.join(args.save_folder, 'clara_medoid_indices.npy'),clara.medoid_indices_)
np.save(os.path.join(args.save_folder, 'clara_labels.npy'),clara.labels_)
np.save(os.path.join(args.save_folder, 'inertia.npy'),clara.inertia_)
#np.save(os.path.join(args.save_folder,'minmax_index.npy'), data_indexes)