import os
import pickle

import numpy as np
import skimage.io

import absl.app
import absl.flags
from absl import logging

from .misc_utils import define_flags_with_default


flags_def = define_flags_with_default(
    input_dir='',
    output='./rainbow_mnist.pkl',
)


def process_data_partition(path):
    data = []
    tasks = [x for x in os.listdir(path) if x.startswith('tform')]
    for i, task in enumerate(tasks):
        print('Processing task: {} / {}'.format(i, len(tasks)))
        task_path = os.path.join(path, task)
        images = []
        labels = []
        for batch in [x for x in os.listdir(task_path) if x.isnumeric()]:
            task_batch_path = os.path.join(task_path, batch)
            for label in range(10):
                task_batch_label_path = os.path.join(task_batch_path, str(label))
                if not os.path.isdir(task_batch_label_path):
                    continue
                for filename in os.listdir(task_batch_label_path):
                    image_path = os.path.join(task_batch_label_path, filename)
                    image = skimage.io.imread(image_path)
                    images.append(image)
                    labels.append(label)
        images = np.stack(images, axis=0).astype(np.float32) / 255.0
        labels = np.array(labels, dtype=np.int64)
        data.append({'images': images, 'labels': labels})
    return data


def main(_):
    FLAGS = absl.flags.FLAGS
    data = {
        'train': process_data_partition(os.path.join(FLAGS.input_dir, 'train')),
        'val': process_data_partition(os.path.join(FLAGS.input_dir, 'val'))
    }
    with open(FLAGS.output, 'wb') as fout:
        pickle.dump(data, fout)


if __name__ == '__main__':
    absl.app.run(main)



