import torch

from imputation_methods.ImputationMethod import ImputationMethod, ClassificationImputationMethod
from models.ClassificationModel import ClassificationModel, ClassProbabilities
from models.imputation_classifiers.ClassificationProbabilityEstimatorWithProxy import \
    ClassificationProbabilityEstimatorWithProxy


class RandomImputatorWithRF(ClassificationImputationMethod):

    def __init__(self, probabilities_estimator: ClassificationProbabilityEstimatorWithProxy):
        super().__init__()
        self.probabilities_estimator = probabilities_estimator

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        self.probabilities_estimator.fit(x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs, batch_size, n_wait,
                            **kwargs)

    @property
    def name(self):
        return f"{self.probabilities_estimator.name}_draw_with_rf"


    def calibrate(self, x_cal, y_cal, deleted_cal):
        pass

    def estimate_probabilities(self, x: torch.Tensor, y: torch.Tensor) -> ClassProbabilities:
        probabilities = self.probabilities_estimator.estimate_probabilities(x, y).probabilities
        n_classes = probabilities.shape[-1]
        uncertain_prediction = torch.ones(n_classes, device=x.device) / n_classes
        epsilon = 0.4
        probabilities = probabilities * (1 - epsilon) + epsilon * uncertain_prediction
        return ClassProbabilities(probabilities)
