from typing import Any, Optional, Tuple, Union

import numpy as np
import scipy.optimize as optimize
import scipy.stats as stats
from ppi_py.ppi import _calc_lam_glm, _ols

from ..fab import fabzCI
from ..models import GaussianGaussianModel, HorseshoeGaussianModel


def ppi_fab_ols_ci(
    X: np.ndarray,
    Y: np.ndarray,
    Yhat: np.ndarray,
    X_unlabeled: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    prior: str = "horseshoe",
    alpha: float = 0.1,
    delta: Optional[float] = None,
    lam: Optional[float] = None,
    point_estimate: bool = False,
    return_aux: bool = False,
    **kwargs: Any,
) -> Union[Tuple[float, float], Tuple[Tuple[float, float], Tuple[float, float]]]:
    (
        estimate,
        (
            imputed_mean,
            rectifier_mean,
            imputed_var,
            _,
            lam,
            model,
        ),
    ) = ppi_fab_ols_pointestimate(
        X, Y, Yhat, X_unlabeled, Yhat_unlabeled, prior, lam, return_aux=True, **kwargs
    )

    def _ci(delta):
        l_delta, u_delta = fabzCI(model, rectifier_mean, delta)
        if point_estimate:
            l_f = 0.0
        else:
            l_f = stats.norm.ppf((alpha - delta) / 2) * np.sqrt(imputed_var)

        return imputed_mean + l_delta + l_f, imputed_mean + u_delta - l_f

    if point_estimate:
        delta = alpha

    if delta is None:

        def _ci_length(delta):
            l, u = _ci(delta)
            return np.mean(u - l)

        delta = optimize.minimize_scalar(_ci_length, bounds=(0, alpha)).x

    if return_aux:
        return _ci(delta), (estimate, lam)
    return _ci(delta)


def ppi_split_ols_ci(
    X: np.ndarray,
    Y: np.ndarray,
    Yhat: np.ndarray,
    X_unlabeled: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    alpha: float = 0.1,
    delta: Optional[float] = None,
    lam: Optional[float] = None,
    point_estimate: bool = False,
    return_aux: bool = False,
) -> Union[Tuple[float, float], Tuple[Tuple[float, float], Tuple[float, float]]]:
    (
        estimate,
        (
            _,
            _,
            imputed_var,
            rectifier_var,
            lam,
        ),
    ) = ppi_ols_pointestimate(
        X, Y, Yhat, X_unlabeled, Yhat_unlabeled, lam, return_aux=True
    )

    def _ci(delta):
        l = stats.norm.ppf(delta / 2) * np.sqrt(rectifier_var)

        if not point_estimate:
            l_f = stats.norm.ppf((alpha - delta) / 2) * np.sqrt(rectifier_var)
            l += l_f

        return estimate + l, estimate - l

    if point_estimate:
        delta = alpha

    if delta is None:

        def _ci_length(delta):
            l, u = _ci(delta)
            return np.mean(u - l)

        delta = optimize.minimize_scalar(_ci_length, bounds=(0, alpha)).x

    if return_aux:
        return _ci(delta), (estimate, lam)
    return _ci(delta)


def ppi_fab_ols_pointestimate(
    X: np.ndarray,
    Y: np.ndarray,
    Yhat: np.ndarray,
    X_unlabeled: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    prior: str = "horseshoe",
    lam: Optional[float] = None,
    return_aux: bool = False,
    **kwargs: Any,
) -> Union[float, Tuple[float, tuple]]:
    use_unlabeled = lam != 0.0

    if return_aux:
        imputed_theta, imputed_theta_se = _ols(
            X_unlabeled, Yhat_unlabeled, return_se=True
        )
        imputed_theta_var = np.square(imputed_theta_se)
    else:
        imputed_theta = _ols(X_unlabeled, Yhat_unlabeled)

    if lam is None:
        rectifier_theta = _ols(X, Y - Yhat)
        estimate = imputed_theta + rectifier_theta
        grads, grads_hat, grads_hat_unlabeled, inv_hessian = _ols_get_stats(
            estimate,
            X,
            Y,
            Yhat,
            X_unlabeled,
            Yhat_unlabeled,
            use_unlabeled=use_unlabeled,
        )
        lam = _calc_lam_glm(
            grads,
            grads_hat,
            grads_hat_unlabeled,
            inv_hessian,
            clip=True,
        )

    if return_aux:
        rectifier_theta, rectifier_theta_se = _ols(X, Y - lam * Yhat, return_se=True)
        rectifier_theta_var = (
            np.square(rectifier_theta_se) + np.square(lam - 1.0) * imputed_theta_var
        )
    else:
        rectifier_theta = _ols(X, Y - lam * Yhat)

    rectifier_theta += (lam - 1.0) * imputed_theta
    estimate = imputed_theta + rectifier_theta

    if prior == "horseshoe":
        model = HorseshoeGaussianModel(rectifier_theta_var)
    elif prior == "gaussian":
        t2 = kwargs.get("t2", 1.0)
        model = GaussianGaussianModel(rectifier_theta_var, t2 * rectifier_theta_var)

    rectifier_theta_est = model.posterior_mean(rectifier_theta)
    estimate = imputed_theta + rectifier_theta_est

    if return_aux:
        return estimate, (
            imputed_theta,
            rectifier_theta,
            imputed_theta_var,
            rectifier_theta_var,
            lam,
            model,
        )
    return estimate


def ppi_ols_pointestimate(
    X: np.ndarray,
    Y: np.ndarray,
    Yhat: np.ndarray,
    X_unlabeled: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    lam: Optional[float] = None,
    return_aux: bool = False,
) -> Union[float, Tuple[float, tuple]]:
    use_unlabeled = lam != 0.0

    if lam is None:
        imputed_theta = _ols(X_unlabeled, Yhat_unlabeled)
        rectifier_theta = _ols(X, Y - Yhat)
        estimate = imputed_theta + rectifier_theta
        grads, grads_hat, grads_hat_unlabeled, inv_hessian = _ols_get_stats(
            estimate,
            X,
            Y,
            Yhat,
            X_unlabeled,
            Yhat_unlabeled,
            use_unlabeled=use_unlabeled,
        )
        lam = _calc_lam_glm(
            grads,
            grads_hat,
            grads_hat_unlabeled,
            inv_hessian,
            clip=True,
        )
        return ppi_ols_pointestimate(
            X,
            Y,
            Yhat,
            X_unlabeled,
            Yhat_unlabeled,
            lam=lam,
            return_aux=return_aux,
        )

    imputed_theta = _ols(X_unlabeled, lam * Yhat_unlabeled)
    rectifier_theta = _ols(X, Y - lam * Yhat)
    estimate = imputed_theta + rectifier_theta

    if return_aux:
        n = len(X)
        N = len(X_unlabeled)

        grads, grads_hat, grads_hat_unlabeled, inv_hessian = _ols_get_stats(
            estimate,
            X,
            Y,
            Yhat,
            X_unlabeled,
            Yhat_unlabeled,
            use_unlabeled=use_unlabeled,
        )

        imputed_theta_var = np.cov(lam * grads_hat_unlabeled.T)
        imputed_theta_var = np.diag(inv_hessian @ imputed_theta_var @ inv_hessian) / N

        rectifier_theta_var = np.cov(grads.T - lam * grads_hat.T)
        rectifier_theta_var = (
            np.diag(inv_hessian @ rectifier_theta_var @ inv_hessian) / n
        )

        return estimate, (
            imputed_theta,
            rectifier_theta,
            imputed_theta_var,
            rectifier_theta_var,
            lam,
        )
    return estimate


def _ols_get_stats(
    pointest: np.ndarray,
    X: np.ndarray,
    Y: np.ndarray,
    Yhat: np.ndarray,
    X_unlabeled: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    use_unlabeled: bool = True,
):
    n = len(X)
    N = len(X_unlabeled)

    grads = X * (X @ pointest - Y)[:, None]
    grads_hat = X * (X @ pointest - Yhat)[:, None]
    grads_hat_unlabeled = (
        X_unlabeled * (X_unlabeled @ pointest - Yhat_unlabeled)[:, None]
    )
    XXt = X.T @ X
    if use_unlabeled:
        XXt_unlabeled = X_unlabeled.T @ X_unlabeled
        hessian = XXt / (n + N) + XXt_unlabeled / (n + N)
    else:
        hessian = XXt / n

    return (
        grads,
        grads_hat,
        grads_hat_unlabeled,
        np.linalg.inv(hessian),
    )
