import torch
from torch import nn


class LogNLLLoss(nn.Module):
    def __init__(self, reduction='mean'):
        super().__init__()
        self.loss = nn.NLLLoss(reduction=reduction)
        self.epsilon = 1e-16

    def forward(self, y_pred, y_true):
        return self.loss(torch.log(y_pred + self.epsilon), y_true)


def to_loss(model):
    if model in ['DTN', 'DTN-D', 'DTN-S', 'SDT']:
        return LogNLLLoss()
    elif model in ['MLP', 'DNDT']:
        return nn.CrossEntropyLoss()
    else:
        raise ValueError()
