import itertools


def expand_tensor_dims_as(in_tensor, x):
    """Expend the dimensions of in_tensor to match x."""
    for d1, d2 in zip(in_tensor.shape, x.shape):
        assert d1 == d2, f"Shapes do not match: {in_tensor.shape} vs {x.shape}"
    return in_tensor.view(list(in_tensor.shape) + [1] * (x.dim() - in_tensor.dim()))


def convert_net_output(out, sde, x, t, src, dst):
    assert src in ["eps", "x0"]
    assert dst in ["eps", "x0", "score"]
    sigma = sde.sigma(t)
    scale = sde.scale(t)
    if src == dst:
        return out
    if src == "eps" and dst == "x0":
        return x - out / sigma
    elif src == "eps" and dst == "score":
        return -out / sigma
    elif src == "x0" and dst == "eps":
        return (x - out) / sigma
    elif src == "x0" and dst == "score":
        return -(x - out) / (sigma ** 2)
    else:
        raise ValueError(f"Cannot convert from {src} to {dst}")


def infinite_loader(dataloader):
    return itertools.cycle(dataloader)


def get_register_fn(_CLASSES):
    def register_fn(cls=None, *, name=None):
        """A decorator for registering predictor classes."""

        def _register(cls):
            if name is None:
                local_name = cls.__name__
            else:
                local_name = name
            if local_name in _CLASSES:
                raise ValueError(f"Already registered model with name: {local_name}")
            _CLASSES[local_name] = cls
            return cls

        if cls is None:
            return _register
        else:
            return _register(cls)
    return register_fn


def splitit(total_size, split_size):
    """Splits total_size into chunks of maximum size split_size and yeilds the chunk sizes.
    It is guaranteed that the sum of the returned chunks is equal to total_size.
    """
    assert total_size >= 0, f"total_size must be non-negative, got {total_size}"
    assert split_size > 0, f"split_size must be positive, got {split_size}"
    for i in range(0, total_size, split_size):
        yield min(total_size - i, split_size)