import numpy as np

from easy_tpp.utils.const import PredOutputIndex
from easy_tpp.utils.metrics import MetricsHelper


@MetricsHelper.register(name='rmse', direction=MetricsHelper.MINIMIZE, overwrite=False)
def rmse_metric_function(predictions, labels, **kwargs):
    """Compute rmse metrics of the time predictions.

    Args:
        predictions (np.array): model predictions.
        labels (np.array): ground truth.

    Returns:
        float: average rmse of the time predictions.
    """
    seq_mask = kwargs.get('seq_mask')
    pred = predictions[PredOutputIndex.TimePredIndex][seq_mask]
    label = labels[PredOutputIndex.TimePredIndex][seq_mask]

    pred = np.reshape(pred, [-1])
    label = np.reshape(label, [-1])
    return np.sqrt(np.mean((pred - label) ** 2))


@MetricsHelper.register(name='mrae', direction=MetricsHelper.MINIMIZE, overwrite=False)
def mrae_metric_function(predictions, labels, **kwargs):
    """Compute median relative absolute error metrics of the time predictions.
    Formula: abs(dt_hat - dt) / dt.
    Note dt is always positive, so this computation is always valid.

    Args:
        predictions (np.array): model predictions.
        label (np.array): ground truth.

    Returns:
        float:  median relative absolute error of the time predictions.
    """
    seq_mask = kwargs.get('seq_mask')
    pred = predictions[PredOutputIndex.TimePredIndex][seq_mask]
    label = labels[PredOutputIndex.TimePredIndex][seq_mask]

    pred = np.reshape(pred, [-1])
    label = np.reshape(label, [-1])
    assert np.all(label > 0)
    return np.median(abs(pred - label)/(label + np.finfo(np.float32).eps))


@MetricsHelper.register(name='acc', direction=MetricsHelper.MAXIMIZE, overwrite=False)
def acc_metric_function(predictions, labels, **kwargs):
    """Compute accuracy ratio metrics of the type predictions.

    Args:
        predictions (np.array): model predictions.
        labels (np.array): ground truth.

    Returns:
        float: accuracy ratio of the type predictions.
    """
    seq_mask = kwargs.get('seq_mask')
    pred = predictions[PredOutputIndex.TypePredIndex][seq_mask]
    label = labels[PredOutputIndex.TypePredIndex][seq_mask]
    pred = np.reshape(pred, [-1])
    label = np.reshape(label, [-1])
    return np.mean(pred == label)
