import torch
import torch.nn.functional as F

class SegCELoss(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, pred, target):
        input_shape = target.size()[1:]
        pred_shape = pred.size()[2:]
        if pred_shape != input_shape:
            pred = F.interpolate(pred, size=input_shape, mode="bilinear", align_corners=False)
        return F.cross_entropy(pred, target, ignore_index=255)
