import torch
import math

def adaptive_kl_weight(base_weight, epoch, total_epochs=100):
    """
    Adaptive KL weight calculation
    Uses a sigmoid curve to control the growth of KL weight
    
    Args:
        base_weight: Base KL weight
        epoch: Current training epoch
        total_epochs: Total number of training epochs
    
    Returns:
        Adaptive KL weight for the current epoch
    """
    midpoint = total_epochs // 2
    steepness = 0.2  # Controls the steepness of the curve
    adaptive_weight = base_weight / (1 + math.exp(-steepness * (epoch - midpoint)))
    return min(adaptive_weight, base_weight)

def loss_function(recon_x, x, mu, log_var, kl_weight=1.0, epoch=None, total_epochs=100):
    """
    Computes the CVAE loss function
    
    Args:
        recon_x: Reconstructed sequence
        x: Target sequence
        mu: Encoder mean
        log_var: Encoder log variance
        kl_weight: Base KL weight
        epoch: Current training epoch
        total_epochs: Total number of training epochs
    
    Returns:
        Total loss, reconstruction loss, KL loss, and current KL weight
    """
    batch_size = x.size(0)
    recon_loss = torch.nn.functional.mse_loss(recon_x, x, reduction='sum') / batch_size
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) / batch_size
    if epoch is not None:
        adaptive_kl_weight_value = adaptive_kl_weight(kl_weight, epoch, total_epochs)
    else:
        adaptive_kl_weight_value = kl_weight
    loss = recon_loss + adaptive_kl_weight_value * kl_loss
    return loss, recon_loss, kl_loss, adaptive_kl_weight_value
