from collections import OrderedDict

import torch


def sample_posterior(moments, latents_scale, latents_bias=None):
    """Sample from the posterior distribution."""
    if latents_bias is None:
        latents_bias = torch.tensor([0, 0, 0, 0]).view(1, 4, 1, 1).to(moments.device)

    device = moments.device
    mean, std = torch.chunk(moments, 2, dim=1)
    z = mean + std * torch.randn_like(mean)
    z = z * latents_scale + latents_bias
    return z


def sample_posterior_2(mean, std, latents_scale, latents_bias=None):
    """Sample from the posterior distribution with separate mean and std."""
    if latents_bias is None:
        latents_bias = torch.tensor([0, 0, 0, 0]).view(1, 4, 1, 1).to(mean.device)

    device = mean.device
    z = mean + std * torch.randn_like(mean)
    z = z * latents_scale + latents_bias
    return z


def update_ema(ema_model, model, accelerator, decay=0.9999):
    """Update Exponential Moving Average model weights."""
    # Get parameters from both models
    ema_params = OrderedDict(ema_model.named_parameters())
    model_params = OrderedDict(model.named_parameters())

    # Find common keys that exist in both models
    common_keys = set(ema_params.keys()).intersection(set(model_params.keys()))

    # Update only the parameters that exist in both models
    for k in common_keys:
        ema_param = ema_params[k]
        model_param = model_params[k]
        ema_param.data.mul_(decay).add_(model_param.data, alpha=1 - decay)


def requires_grad(model, flag=True):
    """Set requires_grad flag for all parameters in a model."""
    for p in model.parameters():
        p.requires_grad = flag
    for p in model.parameters():
        p.requires_grad = flag
