import abc

import numpy as np
import torch

from data_utils.data_corruption.data_corruption_masker import DataCorruptionMasker
from data_utils.data_scaler import DataScaler
from models.ClassificationModel import ClassProbabilities
from models.data_mask_estimators.DataMaskEstimator import DataMaskEstimator
from models.data_mask_estimators.OracleDataMasker import OracleDataMasker


class OracleDataMaskerWithDelta(DataMaskEstimator):

    def __init__(self, data_scaler: DataScaler, data_masker: DataCorruptionMasker, dataset_name: str, x_dim: int, z_dim: int, delta: float=1):
        super().__init__( dataset_name, x_dim,  z_dim)
        self.oracle_model = OracleDataMasker(data_scaler, data_masker, dataset_name, x_dim, z_dim)
        self.delta : float = delta

    def forward(self, x, z) -> ClassProbabilities:
        return self.oracle_model.forward(x, z)
        # probabilities = self.oracle_model.forward(x, z).probabilities
        # new_probabilities = probabilities
        # return ClassProbabilities(new_probabilities)

    @property
    def name(self) -> str:
        return f"{self.base_name}oracle_with_delta={np.round(self.delta, 4)}"
