# /usr/bin/env python3.7

import shutil
import numpy as np
import os
from sklearn.cluster import KMeans
import pickle
import argparse
import yaml
import json


def read_file(filename):
    datapoints = []
    with open(filename, 'rb') as file:
        while file.tell() < os.path.getsize(filename):
            try:
                datapoint = np.load(file)
            except Exception as ex:
                print(ex)
                break
            datapoints.append(datapoint)

    return np.concatenate(datapoints)


def normalize(input):
    mean = np.mean(input, axis=0)
    std = np.std(input, axis=0)
    normalized_input = (
        input - mean[np.newaxis, :]) / (std[np.newaxis, :] + 1e-9)
    return normalized_input, mean, std


def get_input(args):
    raw_inputs = []
    for filename in os.listdir(args.data_dir):
        full_name = os.path.join(args.data_dir, filename)
        data = read_file(full_name)
        raw_inputs.append(data)
        print('processed', full_name, data.shape)

    raw_inputs = np.vstack(raw_inputs)
    raw_inputs = raw_inputs[:, :raw_inputs.shape[1]-1] # ignore format
    print('total', raw_inputs.shape)
    return raw_inputs


def cluster(args, clusters):
    vector = get_input(args)
    vector = vector[:300000]
    print(vector.shape)
    vector, mean, std = normalize(vector)
    
    kmeans = KMeans(n_clusters=clusters)
    kmeans.fit(vector)

    with open(args.saving_dir + "clusters.pkl", 'wb') as f:
        pickle.dump(kmeans, f)

    np.savetxt(args.saving_dir + "mean.txt", mean)
    np.savetxt(args.saving_dir + "std.txt", std)

    # save clusters points as json file
    cluster_points = {}
    for i, point in enumerate(kmeans.cluster_centers_):
        cluster_points[i] = point.tolist()
    with open(args.saving_dir + 'clusters.json', 'w') as f:
        json.dump(cluster_points, f)


def check_dir(args):
    if os.path.isdir(args.saving_dir):
        shutil.rmtree(args.saving_dir)
    os.makedirs(args.saving_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run k-means")
    parser.add_argument(
        "--data-dir",
        default="./data_points/"
    )
    parser.add_argument(
        "--saving-dir",
        default='./weights/kmeans/'
    )
    parser.add_argument(
        "--yaml-settings",
        default='./src/settings.yml'
    )
    args = parser.parse_args()

    with open(args.yaml_settings, 'r') as fh:
        yaml_settings = yaml.safe_load(fh)

    clusters = int(yaml_settings["experiments"][0]
                   ['fingerprint']['abr_config']['num_of_contexts'])

    check_dir(args)
    kmeans = cluster(args, clusters)
