#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
@author: Xiao Jin
In this file we load Linnaeus data
"""
from config import *
import gc
import pathlib
import random as r
import tensorflow as tf
import os

train_data_dir = 'train/'
test_data_dir = 'test/'
AUTOTUNE = tf.data.experimental.AUTOTUNE

def preprocess_image(image):
  data_image = tf.cast(tf.image.decode_jpeg(image, channels=3), tf.float32)
  # size = data_image.numpy().shape()
  # data_image = tf.image.resize(data_image, [64, 64])
  return data_image


def load_and_preprocess_image(path):
  image = tf.io.read_file(path)
  return preprocess_image(image)


def load_data(data_path):
    '''
    Load data
    '''
    train_data_dir = pathlib.Path(data_path)
    # test_data_dir = pathlib.Path(test_data_dir)

    # list all image paths
    all_image_paths = list(train_data_dir.glob('*/*'))
    all_image_paths = [str(path) for path in all_image_paths]
    r.seed(0)
    r.shuffle(all_image_paths)
    image_count = len(all_image_paths)

    # deal with labels
    label_names = sorted(item.name for item in train_data_dir.glob('*/') if item.is_dir())
    # assign label index
    label_to_index = dict((name, index) for index, name in enumerate(label_names))
    # create indice
    all_image_labels = [label_to_index[pathlib.Path(path).parent.name] for path in all_image_paths]

    # load images
    # path dataset
    path_ds = tf.data.Dataset.from_tensor_slices(all_image_paths)
    # image dataset
    image_ds = path_ds.map(load_and_preprocess_image, num_parallel_calls = AUTOTUNE)
    # label dataset
    label_ds = tf.data.Dataset.from_tensor_slices(tf.cast(all_image_labels, tf.int64))

    # build whole dataset
    image_label_ds = tf.data.Dataset.zip((image_ds, label_ds))
    image_label_ds = image_label_ds.batch(image_count)
    # preprocess
    image_label_ds = image_label_ds.map(lambda x, y: (tf.divide(tf.cast(x, tf.float32), 255.0),
                                                      tf.reshape(tf.one_hot(y, 5), (-1, 5))))

    train_data, train_label = zip(*image_label_ds)
    train_data = train_data[0].numpy()
    train_label = train_label[0].numpy()
    train_datasets = []
    for n in range(number_of_workers):
        temp_train_dataset = (
            tf.data.Dataset.from_tensor_slices((train_data[n * data_number: (n + 1) * data_number, :, :, :],
                                                train_label[n * data_number:(n + 1) * data_number, :])).batch(
                data_number))
        train_datasets.append(temp_train_dataset)
    return train_datasets

train_datasets = load_data(train_data_dir)


if __name__ == "__main__":
    train_datasets = load_data(train_data_dir)