"""
Loss functions for training normalizing flows.

This module provides loss functions for training normalizing flows,
including the reverse KL divergence, REINFORCE estimator, and the
stochastic modulation loss for SESaMo.
"""

import logging
from torch import Tensor

logger = logging.getLogger("SESaMo")


class ReverseKL:
    """
    Reverse Kullback-Leibler divergence loss for normalizing flows.
    
    Computes the KL divergence KL(q || p) where q is the flow distribution
    and p is the target distribution defined by the action.
    
    The loss is: L = E_q[S(x) + log q(x)]
    """
    
    def __init__(self):
        """Initialize the reverse KL loss."""
        logger.info("Initialized ReverseKL Loss")

    def __call__(self, actions: Tensor, log_probs: Tensor) -> Tensor:
        """
        Compute the reverse KL divergence loss.
        
        Parameters
        ----------
        actions : Tensor
            Action values S(x) for samples from the flow. Shape: (batch,).
        log_probs : Tensor
            Log probabilities log q(x) of samples under the flow. Shape: (batch,).
            
        Returns
        -------
        Tensor
            Loss values for each sample. Shape: (batch,).
        """
        return actions + log_probs
    
    
class Reinforce:
    """
    REINFORCE gradient estimator for normalizing flows.
    """
    
    def __init__(self):
        """Initialize the REINFORCE loss."""
        logger.info("Initialized Reinforce Loss")

    def __call__(self, actions: Tensor, log_probs: Tensor) -> Tensor:
        """
        Compute the REINFORCE loss with variance reduction.
        
        Parameters
        ----------
        actions : Tensor
            Action values S(x) for samples from the flow. Shape: (batch,).
        log_probs : Tensor
            Log probabilities log q(x) of samples under the flow. Shape: (batch,).
            
        Returns
        -------
        Tensor
            Loss values for each sample. Shape: (batch,).
        """
        kl = actions + log_probs
        rewards = -(kl - kl.mean())
        return -rewards.detach() * log_probs
    
    
class StochmodLoss:
    """
    Loss function for stochastic modulation layers in SESaMo.
    
    Combines the reverse KL divergence with REINFORCE gradients for the
    stochastic modulation component and regularization.
    
    This implements Equation (18),(19) from the SESaMo paper.
    """
    
    def __init__(self):
        """Initialize the stochastic modulation loss."""
        logger.info("Initialized loss for SESaMo")

    def __call__(
        self, 
        actions: Tensor, 
        log_probs: Tensor, 
        log_prob_stochmod: Tensor, 
        regularization: Tensor = None
    ) -> Tensor:
        """
        Compute the SESaMo loss.
        
        Parameters
        ----------
        actions : Tensor
            Action values S(x) for samples. Shape: (batch,).
        log_probs : Tensor
            Log probabilities log q(x) of samples. Shape: (batch,).
        log_prob_stochmod : Tensor
            Log probabilities from the stochastic modulation. Shape: (batch,).
        regularization : Tensor, optional
            Regularization penalty terms. Shape: (batch,).
            
        Returns
        -------
        Tensor
            Loss values for each sample. Shape: (batch,).
        """
        kl = actions + log_probs
        rewards = -(kl - kl.mean())
        return kl - rewards.detach() * log_prob_stochmod + regularization