import collections
import tensorflow as tf
import tensorflow_datasets as tfds
import logging
import tensorflow_io as tfio

#rescale individual axes of the lab color space
def scale_lab_to_01(lab_image):
    L_scaled = lab_image[..., 0] / 100.0    
    a_scaled = (lab_image[..., 1] + 110) / 220.0
    b_scaled = (lab_image[..., 2] + 110) / 220.0
    return tf.stack([L_scaled, a_scaled, b_scaled], axis=-1)

#preprocess each sample individually. transform images to multiple color spaces and concat them
def preprocess_clevrtex(features, resolution, apply_crop=False,
                     get_properties=True):
  image = tf.cast(features["image"], dtype=tf.float32)
  image = ((image / 255.0) - 0.5) * 2.0
  image_float = tf.cast(features["image"], tf.float32) / 255.0

  #use build in tensorflow function to transform images to HSV and LAB color space.
  image_hsv = tf.image.rgb_to_hsv(image_float) 
  image_lab_2 = tfio.experimental.color.rgb_to_lab(image_float)
  image_lab_2 = (scale_lab_to_01(image_lab_2) -0.5) * 2
  gray = tfio.experimental.color.rgb_to_grayscale(image_float)

  #apply center crop for all images
  crop = ((119 - 96, 119 + 96), (159 - 96, 159 + 96)) 
  image = image[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1], :]
  image_hsv = image_hsv[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1], :]
  image_lab_2 = image_lab_2[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1], :]
  gray = gray[crop[0][0]:crop[0][1], crop[1][0]:crop[1][1], :]
  
  #resize all images to fit 128x128 resolution
  image = tf.image.resize(
      image, resolution, method=tf.image.ResizeMethod.BILINEAR)
  image_hsv = tf.image.resize(
      image_hsv, resolution, method=tf.image.ResizeMethod.BILINEAR)
  image_lab_2 = tf.image.resize(
      image_lab_2, resolution, method=tf.image.ResizeMethod.BILINEAR)
  gray = tf.image.resize(
      gray, resolution, method=tf.image.ResizeMethod.BILINEAR)
  #normalize images to fit in range [-1.1]
  image_rgbhsv = tf.concat([image, image_hsv], axis=2)
  image_rgbsv = tf.concat([image, image_hsv[:,:,1:3]], axis=2)
  image_rgbs = tf.concat([image, image_hsv[:,:,1:2]] , axis=2)
  features = {"image": image,"hsv": image_hsv,"image_rgbhsv": image_rgbhsv,"image_rgbsv": image_rgbsv,"image_rgbs": image_rgbs, "gray": (gray -0.5) * 2, "lab2": image_lab_2}
  return features

#start building clevrtex,
def build_clevrtex(split="train", resolution=(128, 128)):
  ds, info = tfds.load("clevr_tex:2.0.0", with_info=True, split=split, shuffle_files=True)
  def _preprocess_fn(x, resolution):
    return preprocess_clevrtex(
        x, resolution)
  ds = ds.map(lambda x: _preprocess_fn(x, resolution))
  return ds

#return batched clevrtex version
def build_clevrtex_iterator(batch_size, split, **kwargs):
  #hardcode split to fit train
  split = "train"
  ds = build_clevrtex(split=split, **kwargs)
  ds = ds.repeat(-1)
  ds = ds.batch(batch_size, drop_remainder=True)
  return iter(ds)

