import math

import numpy as np
import scipy.stats as stats
from utils import  scipy_solve_linear_program, solve_min_dot_simplex


def YKbaseline_rescaleRes(
    cal_preds,
    cal_y,
    cal_sigma,
    test_pred,
    test_sigma,
    test_y,
    alpha,
    # M by N, N, M by N, single, single, single,
):
    M, n = cal_preds.shape  # M models, n calibration points
    k = math.ceil((n + 1) * (1 - alpha))

    # Vectorized computation of S for all models
    sorted_scores = np.sort(np.abs(cal_preds - cal_y) / (cal_sigma + 1e-8), axis=1)
    S = sorted_scores[:, k - 1]  # Get k-th smallest score for each model

    # Vectorized computation of length factors
    sum_sigma_cal = np.sum(cal_sigma, axis=1)
    leng_factor = (sum_sigma_cal + np.sum(test_sigma)) / (n + len(test_sigma))

    # Model selection
    mhat = np.argmin(S * leng_factor)

    if test_y.size == 1:
        # Final predictions
        sigma_hat = test_sigma[mhat]
        cover = np.abs(test_y - test_pred[mhat]) <= S[mhat] * sigma_hat
        length = 2 * S[mhat] * sigma_hat
    else:
        sigma_hat = test_sigma[mhat, :]  # M by test len
        cover = np.mean(np.abs(test_y - test_pred[mhat]) <= S[mhat] * sigma_hat)
        length = 2 * np.mean(2 * S[mhat] * sigma_hat)

    return cover, length.item(), [0, 1]


def YK_adj_rescaleRes(
    cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha
):
    M, n = cal_preds.shape
    til_alpha = (
        n * alpha / (n + 1)
        + 1 / (n + 1)
        - n * (1 / (3 * np.sqrt(n)) + np.sqrt(np.log(2 * M) / (2 * n))) / (n + 1)
    )

    k = math.ceil((n + 1) * (1 - til_alpha))

    if k > n:  # Return infinite interval
        return 1, np.inf, [-np.inf, np.inf]

    # Reuse baseline implementation with adjusted alpha
    sorted_scores = np.sort(np.abs(cal_preds - cal_y) / (cal_sigma + 1e-8), axis=1)
    S = sorted_scores[:, k - 1]
    sum_sigma_cal = np.sum(cal_sigma, axis=1)
    leng_factor = (sum_sigma_cal + np.sum(test_sigma)) / (n + len(test_sigma))
    mhat = np.argmin(S * leng_factor)

    sigma_hat = test_sigma[mhat]
    cover = np.abs(test_y - test_pred[mhat]) <= S[mhat] * sigma_hat
    length = 2 * S[mhat] * sigma_hat
    if test_y.size == 1:
        cover = np.mean(cover)
        length = np.mean(length)

    return cover, length.item(), [0, 1]


def YKsplit_rescaleRes(
    cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha, split_portion=0.5
):
    """Split-version refactored to use precomputed values and internal splitting"""
    M, n = cal_preds.shape  # M models, total n calibration points

    # Split data into selection and calibration portions
    n1 = math.ceil(n * split_portion)
    n2 = n - n1

    # Split calibration data
    preds_sel = cal_preds[:, :n1]  # (M, n1)
    sigma_sel = cal_sigma[:, :n1]
    y_sel = cal_y[:n1]

    preds_cal = cal_preds[:, n1:]  # (M, n2)
    sigma_cal = cal_sigma[:, n1:]
    y_cal = cal_y[n1:]

    # First stage: Model selection
    k1 = math.ceil((n1 + 1) * (1 - alpha))
    sorted_scores_sel = np.sort(np.abs(preds_sel - y_sel) / (sigma_sel + 1e-8), axis=1)
    S = sorted_scores_sel[:, k1 - 1] * np.mean(sigma_sel, axis=1)
    mhat = np.argmin(S)

    # Second stage: Calibration
    k2 = math.ceil((n2 + 1) * (1 - alpha))

    # Get scores for selected model
    scores_cal = np.abs(preds_cal[mhat] - y_cal) / sigma_cal[mhat]
    sorted_scores_cal = np.sort(scores_cal)
    Res = sorted_scores_cal[k2 - 1]

    # Final prediction
    sigma_hat = test_sigma[mhat]
    mu_hat = test_pred[mhat]

    cover = np.abs(test_y - mu_hat) <= Res * sigma_hat
    length = 2 * Res * sigma_hat
    # interval = [mu_hat - Res * sigma_hat, mu_hat + Res * sigma_hat]

    if test_y.size == 1:
        cover = np.mean(cover)
        length = np.mean(length)

    return cover, length.item(), [0, 1]


def ModSel_rescaleRes(
    cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha
):

    # print(test_pred, test_sigma, test_y,)
    # print(cal_preds.shape, cal_y.shape, cal_sigma.shape, test_pred.shape, test_sigma.shape, test_y.shape)
    if test_y.size == 1:
        return single_ModSel_rescaleRes(
            cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha
        )
    else:
        c_list, l_list = [], []
        for i in len(test_y):
            cover, length, Interv, connect, calM = single_ModSel_rescaleRes(
                cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha
            )

            c_list.append(cover)
            l_list.append(length)
        return np.mean(c_list), np.mean(l_list), [0, 1]


def single_ModSel_rescaleRes(
    cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha
):
    """Original logic preserved, only input parameters changed"""
    M, n = cal_preds.shape  # M models, n calibration points
    k = math.ceil((n + 1) * (1 - alpha))

    S = np.zeros((M, 2))
    leng_factor = np.zeros(M)

    # Main logic identical except data access
    for m in range(M):
        # Use precomputed values instead of model predictions
        tmpS = np.sort(np.abs(cal_y - cal_preds[m]) / (cal_sigma[m] + 1e-8))
        # print(tmpS.shape)
        leng_factor[m] = (np.sum((cal_sigma[m] + 1e-8)) + test_sigma[m]) / (n + 1) #TODO
        S[m, 0] = tmpS[k - 2] if k >= 2 else 0  # Handle k=1 edge case
        S[m, 1] = tmpS[k - 1]

    u = np.min(S[:, 1] * leng_factor)
    find_mdl = np.where(S[:, 0] * leng_factor - u <= 0)[0]
    calM = len(find_mdl)

    # print("calM", calM)

    # if calM == 0:  # Preserve original edge case handling
    #     return 0, 0.0, [[], []], 0, 0
    # Get filtered model predictions from precomputed test values
    leng_factor_filtered = leng_factor[find_mdl]
    mu_test = test_pred[find_mdl]
    sigma_filtered = test_sigma[find_mdl]

    slope = leng_factor_filtered / sigma_filtered

    # Original coverage check preserved
    cover = 0 if np.min(np.abs(mu_test - test_y) * slope) > u else 1

    quantl = u / slope
    Lefts = mu_test - quantl
    Rights = mu_test + quantl
    sort_ind = np.argsort(Lefts)
    Lefts = Lefts[sort_ind]
    Rights = Rights[sort_ind]
    # print("Lefts, Rights:", Lefts, Rights)
    loc = 0
    length = 0
    connect = 1
    left = [Lefts[0]]
    right = []
    # loc_count = 0
    while loc < len(Lefts):
        # loc_count+=1
        # if loc_count>1000:
        #     print("BREAKS")
        #     print("Cover", cover)
        #     print("Test Sigma", test_sigma)
        #     print("BREAKS")
        #     break

        find_ind = np.where(Lefts - Rights[loc] > 0)[0]
        if len(find_ind) == 0:
            length += np.max(Rights) - Lefts[loc]
            right.append(np.max(Rights))
            loc = len(Lefts)
        else:
            if find_ind[0] - 1 == loc:
                connect = 0
                right.append(Rights[loc])
                length += Rights[loc] - Lefts[loc]
                loc = find_ind[0]
                left.append(Lefts[loc])
            else:
                length += Lefts[find_ind[0] - 1] - Lefts[loc]
                loc = find_ind[0] - 1

    Interv = [left, right]

    return cover, length, Interv, connect, calM


def stable_conformal(
    cal_preds,
    cal_y,
    cal_sigma,
    test_pred,
    test_sigma,
    test_y,
    alpha,
    eta,
    b_prior,  # stable selection stuff
):

    # (cal_preds, cal_y, cal_sigma, test_pred, test_sigma, test_y, alpha):
    adjusted_alpha = alpha / np.exp(eta)
    M, n = cal_preds.shape
    k = math.ceil((n + 1) * (1 - adjusted_alpha))
    sorted_scores = np.sort(np.abs(cal_preds - cal_y) / (cal_sigma + 1e-8), axis=1)
    S = sorted_scores[:, k - 1]  # Get k-th smallest score for each model

    num_test = test_y.size
    if num_test == 1:
        test_sigma = test_sigma[:, None]
        test_pred = test_pred[:, None]

    coverage_list = []
    length_list = []
    for i in range(num_test):
        sizes = test_sigma[:, i] * S

        # print("Size:", sizes)
        # print("Prior:", b_prior)
        # print('Prior*Sizes', np.sum(sizes*b_prior))
        selection_dist = solve_min_dot_simplex(sizes, b_prior, eta)
        coverage_vec = np.abs(test_pred[:, i] - test_y) <= sizes

        exp_coverage = np.sum(selection_dist * coverage_vec)
        # print('Coverage:  ',exp_coverage.shape, exp_coverage/,coverage_vec.shape, coverage_vec)
        exp_length = 2 * np.sum(selection_dist * sizes)
        coverage_list.append(exp_coverage)
        length_list.append(exp_length)

    return np.sum(coverage_list) / num_test, np.sum(length_list) / num_test, [0, 1]


def adaptive_stable_conformal(
    cal_preds,
    cal_y,
    cal_sigma,
    test_pred,
    test_sigma,
    test_y,
    alpha_prime,
    alpha,
    b_prior=None,  # stable selection stuff
):
    # print("Alpha prime ", alpha_prime, "  Alpha", alpha)
    M, n = cal_preds.shape
    if b_prior is None:
        b_prior = np.ones(M) / M
    k = math.ceil((n + 1) * (1 - alpha_prime))
    sorted_scores = np.sort(np.abs(cal_preds - cal_y) / (cal_sigma + 1e-8), axis=1)
    S = sorted_scores[:, k - 1]  # Get k-th smallest score for each model
    # print("Cal quantiles, ", S)
    num_test = test_y.size

    if num_test == 1:
        test_sigma = test_sigma[:, None]
        test_pred = test_pred[:, None]

    coverage_list = []
    length_list = []
    for i in range(num_test):

        sizes = test_sigma[:, i] * S
        # # print("Size:", sizes)
        # print("test_sigma:  ",test_sigma.shape, test_sigma )
        # print("sizes ",sizes.shape, sizes)

        # print("Prior:", b_prior)
        # print('2*Prior*Sizes', 2*np.sum(sizes.squeeze()*b_prior))
        selection_dist = scipy_solve_linear_program(
            b_prior, 2 * sizes, alpha_prime, alpha
        )
        # print('selection_dist', selection_dist)
        # print('2*selection_dist*Sizes', 2*np.sum(selection_dist*sizes))

        # sizes = test_sigma[:, i] * S
        # selection_dist = scipy_solve_linear_program(b_prior, sizes, alpha_prime, alpha)
        coverage_vec = np.abs(test_pred[:, i] - test_y) <= sizes
        exp_coverage = np.sum(selection_dist * coverage_vec)
        exp_length = 2 * np.sum(selection_dist * sizes)
        coverage_list.append(exp_coverage)
        length_list.append(exp_length)

    return np.sum(coverage_list) / num_test, np.sum(length_list) / num_test, [0, 1]


def calibrate_after_selection_resampling(
    cal_preds,
    cal_y,
    cal_sigma,
    test_pred,
    test_sigma,
    test_y,
    alpha,
    b_prior,
    alpha_pre_selection,
    alpha_post_selection,
    N_resamples=1,
    preliminary_gamma=0.9,
    aux_split_ratio=0.5,
):
    M, n_cal_total = cal_preds.shape

    # --- Input Validation and Batch Detection ---
    single_test_point = (test_y.ndim == 0) or (len(test_y) == 1)
    if single_test_point:
        test_pred = test_pred.reshape(-1, 1)
        test_sigma = test_sigma.reshape(-1, 1)
        test_y = np.array([test_y]).flatten()
    n_test = test_y.shape[0]

    n_aux = math.floor(n_cal_total * aux_split_ratio)
    n_initial_cal = n_cal_total - n_aux
    initial_cal_preds, aux_preds = np.split(cal_preds, [n_initial_cal], axis=1)
    initial_cal_sigma, aux_sigma = np.split(cal_sigma, [n_initial_cal], axis=1)
    initial_cal_y, aux_y = np.split(cal_y, [n_initial_cal])
    k_prelim = math.ceil(n_initial_cal * preliminary_gamma)
    k_prelim = max(1, min(k_prelim, n_initial_cal))
    # --- Step 2: Calculate minimum confidence levels   aux data ---
    initial_scores = np.abs(initial_cal_preds - initial_cal_y[None, :]) / (
        initial_cal_sigma + 1e-8
    )
    sorted_initial_scores = np.sort(initial_scores, axis=1)
    prelim_scores = sorted_initial_scores[:, k_prelim - 1]
    aux_scores = np.abs(aux_preds - aux_y[None, :]) / (aux_sigma + 1e-8)
    sorted_aux_scores = np.sort(aux_scores, axis=1)

    # Generate rank/quantile matrix.
    rank_k_all_models = np.zeros((M, n_aux))
    for k in range(M):  # Iterate through models
        ranks_k = np.sum(initial_scores[k, None, :] <= aux_scores[k, :, None], axis=1)
        rank_k_all_models[k, :] = ranks_k

    # resampling with ppolling
    effective_rank_pool = []
    for i in range(n_aux):  # Loop over internal auxiliary points ONLY
        xi = prelim_scores * aux_sigma[:, i]
        p_xi = scipy_solve_linear_program(
            b_prior, xi, alpha_pre_selection, alpha_post_selection
        )
        s_i_all_samples = np.random.choice(M, size=N_resamples, p=p_xi)
        # E_i_all_samples = gamma_k_i_aux[s_i_all_samples, i]
        effective_rank_pool.extend(rank_k_all_models[s_i_all_samples, i])
        # effective_scores_pool.extend(E_i_all_samples)

    # --- Step 4: Compute the final calibrated quantile q_hat_alpha ---
    rank_rep = np.sort(np.array(effective_rank_pool))
    M_pool = len(rank_rep)
    k_final_rank = math.ceil((1 - alpha) * (M_pool + N_resamples))
    final_rank_index = rank_rep[k_final_rank - 1].astype(int)
    # Ensure final_rank_index is within bounds of sorted_initial_scores
    final_rank_index = np.clip(final_rank_index, 0, sorted_initial_scores.shape[1] - 1)
    # print("final_rank_index", final_rank_index)
    all_coverage = np.zeros(n_test)
    all_lengths = np.zeros(n_test)
    aux_cal_scores = sorted_initial_scores[:, final_rank_index]
    # print("aux_cal_scores", aux_cal_scores)
    for i in range(n_test):
        test_preds_point = test_pred[:, i]
        test_sigmas_point = test_sigma[:, i]
        test_y_point = test_y[i]
        xi_test = 2 * prelim_scores * test_sigmas_point
        p_xi_test = scipy_solve_linear_program(
            b_prior, xi_test, alpha_pre_selection, alpha_post_selection
        )
        lengths_test = test_sigmas_point * aux_cal_scores
        all_lengths[i] = 2 * np.sum(p_xi_test * lengths_test)
        errors = np.abs(test_preds_point - test_y_point)
        all_coverage[i] = np.sum((errors <= lengths_test) * p_xi_test)
    # Return average coverage and length
    return float(np.mean(all_coverage)), float(np.mean(all_lengths))


# def calibrate_after_selection_resampling(
#     cal_preds,
#     cal_y,
#     cal_sigma,
#     test_pred,
#     test_sigma,
#     test_y,
#     alpha,
#     b_prior,
#     alpha_prime,
#     alpha_target_for_selection,
#     N_resamples=10,
#     preliminary_gamma=0.9,
#     aux_split_ratio=0.5,
# ):
#     M, n_cal_total = cal_preds.shape

#     # --- Input Validation and Batch Detection ---
#     single_test_point = (test_y.ndim == 0) or (len(test_y) == 1)
#     if single_test_point:
#         test_pred = test_pred.reshape(-1, 1)
#         test_sigma = test_sigma.reshape(-1, 1)
#         test_y = np.array([test_y]).flatten()
#     n_test = test_y.shape[0]

#     n_aux = math.floor(n_cal_total * aux_split_ratio)
#     n_initial_cal = n_cal_total - n_aux
#     initial_cal_preds, aux_preds = np.split(cal_preds, [n_initial_cal], axis=1)
#     initial_cal_sigma, aux_sigma = np.split(cal_sigma, [n_initial_cal], axis=1)
#     initial_cal_y, aux_y = np.split(cal_y, [n_initial_cal])

#     # --- Step 1: Calculate initial quantiles (q_gamma_tilde) ---

#     initial_scores = np.abs(initial_cal_preds - initial_cal_y[None, :]) / (
#         initial_cal_sigma + 1e-8
#     )
#     sorted_initial_scores = np.sort(initial_scores, axis=1)
#     k = math.ceil((n_initial_cal + 1) * preliminary_gamma)
#     scores_init_all_models = sorted_initial_scores[:, k - 1]
#     print("scores_init_all_models", scores_init_all_models)

#     # --- Step 2: Calculate minimum confidence levels (gamma_k,i) for aux data ---
#     aux_scores = np.abs(aux_preds - aux_y[None, :]) / (aux_sigma + 1e-8)

#     effective_scores_pool = []
#     for i in range(n_aux):  # Loop over internal auxiliary points ONLY
#         xi = scores_init_all_models * aux_sigma[:, i]
#         p_xi = scipy_solve_linear_program(
#             b_prior, xi, alpha_prime, alpha_target_for_selection
#         )
#         s_i_all_samples = np.random.choice(M, size=N_resamples, p=p_xi)
#         effective_scores_pool.extend(aux_scores[s_i_all_samples, i])

#     # --- Step 4: Compute the final calibrated quantile q_hat_alpha ---
#     Gamma_rep = np.array(effective_scores_pool)
#     M_pool = len(Gamma_rep)
#     k_final_rank = math.ceil((1 - alpha) * (M_pool + N_resamples))
#     scores_after_selection = np.sort(Gamma_rep)[k_final_rank - 1]
#     print("scores_after_selection", scores_after_selection)
#     # --- Step 5: Apply to Test Point(s) ---
#     # (Logic unchanged, but uses q_hat_alpha derived from new gamma_k_i_aux)
#     all_coverage = np.zeros(n_test)
#     all_lengths = np.zeros(n_test)

#     # Determine the rank index corresponding to the final quantile level q_hat_alpha
#     # This is relative to the *initial calibration set* used for q_gamma_tilde
#     for i in range(n_test):
#         test_preds_point = test_pred[:, i]
#         test_sigmas_point = test_sigma[:, i]
#         test_y_point = test_y[i]
#         xi_test = 2 * scores_init_all_models * test_sigmas_point

#         # AdaMinSE Selection for the selection mixture

#         print("xi_test. selection length", xi_test)
#         p_xi_test = scipy_solve_linear_program(
#             b_prior, xi_test, alpha_prime, alpha_target_for_selection
#         )
#         print("p_xi_test", p_xi_test)
#         lengths_test = 2 * test_sigmas_point * scores_after_selection
#         all_lengths[i] = np.sum(p_xi_test * lengths_test)
#         print("actual Lengths", lengths_test)
#         errors = 2 * np.abs(test_preds_point - test_y_point)
#         all_coverage[i] = np.sum((errors <= lengths_test) * p_xi_test)
#     # Return average coverage and length
#     return float(np.mean(all_coverage)), float(np.mean(all_lengths))


# def calibrate_after_selection_resampling_no_split(
#     cal_preds,
#     cal_y,
#     cal_sigma,
#     test_pred,
#     test_sigma,
#     test_y,
#     alpha,
#     b_prior,
#     alpha_prime,
#     alpha_target_for_selection,
#     N_resamples=10,
#     preliminary_gamma=0.9,
#     aux_split_ratio=0.5,
# ):
#     M, n_cal_total = cal_preds.shape
#     k_prelim = math.ceil(n_cal_total * preliminary_gamma)
#     k_prelim = max(1, min(k_prelim, n_cal_total))

#     # --- Input Validation and Batch Detection ---
#     single_test_point = (test_y.ndim == 0) or (len(test_y) == 1)
#     if single_test_point:
#         test_pred = test_pred.reshape(-1, 1)
#         test_sigma = test_sigma.reshape(-1, 1)
#         test_y = np.array([test_y]).flatten()
#     n_test = test_y.shape[0]
#     # --- Step 2: Calculate minimum confidence levels   aux data ---
#     all_scores = np.abs(cal_preds - cal_y[None, :]) / (cal_sigma + 1e-8)
#     sorted_initial_scores = np.sort(all_scores, axis=1)
#     prelim_scores = all_scores[:, k_prelim - 1]

#     sorted_all_scores = np.sort(all_scores, axis=1)

#     # Generate rank/quantile matrix.
#     rank_k_all_models = np.zeros((M, n_cal_total))
#     gamma_k_i_aux = np.zeros((M, n_cal_total))  # Initialize array for results
#     for k in range(M):  # Iterate through models
#         # Compare each aux score for model k against all initial scores for model k
#         ranks_k = np.sum(all_scores[k, None, :] <= all_scores[k, :, None], axis=1)
#         # This should be 1 by N
#         # Normalize rank to get empirical p-value/quantile level
#         rank_k_all_models[k, :] = ranks_k
#         gamma_k_i_aux[k, :] = (ranks_k + 1) / (n_cal_total + 1)

#     # resampling with ppolling
#     effective_scores_pool = []
#     effective_rank_pool = []
#     for i in range(n_cal_total):  # Loop over internal auxiliary points ONLY
#         xi = 2 * prelim_scores * cal_sigma[:, i]
#         p_xi = scipy_solve_linear_program(
#             b_prior, xi, alpha_prime, alpha_target_for_selection
#         )
#         s_i_all_samples = np.random.choice(M, size=N_resamples, p=p_xi)
#         E_i_all_samples = gamma_k_i_aux[s_i_all_samples, i]
#         effective_rank_pool.extend(rank_k_all_models[s_i_all_samples, i])
#         effective_scores_pool.extend(E_i_all_samples)

#     # --- Step 4: Compute the final calibrated quantile q_hat_alpha ---
#     rank_rep = np.sort(np.array(effective_rank_pool))
#     Gamma_rep = np.array(effective_scores_pool)
#     M_pool = len(Gamma_rep)
#     k_final_rank = math.ceil((1 - alpha) * (M_pool + N_resamples))
#     final_rank_index = rank_rep[k_final_rank - 1].astype(int)
#     final_rank_index = rank_rep[k_final_rank - 1].astype(int)
#     print("final_rank_index", final_rank_index)
#     print("final_rank_index", final_rank_index)
#     print("final_rank_index", final_rank_index)
#     print("final_rank_index", final_rank_index)

#     all_coverage = np.zeros(n_test)
#     all_lengths = np.zeros(n_test)

#     aux_cal_scores = sorted_initial_scores[:, final_rank_index]
#     print("aux_cal_scores", aux_cal_scores)
#     for i in range(n_test):
#         test_preds_point = test_pred[:, i]
#         test_sigmas_point = test_sigma[:, i]
#         test_y_point = test_y[i]
#         xi_test = 2 * prelim_scores * test_sigmas_point

#         print("xi_test. selection length", xi_test)
#         # AdaMinSE Selection for the selection mixture
#         # print("xi_test. selection length", xi_test)
#         p_xi_test = scipy_solve_linear_program(
#             b_prior, xi_test, alpha_prime, alpha_target_for_selection
#         )
#         print("p_xi_test", p_xi_test)
#         lengths_test = test_sigmas_point * aux_cal_scores
#         all_lengths[i] = 2 * np.sum(p_xi_test * lengths_test)
#         # print("actual Lengths", lengths_test)

#         print(
#             "alpha_prime: {0}".format(alpha_prime),
#             "prelim_gammag: {0} Prelim Match:".format(preliminary_gamma),
#             np.argmin(xi_test) == np.argmin(lengths_test),
#         )
#         errors = np.abs(test_preds_point - test_y_point)
#         all_coverage[i] = np.sum((errors <= lengths_test) * p_xi_test)
#     # Return average coverage and length
#     return float(np.mean(all_coverage)), float(np.mean(all_lengths))


# def calibrate_after_selection_resampling(
#     cal_preds,
#     cal_y,
#     cal_sigma,
#     test_pred,
#     test_sigma,
#     test_y,
#     alpha,
#     b_prior,
#     alpha_prime,
#     alpha_target_for_selection,
#     N_resamples=10,
#     preliminary_gamma=0.9,
#     aux_split_ratio=0.5,
# ):
#     M, n_cal_total = cal_preds.shape

#     # --- Input Validation and Batch Detection ---
#     single_test_point = (test_y.ndim == 0) or (len(test_y) == 1)
#     if single_test_point:
#         test_pred = test_pred.reshape(-1, 1)
#         test_sigma = test_sigma.reshape(-1, 1)
#         test_y = np.array([test_y]).flatten()
#     n_test = test_y.shape[0]
#     all_scores = np.abs(cal_preds - cal_y[None, :]) / (
#         cal_sigma + 1e-8)


#     k_prelim = math.ceil((n_cal_total+1) * preliminary_gamma)
#     sorted_all_scores = np.sort(all_scores, axis=1)
#     prelim_scores = sorted_all_scores[:, k_prelim - 1]
#     #Generate rank/quantile matrix.
#     rank_k_all_models = np.zeros((M, n_cal_total))
#     for k in range(M):
#         ranks_k =  np.argsort(np.argsort(all_scores[k, :]))
#         rank_k_all_models[k, :] = ranks_k
#     # print("rank_k_all_models", np.bincount(rank_k_all_models.astype(int).flatten()))
#     #TODO check this too.....
#     effective_rank_pool = []
#     # print(np.bincount(rank_k_all_models.flatten().astype(int)))
#     for i in range(n_cal_total): # Loop over internal auxiliary points ONLY
#         xi = 2*prelim_scores * all_scores[:, i]
#         p_xi = scipy_solve_linear_program(b_prior, xi, alpha_prime, alpha_target_for_selection)
#         # s_i_all_samples = np.random.choice(M, size=N_resamples, p=p_xi)
#         s_i_all_samples = np.random.choice(M, size=N_resamples)
#         if i<10:
#             print("selected models", s_i_all_samples)
#             print("scores selected samples", all_scores[s_i_all_samples,i])
#             print("rank_k_all_models", rank_k_all_models[s_i_all_samples,i])
#             print("scores corresponding to ranks", sorted_all_scores[s_i_all_samples,rank_k_all_models[s_i_all_samples,i].astype(int)])
#             print("cover inside loop",
#                 all_scores[s_i_all_samples,i]<=sorted_all_scores[s_i_all_samples,rank_k_all_models[s_i_all_samples,i].astype(int)])
#         effective_rank_pool.extend(rank_k_all_models[s_i_all_samples,i])

#     # --- Step 4: Compute the final calibrated quantile q_hat_alpha ---
#     rank_rep = np.sort(np.array(effective_rank_pool))
#     M_pool = len(effective_rank_pool)
#     k_final_rank = math.ceil((1 - alpha) * (M_pool+N_resamples))

#     final_rank_index =(rank_rep[k_final_rank - 1].astype(int))+2
#     print("final_rank_index", final_rank_index)

#     aux_cal_scores = sorted_all_scores[:,final_rank_index]

#     all_coverage = np.zeros(n_test)
#     all_lengths = np.zeros(n_test)


#     print("aux_cal_scores", aux_cal_scores)
#     for i in range(n_test):
#         test_preds_point = test_pred[:, i]
#         test_sigmas_point = test_sigma[:, i]
#         test_y_point = test_y[i]
#         xi_test = 2*prelim_scores * test_sigmas_point


#         print("xi_test. selection length", xi_test)
#         # AdaMinSE Selection for the selection mixture
#         # print("xi_test. selection length", xi_test)
#         p_xi_test = scipy_solve_linear_program(
#             b_prior, xi_test, alpha_prime, alpha_target_for_selection
#         )
#         print("p_xi_test", p_xi_test)
#         lengths_test = test_sigmas_point*aux_cal_scores
#         all_lengths[i] = 2*np.sum(p_xi_test*lengths_test)
#         print("actual Lengths", lengths_test)

#         print("alpha_prime: {0}".format(alpha_prime), "prelim_gammag: {0} Prelim Match:".format(preliminary_gamma),
#               np.argmin(xi_test)==np.argmin(lengths_test))
#         errors = np.abs(test_preds_point - test_y_point)
#         all_coverage[i] = np.sum((errors <= lengths_test) * p_xi_test)
#     # Return average coverage and length
#     return float(np.mean(all_coverage)), float(np.mean(all_lengths))
