""" Code for loading data. """
import numpy as np
import os
import random
import tensorflow as tf

from tensorflow.python.platform import flags
import os
import cv2
import copy


FLAGS = flags.FLAGS

class Dataloader(object):
    def __init__(self, data):
        self.data = data
        root = '/'
        
        if data == 'tiered1':
            val_root = root + '/train'
            folders = os.listdir(val_root)
            folders = [folder.split('n')[1] for folder in folders]
            
            folders = [int(folder) for folder in folders]
            
            folders.sort()
            
            folder_labels = {}
            for i, folder in enumerate(folders):
                folder = str(folder)
                if len(folder) < 8:
                    folder = 'n' + '0'*(8-len(folder)) + folder
                
                folder_labels[folder] = [val_root+'/'+folder, i]
        
        else:
            val_root = root + '/val'
            val_folders = os.listdir(val_root)
            test_root = root + '/test'
            test_folders = os.listdir(test_root)
            
            folders_ = val_folders + test_folders
            folders = copy.deepcopy(folders_)
            
            folders = [folder.split('n')[1] for folder in folders]
    
            folders = [int(folder) for folder in folders]
    
            folders.sort()
    
            folder_labels = {}
            for i, folder in enumerate(folders):
                folder = str(folder)
                if len(folder) < 8:
                    folder = 'n' + '0' * (8 - len(folder)) + folder
                
                if folder in val_folders:
                    folder_labels[folder] = [val_root + '/' + folder, i]
                else:
                    folder_labels[folder] = [test_root + '/' + folder, i]
        
        
        train_images = []
        val_images = []
        for key, value in folder_labels.items():
            path = value[0]
            label = value[1]
        
            images = os.listdir(path)
            images = [image.split('_')[1] for image in images]
            images = [image.split('.')[0] for image in images]
            images = [int(image) for image in images]
            images.sort()
            
            t_images = images[:-100]
            v_images = images[-100:]
            
            t_images = [[path+'/'+key + '_' + str(image)+ '.JPEG', label] for image in t_images]
            v_images = [[path + '/' + key + '_' + str(image) + '.JPEG', label] for image in v_images]
            
            train_images.extend(t_images)
            val_images.extend(v_images)
            
            
        self.train_image_labels = train_images
        self.num_train = len(self.train_image_labels)
        random.shuffle(self.train_image_labels)
        self.train_pointer = 0

        self.val_image_labels = val_images
        self.num_val = len(self.val_image_labels)
        self.val_pointer = 0

        self.image_size = 84
        
        self.image_lists = tf.placeholder(dtype=tf.string, shape=[None, ])
        dataset = tf.data.Dataset.from_tensor_slices(self.image_lists)
        dataset = dataset.map(self.read_image, num_parallel_calls=40)
        dataset = dataset.batch(FLAGS.batch_size * 2)
        iterator = dataset.make_initializable_iterator()
        self.out_images = iterator.get_next()
        self.iterator = iterator.initializer

        self.image_lists_val = tf.placeholder(dtype=tf.string, shape=[None, ])
        dataset_val = tf.data.Dataset.from_tensor_slices(self.image_lists_val)
        dataset_val = dataset_val.map(self.read_image_val, num_parallel_calls=40)
        dataset_val = dataset_val.batch(FLAGS.batch_size * 2)
        iterator_val = dataset_val.make_initializable_iterator()
        self.out_images_val = iterator_val.get_next()
        self.iterator_val = iterator_val.initializer



    def get_batch_data(self, batch_size, train=True):
        val_end = False
        if train:
            if self.train_pointer + 2*batch_size >= self.num_train:
                batch_image_labels = self.train_image_labels[self.train_pointer:self.train_pointer + batch_size]
                self.train_pointer = 0
                random.shuffle(self.train_image_labels)

            else:
                batch_image_labels = self.train_image_labels[self.train_pointer:self.train_pointer + batch_size]
                self.train_pointer += batch_size
            
        else:
            if self.val_pointer + batch_size >= self.num_val:
                batch_image_labels = self.val_image_labels[self.val_pointer:]
                self.val_pointer = 0
                val_end = True
            else:
                batch_image_labels = self.val_image_labels[self.val_pointer:self.val_pointer + batch_size]
                self.val_pointer += batch_size

        batch_files, batch_labels = zip(*batch_image_labels)
        batch_labels = np.array(batch_labels)
        if self.data == 'tiered1':
            batch_labels = make_one_hot(batch_labels, 351)
        else:
            batch_labels = make_one_hot(batch_labels, 257)
        
        return batch_files, batch_labels, val_end


    def _parser(self, image_path):
        image_path = image_path.decode()
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if ('val' in image_path) or ('test' in image_path):
            image = cv2.resize(image, (56, 56))
     
        try:
            _ = image.shape
        except:
            print(image_path)
    
        return image


    def read_image(self, image_path):
        image = tf.py_func(self._parser, inp=[image_path], Tout=tf.uint8)
        image = tf.cast(image, tf.float32)
        image = image / 256
        
        return image


    def read_image_val(self, image_path):
        image = tf.py_func(self._parser, inp=[image_path], Tout=tf.uint8)
        image = tf.cast(image, tf.float32)
        image = image / 256
        
        return image


def make_one_hot(data, classes):
    return (np.arange(classes)==data[:,None]).astype(np.integer)



