import torch

class LossAggregator:
    def __init__(self, losses, ignore_val=-1):
        self.losses = losses
        self.ignore_val = ignore_val
    
    def compute(self,outputs,targets,**kwargs):
        task_losses = []
        #assert len(self.losses) == len(outputs) == len(targets)
        if len(self.losses) == 1:
            return self.losses[0].compute(outputs, targets,**kwargs)
        task_losses = [self.losses[i].compute(outputs[i], targets[i]) for i in range(len(self.losses))]
        return torch.stack(task_losses)

class MaskedBCEWithLogitsLoss:
    def __init__(self, ignore_val=-1):
        self.ignore_val = ignore_val
        self.criterion = torch.nn.BCEWithLogitsLoss(reduction="none")
    
    def compute(self, outputs, targets):
        
        mask = (targets != self.ignore_val).all(1).bool()
        
        preds_g_masked = outputs[mask]
        targets_g_masked = targets[mask]
        return torch.mean(self.criterion(preds_g_masked, targets_g_masked.float()))


class CrossEntropyLoss:
    def __init__(self, ignore_val=-1,label_smoothing=0):
        self.ignore_val = ignore_val
        self.criterion = torch.nn.CrossEntropyLoss(reduction="mean",label_smoothing=label_smoothing)
    
    def compute(self, outputs, targets):
        return self.criterion(outputs, targets.squeeze())
