import torch
import torchvision.transforms.functional as TF


class ToDevice(object):
    def __init__(self, device=None):
        self.device = device

    def __call__(self, tensor_pic):
        return tensor_pic.to(self.device)

    def __repr__(self):
        return self.__class__.__name__ + '(device={0})'.format(self.device)

class NormaliseMinMax(object):
    def __call__(self, tensor):
        self.min = tensor.min()
        self.max = tensor.max()
        return (tensor - self.min) / (self.max - self.min)

    def __repr__(self):
        return self.__class__.__name__ + '(min={0}, max={1})'.format(self.min, self.max)

class NormaliseMean(object):
    def __call__(self, tensor):
        self.min = tensor.min()
        self.max = tensor.max()
        return (tensor - self.min) / (self.max - self.min)

    def __repr__(self):
        return self.__class__.__name__ + '(min={0}, max={1})'.format(self.min, self.max)

class NormaliseMeanImageNet(object):
    def __call__(self, tensor):
        self.min = tensor.min()
        self.max = tensor.max()
        return (tensor - self.min) / (self.max - self.min)

    def __repr__(self):
        return self.__class__.__name__ + '(min={0}, max={1})'.format(self.min, self.max)