import math
import numpy as np
import tensorflow as tf
from albumentations import Compose, OneOf, Flip, Transpose, Rotate, RandomRotate90, GridDistortion, ElasticTransform, OpticalDistortion, RandomBrightness, RandomContrast, RandomGamma


class Generator(tf.keras.utils.Sequence):

    def __init__(self, x_set, y_set, thresh=50, batch_size=1, augment=True):
        self.x = x_set
        self.y = y_set
        self.thresh = thresh
        self.batch_size = batch_size
        self.augment = augment
        self.augmentations = Compose(
        [
            OneOf([
                  Flip(p=0.6),
                  Transpose(p=0.4),
            ], p=0.8),

            OneOf([
                  Rotate(border_mode=0, p=0.4),
                  RandomRotate90(p=0.6),
            ], p=0.8),

            OneOf([
                  GridDistortion(border_mode=0, p=0.5),
                  ElasticTransform(border_mode=0, p=0.5),
                  OpticalDistortion(border_mode=0, p=0.5)
            ], p=0.6),

            OneOf([
                  RandomBrightness(p=0.4),
                  RandomContrast(p=0.4),
                  RandomGamma(p=0.2)
            ], p=0.8),
        ])

    def __len__(self):
        return math.ceil(len(self.x) / self.batch_size)

    def __getitem__(self, idx):
        batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]

        if self.augment is True:
            aug = [self.augmentations(image=i, mask=j) for i, j in zip(batch_x, batch_y)]
            batch_x = np.array([i['image'] for i in aug])
            batch_y = np.array([i['mask'] for i in aug])

        batch_y = np.expand_dims(batch_y, -1)
        batch_y = (batch_y>self.thresh).astype(np.uint8)

        return batch_x/255, batch_y/1


class Dataset:

    def __init__(self, images, masks, thresh, batch_size, augment):
        self.images = images
        self.masks = masks
        self.thresh = thresh
        self.batch_size = batch_size
        self.augment = augment

    def generator_func(self):
        generator = Generator(self.images, self.masks, self.thresh, self.batch_size, self.augment)
        return generator

    def get_generator(self):
        generator = tf.data.Dataset.from_generator(
            self.generator_func,
            output_types=(tf.float32, tf.float32),
            output_shapes=((None, self.images.shape[1], self.images.shape[2], self.images.shape[3]), (None, self.masks.shape[1], self.masks.shape[1], 1))
        )
        generator = generator.repeat()
        generator = generator.prefetch(tf.data.experimental.AUTOTUNE)
        return generator
