from .base_loss import BaseLoss
from occforecasting.registry import MODELS

@MODELS.register_module()
class VQVAEEmbedLoss(BaseLoss):

    def __init__(self, weight=1.0, input_dict=None, **kwargs):
        super().__init__(weight)

        if input_dict is None:
            self.input_dict = {
                'embed_loss': 'embed_loss'
            }
        else:
            self.input_dict = input_dict
        self.loss_func = self.embed_loss
    
    def embed_loss(self, embed_loss):
        return embed_loss