from typing import Any, Optional, Tuple, Union

import numpy as np
import scipy.optimize as optimize
import scipy.stats as stats

from ..fab import fabzCI
from ..models import (
    BayesianGaussianModel,
    GaussianGaussianModel,
    HorseshoeGaussianModel,
)


def ppi_fab_quantile_ci(
    Y: np.ndarray,
    Yhat: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    q: float = 0.5,
    prior: str = "horseshoe",
    alpha: float = 0.1,
    delta: Optional[float] = None,
    point_estimate: bool = False,
    exact_grid: bool = False,
    return_aux: bool = False,
    **kwargs: Any,
) -> Union[Tuple[float, float], Tuple[Tuple[float, float], Tuple[float, float]]]:
    (
        estimate,
        (
            grid,
            imputed_mean,
            rectifier_mean,
            imputed_var,
            rectifier_var,
            model,
        ),
    ) = ppi_fab_quantile_pointestimate(
        Y,
        Yhat,
        Yhat_unlabeled,
        q,
        prior,
        exact_grid=exact_grid,
        return_aux=True,
        **kwargs,
    )

    def _ci(delta):
        in_region = _isin_minkowski_sum_fab(
            model,
            imputed_mean,
            imputed_var,
            rectifier_mean,
            alpha,
            delta,
            null=q,
            point_estimate=point_estimate,
        )
        interval = grid[in_region]  # already ordered

        if not np.any(interval):
            return np.nan, np.nan
        else:
            return interval[0], interval[-1]

    if point_estimate:
        delta = alpha

    if delta is None:

        def _ci_length(delta):
            l, u = _ci(delta)
            return u - l

        print("before")
        delta = optimize.minimize_scalar(_ci_length, bounds=(0, alpha)).x
        print("after")

    if return_aux:
        return _ci(delta), (estimate,)
    return _ci(delta)


def ppi_split_quantile_ci(
    Y: np.ndarray,
    Yhat: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    q: float = 0.5,
    alpha: float = 0.1,
    delta: Optional[float] = None,
    point_estimate: bool = False,
    exact_grid: bool = False,
    return_aux: bool = False,
) -> Union[Tuple[float, float], Tuple[Tuple[float, float], Tuple[float]]]:
    (
        estimate,
        (
            grid,
            imputed_mean,
            rectifier_mean,
            imputed_var,
            rectifier_var,
        ),
    ) = ppi_quantile_pointestimate(
        Y,
        Yhat,
        Yhat_unlabeled,
        q,
        exact_grid=exact_grid,
        return_aux=True,
    )

    def _ci(delta):
        in_region = _isin_minkowski_sum(
            imputed_mean,
            imputed_var,
            rectifier_mean,
            rectifier_var,
            alpha,
            delta,
            null=q,
            point_estimate=point_estimate,
        )
        interval = grid[in_region]  # already ordered

        return interval[0], interval[-1]

    if point_estimate:
        delta = alpha

    if delta is None:

        def _ci_length(delta):
            l, u = _ci(delta)
            return u - l

        delta = optimize.minimize_scalar(_ci_length, bounds=(0, alpha)).x

    if return_aux:
        return _ci(delta), (estimate,)
    return _ci(delta)


def ppi_fab_quantile_pointestimate(
    Y: np.ndarray,
    Yhat: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    q: float = 0.5,
    prior: str = "horseshoe",
    exact_grid: bool = False,
    return_aux: bool = False,
    **kwargs: Any,
) -> Union[float, Tuple[float, tuple]]:
    assert len(Y.shape) == 1
    n = len(Y)

    grid = np.concatenate([Y, Yhat, Yhat_unlabeled])
    if exact_grid:
        grid = np.sort(grid)
    else:
        grid = np.linspace(grid.min(), grid.max(), 1000)  # tweaked to run faster

    imputed = (Yhat_unlabeled[:, None] < grid).astype(float)
    imputed_mean = imputed.mean(axis=0)

    rectifier = (Y[:, None] < grid).astype(float) - (Yhat[:, None] < grid).astype(float)
    rectifier_mean = rectifier.mean(axis=0)
    rectifier_var = rectifier.var(axis=0) / n

    if prior == "horseshoe":
        model = HorseshoeGaussianModel(rectifier_var)  # [grid]
    elif prior == "gaussian":
        t2 = kwargs.get("t2", 1.0)
        model = GaussianGaussianModel(rectifier_var, t2 * rectifier_var)

    rectifier_est = model.posterior_mean(rectifier_mean)

    rectified_cdf = imputed_mean + rectifier_est  # [grid]
    minimizers = np.argmin(np.abs(rectified_cdf - q))
    minimizer = minimizers if isinstance(minimizers, (int, np.int64)) else minimizers[0]
    estimate = grid[minimizer]

    if return_aux:
        N = len(Yhat_unlabeled)
        imputed_var = imputed.var(axis=0) / N  # [grid]
        return estimate, (
            grid,
            imputed_mean,
            rectifier_mean,
            imputed_var,
            rectifier_var,
            model,
        )
    return estimate


def ppi_quantile_pointestimate(
    Y: np.ndarray,
    Yhat: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    q: float = 0.5,
    exact_grid: bool = False,
    return_aux: bool = False,
) -> Union[float, Tuple[float, tuple]]:
    assert len(Y.shape) == 1

    grid = np.concatenate([Y, Yhat, Yhat_unlabeled])
    if exact_grid:
        grid = np.sort(grid)
    else:
        grid = np.linspace(grid.min(), grid.max(), 5000)

    imputed = (Yhat_unlabeled[:, None] < grid).astype(float)
    imputed_mean = imputed.mean(axis=0)

    rectifier = (Y[:, None] < grid).astype(float) - (Yhat[:, None] < grid).astype(float)
    rectifier_mean = rectifier.mean(axis=0)

    rectified_cdf = imputed_mean + rectifier_mean  # [grid]
    minimizers = np.argmin(np.abs(rectified_cdf - q))
    minimizer = minimizers if isinstance(minimizers, (int, np.int64)) else minimizers[0]
    estimate = grid[minimizer]

    if return_aux:
        n = len(Y)
        N = len(Yhat_unlabeled)
        imputed_var = imputed.var(axis=0) / N  # [grid]
        rectifier_var = rectifier.var(axis=0) / n  # [grid]
        return estimate, (
            grid,
            imputed_mean,
            rectifier_mean,
            imputed_var,
            rectifier_var,
        )
    return estimate


def _isin_minkowski_sum(
    imputed_mean: Union[float, np.ndarray],
    imputed_var: Union[float, np.ndarray],
    rectifier_mean: Union[float, np.ndarray],
    rectifier_var: Union[float, np.ndarray],
    alpha: float,
    delta: float,
    null: float = 0.0,
    point_estimate: bool = False,
):
    rectified_point_estimate = imputed_mean + rectifier_mean - null

    l_delta = stats.norm.ppf(delta / 2) * np.sqrt(rectifier_var)

    if point_estimate:
        l = l_delta
    else:
        l_f = stats.norm.ppf((alpha - delta) / 2) * np.sqrt(imputed_var)
        l = l_f + l_delta

    return np.abs(rectified_point_estimate) < -l


def _isin_minkowski_sum_fab(
    model: BayesianGaussianModel,
    imputed_mean: Union[float, np.ndarray],
    imputed_var: Union[float, np.ndarray],
    rectifier_mean: Union[float, np.ndarray],
    alpha: float,
    delta: float,
    null: float = 0.0,
    point_estimate: bool = False,
):
    estimate = -(imputed_mean - null)

    l_delta, u_delta = fabzCI(model, rectifier_mean, delta)
    l_delta = np.asarray(l_delta)
    u_delta = np.asarray(u_delta)

    if point_estimate:
        l = l_delta
        u = u_delta
    else:
        l_f = stats.norm.ppf((alpha - delta) / 2) * np.sqrt(imputed_var)
        l = l_f + l_delta
        u = -l_f + u_delta
    return np.logical_and(estimate > l, estimate < u)
