from smlm.utils.torch import are_broadcastable


def stretch(t, extent, clamp: bool):
    """
    Stretch normalized input from [0, 1] to the specified range.
    Args:
        t (Tensor): Normalized input tensor.
        extent (Tensor): Target range tensor broadcastable to shape t.shape + (2,).
        clamp (bool): Wether or not to clamp values within the specified range.
    """
    if extent.shape[-1] != 2:
        raise ValueError(
            f"extent last dimension must be of size 2; found {extent.shape[-1]}"
        )
    if not are_broadcastable(t.shape, extent.shape[:-1]):
        raise ValueError(
            f"input ({t.shape}) and extent without its last dim ({extent.shape}) must be broadcastable."
        )

    t = t.clip(0.0, 1.0) if clamp else t
    return t * (extent[..., 1] - extent[..., 0]) + extent[..., 0]


def normalize(t, extent, clamp: bool):
    """
    Normalize a tensor in [0, 1] with the image extent.
    Args:
        t (Tensor): Normalized input tensor.
        extent (Tensor): Target range tensor with shape t.size() + (2,).
        clamp (bool): Wether or not to clamp values in [0, 1].
    """
    if extent.shape[-1] != 2:
        raise ValueError(
            f"extent last dimension must be of size 2; found {extent.shape[-1]}"
        )
    if not are_broadcastable(t.shape, extent.shape[:-1]):
        raise ValueError(
            f"input ({t.shape}) and extent without its last dim ({extent.shape}) must be broadcastable."
        )
    t = (t - extent[..., 0]) / (extent[..., 1] - extent[..., 0])
    t = t.clip(0.0, 1.0) if clamp else t
    return t


def stretch_unitbox(t, extent, clamp: bool):
    """
    Stretch normalized input from [-1, 1] to the specified range.
    Args:
        t (Tensor): Normalized input tensor.
        extent (Tensor): Target range tensor broadcastable to shape t.shape + (2,).
        clamp (bool): Wether or not to clamp values within the specified range.
    """
    t = 0.5 * (t + 1.0)
    t = stretch(t, extent=extent, clamp=False)
    return t


def normalize_unitbox(t, extent, clamp: bool):
    """
    Normalize a tensor in [-1, 1] with the image extent.
    Args:
        t (Tensor): Normalized input tensor.
        extent (Tensor): Target range tensor with shape t.size() + (2,).
        clamp (bool): Wether or not to clamp values in [-1, 1].
    """
    t = normalize(t, extent=extent, clamp=clamp)
    t = 2.0 * t - 1.0
    return t
