from abc import ABC, abstractmethod
import torch

class BaseTaskStrategy(ABC):

    @abstractmethod
    def get_default_loss(self, **kwargs) -> torch.nn.Module:
        pass

    @abstractmethod
    def process_outputs(self, logits: torch.Tensor) -> torch.Tensor:
        pass
