import torch


def mixed_init(z_shape, device=None):
    z_init = torch.randn(*z_shape, device=device)
    mask = torch.zeros_like(z_init, device=device).bernoulli_(0.5)
    
    return z_init * mask

