import torch
import torch.nn as nn

import loss.loss as L

class LossWarper(nn.Module):
    def __init__(self, cfg) -> None:
        super().__init__()
        self.cfg = cfg
    
    def forward(self, x, y):
        loss_info = {}
        loss = 0.0

        for loss_name, loss_kwargs in self.cfg.items():
            if "loss" not in loss_name:
                continue
            if "orth" in loss_name and len(x["embedding_pool"]) == 0:
                continue
            loss_info[loss_name] = getattr(L, loss_name)(x, y, **loss_kwargs['kwargs'])
            loss += loss_info[loss_name] * loss_kwargs['scale']
        
        loss_info["all_loss"] = loss
        return loss_info
    
    def get_loss_fn(self, loss_name):
        return getattr(L, loss_name)