import numpy as np

__all__ = ['mse', 'rmse', 'mll', 'mae', 'mre', 'smll', 'smse', 'srmse', 'r2']


def mre(mean, data):
    """Mean relative error.

    :param mean: Mean of prediction.
    :param data: Reference data.
    """
    return ((mean - data).abs().sum(0) / data.abs().sum(0)).mean()


def mse(mean, data):
    """Mean squared error.

    Args:
        mean (tensor): Mean of prediction.
        data (tensor): Reference data.

    Returns:
        tensor: Mean squared error.
    """
    return ((mean - data) ** 2).mean()


def smse(mean, data):
    """Standardised mean squared error.

    Args:
        mean (tensor): Mean of prediction.
        data (tensor): Reference data.

    Returns:
        tensor: Standardised mean squared error.
    """
    return mse(mean, data) / mse(data.mean(), data)


def rmse(mean, data):
    """Root mean squared error.

    Args:
        mean (tensor): Mean of prediction.
        data (tensor): Reference data.

    Returns:
        tensor: Root mean squared error.
    """
    return mse(mean, data) ** .5


def srmse(mean, data):
    """Standardised root mean squared error.

    Args:
        mean (tensor): Mean of prediction.
        data (tensor): Reference data.

    Returns:
        tensor: Standardised root mean squared error.
    """
    return rmse(mean, data) / rmse(data.mean(), data)


def mae(mean, data):
    """Mean absolute error.

    Args:
        mean (tensor): Mean of prediction.
        data (tensor): Reference data.

    Returns:
        tensor: Mean absolute error.
    """
    return np.abs(mean - data).mean()


def mll(mean, variance, data):
    """Mean log loss.

    Args:
        mean (tensor): Mean of prediction.
        variance (tensor): Variance of prediction.
        data (tensor): Reference data.

    Returns:
        tensor: Mean log loss.
    """
    return (0.5 * np.log(2 * np.pi * variance) +
            0.5 * (mean - data) ** 2 / variance).mean()


def smll(mean, variance, data):
    """Standardised mean log loss.

    Args:
        mean (tensor): Mean of prediction.
        variance (tensor): Variance of prediction.
        data (tensor): Reference data.

    Returns:
        tensor: Standardised mean log loss.
    """
    return mll(mean, variance, data) - mll(data.mean(), data.var(ddof=0), data)


def r2(mean, data):
    """R-squared.

    Args:
        mean (tensor): Mean of prediction.
        data (tensor): Reference data.

    Returns:
        tensor: R-squared.
    """
    return 1 - ((data - mean) ** 2).mean() / data.var(ddof=0)
