import numpy as np


def RSE(pred, true):
    return np.sqrt(np.sum((true - pred) ** 2)) / np.sqrt(np.sum((true - true.mean()) ** 2))


def CORR(pred, true):
    u = ((true - true.mean(0)) * (pred - pred.mean(0))).sum(0)
    d = np.sqrt(((true - true.mean(0)) ** 2 * (pred - pred.mean(0)) ** 2).sum(0))
    return (u / d).mean(-1)


def MAE(pred, true):
    return np.mean(np.abs(true - pred))


def MSE(pred, true):
    return np.mean((true - pred) ** 2)


def RMSE(pred, true):
    return np.sqrt(MSE(pred, true))


def MAPE(pred, true):
    return np.mean(np.abs((true - pred) / true))


def MSPE(pred, true):
    return np.mean(np.square((true - pred) / true))


def R2(pred, true):
    # 计算残差平方和 (SSE)
    sse = np.sum((true - pred) ** 2)
    # 计算总平方和 (SST)
    y_mean = np.mean(true)
    sst = np.sum((true - y_mean) ** 2)
    # 计算 R²
    r2 = 1 - (sse / sst)
    return r2


def NSE(pred, true):
    """Nash-Sutcliffe Efficiency (NSE)"""
    return 1-(np.sum((true - pred) ** 2) / np.sum((true - np.mean(true)) ** 2))


def PBIAS(pred, obs):
    """Percent Bias (PBIAS)"""
    if len(pred) == 0 or np.all(obs == 0):
        return np.nan
    return np.sum((pred - obs) / obs) / len(obs) * 100


def KGE(pred, obs):
    """Kling-Gupta Efficiency (KGE)"""
    if len(pred) == 0:
        return np.nan
    # 相关系数
    if np.std(pred) == 0 or np.std(obs) == 0:
        r = 0.0
    else:
        r = np.corrcoef(pred, obs)[0, 1]
        if np.isnan(r):
            r = 0.0
    # λ: 标准差比
    lam = np.std(pred) / np.std(obs) if np.std(obs) != 0 else 1.0
    # γ: 均值比
    gamma = np.mean(pred) / np.mean(obs) if np.mean(obs) != 0 else 1.0
    return 1 - np.sqrt((r - 1) ** 2 + (lam - 1) ** 2 + (gamma - 1) ** 2)


def FLV(pred, obs):
    """Flow Duration Curve Low Flow Variability (FLV)"""
    if len(pred) == 0 or np.sum(obs) == 0:
        return np.nan
    sort_idx = np.argsort(obs)  # 升序（低流量）
    pred_sorted = pred[sort_idx]
    obs_sorted = obs[sort_idx]
    n_low = max(1, int(0.3 * len(obs)))  # 下 30%
    return np.sum(np.abs(pred_sorted[:n_low] - obs_sorted[:n_low])) / np.sum(obs) * 100


def FHV(pred, obs):
    """Flow Duration Curve High Flow Variability (FHV)"""
    if len(pred) == 0 or np.sum(obs) == 0:
        return np.nan
    sort_idx = np.argsort(obs)[::-1]  # 降序（高流量）
    pred_sorted = pred[sort_idx]
    obs_sorted = obs[sort_idx]
    n_high = max(1, int(0.02 * len(obs)))  # 上 2%
    return np.sum(np.abs(pred_sorted[:n_high] - obs_sorted[:n_high])) / np.sum(obs) * 100


def metric(pred, true):
    mae = MAE(pred, true)
    mse = MSE(pred, true)
    rmse = RMSE(pred, true)
    mape = MAPE(pred, true)
    mspe = MSPE(pred, true)
    r2 = R2(pred, true)
    nse = NSE(pred, true)
    pbias = PBIAS(pred, true)
    kge = KGE(pred, true)
    flv = FLV(pred, true)
    fhv = FHV(pred, true)

    return mae, mse, rmse, mape, mspe, r2, nse, pbias, kge, flv, fhv