#!/usr/bin/python
# -*- coding: utf-8 -*-
"""
@author: Xiao Jin
In this file we load Linnaeus data
"""
from config import *
import gc
import pathlib
import random as r
import tensorflow as tf
import os

train_data_dir = 'train/'
test_data_dir = 'test/'
AUTOTUNE = tf.data.experimental.AUTOTUNE

def preprocess_image(image):
  data_image = tf.cast(tf.image.decode_jpeg(image, channels=3), tf.float32)
  # size = data_image.numpy().shape()
  # data_image = tf.image.resize(data_image, [64, 64])
  return data_image


def load_and_preprocess_image(path):
  image = tf.io.read_file(path)
  return preprocess_image(image)


def load_data(data_path):
    '''
    Load data
    '''
    (train_data, train_label), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

    N_train = train_data.shape[0]
    N_test = test_images.shape[0]

    train_dataset = (
        tf.data.Dataset.from_tensor_slices((train_data, train_label)).batch(N_train)
    )
    test_dataset = (
        tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(N_test)
    )
    train_dataset = (
        train_dataset.map(lambda x, y:
                          (tf.divide(tf.cast(x, tf.float32), 255.0),
                           tf.reshape(tf.one_hot(y, 10), (-1, 10))))
    )

    test_dataset = (
        test_dataset.map(lambda x, y:
                         (tf.divide(tf.cast(x, tf.float32), 255.0),
                          tf.reshape(tf.one_hot(y, 10), (-1, 10))))
    )

    train_data, train_label = zip(*train_dataset)
    train_data = train_data[0]
    train_label = train_label[0]

    train_datasets = []
    for worker_index in range(4):
        i = worker_index // 2
        j = worker_index % 2
        slice = train_data[:, 16 * i: 16 * (i + 1), 16 * j: 16 * (j + 1), :]
        train_datasets.append(slice)
    '''
    for n in range(number_of_workers):
        temp_train_dataset = (
            tf.data.Dataset.from_tensor_slices((train_data[n * data_number: (n + 1) * data_number, :, :, :],
                                                train_label[n * data_number:(n + 1) * data_number, :])).batch(
                data_number))
        train_datasets.append(temp_train_dataset)
    '''
    train_datasets.append(train_label)
    train_datasets = tuple(train_datasets)
    train_ds = tf.data.Dataset.from_tensor_slices(train_datasets).batch(data_number)
    return train_ds

train_datasets = load_data(train_data_dir)


if __name__ == "__main__":
    train_datasets = load_data(train_data_dir)
    print('Done')