import abc

import numpy as np
import torch
from tqdm import tqdm

from imputation_methods.imputation_utils import construct_histogram
from models.imputation_classifiers.ClassificationProbabilityEstimatorWithProxy import \
    ClassificationProbabilityEstimatorWithProxy
from models.regressors.MeanRegressor import MeanRegressor

from imputation_methods.ImputationMethod import ImputationMethod, ClassificationImputationMethod
from models.ClassificationModel import ClassificationModel, ClassProbabilities


class ClassificationImputationWithConditionalErrors(ClassificationImputationMethod):


    def __init__(self, probabilities_estimator: ClassificationProbabilityEstimatorWithProxy):
        super().__init__()
        self.probabilities_estimator = probabilities_estimator
        self.label_probabilities = None
        self.higher_probabilities = None

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        self.probabilities_estimator.fit(x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs=1000,
                                         batch_size=64, n_wait=20,
                                         **kwargs)

    def estimate_probabilities(self, x: torch.Tensor, y: torch.Tensor) -> ClassProbabilities:
        estimated_probabilities = self.probabilities_estimator.estimate_probabilities(x, y).probabilities
        y1_bin_edges = self.bin_edges
        y1 = y[:, 0].squeeze()
        new_probabilities = estimated_probabilities.clone()
        for i in tqdm(range(len(y1_bin_edges) - 1)):
            low, high = y1_bin_edges[i], y1_bin_edges[i + 1]
            idx = (y1 >= low) & (y1 <= high)
            bin_size = len(self.label_probabilities[i])
            for j in torch.nonzero(idx):
                j = j.item()
                sampled_error_idx = np.random.randint(0, bin_size)
                curr_label_probability = self.label_probabilities[i][sampled_error_idx].item()
                curr_higher_probability = self.higher_probabilities[i][sampled_error_idx].item()
                curr_lower_probability = 1 - curr_label_probability - curr_higher_probability
                for k in torch.argsort(estimated_probabilities[j], descending=True):
                    k = k.item()
                    curr_probability = estimated_probabilities[j][k]
                    if curr_probability >= curr_higher_probability:
                        new_probabilities[j][k] -= curr_higher_probability
                        curr_higher_probability = 0
                    else:
                        curr_higher_probability -= new_probabilities[j][k]
                        new_probabilities[j][k] = 0
                for k in torch.argsort(estimated_probabilities[j], descending=False):
                    curr_probability = estimated_probabilities[j][k]
                    if curr_probability >= curr_lower_probability:
                        new_probabilities[j][k] -= curr_lower_probability
                        curr_lower_probability = 0
                    else:
                        curr_lower_probability -= new_probabilities[j][k]
                        new_probabilities[j][k] = 0

            # estimated_labels = estimated_probabilities.argmax(dim=0)
            # curr_transition_matrix = self.transition_matrices[i]
            # new_probabilities[idx] = curr_transition_matrix[estimated_labels]

        return ClassProbabilities(new_probabilities)

    def calibrate(self, x_cal, y_cal, deleted_cal):
        super().calibrate(x_cal, y_cal, deleted_cal)
        x_cal = x_cal[~deleted_cal]
        y_cal = y_cal[~deleted_cal]
        y1 = y_cal[:, 0].squeeze()
        y2 = y_cal[:, 1].squeeze()
        n_classes = y_cal.max().int().item() + 1
        estimated_probabilities = self.probabilities_estimator.estimate_probabilities(x_cal, y_cal).probabilities
        y2_one_hot = torch.zeros_like(estimated_probabilities)
        y2_one_hot[range(y2.shape[0]), y2.long()] = 1
        self.bin_edges = construct_histogram(y1)
        y1_bin_edges = self.bin_edges
        self.label_probabilities = []
        self.higher_probabilities = []
        for i in range(len(y1_bin_edges) - 1):
            low, high = y1_bin_edges[i], y1_bin_edges[i + 1]
            idx = (y1 >= low) & (y1 <= high)
            curr_y2_one_hot = y2_one_hot[idx].long()
            bin_size = len(curr_y2_one_hot)
            curr_estimated_probabilities = estimated_probabilities[idx]
            label_probability = curr_estimated_probabilities[range(bin_size), curr_y2_one_hot.argmax(dim=-1)]
            higher_probabilities_idx = curr_estimated_probabilities > label_probability.unsqueeze(1).repeat(1, n_classes)
            curr_estimated_probabilities_masked = curr_estimated_probabilities.clone()
            curr_estimated_probabilities_masked[~higher_probabilities_idx] = 0
            higher_probabilities = curr_estimated_probabilities_masked.sum(dim=-1)
            self.label_probabilities.append(label_probability)
            self.higher_probabilities.append(higher_probabilities)


    @property
    def name(self):
        return f"{self.probabilities_estimator.name}_with_conditional_errors"


