import torch
import torchvision
from io import BytesIO
from PIL import Image
import torchvision.transforms.functional as F
import numpy as np
from utils import quantization
import random
import kornia

transform_to_pil = torchvision.transforms.ToPILImage()
transform_to_tensor = torchvision.transforms.ToTensor()


def jpeg_compression(img, quality_factor):
    # img: tensor
    device = img.device
    # img = transform_to_pil(ps(img).squeeze(dim=0).cpu())
    img = transform_to_pil(img.squeeze(dim=0).cpu())
    outputIoStream = BytesIO()
    img.save(outputIoStream, "JPEG", quality=quality_factor)
    outputIoStream.seek(0)
    img_jpeg = Image.open(outputIoStream)
    img_jpeg = transform_to_tensor(img_jpeg).unsqueeze(dim=0).to(device)
    return img_jpeg


def gaussian_noise(img, mean=0, sigma=10):
    return  quantization(img + torch.randn(img.shape).mul_(sigma/255).to(img.device))


def poisson_noise(img, lam):
    return quantization(img + (torch.from_numpy(np.random.poisson(lam=lam, size=img.shape))/255).to(img.device))


def gaussian_blur(img, kernel_size=5):
    return quantization(F.gaussian_blur(img, kernel_size=kernel_size))


class MedianBlur(torch.nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        # self.random = True
        self.kernel = kernel_size


    def forward(self, img, _=None):
        img_blurred = kornia.filters.median_blur(img, (self.kernel, self.kernel))
        img_blurred=quantization(img_blurred)
        return img_blurred
    
def median_blur(img, kernel_size=5):
    MB = MedianBlur(kernel_size)
    return MB(img)


def salt_and_pepper(img, prob):
    noise_tensor=torch.rand(1, 1, img.shape[2], img.shape[3]).repeat(1, 3, 1, 1)
    img[noise_tensor<prob/2] = 1.
    img[noise_tensor>(1-prob/2)] = 0.
    return img


def hue_shifting(img, factor=0.1):
    factor = factor if random.randint(0, 1) % 2 == 0 else factor * (-1)
    hue_change = torchvision.transforms.ColorJitter(hue=(factor, factor))
    return quantization(hue_change(img))


def brightness_shifting(img, factor=0.2):
    factor = factor if random.randint(0, 1) % 2 == 0 else factor * (-1)
    factor += 1
    brightness_change = torchvision.transforms.ColorJitter(brightness=(factor, factor))
    return quantization(brightness_change(img))


def contrast_shifting(img, factor=0.2):
    factor = factor if random.randint(0, 1) % 2 == 0 else factor * (-1)
    factor += 1
    contrast_change = torchvision.transforms.ColorJitter(contrast=(factor, factor))
    return quantization(contrast_change(img))


def saturation_shifting(img, factor=0.2):
    factor = factor if random.randint(0, 1) % 2 == 0 else factor * (-1) 
    factor += 1
    saturation_change = torchvision.transforms.ColorJitter(saturation=(factor, factor))
    return quantization(saturation_change(img))


def dropout(img, wm_img, prop=0.3):
    # Drops random pixels from the noised image and substitues them with the pixels from the cover image
    mask = np.random.choice([0.0, 1.0], wm_img.shape[2:], p=[1 - prop, prop])
    mask_tensor = torch.tensor(mask, device=wm_img.device, dtype=torch.float)
    mask_tensor = mask_tensor.expand_as(wm_img).to(wm_img.device)
    dropouted_wm_img = img * mask_tensor + wm_img * (1-mask_tensor)

    return dropouted_wm_img


def cropout(img, wm_img, prop=0.3):
    
    prop = pow((1-prop), 0.5)
    img_height = wm_img.shape[2]
    img_width = wm_img.shape[3]

    crop_height = int(img_height * prop)
    crop_width = int(img_width * prop)

    h_start = np.random.randint(0, img_height - crop_height)
    w_start = np.random.randint(0, img_width - crop_width)
    h_end = h_start + crop_height
    w_end = w_start + crop_width

    cropout_mask = torch.zeros_like(wm_img)
    cropout_mask[:, :, h_start:h_end, w_start:w_end] = 1
    cropouted_wm_img = wm_img * cropout_mask + img * (1-cropout_mask)

    return cropouted_wm_img
    

def crop(wm_img, prop=0.96):
    
    prop = pow(prop, 0.5)
    img_height = wm_img.shape[2]
    img_width = wm_img.shape[3]

    crop_height = int(img_height * prop)
    crop_width = int(img_width * prop)

    h_start = np.random.randint(0, img_height - crop_height)
    w_start = np.random.randint(0, img_width - crop_width)

    h_end = h_start + crop_height
    w_end = w_start + crop_width
    
    croped_img = torch.ones_like(wm_img)
    croped_img[:, :, h_start:h_end, w_start:w_end] = wm_img[:, :, h_start:h_end, w_start:w_end]
    return croped_img


def resize_wm_img(img, scale_ratio, img_size):
    new_size = int(scale_ratio * img_size)
    resize = torchvision.transforms.Resize((new_size))
    resize_back = torchvision.transforms.Resize((img_size))
    return quantization(resize_back(resize(img)))
    # return quantization(resize(img))


def rotate_wm_img(img, angle=90):
    img_rotated = torchvision.transforms.functional.rotate(img, angle)
    rotate_back = torchvision.transforms.functional.rotate(img_rotated, -1*angle)
    return rotate_back


def crop_resize(wm_img, crop_prop, img_size):

    prop = pow(crop_prop, 0.5)
    img_height = wm_img.shape[2]
    img_width = wm_img.shape[3]

    crop_height = int(img_height * prop)
    crop_width = int(img_width * prop)

    h_start = np.random.randint(0, img_height - crop_height)
    w_start = np.random.randint(0, img_width - crop_width)

    h_end = h_start + crop_height
    w_end = w_start + crop_width

    geometric_croped_img = wm_img[:, :, h_start:h_end, w_start:w_end]
    # print(geometric_croped_img.shape)

    # resize back to 256 x 256
    resize = torchvision.transforms.Resize((img_size))
    geometric_croped_img = resize(geometric_croped_img)
    # print(geometric_croped_img.shape)

    return geometric_croped_img




