import torch


def col_concat(x, y):
    """Concatenate x and y along the final (column) dimension."""
    return torch.cat([x, y], dim=-1)

def reparameterise(x, clamp=("hard", -20, 2), params=False):
    """
    The reparameterisation trick. 
    Construct a Gaussian from x, taken to parameterise the mean and log standard deviation.
    """
    mean, log_std = torch.split(x, int(x.shape[-1]/2), dim=-1)
    # Bounding log_std helps to regulate its behaviour outside the training data (see PETS paper Appendix A.1).
    if clamp[0] == "hard": # This is used by default for the SAC policy.
        log_std = torch.clamp(log_std, clamp[1], clamp[2])
    elif clamp[0] == "soft": # This is used by default for the PETS model.
        log_std = clamp[1] + torch.nn.functional.softplus(log_std - clamp[1])
        log_std = clamp[2] - torch.nn.functional.softplus(clamp[2] - log_std)
    return (mean, log_std) if params else torch.distributions.Normal(mean, torch.exp(log_std))

def truncated_normal(tensor, mean, std, a, b):
    """
    Sample from a truncated normal distribution.
    Adapted from torch.nn.init._no_grad_trunc_normal_.
    """
    def norm_cdf(x): return (1. + torch.erf(x / 2.**.5)) / 2.
    with torch.no_grad():
        # Get upper and lower cdf values
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)
        # Uniformly fill tensor with values in [0, 1], then transform to [2l-1, 2u-1]
        tensor.uniform_()
        tensor = 2 * (l + tensor * (u - l)) - 1
        # Use inverse cdf transform for normal distribution to get truncated standard normal
        tensor.erfinv_()
        # Transform to proper mean, std
        tensor.mul_(std * (2.**.5))
        tensor.add_(mean)
        # Clamp to ensure it's in the proper range
        tensor.clamp_(min=a, max=b)
        return tensor
