import torch.nn.functional as F
from torch.nn import L1Loss as MAE
from torch.nn import MSELoss as MSE

loss_MAE = MAE(reduction='none')
loss_MSE = MSE(reduction='none')


def _horizon_loss_MSE(y_pred, y_true):
    return F.mse_loss(y_pred, y_true, reduction='none').mean(1)


def _horizon_loss_MAE(y_pred, y_true):
    return F.l1_loss(y_pred, y_true, reduction='none').mean(1)

def _horizon_loss_MSE_dro(y_pred, y_true):  
    return F.mse_loss(y_pred, y_true, reduction='none').max(1)[0]  

def _horizon_loss_MAE_dro(y_pred, y_true):
    return F.l1_loss(y_pred, y_true, reduction='none').max(1)[0]


def get_criterions(args):
    if 'dro' in args.loss and args.horizon > 1:
        print(f'-> Minimizing worst-sample error')
        mse_loss = _horizon_loss_MSE_dro
        mae_loss = _horizon_loss_MAE_dro
    elif args.horizon > 1:
        mse_loss = _horizon_loss_MSE
        mae_loss = _horizon_loss_MAE
    else:
        mse_loss = loss_MSE
        mae_loss = loss_MAE
    criterions = {'mse': mse_loss, 'mae': mae_loss}
    return criterions
