from torch.distributions import Transform
import torch

class IdentityTransform(Transform):
    domain = torch.distributions.constraints.real
    codomain = torch.distributions.constraints.real
    bijective = True
    sign = +1

    def __init__(self, cache_size=0):
        super().__init__(cache_size=cache_size)

    def __eq__(self, other):
        return isinstance(other, IdentityTransform)

    def _call(self, x):
        return x

    def _inverse(self, y):
        return y

    def log_abs_det_jacobian(self, x, y):
        return torch.zeros_like(x)