# global max_val1, max_val2, threshold
import abc
import dataclasses
from abc import ABC

import numpy as np
import torch

from data_utils.data_corruption.corruption_type import CorruptionType
from data_utils.data_corruption.covariates_dimension_reducer import CovariatesDimensionReducer, ZMeanReducer
from utils.utils import get_seed, set_seeds


@dataclasses.dataclass
class DataCorruptionInfo:
    def __init__(self, max_val0: float, max_val1: float, max_val2: float, threshold: float, threshold2: float,
                 threshold3: float,
                 power: float, extreme_max_val: float,
                 min_val: float):
        self.max_val0 = max_val0
        self.max_val1 = max_val1
        self.max_val2 = max_val2
        self.threshold = threshold
        self.threshold2 = threshold2
        self.threshold3 = threshold3
        self.power = power
        self.extreme_max_val = extreme_max_val
        self.min_val = min_val
        self.val_min = None
        self.val_max = None


class DataCorruptionMasker(ABC):

    @abc.abstractmethod
    def get_corruption_probabilities(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor):
        pass

    def get_corruption_mask(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor,
                            seed: int = 0) -> torch.Tensor:
        probability_to_delete = self.get_corruption_probabilities(unscaled_x, unscaled_z).squeeze()
        if seed is not None:
            curr_seed = get_seed()
            set_seeds(seed)
        mask = torch.rand_like(probability_to_delete) < probability_to_delete
        if seed is not None:
            set_seeds(curr_seed)
        return mask


class DummyDataCorruptionMasker(DataCorruptionMasker):

    def get_corruption_probabilities(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor):
        return torch.zeros(len(unscaled_z)).to(unscaled_x.device)


class OracleDataCorruptionMasker(DataCorruptionMasker):

    def __init__(self, unscaled_x, unscaled_z, mask_probability):
        self.unscaled_x = unscaled_x
        self.unscaled_z = unscaled_z
        self.mask_probability = mask_probability
        if len(unscaled_x.shape) > 2:
            unscaled_x = torch.flatten(unscaled_x, start_dim=1)
        self.covariates = torch.cat([unscaled_x, unscaled_z], dim=-1)

    def get_corruption_probabilities(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor):
        if len(unscaled_z.shape) == 1:
            unscaled_z = unscaled_z.unsqueeze(-1)
        if len(unscaled_x.shape) > 2:
            unscaled_x = torch.flatten(unscaled_x, start_dim=1)
        test_covariates = torch.cat([unscaled_x, unscaled_z], dim=-1)
        self.covariates = self.covariates.to(test_covariates.device)
        mask_probabilities = []
        max_n_zero = 0
        min_n_zero = np.inf
        for covariate in test_covariates:
            diffs = (covariate.unsqueeze(0) - self.covariates).abs().mean(dim=-1)
            n_zero = (diffs == 0.).float().sum().item()
            max_n_zero = max(max_n_zero, n_zero)
            min_n_zero = min(min_n_zero, n_zero)
            sample_idx = torch.argmin((covariate.unsqueeze(0) - self.covariates).abs().mean(dim=-1))
            mask_probabilities += [self.mask_probability[sample_idx].item()]
        # print("max_n_zero: ", max_n_zero)
        # print("min_n_zero: ", min_n_zero)

        return torch.Tensor(mask_probabilities).to(unscaled_x.device)


class DefaultDataCorruptionMasker(DataCorruptionMasker):
    def __init__(self, dataset_name: str, covariates_dimension_reducer: CovariatesDimensionReducer,
                 unscaled_full_x: torch.Tensor,
                 unscaled_full_z: torch.Tensor,
                 marginal_masking_ratio: float = 0.2):
        self.covariates_dimension_reducer = covariates_dimension_reducer
        self.dataset_name = dataset_name
        self.unscaled_full_z = unscaled_full_z.clone()
        self.data_masking_info = self.__compute_data_corruption_info(unscaled_full_x, unscaled_full_z,
                                                                     marginal_masking_ratio)
        self.marginal_masking_ratio = marginal_masking_ratio

    def get_group(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor, ):
        corruption_probabilities = self.get_corruption_probabilities(unscaled_x, unscaled_z)
        group = torch.zeros(len(unscaled_x), dtype=torch.int).to(unscaled_x.device)
        group[corruption_probabilities <= 1e-3] = 0
        group[corruption_probabilities > 1e-3] = 1
        return group

    def compute_vals_from_parameters(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor,
                                     data_corruption_info: DataCorruptionInfo):
        unscaled_z = unscaled_z.clone()
        unscaled_x = unscaled_x.clone()
        z = self.covariates_dimension_reducer(unscaled_x, unscaled_z)
        z -= data_corruption_info.min_val
        vals = z.clone()
        vals = vals / data_corruption_info.max_val1 * 2.5
        vals[vals < 0] = 0
        vals = 1 - torch.exp(-vals)
        vals[z < data_corruption_info.threshold] = 0
        vals[vals < 0] = 0

        return vals

    def get_corruption_probabilities(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor):

        vals = self.compute_vals_from_parameters(unscaled_x, unscaled_z, self.data_masking_info)
        z = self.covariates_dimension_reducer(unscaled_x.clone(), unscaled_z.clone())
        if self.data_masking_info.min_val < 0:
            z = z - self.data_masking_info.min_val
        high_val_idx = (z > self.data_masking_info.threshold)

        vals[high_val_idx] = (vals[high_val_idx] - self.data_masking_info.val_min) / (
                self.data_masking_info.val_max - self.data_masking_info.val_min)
        vals[high_val_idx] = vals[high_val_idx] * 0.2 + 0.7

        probability_to_delete = vals ** self.data_masking_info.power
        extreme_idx = z > self.data_masking_info.extreme_max_val
        if extreme_idx.int().sum():
            print(f"warning! got {extreme_idx.int().sum()} indexes with extreme_max_val")
        z_for_extreme_vals = z / self.data_masking_info.extreme_max_val
        z_for_extreme_vals[z_for_extreme_vals <= 1] = 1
        extreme_idx_new_value = 0.05 * (1 - 1 / z_for_extreme_vals) + 0.95
        probability_to_delete[extreme_idx] = extreme_idx_new_value[extreme_idx]
        return probability_to_delete

    def __compute_data_corruption_info(self, unscaled_full_x: torch.Tensor, unscaled_full_z: torch.Tensor,
                                       marginal_masking_ratio: float) -> DataCorruptionInfo:
        unscaled_full_z = unscaled_full_z.clone()
        unscaled_full_x = unscaled_full_x.clone()
        z = self.covariates_dimension_reducer(unscaled_full_x, unscaled_full_z)
        z_min = min(torch.quantile(z, q=0.05).item(), 0)
        if z_min < 0:
            z = z - z_min
        extreme_max_val = torch.quantile(z, q=0.9999).item() * 1.3
        max_val0 = torch.quantile(z, q=0.995).item()
        max_val1 = torch.quantile(z, q=0.95).item()
        max_val2 = torch.quantile(z, q=0.9).item()
        if abs(max_val1 - max_val2) < 0.001:
            print(f"warning: abs(max_val1 - max_val2) is too small: {abs(max_val1 - max_val2)}")

        threshold1 = torch.quantile(z, q=0.77).item()
        threshold2 = torch.quantile(z, q=0.6).item()
        threshold3 = torch.quantile(z, q=0.3).item()
        data_masking_info = DataCorruptionInfo(max_val0, max_val1, max_val2, threshold1, threshold2, threshold3, None,
                                               extreme_max_val, z_min)

        vals = self.compute_vals_from_parameters(unscaled_full_x, unscaled_full_z, data_masking_info)
        high_val_idx = (z > threshold1)
        data_masking_info.val_max = vals[high_val_idx].max()
        data_masking_info.val_min = vals[high_val_idx].min()
        #
        vals[high_val_idx] = (vals[high_val_idx] - data_masking_info.val_min) / (
                data_masking_info.val_max - data_masking_info.val_min)
        vals[high_val_idx] = vals[high_val_idx] * 0.2 + 0.7

        powers = torch.cat([torch.arange(0.01, 10, 0.01)]).to(vals.device)
        vals_rep = vals.unsqueeze(1).repeat(1, len(powers))
        masking_ratio_deviation_rep = (torch.pow(vals_rep, powers).mean(dim=0) - marginal_masking_ratio).abs()
        power_idx = masking_ratio_deviation_rep.argmin().item()
        power = powers[power_idx].item()
        data_masking_info.power = power
        masking_ratio_deviation = masking_ratio_deviation_rep[power_idx].item()
        if masking_ratio_deviation > 0.001:
            print(f"warning: the masking ratio deviation is too large: {np.round(masking_ratio_deviation, 4)}")

        return data_masking_info



class OvercoverageDataCorruptionMasker(DataCorruptionMasker):
    def __init__(self, dataset_name: str, covariates_dimension_reducer: CovariatesDimensionReducer,
                 unscaled_full_x: torch.Tensor,
                 unscaled_full_z: torch.Tensor,
                 marginal_masking_ratio: float = 0.2):
        self.covariates_dimension_reducer = covariates_dimension_reducer
        self.dataset_name = dataset_name
        self.unscaled_full_z = unscaled_full_z.clone()
        self.marginal_masking_ratio = marginal_masking_ratio
        self.default_data_corruption_masker = DefaultDataCorruptionMasker(dataset_name, covariates_dimension_reducer,
                                                                          unscaled_full_x,
                                                                          unscaled_full_z,
                                                                          marginal_masking_ratio)
        data_masking_info = self.default_data_corruption_masker.data_masking_info
        vals = self.compute_vals_from_parameters(unscaled_full_x, unscaled_full_z, data_masking_info)
        z = self.covariates_dimension_reducer(unscaled_full_x, unscaled_full_z)
        if data_masking_info.min_val < 0:
            z -= data_masking_info.min_val
        low_val_idx = z < data_masking_info.threshold3
        self.val_max = vals[low_val_idx].max()
        self.val_min = vals[low_val_idx].min()
        #
        vals[low_val_idx] = (vals[low_val_idx] - self.val_min) / (
                self.val_max - self.val_min)
        vals[low_val_idx] = vals[low_val_idx] * 0.2 + 0.7

        powers = torch.cat([torch.arange(0.01, 10, 0.01)]).to(vals.device)
        vals_rep = vals.unsqueeze(1).repeat(1, len(powers))
        masking_ratio_deviation_rep = (torch.pow(vals_rep, powers).mean(dim=0) - marginal_masking_ratio).abs()
        power_idx = masking_ratio_deviation_rep.argmin().item()
        power = powers[power_idx].item()
        self.power = power

    def compute_vals_from_parameters(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor,
                                     data_corruption_info: DataCorruptionInfo):
        unscaled_z = unscaled_z.clone()
        unscaled_x = unscaled_x.clone()
        z = self.covariates_dimension_reducer(unscaled_x, unscaled_z)
        z -= data_corruption_info.min_val
        vals = z.clone()
        vals = vals / data_corruption_info.max_val1 * 2.5
        vals[vals < 0] = 0
        vals = 1 - torch.exp(-vals)
        vals[z > data_corruption_info.threshold3] = 0
        vals[vals < 0] = 0

        return vals

    def get_corruption_probabilities(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor):
        data_masking_info = self.default_data_corruption_masker.data_masking_info
        vals = self.compute_vals_from_parameters(unscaled_x, unscaled_z, data_masking_info)
        z = self.covariates_dimension_reducer(unscaled_x.clone(), unscaled_z.clone())
        if data_masking_info.min_val < 0:
            z = z - data_masking_info.min_val
        low_val_idx = z < data_masking_info.threshold3

        vals[low_val_idx] = (vals[low_val_idx] - self.val_min) / (
                self.val_max - self.val_min)
        vals[low_val_idx] = vals[low_val_idx] * 0.2 + 0.7

        probability_to_delete = vals ** self.power
        extreme_idx = z > data_masking_info.extreme_max_val
        if extreme_idx.int().sum():
            print(f"warning! got {extreme_idx.int().sum()} indexes with extreme_max_val")
        z_for_extreme_vals = z / data_masking_info.extreme_max_val
        z_for_extreme_vals[z_for_extreme_vals <= 1] = 1
        extreme_idx_new_value = 0.05 * (1 - 1 / z_for_extreme_vals) + 0.95
        probability_to_delete[extreme_idx] = extreme_idx_new_value[extreme_idx]
        return probability_to_delete



class ComplexDataCorruptionMasker(DataCorruptionMasker):
    def __init__(self, dataset_name: str, covariates_dimension_reducer: CovariatesDimensionReducer,
                 unscaled_full_x: torch.Tensor,
                 unscaled_full_z: torch.Tensor,
                 marginal_masking_ratio: float = 0.2):
        self.covariates_dimension_reducer = covariates_dimension_reducer
        self.dataset_name = dataset_name
        self.unscaled_full_z = unscaled_full_z.clone()
        self.data_masking_info = self.__compute_data_corruption_info(unscaled_full_x, unscaled_full_z,
                                                                     marginal_masking_ratio)
        self.marginal_masking_ratio = marginal_masking_ratio

    def compute_vals_from_parameters(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor,
                                     data_corruption_info: DataCorruptionInfo):
        unscaled_z = unscaled_z.clone()
        unscaled_x = unscaled_x.clone()
        z = self.covariates_dimension_reducer(unscaled_x, unscaled_z)
        z -= data_corruption_info.min_val
        vals = z.clone()
        vals = vals / data_corruption_info.max_val1 * 2.5
        # vals[vals < 0] += 2
        vals[vals < 0] = 0
        vals2 = (torch.arctan(0.3*torch.sqrt(6*torch.sin(z) ** 2)) ** (1/3) - 0.8*torch.tanh(torch.cos(z ** 4))) / ((torch.sigmoid(z / 2)+2) / 2) + 0.5 + torch.sin(z**2 / 5) * torch.cos(z**4 / 8)
        t2 = 1.2 # torch.quantile(vals2, q=0.55)
        t3 = 0.5 # torch.quantile(vals2, q=0.55)
        vals[z < data_corruption_info.threshold] = 0
        vals[(vals2 < t2) & (z > data_corruption_info.threshold)] = 0
        vals = 1 - torch.exp(-vals)

        # vals[(vals2 < t3) & (z < data_corruption_info.threshold)] = 0.75
        # vals[(vals2 > t2) & (z < data_corruption_info.threshold)] *= 0.7



        return vals

    def get_corruption_probabilities(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor):

        vals = self.compute_vals_from_parameters(unscaled_x, unscaled_z, self.data_masking_info)
        z = self.covariates_dimension_reducer(unscaled_x.clone(), unscaled_z.clone())
        z = z - self.data_masking_info.min_val

        high_val_idx = (z >  self.data_masking_info.threshold)
        vals[high_val_idx] = (vals[high_val_idx] -  self.data_masking_info.val_min) / (
                self.data_masking_info.val_max -  self.data_masking_info.val_min)
        vals[high_val_idx] = vals[high_val_idx] * 0.8 + 0.05


        probability_to_delete = vals ** self.data_masking_info.power
        extreme_idx = z > self.data_masking_info.extreme_max_val
        if extreme_idx.int().sum():
            print(f"warning! got {extreme_idx.int().sum()} indexes with extreme_max_val")
        z_for_extreme_vals = z / self.data_masking_info.extreme_max_val
        z_for_extreme_vals[z_for_extreme_vals <= 1] = 1
        extreme_idx_new_value = 0.05 * (1 - 1 / z_for_extreme_vals) + 0.95
        probability_to_delete[extreme_idx] = extreme_idx_new_value[extreme_idx]
        # plt.scatter(z.cpu(), probability_to_delete.cpu())
        # plt.xlabel('z')
        # plt.ylabel('missing prob.')
        # plt.title("real probabilities")
        # plt.show()
        return probability_to_delete

    def __compute_data_corruption_info(self, unscaled_full_x: torch.Tensor, unscaled_full_z: torch.Tensor,
                                       marginal_masking_ratio: float) -> DataCorruptionInfo:
        unscaled_full_z = unscaled_full_z.clone()
        unscaled_full_x = unscaled_full_x.clone()
        z = self.covariates_dimension_reducer(unscaled_full_x, unscaled_full_z)
        z_min = min(torch.quantile(z, q=0.05).item(), 0)
        z = z - z_min
        extreme_max_val = torch.quantile(z, q=0.9999).item() * 1.3
        max_val0 = torch.quantile(z, q=0.995).item()
        max_val1 = torch.quantile(z, q=0.95).item()
        max_val2 = torch.quantile(z, q=0.9).item()
        if abs(max_val1 - max_val2) < 0.001:
            print(f"warning: abs(max_val1 - max_val2) is too small: {abs(max_val1 - max_val2)}")

        threshold1 = torch.quantile(z, q=0.5).item()
        threshold2 = torch.quantile(z, q=0.65).item()
        threshold3 = torch.quantile(z, q=0.3).item()
        data_masking_info = DataCorruptionInfo(max_val0, max_val1, max_val2, threshold1, threshold2, threshold3, None,
                                               extreme_max_val, z_min)

        vals = self.compute_vals_from_parameters(unscaled_full_x, unscaled_full_z, data_masking_info)
        high_val_idx = (z > threshold1)
        data_masking_info.val_max = vals[high_val_idx].max()
        data_masking_info.val_min = vals[high_val_idx].min()
        high_val_idx = (z > data_masking_info.threshold)
        vals[high_val_idx] = (vals[high_val_idx] - data_masking_info.val_min) / (
                data_masking_info.val_max - data_masking_info.val_min)
        vals[high_val_idx] = vals[high_val_idx]  * 0.8 + 0.05

        powers = torch.cat([torch.arange(0.01, 10, 0.01)]).to(vals.device)
        vals_rep = vals.unsqueeze(1).repeat(1, len(powers))
        masking_ratio_deviation_rep = (torch.pow(vals_rep, powers).mean(dim=0) - marginal_masking_ratio).abs()
        power_idx = masking_ratio_deviation_rep.argmin().item()
        power = powers[power_idx].item()
        data_masking_info.power = power
        masking_ratio_deviation = masking_ratio_deviation_rep[power_idx].item()
        if masking_ratio_deviation > 0.001:
            print(f"warning: the masking ratio deviation is too large: {np.round(masking_ratio_deviation, 4)}")

        return data_masking_info




class ResponseDataCorruptionMasker(DataCorruptionMasker):

    def __init__(self, base_corruption_masker: DataCorruptionMasker):
        self.base_corruption_masker = base_corruption_masker

    def get_corruption_probabilities(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor):
        return self.base_corruption_masker.get_corruption_probabilities(unscaled_x, unscaled_z)


class CovariateDataCorruptionMask(DataCorruptionMasker):

    def __init__(self,
                 unscaled_full_x: torch.Tensor,
                 unscaled_full_y: torch.Tensor,
                 base_corruption_masker: DataCorruptionMasker):
        super().__init__()
        self.base_corruption_masker = base_corruption_masker
        if len(unscaled_full_y.shape) == 2:
            unscaled_full_y = unscaled_full_y.mean(dim=-1)
        if unscaled_full_x.shape[1] < 5:
            raise Exception("cannot apply corruptions to feature vectors with dimension less than 5")
        x_y_correlation = [abs(np.corrcoef(unscaled_full_x[:, i], unscaled_full_y)[0, 1]) for i in
                           range(unscaled_full_x.shape[1])]
        for i in range(len(x_y_correlation)):
            if np.isnan(x_y_correlation[i]):
                x_y_correlation[i] = 0
        n_features_to_mask = int(0.2 * len(x_y_correlation))
        self.highest_correlation_x_idx = np.argsort(x_y_correlation)[::-1][:n_features_to_mask]
        assert x_y_correlation[self.highest_correlation_x_idx[0]] == np.max(x_y_correlation)

    def get_corruption_probabilities(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor):
        return self.base_corruption_masker.get_corruption_probabilities(unscaled_x, unscaled_z)

    def get_corruption_mask(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor, seed: int = 0):
        one_dimensional_mask = self.base_corruption_masker.get_corruption_mask(unscaled_x, unscaled_z, seed).squeeze()
        # mask = torch.zeros_like(unscaled_x).bool()
        sample_idx_mask = one_dimensional_mask.unsqueeze(1).repeat(1, unscaled_x.shape[1])
        feature_idx_mask = torch.zeros(unscaled_x.shape[1], dtype=torch.bool).to(unscaled_x.device)
        feature_idx_mask[self.highest_correlation_x_idx.tolist()] = True
        feature_idx_mask = feature_idx_mask.unsqueeze(0).repeat(unscaled_x.shape[0], 1)
        mask = feature_idx_mask & sample_idx_mask
        return mask


# class PCPFailCorruptionMasker(DataCorruptionMasker):
#
#     def __init__(self, dataset_name: str, covariates_dimension_reducer: CovariatesDimensionReducer,
#                  unscaled_full_x: torch.Tensor,
#                  unscaled_full_z: torch.Tensor,
#                  marginal_masking_ratio: float = 0.2):
#         self.covariates_dimension_reducer = covariates_dimension_reducer
#         self.dataset_name = dataset_name
#         self.unscaled_full_z = unscaled_full_z.clone()
#         self.marginal_masking_ratio = marginal_masking_ratio
#         self.default_data_corruption_masker = DefaultDataCorruptionMasker(dataset_name, covariates_dimension_reducer,
#                                                                           unscaled_full_x,
#                                                                           unscaled_full_z,
#                                                                           marginal_masking_ratio)
#
#
#     def get_corruption_probabilities(self, unscaled_x: torch.Tensor, unscaled_z: torch.Tensor):
#         data_masking_info = self.default_data_corruption_masker.data_masking_info
#         vals = self.default_data_corruption_masker.compute_vals_from_parameters(unscaled_x, unscaled_z, data_masking_info)
#         z = self.covariates_dimension_reducer(unscaled_x.clone(), unscaled_z.clone())
#         if data_masking_info.min_val < 0:
#             z = z - data_masking_info.min_val
#         high_val_idx = (z > data_masking_info.threshold)
#
#         vals[high_val_idx] = (vals[high_val_idx] - data_masking_info.val_min) / (
#                 data_masking_info.val_max - data_masking_info.val_min)
#         vals[high_val_idx] = vals[high_val_idx] * 0.2 + 0.7
#
#         probability_to_delete = vals ** data_masking_info.power
#         extreme_idx = z > data_masking_info.extreme_max_val
#         if extreme_idx.int().sum():
#             print(f"warning! got {extreme_idx.int().sum()} indexes with extreme_max_val")
#         z_for_extreme_vals = z / data_masking_info.extreme_max_val
#         z_for_extreme_vals[z_for_extreme_vals <= 1] = 1
#         extreme_idx_new_value = 0.05 * (1 - 1 / z_for_extreme_vals) + 0.95
#         probability_to_delete[extreme_idx] = extreme_idx_new_value[extreme_idx]
#
#         return probability_to_delete

class DataCorruptionIndicatorFactory:

    @staticmethod
    def get_corruption_masker(dataset_name: str, corruption_type: CorruptionType,
                              unscaled_full_x: torch.Tensor, unscaled_full_z: torch.Tensor,
                              unscaled_full_y: torch.Tensor,
                              covariates_reducer: CovariatesDimensionReducer = None) -> DataCorruptionMasker:
        if covariates_reducer is None:
            covariates_reducer = ZMeanReducer()
        if 'overcoverage' in dataset_name:
            base_corruption_masker = OvercoverageDataCorruptionMasker(dataset_name, covariates_reducer, unscaled_full_x,
                                                                 unscaled_full_z)
        elif 'pcp_fail' in dataset_name:
            base_corruption_masker = ComplexDataCorruptionMasker(dataset_name, covariates_reducer, unscaled_full_x,
                                                                 unscaled_full_z,
                                                                 marginal_masking_ratio=0.2)
        else:
            base_corruption_masker = DefaultDataCorruptionMasker(dataset_name, covariates_reducer, unscaled_full_x,
                                                                 unscaled_full_z)
        if corruption_type == CorruptionType.MISSING_X or corruption_type == CorruptionType.NOISED_X:
            data_masker = CovariateDataCorruptionMask(unscaled_full_x,
                                                      unscaled_full_y, base_corruption_masker)
        elif corruption_type == CorruptionType.MISSING_Y or corruption_type == CorruptionType.NOISED_Y or corruption_type == CorruptionType.DISPERSIVE_NOISED_Y:
            data_masker = ResponseDataCorruptionMasker(
                                                       base_corruption_masker)
        else:
            raise Exception(f"don't know what data masker to create for corruption type: {corruption_type.name}")
        return data_masker
