import torch


class BaseTransform(object):
    def __init__(self):
        self.image_changed = False

    def transform(self, image_nd, clicks_lists):
        raise NotImplementedError

    def inv_transform(self, prob_map):
        raise NotImplementedError

    def reset(self):
        raise NotImplementedError

    def get_state(self):
        raise NotImplementedError

    def set_state(self, state):
        raise NotImplementedError


class SigmoidForPred(BaseTransform):
    def transform(self, image_nd, clicks_lists):
        return image_nd, clicks_lists

    def inv_transform(self, prob_map):
        return torch.sigmoid(prob_map)

    def reset(self):
        pass

    def get_state(self):
        return None

    def set_state(self, state):
        pass
