import torch
import numpy as np
import torch.nn.functional as F

from . import BaseMetric


class MeanSquareError(BaseMetric):
    def __init__(self, name='mse'):
        super(MeanSquareError, self).__init__(name)

    def call(self, output, target):
        return F.mse_loss(output, target)


class RootMeanSquareError(BaseMetric):
    def __init__(self, name='rmse'):
        super(RootMeanSquareError, self).__init__(name)

    def call(self, output, target):
        return torch.sqrt(torch.mean((output - target) ** 2))


class MeanAbsoluteError(BaseMetric):
    def __init__(self, name='mae'):
        super(MeanAbsoluteError, self).__init__(name)

    def call(self, output, target):
        return torch.mean(torch.abs(output - target))


class SymmetricMeanAbsolutePercentageError(BaseMetric):
    def __init__(self, name='smape'):
        super(SymmetricMeanAbsolutePercentageError, self).__init__(name)

    def call(self, output, target):
        return torch.mean(torch.abs(output - target) / (torch.abs(output) + torch.abs(target)))


class MeanAbsolutePercentageError(BaseMetric):
    def __init__(self, name='smape'):
        super(MeanAbsolutePercentageError, self).__init__(name)

    def call(self, output, target):
        return torch.mean(torch.abs(output - target) / (torch.abs(target) + 1e-6))
