from typing import Tuple

import torch

from multihead_noisyedm.loss import MultiHeadNoisyEDMLoss
from multihead_clsedm.model import MultiHeadCLSEDM

class MultiHeadNoisyEDM(MultiHeadCLSEDM):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        # loss fuction
        self.loss = MultiHeadNoisyEDMLoss(**kwargs)

    def loss_func(self, x: torch.Tensor, y: torch.Tensor, P: torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]:
        return self.loss(self.net, x, y, P)