import abc

import torch


class ErrorSampler(abc.ABC):
    def __init__(self):
        pass

    def fit(self, x_train, z_train, y_train, errors_train, deleted_train, x_val, z_val, y_val, errors_val, deleted_val, **kwargs):
        pass

    def calibrate(self, x_cal, z_cal, y_cal, errors_cal, deleted_cal, **kwargs):
        pass

    @abc.abstractmethod
    def sample_error(self, x_test: torch.Tensor, z_test: torch.Tensor) -> torch.Tensor:
        pass

    @property
    @abc.abstractmethod
    def name(self) -> str:
        pass


