from calibration_schemes.AbstractCalibration import Calibration
from calibration_schemes.TwoStagedConformalPrediction import TwoStagedCalibration
from data_utils.data_corruption.data_corruption_masker import DataCorruptionMasker
from data_utils.data_scaler import DataScaler
from models.qr_models.OracleQuantileRegression import OracleQuantileRegression


class OracleTwoStagedCalibration(TwoStagedCalibration):

    def __init__(self, base_proxy_calibration: Calibration, base_y_calibration: Calibration, alpha: float,
                 dataset_name: str, data_scaler: DataScaler, data_masker: DataCorruptionMasker, x_dim: int, y_dim: int, z_dim,
                 device, seed):
        proxy_qr_model = OracleQuantileRegression(dataset_name, x_dim, y_dim, z_dim, alpha, data_scaler, device=device,
                                                  seed=seed)
        super().__init__(base_proxy_calibration, base_y_calibration, alpha,
                         dataset_name, data_scaler, proxy_qr_model,
                         data_masker)

    @property
    def name(self):
        return f"oracle_w_{self.base_y_calibration.name}"
