import numpy as np
# import cv2
import os
import math
import random

import scipy.misc
import scipy.ndimage as ndi

# import skimage


from PIL import Image
import torch
import torchvision

def foreground_segmentation(img):
    mask = np.zeros((128,128),np.uint8)
    ## Foreground Segmentation
    bgdModel = np.zeros((1,65),np.float64)
    fgdModel = np.zeros((1,65),np.float64)
    rect = (1,1,126,126)
    cv2.grabCut(img,mask,rect,bgdModel,fgdModel,15,cv2.GC_INIT_WITH_RECT)
    mask2 = np.where((mask==2)|(mask==0),0,1).astype(np.uint8)
    return mask2


def preprocess(img):
    # temp = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    temp = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
    temp[:,:,1] = cv2.equalizeHist(temp[:,:,1])
    temp[:,:,2] = cv2.equalizeHist(temp[:,:,2])
    temp = cv2.cvtColor(temp, cv2.COLOR_HSV2BGR)
    temp = cv2.GaussianBlur(temp,(7,7),0)

    mask = foreground_segmentation(temp)
    mask_larger = cv2.resize(mask,(138,138))
    mask_smaller = cv2.resize(mask,(90,90))


    content = img.copy()
    content = cv2.cvtColor(content,cv2.COLOR_BGR2GRAY)
    content = cv2.GaussianBlur(content,(7,7),0)
    content = cv2.equalizeHist(content)
    content = cv2.Canny(content, 20, 50)
    mask[:,:] = 0
    mask[19:109,19:109] = mask_smaller[:,:]
    print(mask.shape)
    print(content.shape)
    content *= mask

    cv2.imwrite('img_sample.jpg', content)


def add_noise(img, img_save, mode):
    """Add noise to an image."""

    img = skimage.io.imread(img) / 255.0
    if mode == "gaussian":
        gimg = skimage.util.random_noise(img, mode=mode, mean=0, var=0.02)
    else:
        gimg = skimage.util.random_noise(img, mode=mode)
    skimage.io.imsave(img_save, gimg * 255.0)


def blur_image(image, img_save, sigma, mode):
    """Applies gaussian blur to an image."""

    img = skimage.io.imread(image) / 255.0
    if mode == "gaussian":
        bimg = skimage.filters.gaussian(img, sigma=sigma)
    else:
        raise ValueError
    skimage.io.imsave(img_save, bimg * 255.0)


def splat_image(image, img_save, diameter=10):
    """Splatter an image."""

    initial_img = cv2.imread(image, cv2.IMREAD_COLOR)
    splatter_img = initial_img.copy()

    def isOutside(r, c, shape):
        if (r < 0 or r >= shape[0]):
            return True
        if (c < 0 or c >= shape[1]):
            return True
        return False

    for r in range(initial_img.shape[0]):
        for c in range(initial_img.shape[1]):
            r1 = r + math.ceil(random.uniform(-0.5, 0.5) * diameter)
            c1 = c + math.ceil(random.uniform(-0.5, 0.5) * diameter)

            if(isOutside(r1, c1, initial_img.shape)):
                splatter_img.itemset((r, c, 0), 0)
                splatter_img.itemset((r, c, 1), 0)
                splatter_img.itemset((r, c, 2), 0)
            else:
                splatter_img.itemset((r, c, 0), initial_img.item((r1, c1, 0)))
                splatter_img.itemset((r, c, 1), initial_img.item((r1, c1, 1)))
                splatter_img.itemset((r, c, 2), initial_img.item((r1, c1, 2)))
    
    cv2.imwrite(img_save, splatter_img)
    # cv2.imshow('SplatterImage', splatter_img)
    # cv2.waitKey()
    # cv2.destroyAllWindows()


def elastic_transform(img, img_save, alpha):
    """Elastic tranform an image. Require disent-py310 environment."""

    orig_img = Image.open(img)
    elastic_transformer = torchvision.transforms.ElasticTransform(alpha=alpha)
    transformed_imgs = elastic_transformer(orig_img)
    transformed_imgs.save(img_save)


def random_posterize(img, img_save, bits):
    """Elastic tranform an image. Require disent-py310 environment."""

    orig_img = Image.open(img)
    posterize_transformer = torchvision.transforms.RandomPosterize(bits=bits)
    transformed_imgs = posterize_transformer(orig_img)
    transformed_imgs.save(img_save)


def wood_noise(img, img_save, num_rings, normalising_constant):
    """We make the wood texture. This is done by taking the basic function f(x,y) = sin(x^2 + y^2), which maps concentric circles,
    and then adding some (optimisable) turbulence to the inputs to optimse the shape of the wood.
    """

    orig_img = Image.open(img)
    width = orig_img.width
    height = orig_img.height
    orig_img = np.asarray(orig_img)


    sin_frequency = math.pi * num_rings

    epsilon = 2
    noise_resolution = 256
    # noise = epsilon * torch.rand((noise_resolution, noise_resolution))
    noise = epsilon * torch.zeros((noise_resolution, noise_resolution))

    # We generate the coordinates required of the noise
    xs, ys = torch.meshgrid(
        torch.arange(0, width), torch.arange(0, height), indexing="xy"
    )
    xs, ys = xs.float(), ys.float()

    # Make the corrdinates between -1 and 1
    x_value = 2 * (xs - width / 2) / width
    y_value = 2 * (ys - height / 2) / height

    # We apply f(x,y) = sin(x^2 + y^2) to the coordinatesm and also apply an
    # interpolation of the noise to the coordinates. The interpolation is
    # done to make the distortion more smooth.
    dist = torch.sqrt(x_value * x_value + y_value * y_value) + noise
    sin_value = torch.abs(torch.sin(sin_frequency * dist)) ** (1 / normalising_constant)
    sin_value = sin_value.unsqueeze(2)

    wood_img = np.asarray(torch.tensor(orig_img) * sin_value, dtype=np.uint8)
    img = Image.fromarray(wood_img)
    img.save(img_save)


def attach_stickers(img, sticker_path, img_save):
    orig_img = Image.open(img)
    sticker = Image.open(sticker_path)

    border_x , border_y = 30, 20
    x = random.randint(border_x, orig_img.width - border_x - sticker.width)
    y = random.randint(border_y, border_y + 30)
    orig_img.paste(sticker, (x, y), mask=sticker)

    orig_img.save(img_save)


if __name__ == '__main__':
    # zip_path = './data/images_colored/traffic_8000_3x128x128.npz'
    # zip_file = np.load(zip_path)
    # imgs = zip_file['imgs']
    # img_sample = imgs[0].transpose(1, 2, 0)

    # img_sample = cv2.imread('img_cut.jpg', cv2.IMREAD_COLOR)
    # print(img_sample)
    # img_sample.resize((128,128,3))

    # img = blur_image(img_sample, kernel_size=(7,7), sigma=0)
    # cv2.imwrite('img_blur.jpg', img)
    # preprocess(img_sample)


    traffic_obj = ['deerCrossing', 'warning', 'workersAhead', 'leftCurve']
    img_root = '/data/open-datasets/traffic/train'
    img_save_root = '/data/open-datasets/traffic/val'
    sticker_path = '/data/open-datasets/traffic/sticker.png'
    attack = 'sticker' ## 'g_noise', 'g_blur', 'splatter'

    img_save_path = os.path.join(img_save_root, attack)

    for obj in traffic_obj:
        obj_count = 25
        img_path = os.path.join(img_root, obj)
        for i in os.listdir(img_path):
            if obj_count > 0:
                img = os.path.join(img_path, i)
                gimg = os.path.join(img_save_path, obj+'_'+i)
                # add_noise(img, gimg, mode='gaussian') ## Gaussian noise
                # blur_image(img, gimg, sigma=2, mode="gaussian") ## Gaussian blur
                # elastic_transform(img, gimg, 60.)
                attach_stickers(img, sticker_path, gimg)
                # splat_image(img, gimg, diameter=6)
                obj_count -= 1


    img_test = '/data/open-datasets/traffic/train/deerCrossing/0_col0_scal0_rot0_pos0.png'
    img_save_test = '/data/open-datasets/traffic/val/deerCrossing_0_col0_scal0_rot0_pos0.png'

    # blur_image(img_test, img_save_test, sigma=2)
    # splat_image(img_test, img_save_test, diameter=6)
    # elastic_transform(img_test, img_save_test, 60.)
    # random_posterize(img_test, img_save_test, 8)
    # wood_noise(img_test, img_save_test, 3, 10)
    # attach_stickers(img_test, stick_path, img_save_test)
