from openmixup.utils import build_from_cfg

import torch
from PIL import Image
from torchvision.transforms import Compose, RandomCrop
import torchvision.transforms.functional as TF

from .registry import DATASETS, PIPELINES
from .base import BaseDataset


def image_to_patches(img):
    """Crop split_per_side x split_per_side patches from input image.

    Args:
        img (PIL Image): input image.

    Returns:
        list[PIL Image]: A list of cropped patches.
    """
    split_per_side = 3  # split of patches per image side
    patch_jitter = 21  # jitter of each patch from each grid
    h, w = img.size
    h_grid = h // split_per_side
    w_grid = w // split_per_side
    h_patch = h_grid - patch_jitter
    w_patch = w_grid - patch_jitter
    assert h_patch > 0 and w_patch > 0
    patches = []
    for i in range(split_per_side):
        for j in range(split_per_side):
            p = TF.crop(img, i * h_grid, j * w_grid, h_grid, w_grid)
            p = RandomCrop((h_patch, w_patch))(p)
            patches.append(p)
    return patches


@DATASETS.register_module
class RelativeLocDataset(BaseDataset):
    """Dataset for relative patch location.
    """

    def __init__(self, data_source, pipeline, format_pipeline, prefetch=False):
        super(RelativeLocDataset, self).__init__(data_source, pipeline)
        assert prefetch == False
        format_pipeline = [build_from_cfg(p, PIPELINES) for p in format_pipeline]
        self.format_pipeline = Compose(format_pipeline)

    def __getitem__(self, idx):
        img = self.data_source.get_sample(idx)
        assert isinstance(img, Image.Image), \
            'The output from the data source must be an Image, got: {}. \
            Please ensure that the list file does not contain labels.'.format(
            type(img))
        img = self.pipeline(img)
        patches = image_to_patches(img)
        patches = [self.format_pipeline(p) for p in patches]
        perms = []
        # create a list of patch pairs
        [perms.append(torch.cat((patches[i], patches[4]), dim=0)) for i in range(9) if i != 4]
        # create corresponding labels for patch pairs
        patch_labels = torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7])
        return dict(img=torch.stack(perms), patch_label=patch_labels)  # 8(2C)HW, 8

    def evaluate(self, scores, keyword, logger=None):
        raise NotImplemented
