# from data_utils.data_scaler import DataScaler
# from data_utils.get_dataset_utils import get_data_generator
# from imputation_methods.ImputationMethod import ImputationMethod
# from imputation_methods.regression_imputations.RegressorImputationWithErrorSampling import \
#     RegressorImputationWithErrorSampling
# from models.regressors.FullRegressorWithLinear import FullRegressorWithLinear
#
#
# class OracleErrorImputation(ImputationMethod):
#
#     def __init__(self, dataset_name : str, x_dim: int, z_dim: int, device, data_scaler: DataScaler, seed,
#                  saved_models_path, hidden_dims, dropout, batch_norm, lr, wd, figures_dir):
#         super().__init__()
#         self.data_generator = get_data_generator(dataset_name, x_dim, z_dim)
#         self.device = device
#         self.data_scaler = data_scaler
#         self.seed = seed
#         self.regressor = FullRegressorWithLinear(dataset_name, saved_models_path, x_dim,
#                                            z_dim,
#                                            hidden_dims,
#                                            dropout,
#                                            batch_norm, lr, wd, device,
#                                            figures_dir=figures_dir,
#                                            seed=seed)
#         self.error_sampler =
#         self.imputator = RegressorImputationWithErrorSampling(self.regressor, self.error_sampler)
#
#     def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
#             **kwargs):
#         self.imputator.fit(x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val,
#                            epochs=epochs, batch_size=batch_size, n_wait=n_wait, **kwargs)
#
#     def calibrate(self, x_cal, y_cal, z_cal, deleted_cal):
#         super().calibrate(x_cal, y_cal, z_cal, deleted_cal)
#         self.imputator.calibrate(x_cal, y_cal, z_cal, deleted_cal)
#
#     @property
#     def name(self):
#         return self.imputator.name
#
#     def predict(self, x, z):
#         return self.imputator.predict(x,z)
#
#
#
