import abc

import numpy as np
import torch

from imputation_methods.imputation_utils import construct_histogram
from models.regressors.MeanRegressor import MeanRegressor

from imputation_methods.ImputationMethod import ImputationMethod, ClassificationImputationMethod
from models.ClassificationModel import ClassificationModel, ClassProbabilities

from models.imputation_classifiers.ClassificationProbabilityEstimatorWithProxy import \
    ClassificationProbabilityEstimatorWithProxy


class ClassificationImputationWithConditionalTransitions(ClassificationImputationMethod):

    def __init__(self, probabilities_estimator: ClassificationProbabilityEstimatorWithProxy):
        super().__init__()
        self.probabilities_estimator = probabilities_estimator
        self.transition_matrices = None

    def fit(self, x_train, y_train, z_train, deleted_train, x_val, y_val, z_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        self.probabilities_estimator.fit(x_train, y_train, deleted_train, x_val, y_val, deleted_val, epochs=1000,
                                         batch_size=64, n_wait=20,
                                         **kwargs)

    def estimate_probabilities(self, x: torch.Tensor, y: torch.Tensor) -> ClassProbabilities:
        estimated_probabilities = self.probabilities_estimator.estimate_probabilities(x, y).probabilities
        y1_bin_edges = self.bin_edges
        y1 = y[:, 0].squeeze()
        new_probabilities = estimated_probabilities.clone()
        for i in range(len(y1_bin_edges) - 1):
            low, high = y1_bin_edges[i], y1_bin_edges[i + 1]
            idx = (y1 >= low) & (y1 <= high)
            if idx.sum().int().item() == 0:
                continue
            estimated_labels = estimated_probabilities[idx].argmax(dim=-1)
            curr_transition_matrix = self.transition_matrices[i]
            new_probabilities[idx] = curr_transition_matrix[estimated_labels.long()]

        return ClassProbabilities(new_probabilities)

    def calibrate(self, x_cal, y_cal, deleted_cal):
        super().calibrate(x_cal, y_cal, deleted_cal)
        x_cal = x_cal[~deleted_cal]
        y_cal = y_cal[~deleted_cal]
        y1 = y_cal[:, 0].squeeze()
        y2 = y_cal[:, 1].squeeze()
        n_classes = y_cal.max().int().item() + 1
        estimated_probabilities = self.probabilities_estimator.estimate_probabilities(x_cal, y_cal).probabilities
        self.bin_edges = construct_histogram(y1, min_bin_size=100)
        y1_bin_edges = self.bin_edges
        self.transition_matrices = []
        for i in range(len(y1_bin_edges) - 1):
            low, high = y1_bin_edges[i], y1_bin_edges[i + 1]
            idx = (y1 >= low) & (y1 <= high)
            estimated_labels = estimated_probabilities[idx].argmax(dim=-1)
            curr_transition_matrix = torch.zeros(n_classes, n_classes, device=x_cal.device)
            for pred_label in range(n_classes):
                pred_label_count = (estimated_labels == pred_label).float().sum()
                if pred_label_count == 0:
                    curr_transition_matrix[pred_label, pred_label] = 1
                else:
                    for real_label in range(n_classes):
                        curr_transition_matrix[pred_label, real_label] = ((y2[idx] == real_label) & (estimated_labels == pred_label)).float().sum() / pred_label_count
            # P(real | pred) = # pred with real / total pred
            self.transition_matrices.append(curr_transition_matrix)

    @property
    def name(self):
        return f"{self.probabilities_estimator.name}_with_conditional_transition_matrix"
