import functools as ft
from typing import Optional, Tuple, Union

import numpy as np
import scipy.optimize as optimize
import scipy.stats as stats

from .models import BayesianGaussianModel, GaussianGaussianModel, HorseshoeGaussianModel


def fabzCI(
    model: BayesianGaussianModel,
    y: Union[float, np.ndarray],
    alpha: float = 0.05,
    coord: Optional[int] = None,
    xtol: float = np.finfo(float).tiny,
    rtol: float = 8.88179e-16,
    maxiter: int = 10000,
) -> Tuple[float, float]:
    s2 = model.s2 if coord is None else model.s2[coord]

    if isinstance(s2, np.ndarray) and s2.ndim > 0:
        y = np.broadcast_to(y, s2.shape)
        return list(
            zip(
                *[
                    fabzCI(
                        model,
                        y,
                        alpha=alpha,
                        coord=coord,
                        xtol=xtol,
                        rtol=rtol,
                        maxiter=maxiter,
                    )
                    for coord in range(len(s2))
                ]
            )
        )

    if coord is not None:
        y = y[coord]

    s = np.sqrt(s2)

    # the root of this function gives the lower/upper bound of the interval
    if isinstance(model, GaussianGaussianModel):
        if coord is None:
            t2 = model.t2
        else:
            t2 = np.broadcast_to(model.t2, model.s2.shape)[coord]

        def root(theta, upper=False):
            if upper:
                z = np.clip(1 - alpha + stats.norm.cdf((y - theta) / s), None, 1)
            else:
                z = np.clip(alpha - stats.norm.cdf((theta - y) / s), 0, None)

            q = stats.norm.ppf(z) if z <= 0.5 else -stats.norm.isf(z)
            return (y + s * q) / (1 + 2 * s2 / t2) - theta
    else:

        def root(theta, upper=False):
            w = optimal_w(
                model, theta, alpha=alpha, s2=s2, xtol=xtol, rtol=rtol, maxiter=maxiter
            )
            if upper:
                z = np.clip(1 - alpha * w, None, 1)
            else:
                z = np.clip(alpha * (1 - w), 0, None)

            q = stats.norm.ppf(z) if z <= 0.5 else -stats.norm.isf(z)
            return y + s * q - theta

    def root_scalar(*args, **kwargs):
        return optimize.root_scalar(
            *args, xtol=xtol, rtol=rtol, maxiter=maxiter, **kwargs
        )

    # we know that thetaU > y + s * z_{1 - alpha} and ubroot is decreasing
    # then, we want a < b such that
    #   - a <  y + s * z_{1 - alpha} as close as possible with ubroot(a) >= 0
    #   - b > y + s * z_{1 - alpha} with ubroot(b) <= 0
    a = b = y + s * stats.norm.ppf(1 - alpha)
    ubroot = ft.partial(root, upper=True)
    while ubroot(a) < 0:
        a -= s * 1e-12
    while ubroot(b) > 0:
        b += s
    thetaU = root_scalar(ubroot, bracket=[a, b]).root

    # we know that thetaL < y + s * z_{alpha} and lbroot is decreasing
    # then, we want a < b such that
    #   - a <  y + s * z_{alpha} with ubroot(a) >= 0
    #   - b > y + s * z_{alpha} as close as possible with ubroot(b) <= 0
    a = b = y + s * stats.norm.ppf(alpha)
    lbroot = root
    while lbroot(a) < 0:
        a -= s
    while lbroot(b) > 0:
        # 1e-12 : x = 1: s
        b += s * 1e-12
    thetaL = root_scalar(lbroot, bracket=[a, b]).root

    return thetaL, thetaU


@ft.partial(np.vectorize, excluded=(0,))
@ft.lru_cache(maxsize=None)
def optimal_w(
    model: BayesianGaussianModel,
    theta: float,
    *,
    alpha: float = 0.05,
    s2: Optional[float] = None,
    xtol: float = np.finfo(float).tiny,
    rtol: float = 8.88179e-16,
    maxiter: int = 10000,
    return_aux: bool = False,
) -> Union[float, Tuple[float, optimize.RootResults]]:
    if theta > 0.0:
        res = optimal_w(
            model,
            -theta,
            alpha=alpha,
            s2=s2,
            xtol=xtol,
            rtol=rtol,
            maxiter=maxiter,
            return_aux=return_aux,
        )
        if return_aux:
            return 1 - res[0], res[1]
        return 1 - res

    s2 = model.s2 if s2 is None else s2

    if isinstance(s2, np.ndarray):
        if s2.ndim > 0:
            res = [
                optimal_w(
                    model,
                    theta,
                    alpha=alpha,
                    s2=_s2,
                    xtol=xtol,
                    rtol=rtol,
                    maxiter=maxiter,
                    return_aux=return_aux,
                )
                for _s2 in s2
            ]
            if return_aux:
                ws, auxs = zip(*res)
                return np.array(list(ws)), list(auxs)
            return np.array(res)

    if isinstance(model, HorseshoeGaussianModel) and s2 != 1.0:
        return optimal_w(
            model,
            theta / np.sqrt(s2),
            alpha=alpha,
            s2=1.0,
            xtol=xtol,
            rtol=rtol,
            maxiter=maxiter,
            return_aux=return_aux,
        )

    if isinstance(model, GaussianGaussianModel):
        t2 = model.t2
        z = 2 * np.sqrt(s2) * theta / t2

        def root(w):
            return stats.norm.ppf(alpha * w) - stats.norm.ppf(alpha * (1 - w)) - z
    else:

        def root(w):
            return model.root(w, theta, alpha, s2=s2)

    w = optimize.root_scalar(
        root,
        bracket=(0.0, 1.0),
        method="brentq",
        xtol=xtol,
        rtol=rtol,
        maxiter=maxiter,
    )

    if return_aux:
        return w.root, w
    return w.root
