import torch
import numpy as np


# 根据null_val创建掩码，将相应位置设为无效样本
def get_MASK(labels, null_val=np.nan):
    if np.isnan(null_val):
        mask = ~torch.isnan(labels)
    else:
        mask = (labels!=null_val)
    mask = mask.float()
    mask /=  torch.mean((mask))
    return mask


# LOSS: MSE，L2损失，低损失情况下有助于平滑收敛，但高损失值情况下容易引发过度惩罚
def get_loss_MSE(preds, labels, null_val=np.nan):
    mask = get_MASK(labels=labels, null_val=null_val)
    # NAN的位置取0，下面也一样
    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
    loss = (preds-labels)**2
    loss = loss * mask
    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
    return torch.mean(loss)


# LOSS: RMSE，一般用于启发式模型，但梯度在零处不可微
def get_loss_RMSE(preds, labels, null_val=np.nan):
    return torch.sqrt(get_loss_MSE(preds=preds, labels=labels, null_val=null_val))


# LOSS: MAE，L1损失，really simple，但对异常值不敏感，零处不可微，慎用
def get_loss_MAE(preds, labels, null_val=np.nan):
    mask = get_MASK(labels=labels, null_val=null_val)
    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
    loss = torch.abs(preds-labels)
    loss = loss * mask
    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
    return torch.mean(loss)


# LOSS: MAPE，一般用于回归分析，其计算与变量规模（单位）无关，但必须注意分母是0的问题；倾向于选择偏低预测的模型
def get_loss_MAPE(preds, labels, null_val=np.nan):
    mask = get_MASK(labels=labels, null_val=null_val)
    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
    loss = torch.abs((preds-labels)/labels) * 100
    loss = loss * mask
    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
    return torch.mean(loss)



# LOSS: HUBER，高于阈值δ的部分将采用MAE，低于阈值δ的部分采用MSE
def get_loss_HUBER(preds, labels, threshold, null_val=np.nan):
    mask = get_MASK(labels=labels, null_val=null_val)
    mask = torch.where(torch.isnan(mask), torch.zeros_like(mask), mask)
    error = torch.abs(preds - labels)
    loss = torch.where(error <= threshold, 0.5 * error**2, threshold * (error - 0.5 * threshold))
    loss = loss * mask
    loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)
    return torch.mean(loss)


def get_loss_hyber(preds, labels, null_val=np.nan):
    mse = get_loss_MSE(preds=preds, labels=labels, null_val=null_val)
    mae = get_loss_MAE(preds=preds, labels=labels, null_val=null_val)
    return 0.5 * mse + 0.5 * mae

type2loss = {
    "mse": get_loss_MSE,
    "rmse": get_loss_RMSE,
    "mae": get_loss_MAE,
    "mape": get_loss_MAPE,
    "huber": get_loss_HUBER,
    "hyber": get_loss_hyber
}

def get_loss(preds, labels):
    mse = get_loss_MSE(preds, labels)
    rmse = get_loss_RMSE(preds, labels)
    mae = get_loss_MAE(preds, labels)
    mape = get_loss_MAPE(preds, labels)
    return mse, rmse, mae, mape
