import torch
import torch.nn as nn


class ExpLoss(nn.Module):
    @staticmethod
    def forward(pred, target, reduction="mean"):
        diff = (pred - target).abs()
        loss = 1 - torch.exp(-diff)
        if reduction == "mean":
            return loss.mean()
        elif reduction == "none":
            return loss
        raise NotImplementedError
