# pylint: disable=missing-function-docstring
import math
import numpy as np
import numpy.ma as ma
from .quantile import QuantileForecasts


def rmse(y_pred: np.ndarray, y_true: np.ndarray) -> float:
    return np.sqrt(((y_pred - y_true) ** 2).mean())


def abs_error_sum(y_pred: np.ndarray, y_true: np.ndarray) -> float:
    return np.abs(y_pred - y_true).sum()


def abs_target_sum(y_true: np.ndarray) -> float:
    return np.sum(np.abs(y_true))


def abs_target_mean(y_true: np.ndarray) -> float:
    return np.mean(np.abs(y_true))


def naive_error(y_past: ma.masked_array, seasonality: int) -> np.ndarray:
    error = np.abs(y_past[:, seasonality:] - y_past[:, :-seasonality]).mean(1)
    return ma.getdata(error)


def mase(y_pred: np.ndarray, y_true: np.ndarray, error: np.ndarray) -> np.ndarray:
    mase_values = np.abs(y_pred - y_true).mean(1) / error
    return mase_values.mean()


def smape(y_pred: np.ndarray, y_true: np.ndarray) -> np.ndarray:
    median = y_pred
    num = np.abs(y_true - median)
    denom = (np.abs(y_true) + np.abs(median)) / 2
    # If the denominator is 0, we set it to float('inf') such that any division yields 0 (this
    # might not be fully mathematically correct, but at least we don't get NaNs)
    denom[denom == 0] = math.inf
    return np.mean(num / denom, axis=1).mean()


def mean_weighted_quantile_loss(y_pred: QuantileForecasts, y_true: np.ndarray) -> float:
    y_true_rep = y_true[:, None].repeat(len(y_pred.quantiles), axis=1)
    quantiles = np.array([float(q) for q in y_pred.quantiles])
    quantile_losses = 2 * np.sum(
        np.abs(
            (y_pred.values - y_true_rep) * ((y_true_rep <= y_pred.values) - quantiles[:, None])
        ),
        axis=-1,
    )  # shape [num_time_series, num_quantiles]
    denom = np.sum(np.abs(y_true))  # shape [1]
    weighted_losses = quantile_losses.sum(0) / denom  # shape [num_quantiles]
    return weighted_losses.mean()
