import paddle
import paddle.vision.transforms as transforms


class PrefetchLoader:
    """A data loader wrapper for prefetching data along with ``ToTensor`` and `Normalize`
    transformations.

    Modified from https://github.com/open-mmlab/OpenSelfSup.
    """

    def __init__(self, loader, mean, std):
        self.loader = loader
        self._mean = mean
        self._std = std

    def __iter__(self):
        first = True
        self.mean = paddle.to_tensor([x * 255 for x in self._mean]).reshape([1, 3, 1, 1])
        self.std = paddle.to_tensor([x * 255 for x in self._std]).reshape([1, 3, 1, 1])

        for next_item in self.loader:
            if "img" in next_item:
                img = next_item["img"]
                next_item["img"] = img.subtract_(self.mean).divide_(self.std)
            else:
                # Semi-supervised loader
                img1 = next_item["img1"]
                img2 = next_item["img2"]
                next_item["img1"] = img1.subtract_(self.mean).divide_(self.std)
                next_item["img2"] = img2.subtract_(self.mean).divide_(self.std)

            if not first:
                yield item
            else:
                first = False

            item = next_item

        yield item

    def __len__(self):
        return len(self.loader)

    @property
    def sampler(self):
        return self.loader.sampler

    @property
    def dataset(self):
        return self.loader.dataset


def prefetch_transform(transform):
    """Remove ``ToTensor`` and ``Normalize`` in ``transform``."""
    transform_list = []
    normalize = False
    for t in transform.transforms:
        if "Normalize" in str(type(t)):
            normalize = True
    if not normalize:
        raise KeyError("No Normalize in transform: {}".format(transform))
    for t in transform.transforms:
        if not ("ToTensor" or "Normalize" in str(type(t))):
            transform_list.append(t)
        if "Normalize" in str(type(t)):
            mean, std = t.mean, t.std
    transform = transforms.Compose(transform_list)

    return transform, mean, std
