import torch.nn as nn
import torch.nn.functional as F
import torch


def compute_loss(pred, true, cfg):
    """
    Compute loss and prediction score

    Args:
        pred (torch.tensor): Unnormalized prediction
        true (torch.tensor): Grou

    Returns: Loss, normalized prediction score

    """
    bce_loss = nn.BCEWithLogitsLoss(reduction=cfg.loss.reduction)
    mse_loss = nn.MSELoss(reduction=cfg.loss.reduction)

    pred = pred.squeeze(-1) if pred.ndim > 1 else pred
    true = true.squeeze(-1) if true.ndim > 1 else true

    if cfg.loss.name == "cross_entropy":
        # multiclass
        if pred.ndim > 1 and true.ndim == 1:
            pred = F.log_softmax(pred, dim=-1)
            return F.nll_loss(pred, true), pred
        # binary or multilabel
        else:
            true = true.float()
            return bce_loss(pred, true), torch.sigmoid(pred)
    elif cfg.model.loss_fun == "mse":
        true = true.float()
        return mse_loss(pred, true), pred
    else:
        raise ValueError("Loss func {} not supported".format(cfg.model.loss_fun))
