import random
import numpy as np
import PIL.Image as Image
import torch

random.seed(1729)


class AddUniformNoise(object):
    """
    Args:
        snr （float）: Signal Noise Rate
        p (float): Probability
    """

    def __init__(self, snr, p=0.9):
        assert isinstance(snr, float) or (isinstance(p, float))
        self.snr = snr
        self.p = p

    def __call__(self, img):
        """
        Args:
            img (PIL Image): PIL Image
        Returns:
            PIL Image: PIL image.
        """
        if random.uniform(0, 1) < self.p:
            img_ = np.array(img).copy()
            c, h, w = img_.shape
            # signal_pct = self.snr
            noise_pct = random.randint(1, 3) * 0.1
            signal_pct = (1 - noise_pct)

            mask = np.random.choice((0, 1, 2), size=(1, h, w), p=[signal_pct, noise_pct / 2., noise_pct / 2.])
            mask = np.repeat(mask, c, axis=0)
            # img_[mask == 1] = 1.
            # img_[mask == 2] = 0.
            random_val = np.random.rand(c, h, w)
            img_[mask == 1] = random_val[mask == 1]
            img_[mask == 2] = random_val[mask == 2]
            return torch.from_numpy(img_)
            # return Image.fromarray(img_.astype('uint8')).convert('RGB')
        else:
            return img


class AddPepperNoise(object):
    """
    Args:
        snr （float）: Signal Noise Rate
        p (float): Probability
    """

    def __init__(self, snr=None, p=None):
        assert isinstance(snr, float) or (isinstance(p, float))
        self.snr = snr
        self.std = snr
        self.mean = 0.

    def __call__(self, img):
        """
        Args:
            img (PIL Image): PIL Image
        Returns:
            PIL Image: PIL image.
        """

        # if random.uniform(0, 1) < self.p:
        #     img_ = np.array(img).copy()
        #     c, h, w = img_.shape
        #     signal_pct = self.snr
        #     noise_pct = (1 - self.snr)
        #
        #     mask = np.random.choice((0, 1, 2), size=(1, h, w), p=[signal_pct, noise_pct / 2., noise_pct / 2.])
        #
        #     mask = np.repeat(mask, c, axis=2)
        #     img_[mask == 1] = 255
        #     img_[mask == 2] = 0
        #     return torch.from_numpy(img_)
        if self.std is not None:
            return img + torch.randn(img.size()) * self.std + self.mean
        else:
            return img


class BlockMnist(object):
    """
    Args:
        snr （float）: Signal Noise Rate
        p (float): Probability
    """

    def __init__(self, test=False):
        self.test = test
        if test:
            self.count = 0
            self.loc_arr = np.load('null_block_loc.npy')
        # null_block = np.load('null_block.npy')
        # null_block = torch.from_numpy(null_block)
        # null_block = null_block.float()
        # torch.save(null_block, 'null_block.pt')
        self.null_block = torch.load('null_block.pt')

    def __call__(self, img):
        """
        Args:
            img (PIL Image): PIL Image
        Returns:
            PIL Image: PIL image.
        """

        # img = torch.hstack((img, self.null_block))
        if self.test:
            top = self.loc_arr[self.count]
            self.count += 1
            if self.count == 10000:
                self.count = 0
        else:
            top = random.randint(0, 1)

        if top:
            img = torch.hstack((img, self.null_block))
        else:
            img = torch.hstack((self.null_block, img))

        # img = np.uint8(img * 255)
        # img = np.transpose(img, (1, 2, 0))
        # import cv2
        # cv2.imwrite('img/null_block.jpg', img)
        return img
