__all__ = ["SMAPE", "MAPE"]

import ot
import torch
from torch import nn


def safe_div(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    div = a / b
    div[(div.isnan()) + (div.isinf())] = 0
    return div


class WassersteinLoss(nn.Module):
    def forward(self, true: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
        M = torch.ones(
            true.shape[1], pred.shape[1], dtype=true.dtype, device=true.device
        )
        losses = torch.zeros(true.shape[0], device=true.device)
        for i, (t, p) in enumerate(zip(true, pred)):
            losses[i] = ot.emd2(t, p, M)
        return losses.mean()


class SMAPE(nn.Module):
    def forward(self, true: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
        return torch.mean(
            safe_div(
                torch.abs(true - pred),
                (torch.abs(true) + torch.abs(pred)) / 2,
            )
        )


class MAPE(nn.Module):
    def forward(self, true: torch.Tensor, pred: torch.Tensor) -> torch.Tensor:
        return torch.mean(torch.abs(safe_div(true - pred, true)))
