import torch
import numpy as np


def marginal_prob_std(t, sigma):
    """
    Returns the standard deviation of our perturbation at time t.
    """
    # Ensure tensor, preserve device & dtype
    if not torch.is_tensor(t):
        t = torch.tensor(t)
    device = t.device
    dtype = t.dtype

    sigma_t = sigma ** (2 * t)
    return torch.sqrt((sigma_t - 1.0) / (2 * np.log(sigma))).to(
        device=device, dtype=dtype
    )


def diff_coeff(t, sigma):
    """
    Diffusion coefficient g(t)
    """
    if not torch.is_tensor(t):
        t = torch.tensor(t)
    return sigma**t
