from supplementary.util import covarage_set, ppipp_interval_mean
import numpy as np
from supplementary.fab_ppi.ppi.quantile import ppi_fab_quantile_ci

BIG_M = 1e99


def conformal_interval_quantile(
    *, psi, scores_cal, err, Yhat_test, thetas, ci_constructor, alpha=0.05, M=1
):

    threshold = np.quantile(np.append(scores_cal, BIG_M), 1 - err)
    C_inf_data = Yhat_test - threshold
    C_sup_data = Yhat_test + threshold
    C_data = np.stack([C_inf_data, C_sup_data], axis=1)
    psi_C = psi(C_data[:, :, None], thetas)
    sup_psi = np.max(psi_C, axis=1)
    inf_psi = np.min(psi_C, axis=1)

    _, _, upper_sup = ci_constructor(sup_psi, alpha=alpha, M=M)
    lower_inf, _, _ = ci_constructor(inf_psi, alpha=alpha, M=M)
    lower_bound = lower_inf - M * err
    upper_bound = upper_sup + M * err

    return covarage_set(thetas, lower_bound=lower_bound, upper_bound=upper_bound)


def power_interval_quantile(
    *, psi, Y_cal, Yhat_cal, Yhat_test, thetas, ci_constructor, alpha=0.05, M=1
):

    delta = alpha / 2
    g_data = psi(Yhat_test[:, None], thetas)
    ret_data = psi(Y_cal[:, None], thetas) - psi(Yhat_cal[:, None], thetas)

    ret_lower, _, ret_upper = ci_constructor(ret_data, alpha=delta, M=M)
    g_lower, _, g_upper = ci_constructor(g_data, alpha=alpha - delta, M=M)

    lower_bound = ret_lower + g_lower
    upper_bound = ret_upper + g_upper

    return covarage_set(thetas, lower_bound=lower_bound, upper_bound=upper_bound)


def ppipp_interval_quantile(
    *,
    psi,
    Y_cal,
    Yhat_cal,
    Yhat_test,
    thetas,
    ci_constructor,
    alpha=0.05,
    M=1,
    split: bool,
):

    psi_data_cal = psi(Y_cal[:, None], thetas)
    psi_data_pred_cal = psi(Yhat_cal[:, None], thetas)
    psi_data_pred_test = psi(Yhat_test[:, None], thetas)
    lower_bounds, upper_bounds = np.empty(len(thetas)), np.empty(len(thetas))
    for i in range(len(thetas)):
        lower_bounds[i], *_, upper_bounds[i] = ppipp_interval_mean(
            Y_cal=psi_data_cal[:, i],
            Yhat_cal=psi_data_pred_cal[:, i],
            Yhat_test=psi_data_pred_test[:, i],
            ci_constructor=ci_constructor,
            alpha=alpha,
            M=M,
            split=split,
        )

    return covarage_set(thetas, lower_bound=lower_bounds, upper_bound=upper_bounds)


def fab_interval_quantile(
    *,
    q,
    Y_cal,
    Yhat_cal,
    Yhat_test,
    thetas,
    ci_constructor,
    alpha=0.05,
    M=1,
):
    lower, upper = ppi_fab_quantile_ci(
        Y_cal,
        Yhat_cal,
        Yhat_test,
        q=q,
        prior="horseshoe",
        alpha=alpha,
        delta=None,
        point_estimate=True,
        exact_grid=False,
        return_aux=False,
    )
    return np.squeeze(lower), np.squeeze(upper)


def classical_interval_quantile(*, psi, Y_cal, thetas, ci_constructor, alpha=0.05, M=1):

    psi_data = psi(Y_cal[:, None], thetas)
    lower_bound, _, upper_bound = ci_constructor(psi_data, alpha=alpha, M=M)

    return covarage_set(thetas, lower_bound=lower_bound, upper_bound=upper_bound)
