"""Implement the post training methods."""

from calibration import HistogramCalibrator, PlattBinnerCalibrator
import glest
from glest import Partitioner
from abc import ABC, abstractmethod
from sklearn.isotonic import IsotonicRegression
from sklearn.calibration import _SigmoidCalibration
import numpy as np
from sklearn.base import clone
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import HistGradientBoostingClassifier, RandomForestClassifier
from sklearn.naive_bayes import GaussianNB
from sklearn.gaussian_process import GaussianProcessClassifier
from sklearn.model_selection import train_test_split
from decision.xp.regrets import compute_regret_GL_UB
from sklearn.tree import (
    DecisionTreeRegressor,
)
# from metacal.metacal import MetaCalCoverageAcc, MetaCalMisCoverage
from glest.helpers import calibration_curve as glest_calibration_curve


class PostTraining(ABC):
    """Base class for post training methods."""

    @abstractmethod
    def fit(self, S, y, X=None):
        pass

    @abstractmethod
    def predict_proba(self, S, X=None):
        pass


class Identity(PostTraining):
    """Identity post training method."""

    def fit(self, S, y, X=None):
        return self

    def predict_proba(self, S, X=None):
        return S


class SklearnRecalibration(PostTraining):
    """Recalibration post training method using sklearn."""

    def __init__(self, method: str = "isotonic"):
        super().__init__()
        self.method = method
        assert method in ["isotonic", "sigmoid"]

    def fit(self, S, y, X=None):
        if self.method == "isotonic":
            calibrator = IsotonicRegression(out_of_bounds="clip")
        elif self.method == "sigmoid":
            calibrator = _SigmoidCalibration()

        calibrator.fit(S, y)
        self.calibrator_ = calibrator
        return self

    def predict_proba(self, S, X=None):
        return self.calibrator_.predict(S)


class PlattRecalibration(PostTraining):
    """Recalibration post training method using uncertainty calibration."""

    def __init__(self):
        super().__init__()

    def fit(self, S, y, X=None):
        n = len(S)
        self.calibrator_ = PlattBinnerCalibrator(n, None)
        self.calibrator_.train_calibration(S, y)
        return self

    def predict_proba(self, S, X=None):
        return self.calibrator_.calibrate(S)


class PlattBinnerRecalibration(PostTraining):
    """Recalibration post training method using uncertainty calibration."""

    def __init__(self, n_bins: int = 10):
        super().__init__()
        self.n_bins = n_bins

    def fit(self, S, y, X=None):
        n = len(S)
        self.calibrator_ = PlattBinnerCalibrator(n, self.n_bins)
        self.calibrator_.train_calibration(S, y)
        return self

    def predict_proba(self, S, X=None):
        return self.calibrator_.calibrate(S)


class HistogramBinningRecalibration(PostTraining):
    """Recalibration post training method using uncertainty calibration."""

    def __init__(self, n_bins: int = 10):
        super().__init__()
        self.n_bins = n_bins

    def fit(self, S, y, X=None):
        n = len(S)
        self.calibrator_ = HistogramCalibrator(n, self.n_bins)
        self.calibrator_.train_calibration(S, y)
        return self

    def predict_proba(self, S, X=None):
        return self.calibrator_.calibrate(S)


class Stacking(PostTraining):
    def __init__(self, est) -> None:
        super().__init__()
        self.est = est

    def fit(self, S, y, X):
        self.est_ = clone(self.est)
        Xs = np.concatenate([S[:, None], X], axis=1)
        self.est_.fit(Xs, y)
        return self

    def predict_proba(self, S, X):
        Xs = np.concatenate([S[:, None], X], axis=1)
        return self.est_.predict_proba(Xs)[:, 1]


class MetaCalAccRecalibration(PostTraining):
    """Recalibration post training method using uncertainty calibration."""

    def __init__(self, target_acc: float = 0.9):
        super().__init__()
        self.target_acc = target_acc

    def fit(self, S, y, X=None):
        if S.ndim == 1:
            S = np.concatenate([1 - S[:, None], S[:, None]], axis=1)
        self.calibrator_ = MetaCalCoverageAcc(self.target_acc)
        self.calibrator_.fit(S, y)
        return self

    def predict_proba(self, S, X=None):
        if S.ndim == 1:
            S = np.concatenate([1 - S[:, None], S[:, None]], axis=1)
        return self.calibrator_.predict(S)[:, 1]


class MetaCalMisRecalibration(PostTraining):
    """Recalibration post training method using uncertainty calibration."""

    def __init__(self, target_alpha: float = 0.05):
        super().__init__()
        self.target_alpha = target_alpha

    def fit(self, S, y, X=None):
        if S.ndim == 1:
            S = np.concatenate([1 - S[:, None], S[:, None]], axis=1)
        self.calibrator_ = MetaCalMisCoverage(self.target_alpha)
        self.calibrator_.fit(S, y)
        return self

    def predict_proba(self, S, X=None):
        if S.ndim == 1:
            S = np.concatenate([1 - S[:, None], S[:, None]], axis=1)
        return self.calibrator_.predict(S)[:, 1]


class SigmoidFineTuning(PostTraining):
    """For this post training method, it is assumed that the link between
    X and S is a sigmoid function in the original model (e.g. last layer of a
    neural network). This method "fine-tunes" the model by relearning the
    sigmoid function. It is an alternative to a pytorch-based approach."""

    def fit(self, S, y, X):
        self.sigmoid_ = LogisticRegression()
        self.sigmoid_.fit(X, y)
        return self

    def predict_proba(self, S, X):
        return self.sigmoid_.predict_proba(X)[:, 1]


class GLAR(PostTraining):
    """The Grouping Loss Adaptative Recalibration defined in the paper."""

    def __init__(
        self,
        partitioner: glest.Partitioner,
        train_size: float = 0.5,
        random_state: int | None = None,
    ) -> None:
        super().__init__()
        self.partitioner = partitioner
        self.train_size = train_size
        self.random_state = random_state

    def fit(self, S, y, X, val_data=None):
        (S_train, y_train, X_train), (S_val, y_val, X_val) = self.get_train_val_data(
            S, y, X, val_data=val_data
        )

        # Fit the partitioner on the training set
        self.partitioner.fit(X_train, S_train, y_train)

        # Eval the local averages on the validation set
        labels_val = self.partitioner.predict(X_val, S_val)
        [frac_pos_val], _ = self.partitioner.evaluate(S_val, labels_val, [y_val])

        # Store the evaluation results for future prediction
        self.frac_pos_val_ = frac_pos_val

        return self

    def get_train_val_data(self, S, y, X, val_data=None):
        if val_data is None:
            X_train, X_val, S_train, S_val, y_train, y_val = train_test_split(
                X, S, y, train_size=self.train_size, random_state=self.random_state
            )
        else:
            X_train = X
            S_train = S
            y_train = y
            S_val, y_val, X_val = val_data

        return (S_train, y_train, X_train), (S_val, y_val, X_val)

    def predict_proba(self, S, X):
        labels_test = self.partitioner.predict(X, S)
        return self.labels_to_probas(labels_test)

    def labels_to_probas(self, labels):
        return self.frac_pos_to_probas(self.frac_pos_val_, labels)

    @staticmethod
    def frac_pos_to_probas(frac_pos, labels):
        return frac_pos[labels[:, 0], labels[:, 1]]


class GLARThresholded(GLAR):
    def __init__(
        self,
        partitioner: Partitioner,
        train_size: float = 0.5,
        random_state: int | None = None,
        GL_min: float | None = 0,
        GL_min_bin: float | None = None,
        mode: bool = "all",
    ) -> None:
        super().__init__(partitioner, train_size, random_state)
        self.GL_min = GL_min
        self.GL_min_bin = GL_min if GL_min_bin is None else GL_min_bin
        self.recalibrator = SklearnRecalibration(method="isotonic")
        self.mode = mode
        assert mode in ["all", "binwise", "both"]

    def fit(self, S, y, X, val_data: tuple = None):
        # self.recalibrator.fit(S, y, X)
        # super().fit(S, y, X, val_data)

        (S_train, y_train, X_train), (S_val, y_val, X_val) = self.get_train_val_data(
            S, y, X, val_data=val_data
        )

        # Fit recalibrator on the whole data
        if val_data is None:
            self.recalibrator.fit(S, y, X)
        else:
            self.recalibrator.fit(
                np.concatenate([S_train, S_val]),
                np.concatenate([y_train, y_val]),
                np.concatenate([X_train, X_val]),
            )

        # Fit the partitioner on the training set
        self.partitioner.fit(X_train, S_train, y_train)

        # Eval the local averages on the validation set
        labels_val = self.partitioner.predict(X_val, S_val)
        [frac_pos_val], counts_val = self.partitioner.evaluate(
            S_val, labels_val, [y_val], aggregate="mean"
        )
        [var_val], _ = self.partitioner.evaluate(
            S_val, labels_val, [y_val], aggregate="var"
        )

        frac_pos_bin_val = glest_calibration_curve(
            frac_pos_val, counts_val, None, remove_empty=False, return_mean_bins=False
        )

        # labels_bin_val = np.stack([labels_val[:, 0], labels_val[:, 0]], axis=1)
        # [frac_pos_bin_val], counts_bin_val = self.partitioner.evaluate(
        #     S_val, labels_bin_val, [y_val], aggregate="mean"
        # )
        # [var_bin_val], _ = self.partitioner.evaluate(
        #     S_val, labels_bin_val, [y_val], aggregate="var"
        # )

        self.frac_pos_ = frac_pos_val
        self.frac_pos_bin_ = frac_pos_bin_val
        self.var_ = var_val
        # self.var_bin_ = var_bin_val
        self.counts_ = counts_val
        # self.counts_bin_ = counts_bin_val
        self.frac_pos_shrinked_ = self.shrink_frac_pos(
            self.frac_pos_, self.frac_pos_bin_, self.var_, self.counts_
        )

        # Estimate GL based on the fitted partitioner
        self.gle_ = glest.GLEstimator(
            S,
            self.partitioner,
            use_fitted=True,
            train_size=self.train_size,
            random_state=self.random_state,
        )
        if val_data is not None:
            test_data = (val_data[2], val_data[1], val_data[0])
        else:
            test_data = val_data
        self.gle_.fit(X, y, test_data=test_data)

        return self

    def predict_proba(self, S, X, t: np.ndarray):
        # GL_bin = self.gle_.GL(reduce_bin=False)
        # GL = self.gle_.GL(reduce_bin=True)
        # C = self.frac_pos_bin_
        # self.RGL_ = compute_regret_GL_UB(C, GL_bin, t)  # (n_bins, k)

        # # R = np.full(S.shape, np.nan)
        # # C = np.full(S.shape, np.nan)
        # M = np.full(S.shape, False)

        # # print(self.frac_pos_shrinked_ - self.frac_pos_)

        # if self.mode == "all":
        #     M = self.RGL_ > self.GL_min

        labels = self.partitioner.predict(X, S)
        R = self.frac_pos_to_probas(self.frac_pos_shrinked_, labels)  # (n,)
        C = self.recalibrator.predict_proba(S)  # (n,)
        M = self.mask_probas(labels[:, 0], t)  # (n, k)

        assert not np.isnan(R).any()
        assert not np.isnan(C).any()
        assert not np.isnan(M).any()

        # return R where M is true else C
        return np.where(M.T, R, C).T  # (n, k)

        # # for i in range(self.partitioner.get_n_bins()):
        # #     bin_idx = labels[:, 0] == i
        # #     probas_bin = self.frac_pos_to_probas(
        # #         self.frac_pos_shrinked_, labels[bin_idx]
        # #     )

        # #     # if (
        # #     #     self.GL_min is None
        # #     #     or (self.mode == "all" and GL > self.GL_min)
        # #     #     or (self.mode == "binwise" and GL_bin[i] > self.GL_min)
        # #     #     or (
        # #     #         self.mode == "both" and GL > self.GL_min and GL_bin[i] > self.GL_min
        # #     #     )
        # #     # ):
        # #     R[bin_idx] = probas_bin
        # #     # else:
        # #     #     # R[bin_idx] = np.mean(probas_bin)
        # #     C[bin_idx] = self.recalibrator.predict_proba(S[bin_idx])

        # # Build mask where we should correct and use R instead of C

        # assert not np.isnan(R).any()
        # assert not np.isnan(C).any()

        # return R

    def mask_probas(self, bin_ids: np.ndarray, t: np.ndarray):
        GL_bin = self.gle_.GL(reduce_bin=False)
        bin_weights = self.counts_.sum(axis=1)
        GL = np.average(GL_bin, weights=bin_weights)
        # GL = self.gle_.GL(reduce_bin=True)
        C = self.frac_pos_bin_
        self.RGL_ = compute_regret_GL_UB(C, GL_bin, t)  # (n_bins, k)

        # R = np.full(S.shape, np.nan)
        # C = np.full(S.shape, np.nan)
        M = np.full(bin_ids.shape, False)

        # print(self.frac_pos_shrinked_ - self.frac_pos_)

        if self.mode == "binwise":
            M = (self.RGL_ > self.GL_min_bin)[bin_ids, :]
        elif self.mode == "all":
            RGL_agg = np.average(self.RGL_, axis=0, weights=bin_weights)
            M = (RGL_agg > self.GL_min)[bin_ids, :]
        elif self.mode == "both":
            RGL_agg = np.average(self.RGL_, axis=0, weights=bin_weights)
            idx = np.logical_and(self.RGL_ > self.GL_min_bin, RGL_agg > self.GL_min)
            M = idx[bin_ids, :]

        return M  # (n, k)

    @staticmethod
    def shrink_frac_pos(
        frac_pos: np.ndarray,
        frac_pos_bin: np.ndarray,
        var: np.ndarray,
        counts: np.ndarray,
    ):
        var_between = GLARThresholded.weighted_var(frac_pos, counts, axis=1)

        # weights = np.divide(
        #     var_between[:, None], var_between[:, None] + np.divide(var, counts)
        # )

        # weights = 0.5
        weights = 1

        # print(weights)

        shrinked_frac_pos = weights * frac_pos + (1 - weights) * frac_pos_bin[:, None]
        return shrinked_frac_pos

    @staticmethod
    def weighted_var(a, weights, axis=0):
        mean = np.average(a, weights=weights, axis=axis, keepdims=True)
        n = a.shape[axis]
        correction = n / (n - 1)
        return correction * np.average(np.square(a - mean), axis=axis)


class FineTuning(PostTraining):
    """Placeholder class for fine tuning the model."""

    def fit(self, S, y, X):
        pass

    def predict_proba(self, S, X):
        pass


class PartitionerDict:
    def __init__(self) -> None:
        rs = 0
        verbose = 10
        raise_on_fit_error = True

        self.partitioners = {
            "leaves5": glest.Partitioner(
                estimator=DecisionTreeRegressor(
                    random_state=rs,
                    max_leaf_nodes=5,
                ),
                predict_method="apply",
                n_bins=15,
                strategy="quantile",
                binwise_fit=True,
                verbose=verbose,
                raise_on_fit_error=raise_on_fit_error,
            ),
            "unconstrained": glest.Partitioner(
                estimator=DecisionTreeRegressor(random_state=rs),
                predict_method="apply",
                n_bins=15,
                strategy="quantile",
                binwise_fit=True,
                verbose=verbose,
                raise_on_fit_error=raise_on_fit_error,
            ),
            "depth5": glest.Partitioner(
                estimator=DecisionTreeRegressor(
                    random_state=rs,
                    min_samples_leaf=100,
                ),
                predict_method="apply",
                n_bins=15,
                strategy="quantile",
                binwise_fit=True,
                verbose=verbose,
                raise_on_fit_error=raise_on_fit_error,
            ),
            "depth10": glest.Partitioner(
                estimator=DecisionTreeRegressor(
                    random_state=rs,
                    max_depth=10,
                ),
                predict_method="apply",
                n_bins=15,
                strategy="quantile",
                binwise_fit=True,
                verbose=verbose,
                raise_on_fit_error=raise_on_fit_error,
            ),
        }

    def __getitem__(self, name: str) -> glest.Partitioner:
        return self.partitioners[name]


class PostTrainingDict:
    def __init__(self) -> None:
        rs = 0

        self.post_training = {
            "none": Identity(),
            "recal_isotonic": SklearnRecalibration(method="isotonic"),
            "recal_sigmoid": SklearnRecalibration(method="sigmoid"),
            "recal_hist10": HistogramBinningRecalibration(n_bins=10),
            "recal_hist15": HistogramBinningRecalibration(n_bins=15),
            "recal_scalbin10": PlattBinnerRecalibration(n_bins=10),
            "recal_scalbin15": PlattBinnerRecalibration(n_bins=15),
            "recal_metacal_acc90": MetaCalAccRecalibration(target_acc=0.9),
            "recal_metacal_mis05": MetaCalMisRecalibration(target_alpha=0.05),
            "stack_logistic": Stacking(LogisticRegression(random_state=rs)),
            "stack_rf": Stacking(RandomForestClassifier(random_state=rs)),
            "stack_hgb": Stacking(HistGradientBoostingClassifier(random_state=rs)),
            "stack_naive_bayes": Stacking(GaussianNB()),
            "stack_gaussian_process": Stacking(
                GaussianProcessClassifier(random_state=rs)
            ),
            "finetuning_sigmoid": SigmoidFineTuning(),
            "finetuning": FineTuning(),
        }

        # Add glar
        partitioner_dict = PartitionerDict()
        for name in partitioner_dict.partitioners.keys():
            self.post_training[f"glar_{name}"] = GLAR(
                partitioner_dict[name], random_state=rs
            )
            self.post_training[f"glart_both2_{name}"] = GLARThresholded(
                partitioner_dict[name],
                random_state=rs,
                GL_min=0.02,
                mode="both",
            )
            self.post_training[f"glart_bw2_{name}"] = GLARThresholded(
                partitioner_dict[name],
                random_state=rs,
                GL_min=0.02,
                mode="binwise",
            )

    def __getitem__(self, name: str) -> PostTraining:
        return self.post_training[name]
