import abc

import numpy as np
import torch
from matplotlib import pyplot as plt
from sklearn.isotonic import IsotonicRegression

from models.ClassificationModel import ClassProbabilities
from models.abstract_models.AbstractModel import Model


class DataMaskEstimator(Model):
    def __init__(self, dataset_name: str, x_dim: int, z_dim: int, **kwargs):
        super().__init__()
        if x_dim == 0 and z_dim == 0:
            raise Exception("cannot handle z_dim==0 and x_dim==0")
        self.use_z = z_dim > 0
        self.use_x = x_dim > 0
        self.z_dim = z_dim
        self.dataset_name = dataset_name
        self.x_dim = x_dim
        self.new_dataset_name = f"{dataset_name}_mask_use_z={self.use_z}_use_x={self.use_x}"
        self.correction = 0
        self.iso_reg = None

    def fit(self, x_train, z_train, deleted_train, x_val, z_val, deleted_val, epochs=1000, batch_size=64, n_wait=20,
            **kwargs):
        pass

    def predict(self, x, z):
        estimated_mask_probabilities = self.forward(x, z).probabilities[:, 1]

        if "oracle" in self.name.lower():
            return estimated_mask_probabilities

        estimated_mask_probabilities = torch.clip(estimated_mask_probabilities, min=0, max=0.98)
        if (estimated_mask_probabilities > 0.95).float().sum().item() >= 1:
            print(f"warning: {self.name} estimated mask probability > 0.95. max estiamted: {estimated_mask_probabilities.max().item()} # > 0.95: {(estimated_mask_probabilities > 0.95).float().sum().item()}")

        if (estimated_mask_probabilities > 0.98).float().sum().item() >= 1:
            print(f"error: {self.name} estimated mask probability > 0.98. max estiamted: {estimated_mask_probabilities.max().item()} # > 0.98: {(estimated_mask_probabilities > 0.98).float().sum().item()}")
        return estimated_mask_probabilities

    def get_calibration_errors(self, x_cal, z_cal, deleted_cal):
        estimated_probabilities = self.forward(x_cal, z_cal).probabilities
        estimated_mask_probabilities = estimated_probabilities[:, 1]
        bin_edges = torch.histogram(estimated_mask_probabilities.detach().cpu(), bins=50, density=False).bin_edges.to(x_cal.device)
        calibration_errors = []
        for i in range(len(bin_edges) - 1):
            a, b = bin_edges[i], bin_edges[i + 1]
            bin_y = deleted_cal[(estimated_mask_probabilities < b) & (estimated_mask_probabilities > a)]
            if len(bin_y) < 8:
                # if len(bin_y) >= 3:
                #     print(f"warning: got into a small bin (size {len(bin_y)}) in calibration error computation")
                continue
            bin_p_hat = estimated_mask_probabilities[
                (estimated_mask_probabilities < b) & (estimated_mask_probabilities > a)]
            acc_bin = (bin_y == bin_p_hat.round()).float()
            bin_error = abs(bin_p_hat.mean().item() - acc_bin.mean().item())
            calibration_errors += [bin_error * len(bin_y) / len(x_cal)]
        estimated_marginal_probability = estimated_mask_probabilities.mean().item()
        real_marginal_probability = deleted_cal.float().mean().item()
        marginal_calibration_error = abs(real_marginal_probability - estimated_marginal_probability)
        nll = -torch.log(estimated_probabilities[range(len(x_cal)), deleted_cal.long()]).mean().item()
        return calibration_errors, marginal_calibration_error, nll

    def calibrate_prediction(self, estimated_mask_probabilities):
        if len(estimated_mask_probabilities.shape) == 0 or estimated_mask_probabilities.shape[0] == 0:
            return estimated_mask_probabilities
        if self.iso_reg is not None:
            return torch.Tensor(self.iso_reg.predict(estimated_mask_probabilities.detach().cpu().numpy())).to(
                estimated_mask_probabilities.device)
        else:
            return estimated_mask_probabilities.clone()

    def calibrate(self, x_cal, z_cal, deleted_cal, **kwargs):
        return
        # TODO: notice that calibration is removed
        # pass
        # estimated_mask_probabilities = self.forward(x_cal, z_cal).probabilities[:, 1]
        # self.iso_reg = IsotonicRegression(out_of_bounds='clip').fit(estimated_mask_probabilities.detach().cpu().numpy(),
        #                                         deleted_cal.detach().cpu().numpy())
        # self.correction = self.get_calibration_error(x_cal, z_cal, deleted_cal)
        # self.correction = 0

    @abc.abstractmethod
    def forward(self, x, z) -> ClassProbabilities:
        pass

    def compute_performance(self, x_test, z_test, full_y_test, deleted_test) -> dict:
        probability_estimate = self.predict(x_test, z_test)
        mask_estimate = probability_estimate.round()
        accuracy = ((mask_estimate - deleted_test.float()).abs() < 1e-2).float().mean().item() * 100
        deleted_accuracy = ((mask_estimate - deleted_test.float()).abs() < 1e-2)[
                               deleted_test].float().mean().item() * 100
        not_deleted_accuracy = ((mask_estimate - deleted_test.float()).abs() < 1e-2)[
                                   ~deleted_test].float().mean().item() * 100
        calibration_errors, marginal_error, nll = self.get_calibration_errors(x_test, z_test, deleted_test)
        if len(calibration_errors) == 0:
            calibration_errors = [torch.nan]
        if len(probability_estimate) == 0:
            probability_estimate = torch.Tensor([np.nan])
        return {
            f"data_masker_accuracy": accuracy,
            f"data_masker_deleted_accuracy": deleted_accuracy,
            f"data_masker_~deleted_accuracy": not_deleted_accuracy,
            f"data_masker_correction": self.correction,
            f"data_masker_marginal_error": marginal_error,
            f"data_masker_ECE": np.mean(calibration_errors),
            f"data_masker_MCE": np.max(calibration_errors),
            f"data_masker_nll": nll,
            f"data_masker_max_estimated_probability": probability_estimate.max().item(),
            f"data_masker_q99_estimated_probability": probability_estimate.quantile(q=0.99).item(),
            f"data_masker_q95_estimated_probability": probability_estimate.quantile(q=0.95).item(),
            f"data_masker_q9_estimated_probability": probability_estimate.quantile(q=0.9).item(),
            f"data_masker_q2_estimated_probability": probability_estimate.quantile(q=0.2).item(),
            f"data_masker_q1_estimated_probability": probability_estimate.quantile(q=0.1).item(),
        }

    @property
    def base_name(self) -> str:
        return f''