from typing import Any, Optional, Tuple, Union

import numpy as np
import scipy.optimize as optimize
import scipy.stats as stats
from ppi_py.utils import reshape_to_2d
from statsmodels.stats.weightstats import _zconfint_generic

from ..fab import fabzCI
from ..models import (
    GaussianGaussianModel,
    HorseshoeGaussianModel,
    ScaledHorseshoeGaussianModel,
)


def ppi_fab_mean_ci(
    Y: np.ndarray,
    Yhat: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    prior: str = "horseshoe",
    alpha: float = 0.1,
    delta: Optional[float] = None,
    lam: Optional[float] = None,
    point_estimate: bool = False,
    return_aux: bool = False,
    **kwargs: Any,
) -> Union[Tuple[float, float], Tuple[Tuple[float, float], Tuple[float, float]]]:
    Y = reshape_to_2d(Y)
    Yhat = reshape_to_2d(Yhat)
    Yhat_unlabeled = reshape_to_2d(Yhat_unlabeled)

    (
        estimate,
        (
            imputed_mean,
            rectifier_mean,
            imputed_var,
            _,
            lam,
            model,
        ),
    ) = ppi_fab_mean_pointestimate(
        Y, Yhat, Yhat_unlabeled, prior, lam, return_aux=True, **kwargs
    )

    def _ci(delta):
        l_delta, u_delta = fabzCI(model, rectifier_mean, delta)
        if point_estimate:
            l_f = 0.0
        else:
            l_f = stats.norm.ppf((alpha - delta) / 2) * np.sqrt(imputed_var)

        return imputed_mean + l_delta + l_f, imputed_mean + u_delta - l_f

    if point_estimate:
        delta = alpha

    if delta is None:

        def _ci_length(delta):
            l, u = _ci(delta)
            return np.mean(u - l)

        delta = optimize.minimize_scalar(_ci_length, bounds=(0, alpha)).x

    if return_aux:
        return _ci(delta), (estimate, lam)
    return _ci(delta)


def ppi_split_mean_ci(
    Y: np.ndarray,
    Yhat: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    alpha: float = 0.1,
    delta: Optional[float] = None,
    lam: Optional[float] = None,
    point_estimate: bool = False,
    return_aux: bool = False,
) -> Union[Tuple[float, float], Tuple[Tuple[float, float], Tuple[float, float]]]:
    Y = reshape_to_2d(Y)
    Yhat = reshape_to_2d(Yhat)
    Yhat_unlabeled = reshape_to_2d(Yhat_unlabeled)

    (
        estimate,
        (
            _,
            _,
            imputed_var,
            rectifier_var,
            lam,
        ),
    ) = ppi_mean_pointestimate(Y, Yhat, Yhat_unlabeled, lam, return_aux=True)

    def _ci(delta):
        l = stats.norm.ppf(delta / 2) * np.sqrt(rectifier_var)

        if not point_estimate:
            l_f = stats.norm.ppf((alpha - delta) / 2) * np.sqrt(rectifier_var)
            l += l_f

        return estimate + l, estimate - l

    if point_estimate:
        delta = alpha

    if delta is None:

        def _ci_length(delta):
            l, u = _ci(delta)
            return np.mean(u - l)

        delta = optimize.minimize_scalar(_ci_length, bounds=(0, alpha)).x

    if return_aux:
        return _ci(delta), (estimate, lam)
    return _ci(delta)


def ppi_fab_mean_pointestimate(
    Y,
    Yhat,
    Yhat_unlabeled,
    prior: str = "horseshoe",
    lam: Optional[float] = None,
    return_aux: bool = False,
    **kwargs: Any,
) -> Union[float, Tuple[float, tuple]]:
    Y = reshape_to_2d(Y)
    Yhat = reshape_to_2d(Yhat)
    Yhat_unlabeled = reshape_to_2d(Yhat_unlabeled)

    n = len(Y)
    N = len(Yhat_unlabeled)

    imputed_mean = np.mean(Yhat_unlabeled, axis=0)  # [d]
    eps = np.finfo(imputed_mean.dtype).eps

    if lam is None:
        # TODO: check carefully generalisation to multivariate
        Yhat_all = np.concatenate([Yhat, Yhat_unlabeled])
        if np.all(var_all := np.var(Yhat_all, ddof=1, axis=0)) > 0:
            cov = np.sum(
                (Y - np.mean(Y, axis=0)) * (Yhat - np.mean(Yhat, axis=0)), axis=0
            ) / (n - 1 + eps)
            lam = cov / (1 + n / N) / var_all
            lam = np.clip(lam, 0, 1)
        else:
            lam = 0.0

    rectifier = Y - lam * Yhat + (lam - 1.0) * imputed_mean  # [n, d]
    rectifier_mean = np.mean(rectifier, axis=0)  # [d]
    rectifier_var = (
        np.var(Y - lam * Yhat, ddof=1, axis=0) / n
        + (lam - 1.0) ** 2 * np.var(Yhat_unlabeled, ddof=1, axis=0) / N
    )  # [d]

    if prior == "horseshoe":
        model = HorseshoeGaussianModel(rectifier_var)
    elif prior == "gaussian":
        t2 = kwargs.get("t2", 1.0)
        model = GaussianGaussianModel(rectifier_var, t2 * rectifier_var)
    elif prior == "scaled_horseshoe":
        l2 = kwargs.get("l2", 1.0)
        model = ScaledHorseshoeGaussianModel(rectifier_var, l2)
    elif prior == "scaled_gaussian":
        t2 = kwargs.get("t2", 1.0)
        model = GaussianGaussianModel(rectifier_var, t2)

    rectifier_est = model.posterior_mean(rectifier_mean)  # [d]

    estimate = imputed_mean + rectifier_est  # [d]

    if return_aux:
        imputed_var = (
            np.sum(np.square(lam * Yhat_unlabeled - imputed_mean), axis=0)
            / (N - 1 + eps)
            / N
        )
        return estimate, (
            imputed_mean,
            rectifier_mean,
            imputed_var,
            rectifier_var,
            lam,
            model,
        )
    return estimate


def ppi_mean_pointestimate(
    Y: np.ndarray,
    Yhat: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    lam: Optional[float] = None,
    return_aux: bool = False,
) -> Union[float, Tuple[float, tuple]]:
    Y = reshape_to_2d(Y)
    Yhat = reshape_to_2d(Yhat)
    Yhat_unlabeled = reshape_to_2d(Yhat_unlabeled)

    n = len(Y)
    N = len(Yhat_unlabeled)

    imputed_mean = np.mean(Yhat_unlabeled, axis=0)
    eps = np.finfo(imputed_mean.dtype).eps

    if lam is None:
        Yhat_all = np.concatenate([Yhat, Yhat_unlabeled])
        if np.all(var_all := np.var(Yhat_all, ddof=1, axis=0)) > 0:
            cov = np.sum(
                (Y - np.mean(Y, axis=0)) * (Yhat - np.mean(Yhat, axis=0)), axis=0
            ) / (n - 1 + eps)
            lam = cov / (1 + n / N) / var_all
            lam = np.clip(lam, 0, 1)
        else:
            lam = 0.0

    imputed_mean *= lam

    rectifier = Y - lam * Yhat
    rectifier_mean = np.mean(rectifier, axis=0)

    estimate = imputed_mean + rectifier_mean

    if return_aux:
        imputed_var = (
            np.sum(np.square(lam * Yhat_unlabeled - imputed_mean), axis=0)
            / (N - 1 + eps)
            / N
        )
        rectifier_var = (
            np.sum(np.square(rectifier - rectifier_mean), axis=0) / (n - 1 + eps) / n
        )
        return estimate, (
            imputed_mean,
            rectifier_mean,
            imputed_var,
            rectifier_var,
            lam,
        )
    return estimate


def classical_mean_ci(Y, alpha=0.1, alternative="two-sided"):
    Y = reshape_to_2d(Y)
    n = Y.shape[0]
    return _zconfint_generic(Y.mean(0), Y.std(0) / np.sqrt(n), alpha, alternative)
