
from typing import Tuple
import torch

from multihead_llpedm.loss import MultiHeadLLPEDMLoss
from multihead_clsedm.model import MultiHeadCLSEDM

class MultiHeadLLPEDM(MultiHeadCLSEDM):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # loss function
        self.loss = MultiHeadLLPEDMLoss(**kwargs)

    def loss_func(self, x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.loss(self.net, x, y)