import numpy as np
from sklearn.model_selection import train_test_split
from supplementary.fab_ppi.ppi.mean import ppi_fab_mean_ci

BIG_M = 1e99


def conformal_interval_mean(
    *, scores_cal, err, Yhat_test, ci_constructor, alpha=0.05, M=1
):
    threshold = np.quantile(np.append(scores_cal, BIG_M), 1 - err)

    ind0 = Yhat_test <= threshold
    ind1 = (1 - Yhat_test) <= threshold
    empty = (~ind0) & (~ind1)
    inf_C = 1 - np.where(empty, (Yhat_test >= 0.5) == 0, ind0)
    sup_C = np.where(empty, (Yhat_test >= 0.5) == 1, ind1).astype(int)

    _, _, upper_sup = ci_constructor(sup_C, alpha=alpha, M=M)
    lower_inf, _, _ = ci_constructor(inf_C, alpha=alpha, M=M)
    lower_bound = lower_inf - M * err
    upper_bound = upper_sup + M * err
    return [lower_bound, upper_bound]


def power_interval_mean(*, Y_cal, Yhat_cal, Yhat_test, ci_constructor, alpha=0.05, M=1):
    delta = alpha / 2
    ret = Yhat_cal - Y_cal
    r_inf, _, r_sup = ci_constructor(ret, alpha=delta, M=M)
    g_inf, _, g_sup = ci_constructor(Yhat_test, alpha=alpha - delta, M=M)
    lower_bound = g_inf - r_sup
    upper_bound = g_sup - r_inf
    return [lower_bound, upper_bound]


def ppipp_interval_mean(
    *, Y_cal, Yhat_cal, Yhat_test, ci_constructor, alpha=0.05, M=1, split: bool
):
    """
    PPI++.

    split=False corresponds to asymptotic PPI++. split=True corresponds to nonasymptotic PPI++.
    """

    # Split
    if split:
        Y_cal_1, Y_cal_2, Yhat_cal_1, Yhat_cal_2 = train_test_split(
            Y_cal, Yhat_cal, test_size=0.85, random_state=0
        )
        Yhat_test_1, Yhat_test_2 = train_test_split(
            Yhat_test, test_size=0.85, random_state=0
        )
    else:
        Y_cal_1 = Y_cal
        Y_cal_2 = Y_cal
        Yhat_cal_1 = Yhat_cal
        Yhat_cal_2 = Yhat_cal
        Yhat_test_1 = Yhat_test
        Yhat_test_2 = Yhat_test
    n = len(Y_cal_2)
    N = len(Yhat_test_2)

    # Figure out $\lambda$
    theta_pp = np.mean(Y_cal_1 - Yhat_cal_1) + np.mean(Yhat_test_1)
    # \hat{H}_\theta = \nabla^2 L_n(\theta) = \nabla^2 1/n \sum_(i=1)^n 1/2 (\theta - Y_i)^2
    # = 1/n \sum_(i=1)^n \nabla^2 1/2 (\theta - Y_i)^2
    # = 1/n \sum_(i=1)^n \nabla (\theta - Y_i)
    # = 1/n \sum_(i=1)^n 1 = 1
    cov = lambda u, v: np.mean((u - np.mean(u)) * (v - np.mean(v)))
    lambda_ = cov(theta_pp - Y_cal_1, theta_pp - Yhat_cal_1) / (
        (1 + n / N) * np.var(Yhat_cal_1)
    )
    lambda_ = 1.0

    # Produce the CI
    # np.mean(Y_cal_2) + lambda_ * (np.mean(Yhat_test_2) - np.mean(Yhat_cal_2))
    # = np.mean(Y_cal_2) + lambda_ * np.mean(Yhat_test_2) - lambda_ * np.mean(Yhat_cal_2)
    # = np.mean(Y_cal_2 - lambda_ * Yhat_cal_2) + lambda_ * np.mean(Yhat_test_2)
    if split:
        delta = alpha / 2
        r_inf, _, r_sup = ci_constructor(
            Y_cal_2 - lambda_ * Yhat_cal_2, alpha=delta, M=M
        )
        e_inf, _, e_sup = ci_constructor(lambda_ * Yhat_test_2, alpha=delta, M=M)
        return [r_inf + e_inf, r_sup + e_sup]
    else:
        point = np.mean(Y_cal_2 - lambda_ * Yhat_cal_2) + np.mean(lambda_ * Yhat_test_2)
        sigma2 = np.var(Y_cal_2 - lambda_ * Yhat_cal_2) + n / N * np.var(
            lambda_ * Yhat_test_2
        )
        return [point - 2 * np.sqrt(sigma2 / n), point + 2 * np.sqrt(sigma2 / n)]


def fab_interval_mean(*, Y_cal, Yhat_cal, Yhat_test, ci_constructor, alpha=0.05, M=1):
    lower, upper = ppi_fab_mean_ci(
        Y_cal,
        Yhat_cal,
        Yhat_test,
        prior="horseshoe",
        alpha=alpha,
        delta=None,
        lam=None,
        point_estimate=False,
        return_aux=False,
    )
    return np.squeeze(lower), np.squeeze(upper)


def classical_interval_mean(*, Y_cal, ci_constructor, alpha=0.05, M=1):
    ci_inf, _, ci_sup = ci_constructor(Y_cal, alpha=alpha, M=M)
    return [ci_inf, ci_sup]
