import torch

from imputation_methods.ImputationMethod import ImputationMethod, ClassificationImputationMethod
from models.ClassificationModel import ClassificationModel, ClassProbabilities
from models.imputation_classifiers.ClassificationProbabilityEstimatorWithProxy import \
    ClassificationProbabilityEstimatorWithProxy


class RandomImputator(ClassificationImputationMethod):


    def __init__(self, probabilities_estimator: ClassificationProbabilityEstimatorWithProxy):
        super().__init__()
        self.probabilities_estimator = probabilities_estimator

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val,epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        self.probabilities_estimator.fit(x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs, batch_size, n_wait,
                                **kwargs)

    @property
    def name(self):
        return f"{self.probabilities_estimator.name}_draw"

    def calibrate(self, x_cal, y_cal, deleted_cal):
        pass

    def estimate_probabilities(self, x: torch.Tensor, y: torch.Tensor) -> ClassProbabilities:
        return self.probabilities_estimator.estimate_probabilities(x, y)
