#!/usr/bin/env python3
"""Compute the discretization bins by randomly sampling files."""

import os
import random
import sys
import numpy as np
from load_data import discretize, load_single_file

# The number of values pixels are discretized to
DISCRETE_VALUES = 4

if __name__ == '__main__':

    # Sample some random files from the provided directories to calculate the pixel discretization bins
    print('Calculating color bins')
    with open(sys.argv[1]) as index_file:
        all_files = list(sorted(line.split('f:')[1].strip() for line in index_file.readlines() if 'f:' in line))
    sample = random.sample(all_files, len(all_files))
    configs, file_ids = zip(*[path.split('/')[-2:] for path in sample])
    sample = [os.path.expanduser(f'~/ais/{config}_downscaled/{file_id}' if 'downscaled' not in config else f'~/ais/{config}/{file_id}') for config, file_id in zip(configs, file_ids)]

    # Concatenate to the same array over time so we don't have to duplicate the same data as a list and a NumPy array in memory
    observations = load_single_file(sample[0])[0]
    for path in sample[1:]:
        new_obs = load_single_file(path)[0]
        observations = np.concatenate([observations, new_obs])
        if observations.shape[0] >= 100_000:
            break

    # Calculate and print out the color bins and then free up the memory
    color_bins = discretize(observations, discrete_values=DISCRETE_VALUES, return_bins=True)[1]
    del observations
    print('Color bins:', color_bins.tolist())
