import torch
import math

def adaptive_kl_weight(base_weight, epoch, total_epochs=100):
    midpoint = total_epochs // 2
    steepness = 0.2
    adaptive_weight = base_weight / (1 + math.exp(-steepness * (epoch - midpoint)))
    return min(adaptive_weight, base_weight)

def loss_function(recon_x, x, mu, log_var, kl_weight=1.0, epoch=None, total_epochs=100):
    batch_size = x.size(0)
    recon_loss = torch.nn.functional.mse_loss(recon_x, x, reduction='sum') / batch_size
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / batch_size
    if epoch is not None:
        adaptive_kl_weight_value = adaptive_kl_weight(kl_weight, epoch, total_epochs)
    else:
        adaptive_kl_weight_value = kl_weight
    loss = recon_loss + adaptive_kl_weight_value * kl_loss
    return loss, recon_loss, kl_loss, adaptive_kl_weight_value
