import torch

from imputation_methods.ImputationMethod import ImputationMethod, ClassificationImputationMethod
from models.ClassificationModel import ClassificationModel, ClassProbabilities
from models.imputation_classifiers.ClassificationProbabilityEstimatorWithProxy import \
    ClassificationProbabilityEstimatorWithProxy


class Top1Imputation(ClassificationImputationMethod):

    def __init__(self, probabilities_estimator: ClassificationProbabilityEstimatorWithProxy):
        super().__init__()
        self.probabilities_estimator = probabilities_estimator

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        self.probabilities_estimator.fit(x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs, batch_size, n_wait,
                                **kwargs)

    @property
    def name(self):
        return f"{self.probabilities_estimator.name}_top1"

    def estimate_probabilities(self, x, y) -> ClassProbabilities:
        probabilities = self.probabilities_estimator.estimate_probabilities(x, y).probabilities
        new_probabilities = torch.zeros_like(probabilities)
        new_probabilities[range(len(probabilities)), probabilities.argmax(dim=-1)] = 1
        return ClassProbabilities(new_probabilities)

    def calibrate(self, x_cal, y_cal, deleted_cal):
        pass
