import torch
import torch.nn as nn

class LossLoader:
    def __init__(self, loss_name="mse", params={'reduction' : "mean"}):
        self.loss_name = loss_name.lower()
        self.params = params if params else {}


    def get_loss(self):
        losses = {
            "mse": nn.MSELoss,
            "rmse": nn.MSELoss,
            "bce": nn.BCEWithLogitsLoss,
            "l1": nn.L1Loss
        }

        if self.loss_name not in losses:
            raise ValueError(f"The Loss {self.loss_name} does not belong to the set of available losses : mse, cross_entropy, l1.")

        return losses[self.loss_name](**self.params)

#test
# loss_fn_tmp = LossLoader(loss_name="mse", params={'reduction' : "mean"}).get_loss()
# print(loss_fn_tmp(2*torch.ones(10), torch.zeros(10)))