"""
Loss functions for MDShortcut diffusion models.

This module provides loss functions for training diffusion models on molecular dynamics data,
supporting both position and element prediction losses with configurable weighting.
"""
import torch.nn.functional as F
from torch import nn


class MaterialLoss:
    """Multi-component loss function for material diffusion model training.
    
    Computes weighted combination of position and element losses using configurable
    loss functions (L1, L2, Huber, KL divergence). Designed for epsilon prediction
    in diffusion models where both atomic positions and element types are predicted.
    
    Attributes:
        position_weight (float): Weight applied to position loss component.
        element_weight (float): Weight applied to element loss component.
        pos_loss_func (callable): Loss function for position predictions.
        el_loss_func (callable): Loss function for element predictions.
    """
    
    def __init__(self, norm_type, element_norm_type=None, position_weight=1.0, element_weight=1.0):
        """Initialize MaterialLoss with specified loss types and weights.
        
        Args:
            norm_type (str): Loss function type for positions ('l1', 'l2', 'huber', 'kl').
            element_norm_type (str, optional): Loss function type for elements. 
                If None, uses same as norm_type. Defaults to None.
            position_weight (float, optional): Weight for position loss. Defaults to 1.0.
            element_weight (float, optional): Weight for element loss. Defaults to 1.0.
        """
        self.position_weight = position_weight
        self.element_weight = element_weight
        if element_norm_type is None:
            element_norm_type = norm_type
        
        self.pos_loss_func = self._get_loss_func(norm_type)
        self.el_loss_func = self._get_loss_func(element_norm_type)

    def __call__(self, pos_tgt, pred_pos, el_tgt, pred_els, **kwargs):
        """Compute the total weighted loss from position and element predictions.
        
        Args:
            pos_tgt (torch.Tensor): Target position noise, shape (n_atoms, 3).
            pred_pos (torch.Tensor): Predicted position noise, shape (n_atoms, 3).
            el_tgt (torch.Tensor): Target element embeddings, shape (n_atoms, d_embed).
            pred_els (torch.Tensor): Predicted element embeddings, shape (n_atoms, d_embed).
            **kwargs: Additional arguments (ignored).
            
        Returns:
            tuple:
                - loss_tot (torch.Tensor): Total weighted loss scalar
                - (loss_pos, loss_el) (tuple): Individual position and element losses
        """
        loss_pos = self.pos_loss_func(pos_tgt, pred_pos)
        if el_tgt is not None and pred_els is not None:
            loss_el = self.el_loss_func(el_tgt, pred_els)
        else:
            loss_el = 0.0

        loss_tot = self.position_weight * loss_pos + self.element_weight * loss_el
        return loss_tot, (loss_pos, loss_el)
    
    @staticmethod
    def _get_loss_func(name):
        """Get loss function by name.
        
        Args:
            name (str): Loss function name ('l1', 'l2', 'huber', 'kl').
            
        Returns:
            callable: PyTorch loss function.
            
        Raises:
            NotImplementedError: If loss function name is not recognized.
        """
        if name == 'l1':
            loss_func = F.l1_loss
        elif name == 'l2':
            loss_func = F.mse_loss
        elif name == 'huber':
            # Huber loss is a combination of L1 and L2 losses and is widely used in diffusion models.
            loss_func = F.smooth_l1_loss
        elif name == 'kl':
            # KL loss is good for comparing two distributions.
            loss_func = nn.KLDivLoss(reduction='batchmean')
        else:
            raise NotImplementedError(f'Unknown loss function: {name}')
        
        return loss_func
