import torch
import torch.nn as nn
import torch.nn.functional as F

from torch_geometric.graphgym.config import cfg
from torch_geometric.graphgym.register import register_loss


@register_loss('cross_entropy_with_ignore_index')
def cross_entropy_with_ignore_index(pred, true):
    if cfg.model.loss_fun == 'cross_entropy_with_ignore_index':
        ignore_index = cfg.dataset.ignore_index
        if pred.ndim > 1 and true.ndim == 1:
            # multiclass
            pred = F.log_softmax(pred, dim=-1)
            return F.nll_loss(pred, true, ignore_index=ignore_index), pred
        else:
            # binary or multilabel
            bce_loss = nn.BCEWithLogitsLoss(reduction=cfg.model.size_average)
            if pred.ndim > 1:
                mask = true.sum(dim=-1) == 0  # ignore examples that have a zero true vector
            else:
                mask = true == ignore_index
            true = true.float()
            return bce_loss(pred[~mask], true[~mask]), torch.sigmoid(pred)


@register_loss('mse_with_ignore_index')
def mse_with_ignore_index(pred, true):
    if cfg.model.loss_fun == 'mse_with_ignore_index':
        mse_loss = nn.MSELoss(reduction=cfg.model.size_average)
        ignore_index = cfg.dataset.ignore_index
        mask = true == ignore_index
        true = true.float()
        return mse_loss(pred[~mask], true[~mask]), pred
