import os
import os.path as pt
from typing import List, Tuple, Callable

import numpy as np
import torch
import torchvision.transforms as transforms
import tqdm
from kornia.filters import gaussian_blur2d
from skimage.transform import rotate as im_rotate
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from torchvision.datasets import VisionDataset

from xad.datasets.bases import TorchvisionDataset
from xad.datasets.color_noise import ceil, floor
from xad.datasets.transformations import torchvision_to_kornia
from xad.utils.logger import Logger


def smooth_noise(img: torch.Tensor, ksize: int, std: float, p: float = 1.0, inplace: bool = True) -> torch.Tensor:
    """
    Smoothens (blurs) the given noise images with a Gaussian kernel.
    :param img: torch tensor (n x c x h x w).
    :param ksize: the kernel size used for the Gaussian kernel.
    :param std: the standard deviation used for the Gaussian kernel.
    :param p: the chance smoothen an image, on average smoothens p * n images.
    :param inplace: whether to apply the operation inplace.
    """
    if not inplace:
        img = img.clone()
    ksize = ksize if ksize % 2 == 1 else ksize - 1
    picks = torch.from_numpy(np.random.binomial(1, p, size=img.size(0))).bool()
    if picks.sum() > 0:
        img[picks] = gaussian_blur2d(img[picks].float(), (ksize, ) * 2, (std, ) * 2).int()
    return img


def confetti_noise(size: torch.Size, p: float = 0.01,
                   blobshaperange: Tuple[Tuple[int, int], Tuple[int, int]] = ((3, 3), (5, 5)),
                   fillval: int = 255, backval: int = 0, ensureblob: bool = True, awgn: float = 0.0,
                   clamp: bool = False, onlysquared: bool = True, rotation: int = 0,
                   colorrange: Tuple[int, int] = None) -> torch.Tensor:
    """
    The noise is based on sampling randomly many rectangles (in the following called blobs) at random positions.
    Additionally, all blobs are of random size (within some range), of random rotation, and of random color.
    The color is randomly chosen per blob, thus consistent within one blob.
    :param size: size of the overall noise image(s), should be (n x h x w) or (n x c x h x w), i.e.
        number of samples, channels, height, width. Blobs are grayscaled for (n x h x w) or c == 1.
    :param p: the probability of inserting a blob per pixel.
        The average number of blobs in the image is p * h * w.
    :param blobshaperange: limits the random size of the blobs. For ((h0, h1), (w0, w1)), all blobs' width
        is ensured to be in {w0, ..., w1}, and height to be in {h0, ..., h1}.
    :param fillval: if the color is not randomly chosen (see colored parameter), this sets the color of all blobs.
        This is also the maximum value used for clamping (see clamp parameter). Can be negative.
    :param backval: the background pixel value, i.e. the color of pixels in the noise image that are not part
         of a blob. Also used for clamping.
    :param ensureblob: whether to ensure that there is at least one blob per noise image.
    :param awgn: amount of additive white gaussian noise added to all blobs.
    :param clamp: whether to clamp all noise image to the pixel value range (backval, fillval).
    :param onlysquared: whether to restrict the blobs to be squares only.
    :param rotation: the maximum amount of rotation (in degrees)
    :param colorrange: the range of possible color values for each blob and channel.
        Defaults to None, where the blobs are not colored, but instead parameter fillval is used.
        First value can be negative.
    :return: torch tensor containing n noise images. Either (n x c x h x w) or (n x h x w), depending on size.
    """
    assert len(size) == 4 or len(size) == 3, 'size must be n x c x h x w'
    if isinstance(blobshaperange[0], int) and isinstance(blobshaperange[1], int):
        blobshaperange = (blobshaperange, blobshaperange)
    assert len(blobshaperange) == 2
    assert len(blobshaperange[0]) == 2 and len(blobshaperange[1]) == 2
    assert colorrange is None or len(size) == 4 and size[1] == 3
    out_size = size
    colors = []
    if len(size) == 3:
        size = (size[0], 1, size[1], size[2])  # add channel dimension
    else:
        size = tuple(size)  # Tensor(torch.size) -> tensor of shape size, Tensor((x, y)) -> Tensor with 2 elements x & y
    mask = (torch.rand((size[0], size[2], size[3])) < p).unsqueeze(1)  # mask[i, j, k] == 1 for center of blob
    while ensureblob and (mask.view(mask.size(0), -1).sum(1).min() == 0):
        idx = (mask.view(mask.size(0), -1).sum(1) == 0).nonzero().squeeze()
        s = idx.size(0) if len(idx.shape) > 0 else 1
        mask[idx] = (torch.rand((s, 1, size[2], size[3])) < p)
    res = torch.empty(size).fill_(backval).int()
    idx = mask.nonzero()  # [(idn, idz, idy, idx), ...] = indices of blob centers
    if idx.reshape(-1).size(0) == 0:
        return torch.zeros(out_size).int()

    all_shps = [
        (x, y) for x in range(blobshaperange[0][0], blobshaperange[1][0] + 1)
        for y in range(blobshaperange[0][1], blobshaperange[1][1] + 1) if not onlysquared or x == y
    ]
    picks = torch.FloatTensor(idx.size(0)).uniform_(0, len(all_shps)).int()  # for each blob center pick a shape
    nidx = []
    for n, blobshape in enumerate(all_shps):
        if (picks == n).sum() < 1:
            continue
        bhs = range(-(blobshape[0] // 2) if blobshape[0] % 2 != 0 else -(blobshape[0] // 2) + 1, blobshape[0] // 2 + 1)
        bws = range(-(blobshape[1] // 2) if blobshape[1] % 2 != 0 else -(blobshape[1] // 2) + 1, blobshape[1] // 2 + 1)
        extends = torch.stack([
            torch.zeros(len(bhs) * len(bws)).long(),
            torch.zeros(len(bhs) * len(bws)).long(),
            torch.arange(bhs.start, bhs.stop).repeat(len(bws)),
            torch.arange(bws.start, bws.stop).unsqueeze(1).repeat(1, len(bhs)).reshape(-1)
        ]).transpose(0, 1)
        nid = idx[picks == n].unsqueeze(1) + extends.unsqueeze(0)
        if colorrange is not None:
            col = torch.randint(
                colorrange[0], colorrange[1], (3, )
            )[:, None].repeat(1, nid.reshape(-1, nid.size(-1)).size(0)).int()
            colors.append(col)
        nid = nid.reshape(-1, extends.size(1))
        nid = torch.max(torch.min(nid, torch.LongTensor(size) - 1), torch.LongTensor([0, 0, 0, 0]))
        nidx.append(nid)
    idx = torch.cat(nidx)  # all pixel indices that blobs cover, not only center indices
    shp = res[idx.transpose(0, 1).numpy()].shape
    if colorrange is not None:
        colors = torch.cat(colors, dim=1)
        gnoise = (torch.randn(3, *shp) * awgn).int() if awgn != 0 else (0, 0, 0)
        res[idx.transpose(0, 1).numpy()] = colors[0] + gnoise[0]
        res[(idx + torch.LongTensor((0, 1, 0, 0))).transpose(0, 1).numpy()] = colors[1] + gnoise[1]
        res[(idx + torch.LongTensor((0, 2, 0, 0))).transpose(0, 1).numpy()] = colors[2] + gnoise[2]
    else:
        gnoise = (torch.randn(shp) * awgn).int() if awgn != 0 else 0
        res[idx.transpose(0, 1).numpy()] = torch.ones(shp).int() * fillval + gnoise
        res = res[:, 0, :, :]
        if len(out_size) == 4:
            res = res.unsqueeze(1).repeat(1, out_size[1], 1, 1)
    if clamp:
        res = res.clamp(backval, fillval) if backval < fillval else res.clamp(fillval, backval)
    mask = mask[:, 0, :, :]
    if rotation > 0:
        idx = mask.nonzero()
        res = res.unsqueeze(1) if res.dim() != 4 else res
        res = res.transpose(1, 3).transpose(1, 2)
        for pick, blbctr in zip(picks, mask.nonzero()):
            rot = np.random.uniform(-rotation, rotation)
            p1, p2 = all_shps[pick]
            dims = (
                blbctr[0],
                slice(max(blbctr[1] - floor(0.75 * p1), 0), min(blbctr[1] + ceil(0.75 * p1), res.size(1) - 1)),
                slice(max(blbctr[2] - floor(0.75 * p2), 0), min(blbctr[2] + ceil(0.75 * p2), res.size(2) - 1)),
                ...
            )
            res[dims] = torch.from_numpy(
                im_rotate(
                    res[dims].float(), rot, order=0, cval=0, center=(blbctr[1]-dims[1].start, blbctr[2]-dims[2].start),
                    clip=False
                )
            ).int()
        res = res.transpose(1, 2).transpose(1, 3)
        res = res.squeeze() if len(out_size) != 4 else res
    return res


class ADConfettiDataset(TorchvisionDataset):
    base_folder = 'confetti'

    def __init__(self, normal_dataset: TorchvisionDataset, root: str, normal_classes: List[int], nominal_label: int,
                 train_transform: transforms.Compose, test_transform: transforms.Compose,
                 raw_shape: Tuple[int, int, int], logger: Logger = None, limit_samples: int = np.infty, **kwargs):
        root = pt.join(root, self.base_folder)
        super().__init__(
            root, normal_classes, nominal_label, train_transform, test_transform, 1, raw_shape, logger, limit_samples,
            **kwargs
        )
        self.normal_dataset = normal_dataset
        assert len(normal_classes) == 1, 'ADConfettiDataset only works with *one* normal class'

        self._train_set = ConfettiDataset(
            self.normal_dataset, self.root, self.train_transform, self.target_transform,
        )
        self._train_set = self.create_subset(self._train_set, self._train_set.targets)

    def _get_raw_train_set(self):
        train_set = ConfettiDataset(
            self.normal_dataset, self.root,
            transforms.Compose([
                transforms.Resize(self.raw_shape[-1]), transforms.CenterCrop(64), transforms.ToTensor(),
            ]),
            self.target_transform,
        )
        return Subset(
            train_set,
            np.argwhere(
                np.isin(np.asarray(train_set.targets), self.normal_classes)
            ).flatten().tolist()
        )


class ConfettiDataset(VisionDataset):
    def __init__(self, normal_dataset: TorchvisionDataset, root: str, transform: Callable, target_transform: Callable,
                 samples: int = 20000, invert_threshold: float = 0.025):
        if len(normal_dataset.train_set) < 2500:
            samples = int(samples * 0.1)
        super(ConfettiDataset, self).__init__(root, transform=transform, target_transform=target_transform)
        self.transform = transforms.Compose([  # no PIL images
            transforms.ToTensor(),
            torchvision_to_kornia([t for t in self.transform.transforms if not isinstance(t, transforms.ToTensor)]),
        ])
        name = f'{samples}_confetti_{normal_dataset}_CLS{"-".join([str(i) for i in normal_dataset.normal_classes])}.pt'
        if not pt.isfile(pt.join(root, name)):
            shape = torch.Size([500, *normal_dataset.raw_shape])
            print(
                f'Creating Confetti dataset for classes {normal_dataset.normal_classes} of {normal_dataset} '
                f'with {samples} samples of shape {shape[1:]}...'
            )
            blobrange = ((8, 8), (54, 54)) if shape[-1] >= 128 else ((4, 4), (18, 18))

            raw_train_set = normal_dataset._get_raw_train_set()
            loader = DataLoader(dataset=raw_train_set, batch_size=500, shuffle=True, num_workers=4, pin_memory=True, )
            pbar = tqdm.tqdm(total=int(np.ceil(samples / len(raw_train_set) * len(loader))), desc='Creating confetti')
            generated_oe = []
            for ep in range(int(np.ceil(samples / len(raw_train_set)))):
                for x, _, _ in loader:
                    # generate additive noise
                    generated_noise_rgb = confetti_noise(
                        x.shape, 0.000018, blobrange, fillval=255, clamp=False, awgn=0, rotation=45, colorrange=(-256, 0)
                    )
                    generated_noise = confetti_noise(
                        x.shape, 0.000012, blobrange, fillval=-255, clamp=False, awgn=0, rotation=45
                    )
                    generated_noise = generated_noise_rgb + generated_noise
                    generated_noise = smooth_noise(generated_noise, 25, 5, 1.0).to(x.device).float().div(255)

                    # # invert noise for bright regions (bright regions are considered being on average > brightness_threshold)
                    # bright_regions = x.sum(1) > brightness_threshold * x.shape[1]
                    # for ch in range(x.shape[1]):
                    #     gnch = generated_noise[:, ch]
                    #     gnch[bright_regions] = gnch[bright_regions] * -1
                    #     generated_noise[:, ch] = gnch

                    # invert noise if difference of malformed-original is less than threshold and inverted difference is larger
                    diff = abs((x + generated_noise).clamp(0, 1) - x).flatten(1).mean(1)
                    diffi = abs((x - generated_noise).clamp(0, 1) - x).flatten(1).mean(1)
                    inv = [i for i, (d, di) in enumerate(zip(diff, diffi)) if d < invert_threshold and di > d]
                    generated_noise[inv] = -generated_noise[inv]

                    oe = (x + generated_noise).clamp(0, 1)
                    generated_oe.append(oe.mul(255).byte())
                    pbar.update()

                    if np.sum([len(o) for o in generated_oe]) >= samples:
                        break

            pbar.close()
            generated_oe = torch.cat(generated_oe)[:samples].permute(0, 2, 3, 1).numpy()
            os.makedirs(pt.join(root), exist_ok=True)
            torch.save(generated_oe, pt.join(root, name), pickle_protocol=4)
        else:
            generated_oe = torch.load(pt.join(root, name))
        self.data = generated_oe
        self.targets = [0 for _ in range(self.data.shape[0])]

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int, int]:
        img, target = self.data[index], self.targets[index]

        if self.transform is not None:
            img = self.transform(img).squeeze()

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target, index
