import os
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
#from tensorflow.keras.datasets import cifar10
SIZE = 224
CROP_FRACTION = 0.875
MEAN_IMAGENET = tf.constant([0.485, 0.456, 0.406], shape=[3], dtype=tf.float32)
STD_IMAGENET  =  tf.constant([0.229, 0.224, 0.225], shape=[3], dtype=tf.float32)
DIVISOR = tf.cast(1.0 /255.0 , tf.float32)
STD_DIVISOR = tf.cast(1.0 / STD_IMAGENET, tf.float32)
def normalize(image):
    image = tf.cast(image, tf.float32)

    #image = image * DIVISOR
    image = image - ( MEAN_IMAGENET*255)
    #image = image * STD_DIVISOR

    return image

def normalize_vgg(image):
    image = tf.cast(image, tf.float32)

    #image = image * DIVISOR*255
    image = image -( MEAN_IMAGENET*255)
    #image = image * STD_DIVISOR
    #image = image*128
    return image
def normalize_clipped(image):
    image = tf.cast(image, tf.float32)

    #image = image * DIVISOR*255
    #image = image -( MEAN_IMAGENET*255)
    #image = image * STD_DIVISOR
    #image = image*128
    image = tf.clip_by_value(image,0.,255.)
    return image
def _distorted_bounding_box_crop(image,
                                bbox,
                                min_object_covered=0.1,
                                aspect_ratio_range=(0.75, 1.33),
                                area_range=(0.05, 1.0),
                                max_attempts=100,
                                scope=None):
    shape = tf.shape(image)
    sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
        shape,
        bounding_boxes=bbox,
        min_object_covered=min_object_covered,
        aspect_ratio_range=aspect_ratio_range,
        area_range=area_range,
        max_attempts=max_attempts,
        use_image_if_no_bounding_boxes=True)
    bbox_begin, bbox_size, _ = sample_distorted_bounding_box

      # Crop the image to the specified bounding box.
    offset_y, offset_x, _ = tf.unstack(bbox_begin)
    target_height, target_width, _ = tf.unstack(bbox_size)
   

    image = tf.image.crop_to_bounding_box(image, offset_y, offset_x, target_height, target_width)

    return image


def _center_crop(image, image_size = 224):
    shape = tf.shape(image)
    image_height = shape[0]
    image_width = shape[1]

  # crop_fraction = image_size / (image_size + crop_padding)
    crop_padding = round(image_size * (1/CROP_FRACTION - 1))
    padded_center_crop_size = tf.cast(
      ((image_size / (image_size + crop_padding)) *
       tf.cast(tf.minimum(image_height, image_width), tf.float32)),
      tf.int32)

    offset_height = ((image_height - padded_center_crop_size) + 1) // 2
    offset_width = ((image_width - padded_center_crop_size) + 1) // 2
    crop_window = tf.stack([offset_height, offset_width,
                          padded_center_crop_size, padded_center_crop_size])

    image = tf.image.crop_to_bounding_box(image, offset_height, offset_width,
                                          padded_center_crop_size, padded_center_crop_size)
    image = tf.image.resize([image], [image_size, image_size], method='bicubic')[0]

    return image

def _at_least_x_are_equal(a, b, x):
    match = tf.equal(a, b)
    match = tf.cast(match, tf.int32)
    return tf.greater_equal(tf.reduce_sum(match), x)

def _random_crop(image, image_size):
    bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
    image_new = _distorted_bounding_box_crop(
        image,
        bbox,
        min_object_covered=0.1,
        aspect_ratio_range=(3. / 4, 4. / 3.),
        area_range=(0.08, 1.0),
        max_attempts=10,
        scope=None)

    original_shape = tf.shape(image)
    bad = _at_least_x_are_equal(original_shape, tf.shape(image_new), 3)

    image = tf.cond(
        bad,
        lambda: _center_crop(image, image_size),
        lambda: tf.image.resize([image_new], [image_size, image_size], method='bicubic')[0])

    return image

def random_size(img, size_min=256, size_max = 480):
    sub_shape = tf.shape(img)[:2]
    coeff = tf.random.uniform((), minval=size_min, maxval=size_max) / tf.cast(tf.reduce_min(sub_shape), tf.dtypes.float32)
    new_shape = tf.cast(tf.cast(sub_shape, tf.dtypes.float32) * coeff, tf.dtypes.int32)
    #tf.print(coeff,sub_shape,new_shape)
    return  tf.image.resize(img, new_shape)
   
def h_ratio(sub_shape,coeff):
    return (sub_shape[0],tf.cast(tf.cast(sub_shape[1], tf.dtypes.float32) * coeff, tf.dtypes.int32))

def w_ratio(sub_shape,coeff):
    return (tf.cast(tf.cast(sub_shape[0], tf.dtypes.float32) * coeff, tf.dtypes.int32),sub_shape[1])

def random_ratio(img, minval=0.8, maxval = 1.25):
    sub_shape = tf.shape(img)[:2]
    coeff = tf.random.uniform((), minval=minval, maxval=maxval)
    #new_shape = tf.cast(tf.cast(sub_shape, tf.dtypes.float32) * coeff, tf.dtypes.int32)
    #tf.print(coeff,sub_shape,new_shape)
    new_shape = tf.cond(tf.less(sub_shape[1], sub_shape[0]),
               true_fn = lambda: w_ratio(sub_shape,coeff),
               false_fn =lambda: h_ratio(sub_shape,coeff))
    #tf.print(coeff, sub_shape, new_shape)
    return  tf.image.resize(img, new_shape)


def random_chanel_contrast(img, minval=0.8, maxval = 1.2):
    coeff = tf.random.uniform([3], minval=minval, maxval=maxval)
    return img * coeff
    
def random_chanel_brightness(img, val =1 ):
    coeff = tf.random.uniform([3], minval=-val, maxval=val)
    return img + coeff

def random_pca(image,pca_std = 0.1):
    eigval = tf.transpose(tf.convert_to_tensor ([[55.46, 4.794, 1.148]]))
    eigvec = tf.convert_to_tensor ([[-0.5836, -0.6948, 0.4203],
          [-0.5808, -0.0045, -0.8140],
          [-0.5675, 0.7192, 0.4009]])
    alpha = tf.random.normal((3,),0, pca_std)
    offset = (alpha*eigvec )@ eigval
    image = image + tf.squeeze(offset)
    return tf.clip_by_value(image,0,255)


def augment_train(preproc_func, contrast_min = 1., contrast_max = 1., bright = 0.,cutout = False):
    def augment_func(x, label):
        x = tf.cast(x, tf.int32)
        
        x = _random_crop(x, 224)
        
        if contrast_min != contrast_max:
            
            x = random_chanel_contrast(x, contrast_min, contrast_max)
        if bright != 0.:
            x = random_chanel_brightness(x, val = bright)

        x = tf.image.random_flip_left_right(x)
        x = preproc_func(x)
        if cutout:
            x =  tf.expand_dims( x, axis=0)
            x = tf.squeeze(tfa.image.random_cutout(x, (64,64), constant_values = 0.), axis=0)
        label = tf.one_hot(label, 1000, name='label', axis=-1)
        return x,label
    return augment_func

def augment_test(preproc_func):
    def augment_func(x, label):
        x = tf.cast(x, tf.int32)
        x = _center_crop(x,224)
        x = preproc_func(x)
        #x = random_size(x, size_min=256, s
        label = tf.one_hot(label, 1000, name='label', axis=-1)
        return x,label
    return augment_func  
    
def batch_test_set(dataset, batch_size,scale_func,aug = False):
    dataset = dataset.map(scale_func,num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dataset  

def batch_train_set(dataset, batch_size,scale_func, 
                    shuffle = 0):

    dataset = dataset.repeat()
    if shuffle>0:
        print("shuffle",shuffle)
        dataset = dataset.shuffle(shuffle,reshuffle_each_iteration=True)
    dataset = dataset.map(scale_func,num_parallel_calls=tf.data.experimental.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dataset


def imagenet_dataset(batch_size,
                     preprocess = "VGG",
                     shuffle = 0,
                     contrast_min = 1., 
                     contrast_max = 1., 
                     cutout = False,
                     bright = 0.,
                     shuffle_files = True,
                     write_dir = '/mnt/terminus/imagenet/extracted/',
                     compute_train_val = False, verbose = False):
    
    temp = tf.zeros([4, 32, 32, 3])
    v= tf.keras.applications.densenet.preprocess_input(temp)
    if preprocess == "VGG" :
        preproc_func = normalize_vgg
    elif preprocess == "clipped" :
        preproc_func =  normalize_clipped
    elif preprocess == "resnet" :
        preproc_func = normalize
    else :
        preproc_func = preprocess
    if verbose :
        print("batch_size :",batch_size)
        print("preprocess :",preprocess)
        print("shuffle :",shuffle)
        print("contrast_min :",contrast_min)
        print("contrast_max :",contrast_max)
        print("bright :",bright)
        
    labels_path = tf.keras.utils.get_file('ImageNetLabels.txt','https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
    imagenet_labels = np.array(open(labels_path).read().splitlines())
    datasets, info = tfds.load(name="imagenet2012", 
                               with_info=True, 
                               as_supervised=True, 
                               download=False, 
                               shuffle_files = shuffle_files,
                               data_dir=os.path.join(write_dir, 'data')
                            )
    
    train = batch_train_set(datasets['train'], 
                            batch_size,
                            augment_train(preproc_func,
                                          contrast_min = contrast_min, 
                                          contrast_max = contrast_max, 
                                          cutout = cutout,
                                          bright = bright),
                            shuffle=shuffle) 
    val = batch_test_set(datasets['validation'], batch_size,augment_test(preproc_func))
    if compute_train_val :
        train_val = batch_test_set(datasets['train'], batch_size,augment_test(preproc_func))
        return train, val,train_val, info
    else :
        return train, val, info