import tensorflow as tf
import numpy as np
from tensorflow.data import Dataset
from skimage.filters import gaussian

#gaussian = None

"""==============================================================================="""

def parse_images(image_path):
    image_string = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image_string, channels=3)
    image = tf.image.convert_image_dtype(image, tf.float32)
    image = tf.image.resize(image, size=[224, 224])
    return image

"""==============================================================================="""

def sample_gaussian_density_map(index):
    p = np.random.choice(range(5,80), size=1)[0]/(56*56)
    dmap1 = np.random.choice([0,1], p=[1-p,p], size=(56,56))
    densitymap1 = gaussian(dmap1.astype(np.float64), sigma=(1.5,1.5), mode="reflect")
    densitymap1 = np.expand_dims(densitymap1, axis=-1)
    return densitymap1

"""==============================================================================="""

def sample_biased_gaussian_density_map(index, x=0, y=0, xoff=56, yoff=28):
    dmap1 = np.zeros((56,56))
    p = np.random.choice(range(5,80), size=1)[0]/(xoff*yoff)
    p = np.random.choice(range(5,80), size=1)[0]/(xoff*yoff)
    dmap1[x:x+xoff, y:yoff] = np.random.choice([0,1], p=[1-p,p], size=(xoff,yoff))
    densitymap1 = gaussian(dmap1.astype(np.float64), sigma=(1.5,1.5), mode="reflect")
    densitymap1 = np.expand_dims(densitymap1, axis=-1)
    return densitymap1

"""==============================================================================="""

def tf_biased_density_map(index):
    im_shape = (56,56,1)
    [dmap1,]= tf.py_function(sample_biased_gaussian_density_map, [index], [tf.float32])
    dmap1.set_shape(im_shape)
    return dmap1

"""==============================================================================="""

def tf_density_map(index):
    im_shape = (56,56,1)
    [dmap1,]= tf.py_function(sample_gaussian_density_map, [index], [tf.float32])
    dmap1.set_shape(im_shape)
    return dmap1


"""==============================================================================="""

def parse_file(examples, file_path, key):
    """
    Convenience function for joining a list of examples
    """
    return [file_path.rstrip("/") + "/" + d[key].lstrip("/") for d in examples]

"""==============================================================================="""

def get_image_ds(images):
    """
    Convenience function for converting an image list into an image dataset
    """
    img_ds = Dataset.from_tensor_slices(images)
    return img_ds.map(parse_images)

"""==============================================================================="""

class AbstractDSIterator(object):
    """
    Abstract dataset iterator for generating image datasets
    """
    def __init__(self, examples, file_path):
        self.examples = examples
        self.file_path = file_path

    def build_dataset(self, batch_size, drop_remainder=False):
        raise NotImplmentedError

"""==============================================================================="""
class ImgSingletDSIterator(AbstractDSIterator):
    """
    This class is responsible for taking a list of images and returning an
    img singlet dataset that produces batched single images
    """
    def __init__(self, examples, file_path):
        super().__init__(examples, file_path)
        self.img = [file_path.rstrip("/") + "/" + e.lstrip("/") for e in examples]

    def build_dataset(self, batch_size, drop_remainder=False):
        img_ds = Dataset.from_tensor_slices(self.img).shuffle(1024)
        img_ds = img_ds.map(parse_images, num_parallel_calls=-1)
        img_ds = img_ds.batch(batch_size, drop_remainder=drop_remainder)
        img_ds = img_ds.prefetch(5)
        return img_ds.repeat()

"""==============================================================================="""

class RankDSIterator(AbstractDSIterator):
    """
    This class is responsible for taking a formatted set of examples and
    converting them to an image dataset object, specific to the problem of
    siamese network ranking
    """
    def __init__(self, examples, file_path):
        super().__init__(examples, file_path)
        self.img_i = parse_file(examples, file_path, "image_i")
        self.img_j = parse_file(examples, file_path, "image_j")
        self.rank = [int(d["count_i"] > d["count_j"]) for d in examples]

    def make_rankpair_ds(self, img_i, img_j, rank):
        img_i_ds = parse_images(img_i)
        img_j_ds = parse_images(img_j)
        return (img_i_ds, img_j_ds), rank

    def build_dataset(self, batch_size, drop_remainder=False):
        img_i_ds = Dataset.from_tensor_slices(self.img_i)
        img_j_ds = Dataset.from_tensor_slices(self.img_j)
        rank_ds = Dataset.from_tensor_slices(self.rank)
        img_ds = Dataset.zip((img_i_ds, img_j_ds, rank_ds)).shuffle(1024)
        img_ds = img_ds.map(lambda x,y,z: self.make_rankpair_ds(x,y,z), num_parallel_calls=-1)
        img_ds = img_ds.batch(batch_size, drop_remainder=drop_remainder)
        img_ds = img_ds.prefetch(5)
        return img_ds

"""==============================================================================="""

class DMapCountDSIterator(AbstractDSIterator):
    def __init__(self, examples, file_path):
        super().__init__(examples, file_path)
        self.images = parse_file(examples, file_path, "image")
        self.counts = [int(d["count"]) for d in examples]

    def make_count_ds(self,x,y,z):
        x = parse_images(x)
        return x,y,z

    def build_dataset(self, batch_size, drop_remainder=False):
        image_ds = Dataset.from_tensor_slices(self.images)
        count_ds = Dataset.from_tensor_slices(self.counts)

        dmap_ds = Dataset.from_tensor_slices(np.arange(0,len(self.images)))
        dmap_ds = dmap_ds.map(tf_density_map, num_parallel_calls=-1)
        all_ds = Dataset.zip((image_ds, count_ds, dmap_ds)).shuffle(1024)

        all_ds = all_ds.map(lambda x,y,z: self.make_count_ds(x,y,z), num_parallel_calls=-1)
        all_ds = all_ds.batch(batch_size, drop_remainder=drop_remainder)
        all_ds = all_ds.prefetch(5)
        return all_ds


"""==============================================================================="""

class CountDSIterator(AbstractDSIterator):
    def __init__(self, examples, file_path):
        super().__init__(examples, file_path)
        self.images = parse_file(examples, file_path, "image")
        self.counts = [int(d["count"]) for d in examples]

    def build_dataset(self, batch_size, drop_remainder=False):
        image_ds = get_image_ds(self.images)
        count_ds = Dataset.from_tensor_slices(self.counts)

        counting_ds = Dataset.zip((image_ds, count_ds))
        counting_ds = counting_ds.shuffle(1024).batch(batch_size, drop_remainder=drop_remainder)
        return counting_ds

"""==============================================================================="""

class DMapDSIterator(RankDSIterator):
    def __init__(self, examples, file_path):
        super().__init__(examples, file_path)

    def make_rankpair_ds(self, images, dmaps):
        img_i, img_j, rank_ij  = images

        img_i_ds = parse_images(img_i)
        img_j_ds = parse_images(img_j)

        return dmaps, img_i_ds, img_j_ds, rank_ij
    
    def density_map_func(self, index):
        return tf_density_map(index)

    def build_dataset(self, batch_size, drop_remainder=False):
        img_i_ds = Dataset.from_tensor_slices(self.img_i)
        img_j_ds = Dataset.from_tensor_slices(self.img_j)
        rank_ds = Dataset.from_tensor_slices(self.rank)
        img_ds = Dataset.zip((img_i_ds, img_j_ds, rank_ds)).shuffle(1024)

        dmap_ds = Dataset.from_tensor_slices(np.arange(0,len(self.img_i)))
        dmap_ds = dmap_ds.map(self.density_map_func, num_parallel_calls=-1)
        all_ds = Dataset.zip((img_ds, dmap_ds))

        all_ds = all_ds.map(lambda x,y: self.make_rankpair_ds(x,y), num_parallel_calls=-1)
        all_ds = all_ds.batch(batch_size, drop_remainder=drop_remainder)
        all_ds = all_ds.prefetch(5)
        return all_ds
    
class BiasedDmapDSIterator(DMapDSIterator):
    def __init__(self, examples, file_path):
        super().__init__(examples, file_path)
        
    def density_map_func(self, index):
        return tf_biased_density_map(index)
        
        

