# Transformations

import torch

def standardize(x, sigma_x=None):
    """
    Transform the data to have zero mean and a standard devation at 1, e.g. z-score
    
    :param x: torch tensor to be standardized
    :return: 3-tuple w/ standardized x, the mean of x, std of x
    - 
    """
    
    mu_x = torch.mean(x) if len(x.shape) == 1 else torch.mean(x, axis=0)
    
    if sigma_x is None:
        sigma_x = torch.std(x) if len(x.shape) == 1 else torch.std(x, axis=0)
    
    x_standardize = (x - mu_x) / sigma_x
    
    return x_standardize, mu_x, sigma_x


def standardize_inv(x, mu, sigma):
    """
    Given mean and standard deviation, x is transformed back to normale mean and variance
    """
    return x * sigma + mu