import torch
from torch import Generator

_MAX_SEED = torch.iinfo(torch.int64).max


@torch.compiler.disable()
def derive_new_seed(previous_seed: int):
    gen = get_generator(previous_seed)
    seed = torch.randint(0, _MAX_SEED, size=(), generator=gen, dtype=torch.int64)
    seed = seed.item()
    return seed


@torch.compiler.disable()
def get_generator(seed: int, device=None):
    gen = Generator(device).manual_seed(seed)
    return gen


def jitter(x: float, std: float, size: int, device, gen: Generator):
    if std <= 0.0:
        return x

    eps = torch.randn(size, device=device, generator=gen)
    eps = torch.clip(eps, min=-5.0, max=5.0)
    eps = std * eps
    return x * torch.exp(eps)
