import abc

import numpy as np
import torch

from imputation_methods.imputation_utils import construct_histogram
from models.imputation_classifiers.ClassificationProbabilityEstimatorWithProxy import \
    ClassificationProbabilityEstimatorWithProxy
from models.regressors.MeanRegressor import MeanRegressor

from imputation_methods.ImputationMethod import ImputationMethod, ClassificationImputationMethod
from models.ClassificationModel import ClassificationModel, ClassProbabilities


class ConditionalUncertainClassificationImputator(ClassificationImputationMethod):

    def __init__(self, probabilities_estimator: ClassificationProbabilityEstimatorWithProxy):
        super().__init__()
        self.probabilities_estimator = probabilities_estimator
        self.betas = None
        self.bin_edges = None

    @property
    def name(self):
        return f"uncertain_{self.probabilities_estimator.name}"

    def estimate_probabilities(self, x: torch.Tensor, y: torch.Tensor) -> ClassProbabilities:
        probabilities = self.probabilities_estimator.estimate_probabilities(x,y).probabilities
        n_labels = probabilities.shape[-1]
        y1_bin_edges = self.bin_edges
        y1 = y[:, 0].squeeze()
        new_probabilities = probabilities.clone()
        for i in range(len(y1_bin_edges) - 1):
            low, high = y1_bin_edges[i], y1_bin_edges[i + 1]
            idx = (y1 >= low) & (y1 <= high)
            error_term = self.betas[i]
            new_probabilities[idx] = (1 - error_term) * probabilities[idx] + error_term * (1 / n_labels)

        return ClassProbabilities(new_probabilities)

    def calibrate(self, x_cal, y_cal, deleted_cal):
        super().calibrate(x_cal, y_cal, deleted_cal)
        x_cal = x_cal[~deleted_cal]
        y_cal = y_cal[~deleted_cal]
        y1 = y_cal[:, 0].squeeze()
        y2 = y_cal[:, 1].squeeze()
        estimated_probabilities = self.probabilities_estimator.estimate_probabilities(x_cal, y_cal).probabilities
        n_classes = estimated_probabilities.shape[-1]
        y2_one_hot = torch.zeros_like(estimated_probabilities)
        y2_one_hot[range(y2.shape[0]), y2.long()] = 1
        self.bin_edges = construct_histogram(y1)
        y1_bin_edges = self.bin_edges
        self.betas = []
        for i in range(len(y1_bin_edges) - 1):
            low, high = y1_bin_edges[i], y1_bin_edges[i + 1]
            idx = (y1 >= low) & (y1 <= high)
            real_conditional_probabilities = y2_one_hot[idx].mean(dim=0)
            real_uncertainty_level = (real_conditional_probabilities - 1 / n_classes).abs()
            estimated_conditional_probabilities = estimated_probabilities[idx].mean(dim=0)
            estimated_uncertainty_level = (estimated_conditional_probabilities - 1 / n_classes).abs()
            self.betas += [torch.max(1 - real_uncertainty_level / estimated_uncertainty_level).item()]

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        self.probabilities_estimator.fit(x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs, batch_size, n_wait,
                                **kwargs)