import numpy as np
from scipy.optimize import linprog


def solve_min_dot_simplex(lambdas, b, eta):
    """
    Solves:
       min_a <a, q>
       subject to a in simplex, 0 <= a_i <= b_i * exp(eta), sum_i a_i = 1.
    Assumes sum_i b_i * exp(eta) >= 1 (feasibility).
    """
    d = lambdas.shape[0]
    # 1) Sort indices by ascending q
    sorted_indices = np.argsort(lambdas)  # ascending order

    # 2) Iterate in sorted order to fill from smallest q_i to largest
    mass_so_far = 1.0
    a_sorted = np.zeros(d)

    for idx in range(d):
        i = sorted_indices[idx]
        cap = b[i] * np.exp(eta)
        allocation = np.minimum(mass_so_far, cap)
        a_sorted[idx] = allocation
        mass_so_far -= allocation

    # 3) Place back into the original order
    a = np.zeros(d)
    a[sorted_indices] = a_sorted

    return a


def scipy_solve_linear_program(b, lambdas, alpha_prime, alpha):
    """
    Solves the linear program defined in the problem description.

    Args:
        b (list): List of b_i values (prior distribution on the simplex)
        lambdas (list): List of lambda_i values (values for the whatever.).
        alpha_prime (float): Value of alpha': the current coverage
        alpha (float): Value of alpha: the required coverage

    Returns:
        numpy.ndarray or None: The optimal values of p, or None if the problem is infeasible.
    """

    # print(b, alpha_prime, alpha)
    # raise
    k = len(b)
    if len(lambdas) != k:
        raise ValueError("The length of lambdas must be equal to the length of b.")

    # Objective function coefficients
    c = np.array(list(lambdas) + [0] * k + [0] * 2)

    # Equality constraint: sum(p_i) = 1
    A_eq = np.array([[1] * k + [0] * k + [0] * 2])
    b_eq = np.array([1])

    # Inequality constraints
    A_ub_list = []
    b_ub_list = []

    # Constraint: p_i <= eta * b_i + s_i  =>  p_i - s_i - eta * b_i <= 0
    for i in range(k):
        row = np.zeros(2 * k + 2)
        row[i] = 1  # coefficient of p_i
        row[k + i] = -1  # coefficient of s_i
        row[2 * k + 1] = -b[i]  # coefficient of eta
        A_ub_list.append(row)
        b_ub_list.append(0)

    # Constraint: sum(s_i) <= tau  =>  sum(s_i) - tau <= 0
    row = np.zeros(2 * k + 2)
    row[k : 2 * k] = 1  # coefficients of s_i
    row[2 * k] = -1  # coefficient of tau
    A_ub_list.append(row)
    b_ub_list.append(0)

    # Constraint: eta * alpha_prime + tau <= alpha
    row = np.zeros(2 * k + 2)
    row[2 * k] = 1  # coefficient of tau
    row[2 * k + 1] = alpha_prime  # coefficient of eta
    A_ub_list.append(row)
    b_ub_list.append(alpha)

    A_ub = np.array(A_ub_list)
    b_ub = np.array(b_ub_list)

    # Bounds for variables: p >= 0, s >= 0, tau >= 0, eta >= 0
    bounds = [(0, None)] * (2 * k + 2)

    result = linprog(
        c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bounds, method="highs"
    )
    # print(result,flush=True)
    if result.success is False:
        print("Optimization failed:", result.message)
        print(
            "input b{0}, lambdas{1}, alpha_prime{2}, alpha{3}".format(
                b, lambdas, alpha_prime, alpha
            )
        )
        return b
    return result.x[:k]  # The optimal values of p
