import abc
from abc import ABC

import torch

from models.ClassificationModel import ClassProbabilities
from models.abstract_models.NetworkLearningModel import NetworkLearningModel


class ClassificationProbabilityEstimatorWithProxy(NetworkLearningModel, ABC):
    def __init__(self, dataset_name: str, saved_models_path: str, figures_dir: str, seed: int):
        super().__init__(dataset_name, saved_models_path, figures_dir, seed)

    @abc.abstractmethod
    def estimate_probabilities(self, x: torch.Tensor, y: torch.Tensor) -> ClassProbabilities:
        pass
