from typing import Any, Optional, Tuple, Union

import numpy as np
import scipy.optimize as optimize
import scipy.special as sc
import scipy.stats as stats
from sklearn.linear_model import LogisticRegression
from statsmodels.stats.weightstats import _zconfint_generic

from ..fab import fabzCI
from ..models import (
    BayesianGaussianModel,
    GaussianGaussianModel,
    HorseshoeGaussianModel,
)


def ppi_fab_logistic_ci(
    X: np.ndarray,
    Y: np.ndarray,
    Yhat: np.ndarray,
    X_unlabeled: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    prior: str = "horseshoe",
    alpha: float = 0.1,
    delta: Optional[float] = None,
    point_estimate: bool = False,
    grid_size: int = 200,
    grid_limit: int = 800,
    grid_radius: Optional[float] = None,
    grid_relative: bool = False,
    max_refinements: int = 10,
    return_aux: bool = False,
    init_grid: Optional[np.ndarray] = None,  # [d, 2]
    **kwargs: Any,
) -> Union[Tuple[float, float], Tuple[Tuple[float, float], Tuple[float]]]:
    N, d = X_unlabeled.shape

    (
        estimate,
        (
            _,
            rectifier_mean,  # [d]
            rectifier_var,
            model,
        ),
    ) = ppi_fab_logistic_pointestimate(
        X,
        Y,
        Yhat,
        X_unlabeled,
        Yhat_unlabeled,
        prior,
        return_aux=True,
        **kwargs,
    )

    if init_grid is not None:
        grid_radius = np.diff(init_grid, axis=-1).squeeze(-1)
    elif grid_radius is None:
        n = X.shape[0]
        grid_radius = -stats.norm.ppf(alpha / 2) * np.sqrt(
            np.diag(logistic_covariance_matrix(X, Y, estimate)) / n
        )
    elif grid_relative:
        grid_radius *= np.abs(estimate)

    grid_radius = np.broadcast_to(grid_radius, d)

    def _ci(delta):
        nonlocal grid_radius
        nonlocal grid_size

        num_in_region = 0
        grid_edge_accepted = True
        refinements = 0

        while num_in_region == 0 or grid_edge_accepted:
            if refinements >= max_refinements:
                u = np.ones(d) * np.infty
                return -u, u

            if init_grid is None:
                grid = np.array(
                    np.meshgrid(
                        *[
                            np.linspace(
                                _estimate - _grid_radius * 2**refinements,
                                _estimate + _grid_radius * 2**refinements,
                                int(np.pow(grid_size, 1 / d)) + 1,
                            )
                            for _estimate, _grid_radius in zip(estimate, grid_radius)
                        ]
                    )
                ).T.reshape(-1, d)  # [grid_size, d]
            else:
                grid = np.array(
                    np.meshgrid(
                        *[
                            np.linspace(
                                init_grid[i, 0]
                                - _grid_radius
                                * np.where(refinements == 0, 0, 2**refinements),
                                init_grid[i, 1]
                                + _grid_radius
                                * np.where(refinements == 0, 0, 2**refinements),
                                int(np.pow(grid_size, 1 / d)) + 1,
                            )
                            for i, _grid_radius in enumerate(grid_radius)
                        ]
                    )
                ).T.reshape(-1, d)  # [grid_size, d]

            mu_theta = sc.expit(X_unlabeled @ grid.T)  # [N, grid_size]
            imputed = (
                X_unlabeled[:, None, :]
                * (mu_theta - Yhat_unlabeled[:, None])[..., None]
            )  # [N, grid_size, d]
            imputed_mean = np.mean(imputed, axis=0)
            imputed_var = np.var(imputed, axis=0) / N

            in_region = _isin_minkowski_sum_fab(
                model,
                imputed_mean,
                imputed_var,
                rectifier_mean,
                alpha,
                delta,
                point_estimate=point_estimate,
            )  # [grid_size]
            num_in_region = np.sum(in_region)
            grid_edge_accepted = in_region[0] or in_region[-1]

            grid_size = np.clip(grid_size * 2, None, grid_limit)
            refinements += 1

        in_region_idx = in_region.nonzero()[0]
        if (first := in_region_idx[0]) > 0:
            in_region[first - 1] = True
        if (last := in_region_idx[-1]) < len(in_region) - 1:
            in_region[last + 1] = True
        interval = grid[in_region]
        return interval.min(axis=0), interval.max(axis=0)

    if point_estimate:
        delta = alpha

    if delta is None:

        def _ci_length(delta):
            l, u = _ci(delta)
            return np.mean(u - l)

        delta = optimize.minimize_scalar(_ci_length, bounds=(0, alpha)).x

    if return_aux:
        return _ci(delta), (estimate,)
    return _ci(delta)


def ppi_split_logistic_ci(
    X: np.ndarray,
    Y: np.ndarray,
    Yhat: np.ndarray,
    X_unlabeled: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    alpha: float = 0.1,
    delta: Optional[float] = None,
    point_estimate: bool = False,
    grid_size: int = 200,
    grid_limit: int = 800,
    grid_radius: Optional[float] = None,
    grid_relative: bool = False,
    max_refinements: int = 10,
    return_aux: bool = False,
    init_grid: Optional[np.ndarray] = None,  # [d, 2]
) -> Union[Tuple[float, float], Tuple[Tuple[float, float], Tuple[float]]]:
    N, d = X_unlabeled.shape

    (
        estimate,
        (
            _,
            rectifier_mean,  # [d]
            rectifier_var,  # [d]
        ),
    ) = ppi_logistic_pointestimate(
        X,
        Y,
        Yhat,
        X_unlabeled,
        Yhat_unlabeled,
        return_aux=True,
    )

    if init_grid is not None:
        grid_radius = np.diff(init_grid, axis=-1).squeeze(-1)
    elif grid_radius is None:
        n = X.shape[0]
        grid_radius = -stats.norm.ppf(alpha / 2) * np.sqrt(
            np.diag(logistic_covariance_matrix(X, Y, estimate)) / n
        )
    elif grid_relative:
        grid_radius *= np.abs(estimate)

    grid_radius = np.broadcast_to(grid_radius, d)

    def _ci(delta):
        nonlocal grid_radius
        nonlocal grid_size

        num_in_region = 0
        grid_edge_accepted = True
        refinements = 0

        while num_in_region == 0 or grid_edge_accepted:
            if refinements >= max_refinements:
                u = np.ones(d) * np.infty
                return -u, u

            if init_grid is None:
                grid = np.array(
                    np.meshgrid(
                        *[
                            np.linspace(
                                _estimate - _grid_radius * 2**refinements,
                                _estimate + _grid_radius * 2**refinements,
                                int(np.pow(grid_size, 1 / d)) + 1,
                            )
                            for _estimate, _grid_radius in zip(estimate, grid_radius)
                        ]
                    )
                ).T.reshape(-1, d)  # [grid_size, d]
            else:
                grid = np.array(
                    np.meshgrid(
                        *[
                            np.linspace(
                                init_grid[i, 0]
                                - _grid_radius
                                * np.where(refinements == 0, 0, 2**refinements),
                                init_grid[i, 1]
                                + _grid_radius
                                * np.where(refinements == 0, 0, 2**refinements),
                                int(np.pow(grid_size, 1 / d)) + 1,
                            )
                            for i, _grid_radius in enumerate(grid_radius)
                        ]
                    )
                ).T.reshape(-1, d)  # [grid_size, d]

            mu_theta = sc.expit(X_unlabeled @ grid.T)  # [N, grid_size]
            imputed = (
                X_unlabeled[:, None, :]
                * (mu_theta - Yhat_unlabeled[:, None])[..., None]
            )  # [N, grid_size, d]
            imputed_mean = np.mean(imputed, axis=0)
            imputed_var = np.var(imputed, axis=0) / N

            in_region = _isin_minkowski_sum_split(
                imputed_mean,
                imputed_var,
                rectifier_mean,
                rectifier_var,
                alpha,
                delta,
                point_estimate=point_estimate,
            )  # [grid_size]
            num_in_region = np.sum(in_region)
            grid_edge_accepted = in_region[0] or in_region[-1]

            grid_size = np.clip(grid_size * 2, None, grid_limit)
            refinements += 1

        in_region_idx = in_region.nonzero()[0]
        if (first := in_region_idx[0]) > 0:
            in_region[first - 1] = True
        if (last := in_region_idx[-1]) < len(in_region) - 1:
            in_region[last + 1] = True
        interval = grid[in_region]
        return interval.min(axis=0), interval.max(axis=0)

    if point_estimate:
        delta = alpha

    if delta is None:

        def _ci_length(delta):
            l, u = _ci(delta)
            return np.mean(u - l)

        delta = optimize.minimize_scalar(_ci_length, bounds=(0, alpha)).x

    if return_aux:
        return _ci(delta), (estimate,)
    return _ci(delta)


def ppi_logistic_ci(
    X: np.ndarray,
    Y: np.ndarray,
    Yhat: np.ndarray,
    X_unlabeled: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    alpha: float = 0.1,
    grid_size: int = 200,
    grid_limit: int = 800,
    grid_radius: Optional[float] = None,
    grid_relative: bool = False,
    max_refinements: int = 10,
    return_aux: bool = False,
    init_grid: Optional[np.ndarray] = None,  # [d, 2]
) -> Union[Tuple[float, float], Tuple[Tuple[float, float], Tuple[float]]]:
    N, d = X_unlabeled.shape

    (
        estimate,
        (
            _,
            rectifier_mean,  # [d]
            rectifier_var,  # [d]
        ),
    ) = ppi_logistic_pointestimate(
        X,
        Y,
        Yhat,
        X_unlabeled,
        Yhat_unlabeled,
        return_aux=True,
    )

    if init_grid is not None:
        grid_radius = np.diff(init_grid, axis=-1).squeeze(-1)
    elif grid_radius is None:
        n = X.shape[0]
        grid_radius = -stats.norm.ppf(alpha / 2) * np.sqrt(
            np.diag(logistic_covariance_matrix(X, Y, estimate)) / n
        )
    elif grid_relative:
        grid_radius *= np.abs(estimate)

    grid_radius = np.broadcast_to(grid_radius, d)

    num_in_region = 0
    grid_edge_accepted = True
    refinements = 0

    while num_in_region == 0 or grid_edge_accepted:
        if refinements >= max_refinements:
            u = np.ones(d) * np.infty
            return -u, u

        if init_grid is None:
            grid = np.array(
                np.meshgrid(
                    *[
                        np.linspace(
                            _estimate - _grid_radius * 2**refinements,
                            _estimate + _grid_radius * 2**refinements,
                            int(np.pow(grid_size, 1 / d)) + 1,
                        )
                        for _estimate, _grid_radius in zip(estimate, grid_radius)
                    ]
                )
            ).T.reshape(-1, d)  # [grid_size, d]
        else:
            grid = np.array(
                np.meshgrid(
                    *[
                        np.linspace(
                            init_grid[i, 0]
                            - _grid_radius
                            * np.where(refinements == 0, 0, 2**refinements),
                            init_grid[i, 1]
                            + _grid_radius
                            * np.where(refinements == 0, 0, 2**refinements),
                            int(np.pow(grid_size, 1 / d)) + 1,
                        )
                        for i, _grid_radius in enumerate(grid_radius)
                    ]
                )
            ).T.reshape(-1, d)  # [grid_size, d]

        mu_theta = sc.expit(X_unlabeled @ grid.T)  # [N, grid_size]
        imputed = (
            X_unlabeled[:, None, :] * (mu_theta - Yhat_unlabeled[:, None])[..., None]
        )  # [N, grid_size, d]
        imputed_mean = np.mean(imputed, axis=0)
        imputed_var = np.var(imputed, axis=0) / N

        in_region = _isin_minkowski_sum(
            imputed_mean,
            imputed_var,
            rectifier_mean,
            rectifier_var,
            alpha,
        )  # [grid_size]
        num_in_region = np.sum(in_region)
        grid_edge_accepted = in_region[0] or in_region[-1]

        grid_size = np.clip(grid_size * 2, None, grid_limit)
        refinements += 1

    in_region_idx = in_region.nonzero()[0]
    if (first := in_region_idx[0]) > 0:
        in_region[first - 1] = True
    if (last := in_region_idx[-1]) < len(in_region) - 1:
        in_region[last + 1] = True
    interval = grid[in_region]
    ci = interval.min(axis=0), interval.max(axis=0)

    if return_aux:
        return ci, (estimate,)
    return ci


def ppi_fab_logistic_pointestimate(
    X: np.ndarray,
    Y: np.ndarray,
    Yhat: np.ndarray,
    X_unlabeled: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    prior: str = "horseshoe",
    return_aux: bool = False,
    **kwargs: Any,
) -> Union[float, Tuple[float, tuple]]:
    n, d = X.shape
    N = len(X_unlabeled)

    rectifier = X * (Yhat - Y)[:, None]
    rectifier_mean = rectifier.mean(axis=0)
    rectifier_var = rectifier.var(axis=0) / n

    if prior == "horseshoe":
        model = HorseshoeGaussianModel(rectifier_var)
    elif prior == "gaussian":
        t2 = kwargs.get("t2", 1.0)
        model = GaussianGaussianModel(rectifier_var, t2 * rectifier_var)

    rectifier_est = model.posterior_mean(rectifier_mean)

    init = logistic(X, Y)

    def root(theta):
        mu_theta = sc.expit(X_unlabeled @ theta)
        return X_unlabeled.T @ (mu_theta - Yhat_unlabeled) / N + rectifier_est

    estimate = optimize.root(root, init, tol=1e-15).x

    mu_theta = sc.expit(X_unlabeled @ estimate)
    imputed_mean = X_unlabeled.T @ (mu_theta - Yhat_unlabeled) / N

    if return_aux:
        return estimate, (
            imputed_mean,
            rectifier_mean,
            rectifier_var,
            model,
        )
    return estimate


def ppi_logistic_pointestimate(
    X: np.ndarray,
    Y: np.ndarray,
    Yhat: np.ndarray,
    X_unlabeled: np.ndarray,
    Yhat_unlabeled: np.ndarray,
    return_aux: bool = False,
) -> Union[float, Tuple[float, tuple]]:
    n, d = X.shape
    N = len(X_unlabeled)

    rectifier = X * (Yhat - Y)[:, None]
    rectifier_mean = rectifier.mean(axis=0)

    init = logistic(X, Y)

    def root(theta):
        mu_theta = sc.expit(X_unlabeled @ theta)
        return X_unlabeled.T @ (mu_theta - Yhat_unlabeled) / N + rectifier_mean

    estimate = optimize.root(root, init, tol=1e-15).x

    mu_theta = sc.expit(X_unlabeled @ estimate)
    imputed_mean = X_unlabeled.T @ (mu_theta - Yhat_unlabeled) / N

    if return_aux:
        rectifier_var = rectifier.var(axis=0) / n
        return estimate, (
            imputed_mean,
            rectifier_mean,
            rectifier_var,
        )
    return estimate


def _isin_minkowski_sum(
    imputed_mean: Union[float, np.ndarray],  # [grid_size, d]
    imputed_var: Union[float, np.ndarray],  # [grid_size, d]
    rectifier_mean: Union[float, np.ndarray],  # [d]
    rectifier_var: Union[float, np.ndarray],  # [d]
    alpha: float,
):
    d = len(rectifier_mean)
    rectified_point_estimate = imputed_mean + rectifier_mean  # [grid_size, d]

    l = stats.norm.ppf(alpha / (2 * d)) * np.sqrt(
        rectifier_var + imputed_var
    )  # [grid_size, d]

    return np.all(np.abs(rectified_point_estimate) <= -l, axis=-1)  # [grid_size]


def _isin_minkowski_sum_split(
    imputed_mean: Union[float, np.ndarray],  # [grid_size, d]
    imputed_var: Union[float, np.ndarray],  # [grid_size, d]
    rectifier_mean: Union[float, np.ndarray],  # [d]
    rectifier_var: Union[float, np.ndarray],  # [d]
    alpha: float,
    delta: float,
    point_estimate: bool = False,
):
    d = len(rectifier_mean)
    rectified_point_estimate = imputed_mean + rectifier_mean  # [grid_size, d]

    l_delta = stats.norm.ppf(delta / (2 * d)) * np.sqrt(rectifier_var)  # [d]

    if point_estimate:
        l = l_delta
    else:
        l_f = stats.norm.ppf((alpha - delta) / (2 / d)) * np.sqrt(imputed_var)
        l = l_f + l_delta  # [grid_size, d]

    return np.all(np.abs(rectified_point_estimate) <= -l, axis=-1)  # [grid_size]


def _isin_minkowski_sum_fab(
    model: BayesianGaussianModel,
    imputed_mean: Union[float, np.ndarray],
    imputed_var: Union[float, np.ndarray],
    rectifier_mean: Union[float, np.ndarray],
    alpha: float,
    delta: float,
    point_estimate: bool = False,
):
    d = len(rectifier_mean)
    estimate = -imputed_mean

    l_delta, u_delta = fabzCI(model, rectifier_mean, delta / d)  # [d]
    l_delta = np.asarray(l_delta)
    u_delta = np.asarray(u_delta)

    if point_estimate:
        l = l_delta
        u = u_delta
    else:
        l_f = stats.norm.ppf((alpha - delta) / (2 * d)) * np.sqrt(
            imputed_var
        )  # [grid_size, d]
        l = l_f + l_delta
        u = -l_f + u_delta
    return np.all(np.logical_and(estimate >= l, estimate <= u), axis=-1)  # [grid_size]


def logistic(X, Y):
    regression = LogisticRegression(
        penalty=None,
        solver="lbfgs",
        max_iter=10000,
        tol=1e-15,
        fit_intercept=False,
    ).fit(X, Y)
    return regression.coef_.squeeze(axis=0)


def logistic_covariance_matrix(X, Y, pointest):
    n, d = X.shape
    mu = sc.expit(X @ pointest)  # [n]
    V = np.zeros((d, d))  # [d, d]
    grads = np.zeros((n, d))  # [n, d]
    for i in range(n):
        V += 1 / n * mu[i] * (1 - mu[i]) * X[i : i + 1, :].T @ X[i : i + 1, :]
        grads[i] += (mu[i] - Y[i]) * X[i]
    V_inv = np.linalg.inv(V)
    return V_inv @ np.broadcast_to(np.cov(grads.T), V.shape) @ V_inv


def classical_logistic_ci(X, Y, alpha=0.1, alternative="two-sided"):
    n, _ = X.shape
    pointest = logistic(X, Y)  # [d]
    cov_mat = logistic_covariance_matrix(X, Y, pointest)  # [d, d]
    return _zconfint_generic(
        pointest, np.sqrt(np.diag(cov_mat) / n), alpha, alternative
    )
