from data_utils.datasets.dataset import Dataset
from data_utils.datasets.synthetic_dataset_generator import SyntheticDataGenerator
from data_utils.get_dataset_utils import get_data_generator
from error_sampler.ErrorSampler import ErrorSampler
from models.CVAE_GAN import CVAE_GAN, CVAE_GAN_MODE
from models.regressors.MeanRegressor import MeanRegressor


class OracleErrorSampler(ErrorSampler):

    def __init__(self, dataset_name: str, x_dim: int, z_dim: int, regressor: MeanRegressor):
        super().__init__()
        self.regressor = regressor
        self.data_generator = get_data_generator(dataset_name, x_dim, z_dim)


    def fit(self, x_train, z_train, y_train, errors_train, deleted_train, x_val, z_val, y_val, errors_val, deleted_val,
            **kwargs):
        pass

    def sample_error(self, x_test, z_test):
        model_predictions = self.regressor.predict_mean(x_test, z_test).squeeze()
        y = self.data_generator.get_y_given_x_z(x_test, z_test).squeeze()
        error = y - model_predictions
        return error

    @property
    def name(self) -> str:
        return f"oracle_error_sampler"
