import torch

from calibration_schemes.AbstractCalibration import Calibration
from calibration_schemes.TwoStagedConformalPrediction import TwoStagedCalibration
from data_utils.data_corruption.data_corruption_masker import DataCorruptionMasker
from data_utils.data_scaler import DataScaler
from models.qr_models.OracleQuantileRegression import OracleQuantileRegression


class BadTwoStagedCalibration(TwoStagedCalibration):

    def __init__(self, base_proxy_calibration: Calibration, base_y_calibration: Calibration, alpha: float,
                 dataset_name: str, data_scaler: DataScaler, data_masker: DataCorruptionMasker, x_dim: int, y_dim: int, z_dim,
                 device, seed):
        proxy_qr_model = OracleQuantileRegression(dataset_name, x_dim, y_dim, z_dim, alpha, data_scaler, device=device,
                                                  seed=seed)
        super().__init__(base_proxy_calibration, base_y_calibration, alpha,
                         dataset_name, data_scaler, proxy_qr_model,
                         data_masker)

    def get_mask_probabilities(self, scaled_x, scaled_z):
        unscaled_x = self.data_scaler.unscale_x(scaled_x)
        unscaled_z = self.data_scaler.unscale_z(scaled_z)
        mask_probabilities = self.data_masker.get_corruption_probabilities(unscaled_x, unscaled_z)
        random_noise = (torch.rand_like(mask_probabilities) - 0.5) * 2 * 0.3
        noised_mask_probabilities = (mask_probabilities + random_noise) % 1
        return noised_mask_probabilities

    @property
    def name(self):
        return f"bad_w_{self.base_y_calibration.name}"
