import typing as tp

import numpy as np
from sklearn.base import BaseEstimator

from generative_prediction_sets.simulation_data import get_k_min

EstimatorType = tp.TypeVar("EstimatorType", bound=BaseEstimator)


class TransformedEstimator(BaseEstimator, tp.Generic[EstimatorType]):
    """A wrapper for an estimator that transforms the inputs, outputs, and prediction outputs."""

    def __init__(
        self,
        estimator: EstimatorType,
        transform: tp.Callable[
            [np.ndarray, np.ndarray | None], tuple[np.ndarray, np.ndarray | None]
        ] = lambda x, y: (x, y),
        predict_fn: tp.Callable[
            [EstimatorType, np.ndarray], np.ndarray
        ] = lambda estimator, X: estimator.predict(X),
        invert_outputs: tp.Callable[[np.ndarray], np.ndarray] = lambda x: x,
    ):
        self.estimator = estimator
        self.transform = transform
        self.predict_method = predict_fn
        self.invert_outputs = invert_outputs
        super().__init__()

    def fit(self, X: np.ndarray, y: np.ndarray):
        X, y = self.transform(X, y)
        self.estimator.fit(X, y)

        return self

    def predict(self, X: np.ndarray):
        X, _ = self.transform(X, None)
        raw_outputs = self.predict_method(self.estimator, X)
        return self.invert_outputs(raw_outputs)


def explode_bernoulli_trials(X: np.ndarray, y: np.ndarray | None = None):
    """If y is a set of bernoulli trials (i.e. a 2D array), explode it into a 1D array, repeating X for each trial.

    If y is None, simply return X unchanged."""
    if y is None:
        return X, y

    flat_Y = y.flatten()
    flat_X = np.repeat(X, y.shape[1], axis=0)

    return flat_X, flat_Y


def BernoulliPEstimator(
    base_estimator: EstimatorType,
) -> TransformedEstimator[EstimatorType]:
    """Estimate the probability of a positive outcome from a set of Bernoulli trials.

    The returned estimator expects a 2D array of bernoulli trials for 'y'."""
    if not hasattr(base_estimator, "predict_proba"):
        raise ValueError("Base estimator must have a predict_proba method.")
    return TransformedEstimator(
        base_estimator,
        transform=explode_bernoulli_trials,
        predict_fn=lambda estimator, X: estimator.predict_proba(X)[:, 1],
    )


def BernoulliMapEstimator(
    base_estimator: EstimatorType,
) -> TransformedEstimator[EstimatorType]:

    def _map_estimate_transform(X: np.ndarray, y: np.ndarray | None = None):
        if y is None:
            return X, y
        map_estimates = (np.sum(y, axis=1) + 1) / (y.shape[1] + 2)
        return X, map_estimates

    return TransformedEstimator(
        base_estimator,
        transform=_map_estimate_transform,
    )


def BernoulliFirstKMapEstimator(
    base_estimator: EstimatorType,
) -> TransformedEstimator[EstimatorType]:
    def _first_k_map_estimate_transform(X: np.ndarray, y: np.ndarray | None = None):
        if y is None:
            return X, y
        k_mins = get_k_min(y, no_success_value=y.shape[1])
        map_estimates = 2 / (k_mins + 2)
        return X, map_estimates

    return TransformedEstimator(
        base_estimator,
        transform=_first_k_map_estimate_transform,
    )
