
import random
import torch
import numpy as np

from train.pytorch_wrapper.vision_utils import pil_to_convnet


class Identity(object):
    """Identity Transform (do nothing)
    """
    def __call__(self, pic):
        return pic

    def __repr__(self):
        return self.__class__.__name__ + '()'


class ToBinaryTarget(object):
    """Prepare segmentation mask for single class segmentation.
        - has to have shape (height, width)
        - has to be of data type torch.float having values in [0, 1]
    """
    def __call__(self, pic):
        pic = pil_to_convnet(pic)
        pic = torch.from_numpy(pic)
        pic = pic.float().div(255)
        return pic

    def __repr__(self):
        return self.__class__.__name__ + '()'


class PILToTensor(object):
    """Prepare segmentation mask for single class segmentation.
        - has to have shape (height, width)
        - has to be of data type torch.float having values in [0, 1]
    """
    def __call__(self, pic):
        pic = pil_to_convnet(pic)
        pic = torch.from_numpy(pic).float()
        return pic

    def __repr__(self):
        return self.__class__.__name__ + '()'


class ToMultiClassTarget(object):
    """Prepare segmentation mask for multi class segmentation.
            - has to have shape (batch_size, height, width)
            - has to be of data type torch.long containing the class id in each pixel
    """
    def __call__(self, pic):
        pic = pil_to_convnet(pic)
        return torch.from_numpy(pic[0]).long()

    def __repr__(self):
        return self.__class__.__name__ + '()'


class MultiViewCompose(object):
    """Composes several transforms together.

    Args:
        transforms (list of ``Transform`` objects): list of transforms to compose.
    """

    def __init__(self, transforms: list):
        self.transforms = transforms

    def __call__(self, inputs: list):

        # preserve random state
        np_state = np.random.get_state()
        random_state = random.getstate()

        # iterate inputs
        for i in range(len(inputs)):

            # reset random state
            np.random.set_state(np_state)
            random.setstate(random_state)

            # apply transforms
            for apply, t in self.transforms[i]:
                tmp = t(inputs[i])
                if apply:
                    inputs[i] = tmp
        return inputs

    def __repr__(self):
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += '    {0}'.format(t)
        format_string += '\n)'
        return format_string
