import warnings

import torch


@torch.no_grad()
def check_latent_space_valid(z: torch.Tensor, valid_range: float = 4.0, warning_threshold=0.1, warn=True):
    """
    Check how many samples fall into [-valid_range, valid_range] in all dimensions.
    Return the percentage of samples that fall outside this range.
    Samples that fall outside this range are unlikely according to the latent isotropic Gaussian.
    """
    percentage_outside_range = float(torch.mean((z.abs() > valid_range).any(dim=1).float()))
    if warn and percentage_outside_range > warning_threshold:
        warnings.warn(f'{percentage_outside_range:.2f} of all samples fall outside expected latent space area.')
    return percentage_outside_range
