from calibration_schemes.AbstractCalibration import Calibration
from calibration_schemes.PrivilegedConformalPrediction import PrivilegedConformalPrediction
from data_utils.data_corruption.data_corruption_masker import DataCorruptionMasker
from data_utils.data_scaler import DataScaler


class BadFullyWeightedCalibration(PrivilegedConformalPrediction):

    def __init__(self, base_y_calibration: Calibration, alpha: float, dataset_name: str, data_scaler: DataScaler,
                 data_masker: DataCorruptionMasker):
        super().__init__(base_y_calibration, alpha, dataset_name, data_scaler, data_masker)

    def get_mask_probabilities(self, scaled_x, scaled_z):
        unscaled_x = self.data_scaler.unscale_x(scaled_x)
        unscaled_z = self.data_scaler.unscale_z(scaled_z)
        mask_probabilities = self.data_masker.get_corruption_probabilities(unscaled_x, unscaled_z)
        noised_mask_probabilities = mask_probabilities / 4
        return noised_mask_probabilities

    @property
    def name(self):
        return f"bad_w2_{self.base_y_calibration.name}"
