# inspired by https://github.com/asprenger/keras_fc_densenet.git
"""
Preprocess images and labels from the CamVid dataset (the 701 still images). 
The images are resized. The labels are mapped to unique int IDs and cropped. 
"""
import os
import glob
import argparse
from PIL import Image
import numpy as np

def open_image(image_path, image_size):
    """Load an image and resize it to `image_size`.
       Args:
           image_path: path of image file
           image_size: a tuple (width, height)
       Return:
           array of shape (height, width, 3)
    """
    image = Image.open(image_path)
    width, height = image.size
    if not (width == image_size[0] and height == image_size[1]):
        print('resizing')
        image = image.resize(image_size, Image.NEAREST)
    return np.array(image)

def load_label_colors(label_colors_path):
    """Load the `label_colors.txt` file from the CamVid dataset.
       Return:
           label_codes: list of label codes, arrays with RGB values
           label_names: list of label code names
           label_code2id: dict that maps label codes to unique IDs
    """
    def parse_code(l):
        a, b = l.strip().split("\t")
        return tuple(int(o) for o in a.split(' ')), b

    label_codes, label_names = zip(*[parse_code(l) for l in open(label_colors_path)])
    label_codes = list(label_codes)
    label_names = list(label_names)

    # assign IDs to codes
    label_code2id = {v:k for k,v in enumerate(label_codes)} 

    return label_codes, label_names, label_code2id

def conv_label(label, label_code2id):
    """Map all color codes in `label_image` to code IDs.
       Args:
           label: array of shape [height, width, 3]
           label_code2id: dict that maps label codes to unique IDs
       Return:  
           array of shape [height, width, 1]
    """
    assert len(label.shape) == 3
    assert label.shape[2] == 3
    height = label.shape[0]
    width = label.shape[1]
    result = np.zeros((height, width), 'uint8')
    for j in range(height): 
        for k in range(width):
            try:
                result[j, k] = label_code2id[tuple(label[j, k])]
            except KeyError:
                print('Unknown label code: %s' % label[j, k])
                result[j, k] = label_code2id[tuple(np.array([0,0,0]))]

    return result

def main(input_path, output_path, image_height, image_width):
    label_colors_path = os.path.join('/data/Data/camvid', 'label_colors.txt')

    out_path = {'train': os.path.join(output_path, 'camvid-%dx%d-train' % (image_height, image_width)),
            'val': os.path.join(output_path, 'camvid-%dx%d-val' % (image_height, image_width)),
            'test': os.path.join(output_path, 'camvid-%dx%d-test' % (image_height, image_width))}

    os.makedirs(output_path, exist_ok=True)

    image_size = (image_width, image_height)
    _, _, label_code2id = load_label_colors(label_colors_path)
    print('Found %d classes' % len(label_code2id))

    for dataSplit in ['test', 'train', 'val']:
        print('loading', dataSplit)
        images_path = os.path.join(input_path, dataSplit)
        labels_path = os.path.join(input_path, dataSplit + 'annot')
        image_paths = glob.glob(os.path.join(images_path, '*.png'))

        num_images = len(image_paths)
        print('Found %d images' % num_images)

        # load all images in memory
        images = []
        labels = []
        for image_path in image_paths:
            print(image_path)
            image = open_image(image_path, image_size)
            images.append(image)
            label_path = os.path.join(labels_path, os.path.basename(image_path))
            label = open_image(label_path, image_size)
            # labels are already converted: Sky, Building, Pole, Road, Pavement, Tree, SignSymbol, Fence, Car, Pedestrian, Bicyclist, Unlabelled
            # label = conv_label(label, label_code2id)
            labels.append(label)

        images = np.array(images)
        labels = np.expand_dims(np.array(labels), 3)

        # labels = labels.reshape((labels.shape[0], labels.shape[1]*labels.shape[2]))
        labelList = np.unique(labels)
        numClasses = len(labelList)
        labelMap = np.arange(numClasses, dtype=np.uint8)
        assert(np.all(np.equal(labelMap, labelList)))
        labelMap[-1] = 255
        labels = labelMap[labels]

        print('images: %s, %s' % (images.shape, images.dtype))
        print('labels: %s, %s, %s' % (labels.shape, labels.dtype, np.unique(labels)))

        print('Writing:', out_path[dataSplit])
        np.savez(out_path[dataSplit], images=images, labels=labels)
        # write_tf_records(train_out_path, train_images, train_labels)

        images = images / 255.
        print('Train image mean: %f' % np.mean(images))
        print('Train image std.: %f' % np.std(images))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--input-path', help='input directory', default='/home//repos/SegNet-Tutorial/CamVid')
    parser.add_argument('--output-path', help='output directory', default='/home////data/camvid')
    parser.add_argument('--image-height', help='height for resizing images during loading', type=int, default=360)
    parser.add_argument('--image-width', help='width for resizing images during loading', type=int, default=480)
    args = parser.parse_args()
    main(**args.__dict__)
