import abc

import torch

from models.ClassificationModel import ClassProbabilities
from utils.utils import compute_ece, compute_model_certainty, compute_average_entropy, compute_entropy_deleted_correlation


class ImputationMethod(abc.ABC):
    def __init__(self):
        self.calibrated_count = 0
        self.fit_count = 0

    @property
    @abc.abstractmethod
    def name(self):
        pass

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        if self.fit_count > 0:
            print(
                f"warning: {self.name} network learning model was fitted {self.fit_count} times already and is now fitted once again.")
        self.fit_count += 1

    @abc.abstractmethod
    def predict(self, x, z):
        pass

    @abc.abstractmethod
    def calibrate(self, x_cal, y_cal, z_cal, deleted_cal):
        if self.fit_count == 0:
            print(f"warning: {self.name} imputation is calibrated without being fit")
        if self.calibrated_count > 0:
            print(f"warning: {self.name} imputation was calibrated already calibrated {self.calibrated_count} times and is now called again")
        self.calibrated_count += 1
        pass

    def compute_performance(self, x_test, y_test, z_test, full_y_test, deleted_test, test_calibrated_intervals):
        return {}


class ClassificationImputationMethod(ImputationMethod, abc.ABC):
    def __init__(self):
        super().__init__()

    @abc.abstractmethod
    def estimate_probabilities(self, x: torch.Tensor, y: torch.Tensor) -> ClassProbabilities:
        pass

    def predict(self, x, y):
        estimated_probabilities = self.estimate_probabilities(x, y).probabilities
        return torch.multinomial(estimated_probabilities, 1).squeeze()


    def compute_performance(self, x, y, full_y, deleted, model_prediction : ClassProbabilities):
        probabilities = self.estimate_probabilities(x, y).probabilities
        y2_ece = compute_ece(y[:, 1], probabilities)
        not_deleted_y2_ece = compute_ece(full_y[~deleted, 1], probabilities[~deleted])
        deleted_y2_ece = compute_ece(full_y[deleted, 1], probabilities[deleted])
        full_y2_ece = compute_ece(full_y[:, 1], probabilities)

        # y2_oce = compute_oce(y[:, 1], probabilities)
        # not_deleted_y2_oce = compute_oce(full_y[~deleted, 1], probabilities[~deleted])
        # deleted_y2_oce = compute_oce(full_y[deleted, 1], probabilities[deleted])
        # full_y2_oce = compute_oce(full_y[:, 1], probabilities)

        entropy = compute_average_entropy(probabilities)
        not_deleted_entropy = compute_average_entropy(probabilities[~deleted])
        deleted_entropy = compute_average_entropy(probabilities[deleted])

        entropy_deleted_correlation = compute_entropy_deleted_correlation(probabilities, deleted)

        model_certainty_level = compute_model_certainty(probabilities)
        not_deleted_model_certainty_level = compute_model_certainty(probabilities[~deleted])
        deleted_model_certainty_level = compute_model_certainty(probabilities[deleted])
        return {
            'y2 ece': y2_ece,
            '~deleted y2 ece': not_deleted_y2_ece,
            'deleted y2 ece': deleted_y2_ece,
            'full y2 ece': full_y2_ece,

            # 'y2 oce': y2_oce,
            # '~deleted y2 oce': not_deleted_y2_oce,
            # 'deleted y2 oce': deleted_y2_oce,
            # 'full y2 oce': full_y2_oce,

            'entropy': entropy,
            '~deleted entropy': not_deleted_entropy,
            'deleted entropy': deleted_entropy,

            'entropy deleted correlation': entropy_deleted_correlation,

            'model certainty level': model_certainty_level,
            '~deleted model certainty level': not_deleted_model_certainty_level,
            'deleted model certainty level': deleted_model_certainty_level,
        }

