import math
import numbers
import numpy as np
import random
import torch
import torch.nn as nn
from torch.autograd import Function
from torch.autograd import Variable
from copy import deepcopy
import cv2

def random_rotation(x_image, y_image):

    x_image = x_image.numpy().transpose(1,2,0)
    y_image = y_image.numpy().transpose(1,2,0)

    rows_x,cols_x, chl_x = x_image.shape

    rand_num = np.random.randint(-180,180)
    M1 = cv2.getRotationMatrix2D((cols_x/2,rows_x/2),rand_num,1)
    M2 = cv2.getRotationMatrix2D((cols_x/2,cols_x/2),rand_num,1)
    x_image = cv2.warpAffine(x_image,M1,(cols_x,rows_x))
    y_image = cv2.warpAffine(y_image,M2,(cols_x,cols_x))

    return torch.tensor(np.array(x_image)).unsqueeze(0),  torch.tensor(np.array(y_image).transpose(2,0,1))

class RandomErasing_small(object):

    def __init__(self, probability=1, sl=0.005, sh=0.02, r1=0.05, mean=[[0.4914, 0.4822, 0.4465], [0, 0, 0]]):
        self.probability = probability
        self.mean = mean[0]
        self.mean1 = mean[1]
        self.sl = sl
        self.sh = sh
        self.r1 = r1

    def __call__(self, img, fa_img, mask):

        real_img = deepcopy(img)
        while 1:
            area = img.size()[1] * img.size()[2]

            target_area = random.uniform(self.sl, self.sh) * area
            aspect_ratio = random.uniform(self.r1, 1 / self.r1)

            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))


            if w < img.size()[2] and h < img.size()[1]:

                x1 = random.randint(0, img.size()[1] - h)
                y1 = random.randint(0, img.size()[2] - w)

                x2 = random.randint(0, img.size()[1] - h)
                y2 = random.randint(0, img.size()[2] - w)

                if img.size()[0] == 3:

                    if random.uniform(0, 1) > 0.2:

                        img[0, x1:x1 + h, y1:y1 + w] = fa_img[0, x2:x2 + h, y2:y2 + w]
                        img[1, x1:x1 + h, y1:y1 + w] = fa_img[1, x2:x2 + h, y2:y2 + w]
                        img[2, x1:x1 + h, y1:y1 + w] = fa_img[2, x2:x2 + h, y2:y2 + w]
                        mask[0, x1:x1 + h, y1:y1 + w] = 0

                    else:
                        if random.uniform(0, 1) > 0.7:
                            img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
                            img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
                            img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
                            mask[0, x1:x1 + h, y1:y1 + w] = 0
                        elif random.uniform(0, 1) > 0.4:
                            img[0, x1:x1 + h, y1:y1 + w] = self.mean1[0]
                            img[1, x1:x1 + h, y1:y1 + w] = self.mean1[1]
                            img[2, x1:x1 + h, y1:y1 + w] = self.mean1[2]
                            mask[0, x1:x1 + h, y1:y1 + w] = 0
                        else:
                            img[0, x1:x1 + h, y1:y1 + w] = 1
                            img[1, x1:x1 + h, y1:y1 + w] = 1
                            img[2, x1:x1 + h, y1:y1 + w] = 1
                            mask[0, x1:x1 + h, y1:y1 + w] = 0
                else:
                    img[0, x1:x1 + h, y1:y1 + w] = fa_img

                mask = 1-mask
                mask, img = random_rotation(mask, img)

                mask = (mask > 0.5).float()
                img = mask * img + (1 - mask) * real_img

                return img, (1 - mask)

class RandomErasing_big(object):

    def __init__(self, probability=1, sl=0.02, sh=0.4, r1=0.05, mean=[[0.4914, 0.4822, 0.4465], [0, 0, 0]]):
        self.probability = probability
        self.mean = mean[0]
        self.mean1 = mean[1]
        self.sl = sl
        self.sh = sh
        self.r1 = r1

    def __call__(self, img, fa_img, mask):

        real_img = deepcopy(img)
        while 1:
            area = img.size()[1] * img.size()[2]

            target_area = random.uniform(self.sl, self.sh) * area
            aspect_ratio = random.uniform(self.r1, 1 / self.r1)

            h = int(round(math.sqrt(target_area * aspect_ratio)))
            w = int(round(math.sqrt(target_area / aspect_ratio)))


            if w < img.size()[2] and h < img.size()[1]:

                x1 = random.randint(0, img.size()[1] - h)
                y1 = random.randint(0, img.size()[2] - w)

                x2 = random.randint(0, img.size()[1] - h)
                y2 = random.randint(0, img.size()[2] - w)

                if img.size()[0] == 3:

                    if random.uniform(0, 1) > 0.2:

                        img[0, x1:x1 + h, y1:y1 + w] = fa_img[0, x2:x2 + h, y2:y2 + w]
                        img[1, x1:x1 + h, y1:y1 + w] = fa_img[1, x2:x2 + h, y2:y2 + w]
                        img[2, x1:x1 + h, y1:y1 + w] = fa_img[2, x2:x2 + h, y2:y2 + w]
                        mask[0, x1:x1 + h, y1:y1 + w] = 0

                    else:
                        if random.uniform(0, 1) > 0.7:
                            img[0, x1:x1 + h, y1:y1 + w] = self.mean[0]
                            img[1, x1:x1 + h, y1:y1 + w] = self.mean[1]
                            img[2, x1:x1 + h, y1:y1 + w] = self.mean[2]
                            mask[0, x1:x1 + h, y1:y1 + w] = 0
                        elif random.uniform(0, 1) > 0.4:
                            img[0, x1:x1 + h, y1:y1 + w] = self.mean1[0]
                            img[1, x1:x1 + h, y1:y1 + w] = self.mean1[1]
                            img[2, x1:x1 + h, y1:y1 + w] = self.mean1[2]
                            mask[0, x1:x1 + h, y1:y1 + w] = 0
                        else:
                            img[0, x1:x1 + h, y1:y1 + w] = 1
                            img[1, x1:x1 + h, y1:y1 + w] = 1
                            img[2, x1:x1 + h, y1:y1 + w] = 1
                            mask[0, x1:x1 + h, y1:y1 + w] = 0
                else:
                    img[0, x1:x1 + h, y1:y1 + w] = fa_img

                mask = 1-mask
                mask, img = random_rotation(mask, img)

                mask = (mask > 0.5).float()
                img = mask * img + (1 - mask) * real_img

                return img, (1 - mask)


def rgb2hsv(rgb):

    r, g, b = rgb[:, 0, :, :], rgb[:, 1, :, :], rgb[:, 2, :, :]

    Cmax = rgb.max(1)[0]
    Cmin = rgb.min(1)[0]
    delta = Cmax - Cmin

    hue = torch.atan2(math.sqrt(3) * (g - b), 2 * r - g - b)
    hue = (hue % (2 * math.pi)) / (2 * math.pi)
    saturate = delta / Cmax
    value = Cmax
    hsv = torch.stack([hue, saturate, value], dim=1)
    hsv[~torch.isfinite(hsv)] = 0.
    return hsv

def hsv2rgb(hsv):

    h, s, v = hsv[:, [0]], hsv[:, [1]], hsv[:, [2]]
    c = v * s

    n = hsv.new_tensor([5, 3, 1]).view(3, 1, 1)
    k = (n + h * 6) % 6
    t = torch.min(k, 4 - k)
    t = torch.clamp(t, 0, 1)

    return v - c * t

class RandomHSVFunction(Function):
    @staticmethod
    def forward(ctx, x, f_h, f_s, f_v):
        # ctx is a context object that can be used to stash information
        # for backward computation
        x = rgb2hsv(x)
        h = x[:, 0, :, :]
        h += (f_h * 255. / 360.)
        h = (h % 1)
        x[:, 0, :, :] = h
        x[:, 1, :, :] = x[:, 1, :, :] * f_s
        x[:, 2, :, :] = x[:, 2, :, :] * f_v
        x = torch.clamp(x, 0, 1)
        x = hsv2rgb(x)
        return x

    @staticmethod
    def backward(ctx, grad_output):
        # We return as many input gradients as there were arguments.
        # Gradients of non-Tensor arguments to forward must be None.
        grad_input = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.clone()
        return grad_input, None, None, None


class ColorJitterLayer(nn.Module):
    def __init__(self, p, brightness, contrast, saturation, hue):
        super(ColorJitterLayer, self).__init__()
        self.prob = p
        self.brightness = self._check_input(brightness, 'brightness')
        self.contrast = self._check_input(contrast, 'contrast')
        self.saturation = self._check_input(saturation, 'saturation')
        self.hue = self._check_input(hue, 'hue', center=0, bound=(-0.5, 0.5),
                                     clip_first_on_zero=False)

    def _check_input(self, value, name, center=1, bound=(0, float('inf')), clip_first_on_zero=True):
        if isinstance(value, numbers.Number):
            if value < 0:
                raise ValueError("If {} is a single number, it must be non negative.".format(name))
            value = [center - value, center + value]
            if clip_first_on_zero:
                value[0] = max(value[0], 0)
        elif isinstance(value, (tuple, list)) and len(value) == 2:
            if not bound[0] <= value[0] <= value[1] <= bound[1]:
                raise ValueError("{} values should be between {}".format(name, bound))
        else:
            raise TypeError("{} should be a single number or a list/tuple with lenght 2.".format(name))

        # if value is 0 or (1., 1.) for brightness/contrast/saturation
        # or (0., 0.) for hue, do nothing
        if value[0] == value[1] == center:
            value = None
        return value

    def adjust_contrast(self, x):
        if self.contrast:
            factor = x.new_empty(x.size(0), 1, 1, 1).uniform_(*self.contrast)
            means = torch.mean(x, dim=[2, 3], keepdim=True)
            x = (x - means) * factor + means
        return torch.clamp(x, 0, 1)

    def adjust_hsv(self, x):
        f_h = x.new_zeros(x.size(0), 1, 1)
        f_s = x.new_ones(x.size(0), 1, 1)
        f_v = x.new_ones(x.size(0), 1, 1)

        if self.hue:
            f_h.uniform_(*self.hue)
        if self.saturation:
            f_s = f_s.uniform_(*self.saturation)
        if self.brightness:
            f_v = f_v.uniform_(*self.brightness)

        return RandomHSVFunction.apply(x, f_h, f_s, f_v)

    def transform(self, inputs):
        # Shuffle transform
        if np.random.rand() > 0.5:
            transforms = [self.adjust_contrast, self.adjust_hsv]
        else:
            transforms = [self.adjust_hsv, self.adjust_contrast]

        for t in transforms:
            inputs = t(inputs)

        return inputs

    def forward(self, inputs):
        _prob = inputs.new_full((inputs.size(0),), self.prob)
        _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
        return inputs * (1 - _mask) + self.transform(inputs) * _mask

def get_simclr_augmentation():
    color_jitter = ColorJitterLayer(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=1)
    transform = nn.Sequential(
        color_jitter,
    )

    return transform


def color_ji(images,n = 3):

    oo = get_simclr_augmentation()(images)

    aa = np.random.randint(4)

    for i in range(aa):
        oo = get_simclr_augmentation()(oo)

    return oo

class Rotation(nn.Module):
    def __init__(self, max_range = 4, prob=0):
        super(Rotation, self).__init__()
        self.max_range = max_range
        self.prob = prob

    def forward(self, input, aug_index=None):
        _device = input.device

        _, _, H, W = input.size()

        if aug_index is None:
            aug_index = np.random.randint(3) + 1

            output = torch.rot90(input, aug_index, (2, 3))

            _prob = input.new_full((input.size(0),), self.prob)
            _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
            output = _mask * input + (1-_mask) * output

        else:
            aug_index = aug_index % self.max_range
            output = torch.rot90(input, aug_index, (2, 3))

        return output

class CutPerm(nn.Module):
    def __init__(self, max_range = 4, prob = 0):
        super(CutPerm, self).__init__()
        self.max_range = max_range
        self.prob = prob

    def forward(self, input, aug_index=None):
        _device = input.device

        _, _, H, W = input.size()

        if aug_index is None:
            aug_index = np.random.randint(3) + 1

            output = self._cutperm(input, aug_index)

            _prob = input.new_full((input.size(0),), self.prob)
            _mask = torch.bernoulli(_prob).view(-1, 1, 1, 1)
            output = _mask * input + (1 - _mask) * output

        else:
            aug_index = aug_index % self.max_range
            output = self._cutperm(input, aug_index)

        return output

    def _cutperm(self, inputs, aug_index):

        _, _, H, W = inputs.size()
        h_mid = int(H / 2)
        w_mid = int(W / 2)

        jigsaw_h = aug_index // 2
        jigsaw_v = aug_index % 2
        #print(jigsaw_h,jigsaw_v)

        if jigsaw_h == 1:
            inputs = torch.cat((inputs[:, :, h_mid:, :], inputs[:, :, 0:h_mid, :]), dim=2)
        if jigsaw_v == 1:
            inputs = torch.cat((inputs[:, :, :, w_mid:], inputs[:, :, :, 0:w_mid]), dim=3)

        return inputs


def Anomaly_generation1(image):


    x_col = color_ji(image)

    aug_index1 = np.random.randint(3) + 1
    oo = CutPerm()(image, aug_index1)
    aug_index = np.random.randint(3) + 1
    oo = Rotation()(oo, aug_index)
    aug_index1 = np.random.randint(3) + 1
    oo = CutPerm()(oo, aug_index1)

    x_Ano1 = color_ji(oo)

    aug_index = np.random.randint(2)

    if aug_index:
        oo1 = image[torch.randperm(image.size()[0])]
    else:
        oo1 = x_col[torch.randperm(x_col.size()[0])]

    aug_index = np.random.randint(2)

    if aug_index:

        x_Ano1 = (x_Ano1 + oo1) / 2
    else:

        x_Ano1 =x_Ano1

    x_Ano2 = []
    A_Ano2 = []

    for in_re, in_fa, masa in zip(image.cpu(), x_Ano1.cpu()[torch.randperm(x_Ano1.size()[0])],
                                  torch.ones([image.size(0), 1, 256, 256])):

        if random.uniform(0, 1) > 0.7:
            in_re, masa = RandomErasing_small()(in_re, in_fa, masa)
            x_Ano2.append(in_re.unsqueeze(0))
            A_Ano2.append(masa.unsqueeze(0))
        else:
            in_re, masa = RandomErasing_big()(in_re, in_fa, masa)
            x_Ano2.append(in_re.unsqueeze(0))
            A_Ano2.append(masa.unsqueeze(0))


    x_Ano2 = torch.autograd.Variable(torch.cat(x_Ano2, dim=0)).cuda()
    A_Ano2 = torch.autograd.Variable(torch.cat(A_Ano2, dim=0)).cuda()


    return x_Ano1, x_Ano2, A_Ano2


def Anomaly_generation(image):

    x_col = color_ji(image)

    aug_index1 = np.random.randint(3) + 1
    oo = CutPerm()(image, aug_index1)
    aug_index = np.random.randint(3) + 1
    oo = Rotation()(oo, aug_index)
    aug_index1 = np.random.randint(3) + 1
    oo = CutPerm()(oo, aug_index1)

    x_Ano1 = color_ji(oo)

    aug_index = np.random.randint(2)

    if aug_index:
        oo1 = image[torch.randperm(image.size()[0])]
    else:
        oo1 = x_col[torch.randperm(x_col.size()[0])]

    x_Ano1 = (x_Ano1 + oo1) / 2

    x_Ano2 = []
    A_Ano2 = []

    for in_re, in_fa, masa in zip(image.cpu(), x_Ano1.cpu()[torch.randperm(x_Ano1.size()[0])],
                                  torch.ones([image.size(0), 1, 224, 224])):

        in_re, masa = RandomErasing()(in_re, in_fa, masa)
        x_Ano2.append(in_re.unsqueeze(0))
        A_Ano2.append(masa.unsqueeze(0))

    x_Ano2 = torch.autograd.Variable(torch.cat(x_Ano2, dim=0)).cuda()
    A_Ano2 = torch.autograd.Variable(torch.cat(A_Ano2, dim=0)).cuda()

    return x_Ano1, x_Ano2, A_Ano2