import tensorflow as tf
import pickle

def load_natural_ds(image_size=20):
    if image_size == 20:
        path="/home/data/datasets/fakelabeled_natural_commonfiltered_640000_20px.pkl"
    elif image_size == 40:
        path="/home/data/datasets/fakelabeled_natural_commonfiltered_640000_40px.pkl"
    elif image_size == 50:
        path="/home/data/datasets/fakelabeled_natural_commonfiltered_640000_50px.pkl"
    else:
        raise ValueError('Image size can be either 20px, 40px or 50px')
        
    with open(path, 'rb') as dataset_file:
        data = pickle.load(dataset_file)

    return data


def load_texture_ds(image_size=20):
    if image_size == 20:
        path = "/home/data/datasets/labeled_texture_oatleathersoilcarpetbubbles_commonfiltered_640000_20px.pkl"
    elif image_size == 40:
        path = "/home/data/datasets/labeled_texture_oatleathersoilcarpetbubbles_commonfiltered_640000_40px.pkl"
    elif image_size == 50:
       path = "/home/data/datasets/labeled_texture_oatleathersoilcarpetbubbles_commonfiltered_640000_50px.pkl" 
    else:
        raise ValueError('Image size can be either 20px, 40px or 50px')

    with open(path, 'rb') as dataset_file:
        data = pickle.load(dataset_file) 

    return data

def load_ds_from_file(path):
    with open(path, 'rb') as dataset_file:
        data = pickle.load(dataset_file) 

    return data 

def get_natural_ds(batch_size=32, image_size=20, subsample=1):
    data = load_natural_ds(image_size=image_size)

    ds_train=tf.data.Dataset.from_tensor_slices((data["train_images"][::subsample],data["train_labels"][::subsample]))
    ds_test=tf.data.Dataset.from_tensor_slices((data["test_images"],data["test_labels"]))
    
    ds_train_proc=ds_train.map(preprocess_natural).shuffle(int(10e6)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    ds_test_proc=ds_test.map(preprocess_natural).shuffle(int(10e6)).batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return ds_train_proc, ds_test_proc        
    
def get_texture_ds(batch_size=32, image_size=20):
    data = load_texture_ds(image_size=image_size)

    ds_train=tf.data.Dataset.from_tensor_slices((data["train_images"],data["train_labels"]))
    ds_test=tf.data.Dataset.from_tensor_slices((data["test_images"],data["test_labels"]))

    ds_train_proc=ds_train.map(preprocess_texture).shuffle(int(10e6)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    ds_test_proc=ds_test.map(preprocess_texture).shuffle(int(10e6)).batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return ds_train_proc, ds_test_proc

def preprocess_natural(sample, label):
    image = tf.cast(sample, tf.float32)
    return image, image

def preprocess_texture(sample, label):
    image = tf.cast(sample, tf.float32)
    return image, label