from data_utils.data_scaler import DataScaler
from data_utils.get_dataset_utils import get_data_generator
from imputation_methods.ImputationMethod import ImputationMethod


class OracleImputation(ImputationMethod):

    def __init__(self, dataset_name : str, x_dim: int, z_dim: int, device, data_scaler: DataScaler, seed):
        super().__init__()
        self.data_generator = get_data_generator(dataset_name, x_dim, z_dim)
        self.device = device
        self.data_scaler = data_scaler
        self.seed = seed

    def calibrate(self, x_cal, y_cal, z_cal, deleted_cal):
        super().calibrate(x_cal, y_cal, z_cal, deleted_cal)

    @property
    def name(self):
        return f"oracle"

    def predict(self, x, z):
        unscaled_x = self.data_scaler.unscale_x(x)
        unscaled_z = self.data_scaler.unscale_z(z)
        sample = self.data_generator.get_y_given_x_z(unscaled_x, unscaled_z, seed=self.seed)
        scaled_sample = self.data_scaler.scale_y(sample)
        return scaled_sample



