
import numpy as np
from PIL import Image

from torch.utils.data.dataset import Dataset
from torchvision import transforms

from train.pytorch_wrapper.vision_utils import compute_boundary_weights, get_weighted_crop_params


class XyDataset(Dataset):
    """
    Standard X/y dataset for training prediction models (f(X) -> y).

    Parameters
    ----------
        X: arbitrary object, input data
        y: arbitrary object, target data
        transform: callable, transforms input data
        target_transform: callable, transforms target data

    Examples
    --------
        - transform can be used to load the data from a file
        - transform can be used to perform data augmentation
    """

    def __init__(self, X, y, transform=None, target_transform=None):
        self.X = X
        self.y = y
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        X, y = self.X[index], self.y[index]

        if self.transform is not None:
            X = self.transform(X)

        if self.target_transform is not None:
            y = self.target_transform(y)

        return {'inputs': {'X': X},
                'targets': {'y': y}}

    def __len__(self):
        return len(self.y)


class SegmentationDataset(Dataset):
    """
    Semantic segmentation dataset.

    Parameters:
    -----------
        X: list of PIL images, input images
        y: list of PIL images, segmentation targets
        transform: pytorch_wrapper.transforms.MultiViewCompose, augmentation transformations applied to images and masks
        boundary_weights: float, if > 0.0 boundary weights are computed for segmentation masks
        weighted_sampling: dictionary, over-sample regions containing foreground pixels
            {"power": 0.0, "crop_size": None}
            {"power": 5.0, "crop_size": (256, 256)}
        load_from_file: callable, load data from file using this function (useful for large datasets)
    """

    def __init__(self, X, y, transform, boundary_weights=0.0, weighted_sampling=None, load_from_file=None):
        super().__init__()
        self.X = X
        self.y = y
        self.transform = transform
        self.load_from_file = load_from_file

        self.boundary_weights = boundary_weights
        self.boundary_cache = dict()

        if weighted_sampling is None:
            weighted_sampling = {"power": 0.0,
                                 "crop_size": None}
        self.weighted_sampling = weighted_sampling

    def __getitem__(self, index):
        X, y, w = self.X[index], self.y[index], None

        # load data from files
        if self.load_from_file is not None:
            X, y, w = self.load_from_file(X, y)

        # compute weights if requested
        if self.boundary_weights > 0.0:

            if index in self.boundary_cache:
                w = self.boundary_cache[index]
            else:
                y_array = np.array(y).astype(np.float32)
                w = compute_boundary_weights(y_array, factor=self.boundary_weights)
                w = w / w.max() * 255
                w = Image.fromarray(w.astype(np.uint8))
                self.boundary_cache[index] = w

        # apply weighted random crop
        if self.weighted_sampling["power"] > 0:
            i, j, th, tw = get_weighted_crop_params(y, output_size=self.weighted_sampling["crop_size"],
                                                    power=self.weighted_sampling["power"])

            X = transforms.functional.crop(X, i, j, th, tw)
            y = transforms.functional.crop(y, i, j, th, tw)
            if self.boundary_weights:
                w = transforms.functional.crop(w, i, j, th, tw)

        # apply PIL transforms
        if w is not None:
            X, y, w = self.transform([X, y, w])

            w /= 255
            w /= w.min()

            data = {'inputs': {'X': X},
                    'targets': {'y': y, 'w': w}}
        else:
            X, y = self.transform([X, y])
            data = {'inputs': {'X': X},
                    'targets': {'y': y}}

        return data

    def __len__(self):
        return len(self.X)


class SemiSupervisedDataSet(Dataset):

    def __init__(self, X, Z, y, transform=None, target_transform=None):
        self.X = X
        self.Z = Z
        self.y = y

        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):

        # get upervised data
        X, y = self.X[index], self.y[index]

        # randomly sample unsupervised data
        rand_index = np.random.randint(0, len(self.Z))
        Z = self.Z[rand_index]

        if self.transform is not None:
            X = self.transform(X)
            Z = self.transform(Z)

        if self.target_transform is not None:
            y = self.target_transform(y)

        return {'inputs': {'X': X, 'Z': Z},
                'targets': {'y': y, 'X': X, 'Z': Z}}

    def __len__(self):
        return len(self.y)


if __name__ == "__main__":
    """ main """
    pass
