import argparse
import json
import os
from sys import meta_path
import numpy as np


def read_file(filename):
    datapoints = []
    with open(filename, 'rb') as file:
        while file.tell() < os.path.getsize(filename):
            datapoint = np.load(file)
            datapoints.append(datapoint)

    return np.concatenate(datapoints)


def normalize(input):
    mean = np.mean(input, axis=0)
    std = np.std(input, axis=0)
    return input.shape[1], mean, std


def save(meta_path, obs_size, obs_mean, obs_std):
    meta = {'obs_size': obs_size,
            'obs_mean': obs_mean.tolist(),
            'obs_std': obs_std.tolist()}
    with open(meta_path, 'w') as fh:
        json.dump(meta, fh)


def normalize_all(args, horizon=5, num_formats=10):
    # normalize input
    for i in range(horizon):
        all_data = []
        for client_idx in range(10):
            all_data.append(read_file(f'{args.input}/hidden-{client_idx}_input{i+1}.npy'))
        data = np.concatenate(all_data)
        data = data[:, :num_formats, :]
        data = np.reshape(data, (data.shape[0], -1))
        obs_size, obs_mean, obs_std = normalize(data)
        meta_path = f'{args.saving_dir}cpp-meta-hidden-{i}.json'
        save(meta_path, obs_size, obs_mean, obs_std)

    # # normalize mpc
    # data = read_file(args.buffer_format_file)
    # obs_size, obs_mean, obs_std = normalize(data)
    # meta_path = f'{args.saving_dir}cpp-meta-mpc.json'
    # save(meta_path, obs_size, obs_mean, obs_std)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run k-means")
    parser.add_argument(
        "--input",
        default="./data_points"
    )
    parser.add_argument(
        "--buffer-format-file",
        default="./data_points/hidden_mpc.npy"
    )
    parser.add_argument(
        "--saving-dir",
        default='./data_points/'
    )
    args = parser.parse_args()

    normalize_all(args)
