import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import copy
import time
from tqdm import tqdm as tqdm
from scipy.spatial import distance
from sklearn_extra.cluster import CLARA
from pathlib import Path
import os
import argparse
from scipy.spatial import distance
from scipy import stats
import numpy as np
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin_min

def normalize(timeseries):
    return (timeseries-timeseries.min())/(timeseries.max()-timeseries.min())

def z_norm(timeseries):
    '''
    timeseries = np.expand_dims(timeseries, 1)
    scaler = StandardScaler()
    scaler.fit(timeseries)
    res = scaler.transform(timeseries)[:,0]
    '''

    return stats.zscore(timeseries)

def normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = normalize(data[i])

    return data

def z_normalize_data_per_trace(data):
    print(data.shape)
    for i in range(len(data)):
        data[i] = z_norm(data[i])

    return data

def parse_arguments():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('--save_folder', type=str)
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--num_sample', type=int)
    parser.add_argument('--num_cluster', type=int)
    parser.add_argument('--normalize', type=int)
    
    return parser

parser = parse_arguments()
args = parser.parse_args()

data = np.load(args.data_path)
data = data['data'][:args.num_sample]
print(data.shape)
if args.normalize == 1:
    print('Norm 1')
    data = normalize_data_per_trace(data)
elif args.normalize == 2:
    print('Norm 2')
    data = z_normalize_data_per_trace(data)
print(data.shape)

print('Clustering in progress...')
dist_metric = distance.euclidean
start_time = time.time()
kmeans = KMeans(n_clusters=args.num_cluster, random_state=0).fit(data)
print('------------Took: {}s --------------'.format(time.time() - start_time))
#data_indexes = fast_min_max_sampling(data, args, num_sample=args.num_sample)
Path(args.save_folder).mkdir(parents=True, exist_ok=True)



# Fit KMeans (you've already done this)
# kmeans = KMeans(n_clusters=k).fit(X)

# Get the closest real data point to each centroid
closest_indices, _ = pairwise_distances_argmin_min(kmeans.cluster_centers_, data)
np.save(os.path.join(args.save_folder, 'kmeans_medoid_indices.npy'),closest_indices)
#np.save(os.path.join(args.save_folder,'minmax_index.npy'), data_indexes)