import os
import pickle

import numpy as np
import random
import skimage.io
from scipy.misc import imread, imresize

import absl.app
import absl.flags
from absl import logging

from misc_utils import define_flags_with_default

flags_def = define_flags_with_default(
    input_dir='',
    output='./miniimagenet.pkl',
)

def process_data_partition(tasks):
    data = []
    for i, task_path in enumerate(tasks):
        print('Processing task: {} / {}'.format(i, len(tasks)))
        images = []
        for filename in os.listdir(task_path):
            image_path = os.path.join(task_path, filename)
            image = np.float32(imread(image_path))
            assert np.max(image) > 1.
            image /= 255.
            images.append(image)
        images = np.stack(images, axis=0).astype(np.float32)
        data.append({'images': images})
    return data


def main(_):
    FLAGS = absl.flags.FLAGS
    train_tasks = [os.path.join(os.path.join(FLAGS.input_dir, 'train'), x) for x in os.listdir(os.path.join(FLAGS.input_dir, 'train')) if x.startswith('n')]
    val_tasks = [os.path.join(os.path.join(FLAGS.input_dir, 'val'), x) for x in os.listdir(os.path.join(FLAGS.input_dir, 'val')) if x.startswith('n')]
    data = {
        'train': process_data_partition(train_tasks),
        'val': process_data_partition(val_tasks)
    }
    with open(FLAGS.output, 'wb') as fout:
        pickle.dump(data, fout)


if __name__ == '__main__':
    absl.app.run(main)




