import numpy as np
import coma_utils as cu
from scipy.optimize import linprog  # Import linprog
from tqdm import tqdm

LARGE_LAMBDA = 1e5


# --- LP Solver Function (Provided & Modified) ---
def scipy_solve_linear_program(b, lambdas, alpha_prime, alpha):

    # print('lambdas', lambdas)
    # print('b: ', b, "utily", np.sum(b*lambdas))
    """Solves the AdaMinSE LP using scipy.optimize.linprog. (Improved Handling)"""
    k = len(b)
    if len(lambdas) != k:
        raise ValueError("Length mismatch: lambdas and b")
    if not (0 <= alpha_prime <= 1) or not (0 <= alpha <= 1):
        return np.ones(k) / k if k > 0 else np.array([])
    if not np.all(np.isfinite(b)) or np.any(b < -1e-9):
        return np.ones(k) / k if k > 0 else np.array([])
    b_sum = np.sum(b)
    if abs(b_sum - 1.0) > 1e-6 or b_sum <= 0:
        if b_sum <= 1e-9:
            return np.ones(k) / k
        b = np.maximum(0, b) / b_sum
    else:
        b = np.maximum(0, b)
    lambdas_finite = np.nan_to_num(
        lambdas, nan=LARGE_LAMBDA, posinf=LARGE_LAMBDA, neginf=LARGE_LAMBDA
    )
    c_lp = np.concatenate([lambdas_finite, np.zeros(k + 2)])
    A_eq_lp = np.array([np.concatenate([np.ones(k), np.zeros(k + 2)])])
    b_eq_lp = np.array([1])
    A_ub_lp_list = []
    for i in range(k):
        row = np.zeros(2 * k + 2)
        row[i] = 1
        row[k + i] = -1
        row[2 * k + 1] = -b[i]
        A_ub_lp_list.append(row)
    row = np.zeros(2 * k + 2)
    row[k : 2 * k] = 1
    row[2 * k] = -1
    A_ub_lp_list.append(row)
    row = np.zeros(2 * k + 2)
    row[2 * k] = 1
    row[2 * k + 1] = alpha_prime
    A_ub_lp_list.append(row)
    A_ub_lp = np.array(A_ub_lp_list)
    b_ub_lp = np.concatenate([np.zeros(k + 1), np.array([alpha])])
    bounds_lp = [(0, None)] * (2 * k + 1) + [(1, None)]
    try:
        result = linprog(
            c_lp,
            A_ub=A_ub_lp,
            b_ub=b_ub_lp,
            A_eq=A_eq_lp,
            b_eq=b_eq_lp,
            bounds=bounds_lp,
            method="highs",
            options={"presolve": True, "time_limit": 0.5},
        )
    except ValueError as e:
        print(f"LP Setup Error: {e}")
        result = None
    if result and result.success:
        p_optimal = result.x[:k]
        p_optimal = np.maximum(0, p_optimal)
        p_sum = np.sum(p_optimal)
        p_optimal = p_optimal / p_sum if p_sum > 1e-9 else np.ones(k) / k
        # print('p_optimal',p_optimal, "utily", np.sum(p_optimal*lambdas))
        return p_optimal

    else:
        return b  # Fallback to prior b


# --- Run Standard COMA Aggregation ---
def run_coma_aggregation(
    results_list,
    y_test,
    n_pred,
    loss_matrix,
    raw_ind_coverage,
    raw_ind_lengths,
    eta_hedge,
):
    """
    Performs standard COMA aggregation using AdaHedge and Hedge weights.

    This function takes the results from individual ACI model runs and aggregates
    their prediction intervals using two standard COMA strategies:
    1.  Weights derived from the AdaHedge algorithm run on interval lengths.
    2.  Weights derived from the Hedge (Exponential Weights) algorithm run on interval lengths.

    For each strategy (AdaHedge/Hedge based weights), it calculates metrics based on
    two interpretations over the n_pred time steps:
    a)  **Majority Vote Set:** Constructs a prediction set where a point is included
        if the sum of weights of the base models covering that point exceeds 0.5.
        Metrics: `cov_maj` (coverage indicator 0/1), `loss_maj` (length of the set).
    b)  **Expected Value (Mixture):** Treats weights as a probability distribution over
        the base models. Calculates the expected coverage and expected length under
        this distribution.
        Metrics: `exp_cov` (sum w_i * cov_i), `exp_len` (sum w_i * len_i).

    Args:
        results_list (list): List of result dictionaries from individual model ACI runs.
                             Each dict should contain 'piAdapt' (n_pred x 2 array of intervals).
        y_test (np.ndarray): Array of true target values for the prediction period (length n_pred).
        n_pred (int): Number of prediction steps.
        loss_matrix (np.ndarray): (n_pred x num_models) matrix of interval lengths
                                  (used to calculate COMA weights). Inf indicates invalid intervals.
        raw_ind_coverage (np.ndarray): (n_pred x num_models) matrix of raw coverage indicators (0/1/NaN)
                                       from individual models. Used for `exp_cov`.
        raw_ind_lengths (np.ndarray): (n_pred x num_models) matrix of raw interval lengths
                                      from individual models. Inf indicates invalid. Used for `exp_len`.
        eta_hedge (float): Fixed learning rate for the Hedge algorithm.

    Returns:
        tuple: (coma_results, base_weights_ada, base_weights_hedge)
            coma_results (dict): Dictionary containing results for COMA strategies. Keys:
                'coma_ada': Results using AdaHedge weights ('w').
                'coma_hedge': Results using Hedge weights ('w').
                Each sub-dictionary contains metric arrays (length n_pred):
                    'cov_maj', 'loss_maj', 'exp_cov', 'exp_len'.
            base_weights_ada (np.ndarray): (n_pred x num_models) array of weights 'w' from AdaHedge.
            base_weights_hedge (np.ndarray): (n_pred x num_models) array of weights 'w' from Hedge.
    """
    # 1. Get Base COMA weights
    adahedge_res = cu.adahedge(loss_matrix[:n_pred])  # Assume loss_matrix is valid
    hedge_res = cu.hedge(loss_matrix[:n_pred], eta=eta_hedge)
    weights_ada_prior = adahedge_res["weights"]
    weights_hedge_prior = hedge_res["weights"]

    # 2. Initialize Storage
    coma_results = {
        "coma_ada": {
            "cov_maj": np.zeros(n_pred),
            "loss_maj": np.zeros(n_pred),
            "exp_cov": np.zeros(n_pred),
            "exp_len": np.zeros(n_pred),
        },
        "coma_hedge": {
            "cov_maj": np.zeros(n_pred),
            "loss_maj": np.zeros(n_pred),
            "exp_cov": np.zeros(n_pred),
            "exp_len": np.zeros(n_pred),
        },
    }

    # 3. Aggregation Loop (Simplified checks)
    print("   Performing step-by-step COMA aggregation...")
    for i in tqdm(range(n_pred), desc="COMA Aggregation"):
        intervals_t = np.array([res["piAdapt"][i, :] for res in results_list])
        b_ada_t = weights_ada_prior[i, :]
        b_hedge_t = weights_hedge_prior[i, :]
        cov_ind_t = raw_ind_coverage[i, :]
        len_ind_t = raw_ind_lengths[i, :]
        y_target = y_test[i]

        for strategy, prior_weights in [
            ("coma_ada", b_ada_t),
            ("coma_hedge", b_hedge_t),
        ]:
            res_dict = coma_results[strategy]
            pi_m = cu.majority_vote(intervals_t, prior_weights, rho=0.5)
            # pi_wm = cu.majority_vote(intervals_t, prior_weights, rho=np.random.uniform(0.5, 1.0)) # Removed Randomized
            res_dict["cov_maj"][i] = cu.covr_fun(pi_m, y_target)
            # res_dict['cov_wm'][i] = cu.covr_fun(pi_wm, y_target) # Removed Randomized
            res_dict["loss_maj"][i] = cu.loss_fun(pi_m)

            loss_majority_vote = res_dict["loss_maj"][i]
            if not np.isfinite(loss_majority_vote):
                print(f"\nDEBUG: COMA Inf Length at step {i} for strategy '{strategy}'")
                print(f"  Input Intervals (intervals_t):\n{intervals_t}")
                print(f"  Input Weights (prior_weights): {prior_weights}")
                print(f"  Output Majority Vote Set (pi_m): {pi_m}")
            # res_dict['loss_wm'][i] = cu.loss_fun(pi_wm) # Removed Randomized
            res_dict["exp_cov"][i] = np.sum(prior_weights * np.nan_to_num(cov_ind_t))
            res_dict["exp_len"][i] = np.sum(
                prior_weights
                * np.nan_to_num(len_ind_t, nan=np.inf, posinf=np.inf, neginf=np.inf)
            )  # Use nan_to_num for safety

    print("COMA Aggregation finished.")
    return coma_results, weights_ada_prior, weights_hedge_prior


def run_adacoma_aggregation(
    results_list,
    y_test,
    n_pred,
    raw_ind_coverage,
    raw_ind_lengths,
    base_weights_ada,
    base_weights_hedge,
    alpha_prime,
    alpha_final,
):
    """
     Performs AdaCOMA aggregation using base COMA weights as priors for AdaMinSE.

    Args:
         results_list (list): List of result dictionaries from individual model ACI runs.
                              Each dict should contain 'piAdapt' (n_pred x 2 array).
         y_test (np.ndarray): Array of true target values for the prediction period (length n_pred).
         n_pred (int): Number of prediction steps.
         raw_ind_coverage (np.ndarray): (n_pred x num_models) matrix of raw coverage indicators (0/1/NaN).
                                        Used for `exp_cov`.
         raw_ind_lengths (np.ndarray): (n_pred x num_models) matrix of raw interval lengths (lambdas).
                                       Inf indicates invalid. Used for LP and `exp_len`.
         base_weights_ada (np.ndarray): (n_pred x num_models) COMA weights from AdaHedge (prior b).
         base_weights_hedge (np.ndarray): (n_pred x num_models) COMA weights from Hedge (prior b).
         alpha_prime (float): Target miscoverage used for individual models (input `alpha'` to LP).
         alpha_final (float): Target miscoverage desired after selection (input `alpha` to LP).

     Returns:
         dict: adacoma_results dictionary containing results for AdaCOMA strategies. Keys:
             'adacoma_ada': Results using AdaHedge prior for LP.
             'adacoma_hedge': Results using Hedge prior for LP.
             Each strategy sub-dictionary contains:
                 'cov_maj': Coverage indicator for deterministic majority vote set using 'p'.
                 'loss_maj': Length of deterministic majority vote set using 'p'.
                 'exp_cov': Expected coverage under the selection distribution 'p'.
                 'exp_len': Expected length under the selection distribution 'p'.
                 'p_weights': (n_pred x num_models) array of selection weights 'p' from LP.
                 (Metric arrays are length n_pred).
             Returns None if aggregation cannot proceed.
    """
    # (Function code as provided previously)
    print("\nRunning AdaCOMA Aggregation ...")
    num_models = len(results_list)
    if num_models == 0 or n_pred <= 0:
        return None

    adacoma_results = {
        "adacoma_ada": {
            "cov_maj": np.zeros(n_pred),
            "loss_maj": np.zeros(n_pred),
            "exp_cov": np.zeros(n_pred),
            "exp_len": np.zeros(n_pred),
            "p_weights": np.zeros((n_pred, num_models)),
        },
        "adacoma_hedge": {
            "cov_maj": np.zeros(n_pred),
            "loss_maj": np.zeros(n_pred),
            "exp_cov": np.zeros(n_pred),
            "exp_len": np.zeros(n_pred),
            "p_weights": np.zeros((n_pred, num_models)),
        },
    }

    print("   Performing step-by-step AdaCOMA aggregation...")
    for i in tqdm(range(n_pred), desc="AdaCOMA Aggregation"):
        intervals_t = np.array([res["piAdapt"][i, :] for res in results_list])
        b_ada_t = base_weights_ada[i, :]
        b_hedge_t = base_weights_hedge[i, :]
        lambdas_t = raw_ind_lengths[i, :]
        lambdas_t_finite = np.nan_to_num(
            lambdas_t, nan=LARGE_LAMBDA, posinf=LARGE_LAMBDA, neginf=LARGE_LAMBDA
        )
        cov_ind_t = raw_ind_coverage[i, :]
        len_ind_t = raw_ind_lengths[i, :]
        y_target = y_test[i]

        for strategy, prior_weights, p_storage in [
            ("adacoma_ada", b_ada_t, "p_weights"),
            ("adacoma_hedge", b_hedge_t, "p_weights"),
        ]:
            res_dict = adacoma_results[strategy]
            p_t = scipy_solve_linear_program(
                b=prior_weights,
                lambdas=lambdas_t_finite,
                alpha_prime=alpha_prime,
                alpha=alpha_final,
            )
            res_dict[p_storage][i, :] = p_t
            # Majority Vote Metrics using p_t
            pi_m = cu.majority_vote(intervals_t, p_t, rho=0.5)
            res_dict["cov_maj"][i] = cu.covr_fun(pi_m, y_target)
            res_dict["loss_maj"][i] = cu.loss_fun(pi_m)
            # Expected Value Metrics using p_t
            res_dict["exp_cov"][i] = np.sum(p_t * np.nan_to_num(cov_ind_t))
            res_dict["exp_len"][i] = np.sum(
                p_t * np.nan_to_num(len_ind_t, nan=np.inf, posinf=np.inf, neginf=0.0)
            )

            loss_majority_vote = res_dict["loss_maj"][i]
            if not np.isfinite(loss_majority_vote):
                print(f"\nDEBUG: COMA Inf Length at step {i} for strategy '{strategy}'")
                print(f"  Input Intervals (intervals_t):\n{intervals_t}")
                print(f"  Input Weights (prior_weights): {prior_weights}")
                print(f"  Output Majority Vote Set (pi_m): {pi_m}")

    print("AdaCOMA Aggregation finished.")
    return adacoma_results


def run_fixed_aggregation(
    results_list,
    y_test,
    n_pred,
    raw_ind_coverage,
    raw_ind_lengths,
    base_weights_ada,
    base_weights_hedge,
    alpha_prime,
    alpha_final,
):
    """
     Performs AdaCOMA aggregation using base COMA weights as priors for AdaMinSE.

    Args:
         results_list (list): List of result dictionaries from individual model ACI runs.
                              Each dict should contain 'piAdapt' (n_pred x 2 array).
         y_test (np.ndarray): Array of true target values for the prediction period (length n_pred).
         n_pred (int): Number of prediction steps.
         raw_ind_coverage (np.ndarray): (n_pred x num_models) matrix of raw coverage indicators (0/1/NaN).
                                        Used for `exp_cov`.
         raw_ind_lengths (np.ndarray): (n_pred x num_models) matrix of raw interval lengths (lambdas).
                                       Inf indicates invalid. Used for LP and `exp_len`.
         base_weights_ada (np.ndarray): (n_pred x num_models) COMA weights from AdaHedge (prior b).
         base_weights_hedge (np.ndarray): (n_pred x num_models) COMA weights from Hedge (prior b).
         alpha_prime (float): Target miscoverage used for individual models (input `alpha'` to LP).
         alpha_final (float): Target miscoverage desired after selection (input `alpha` to LP).

     Returns:
         dict: adacoma_results dictionary containing results for AdaCOMA strategies. Keys:
             'adacoma_ada': Results using AdaHedge prior for LP.
             'adacoma_hedge': Results using Hedge prior for LP.
             Each strategy sub-dictionary contains:
                 'cov_maj': Coverage indicator for deterministic majority vote set using 'p'.
                 'loss_maj': Length of deterministic majority vote set using 'p'.
                 'exp_cov': Expected coverage under the selection distribution 'p'.
                 'exp_len': Expected length under the selection distribution 'p'.
                 'p_weights': (n_pred x num_models) array of selection weights 'p' from LP.
                 (Metric arrays are length n_pred).
             Returns None if aggregation cannot proceed.
    """
    # (Function code as provided previously)
    print("\nRunning AdaCOMA Aggregation ...")
    num_models = len(results_list)
    if num_models == 0 or n_pred <= 0:
        return None

    adacoma_results = {
        "adacoma_ada": {
            "cov_maj": np.zeros(n_pred),
            "loss_maj": np.zeros(n_pred),
            "exp_cov": np.zeros(n_pred),
            "exp_len": np.zeros(n_pred),
            "p_weights": np.zeros((n_pred, num_models)),
        },
        "adacoma_hedge": {
            "cov_maj": np.zeros(n_pred),
            "loss_maj": np.zeros(n_pred),
            "exp_cov": np.zeros(n_pred),
            "exp_len": np.zeros(n_pred),
            "p_weights": np.zeros((n_pred, num_models)),
        },
    }

    print("   Performing step-by-step AdaCOMA aggregation...")
    for i in tqdm(range(n_pred), desc="AdaCOMA Aggregation"):
        intervals_t = np.array([res["piAdapt"][i, :] for res in results_list])
        b_ada_t = base_weights_ada[i, :]
        b_hedge_t = base_weights_hedge[i, :]
        lambdas_t = raw_ind_lengths[i, :]
        lambdas_t_finite = np.nan_to_num(
            lambdas_t, nan=LARGE_LAMBDA, posinf=LARGE_LAMBDA, neginf=LARGE_LAMBDA
        )
        cov_ind_t = raw_ind_coverage[i, :]
        len_ind_t = raw_ind_lengths[i, :]
        y_target = y_test[i]

        for strategy, prior_weights, p_storage in [
            ("adacoma_ada", b_ada_t, "p_weights"),
            ("adacoma_hedge", b_hedge_t, "p_weights"),
        ]:
            res_dict = adacoma_results[strategy]
            prior_weights = np.ones_like(prior_weights)
            prior_weights = prior_weights / np.sum(prior_weights)
            p_t = scipy_solve_linear_program(
                b=prior_weights,
                lambdas=lambdas_t_finite,
                alpha_prime=alpha_prime,
                alpha=alpha_final,
            )
            res_dict[p_storage][i, :] = p_t
            # Majority Vote Metrics using p_t
            pi_m = cu.majority_vote(intervals_t, p_t, rho=0.5)
            res_dict["cov_maj"][i] = cu.covr_fun(pi_m, y_target)
            res_dict["loss_maj"][i] = cu.loss_fun(pi_m)
            # Expected Value Metrics using p_t
            res_dict["exp_cov"][i] = np.sum(p_t * np.nan_to_num(cov_ind_t))
            res_dict["exp_len"][i] = np.sum(
                p_t * np.nan_to_num(len_ind_t, nan=np.inf, posinf=np.inf, neginf=0.0)
            )

            loss_majority_vote = res_dict["loss_maj"][i]
            if not np.isfinite(loss_majority_vote):
                print(f"\nDEBUG: COMA Inf Length at step {i} for strategy '{strategy}'")
                print(f"  Input Intervals (intervals_t):\n{intervals_t}")
                print(f"  Input Weights (prior_weights): {prior_weights}")
                print(f"  Output Majority Vote Set (pi_m): {pi_m}")

    print("AdaCOMA Aggregation finished.")
    return adacoma_results
