import jax
import jax.numpy as jnp
import numpy as np
from jaxopt import OSQP
from scipy.optimize import linprog
from jaxopt import projection
from jaxopt.projection import projection_simplex


@jax.jit
def solve_min_dot_simplex(q, b, eta):
    """
    Solves:
       min_a <a, q>
       subject to a in simplex, 0 <= a_i <= b_i * exp(eta), sum_i a_i = 1.
    Assumes sum_i b_i * exp(eta) >= 1 (feasibility).
    """
    d = q.shape[0]
    # 1) Sort indices by ascending q
    sorted_indices = jnp.argsort(q)  # ascending order

    # We'll build a in sorted order, then reorder back.
    a_sorted = jnp.zeros(d)

    def fill_coordinate(carry, idx):
        mass_so_far, a_so_far = carry

        # how much can we allocate to this coordinate?
        i = sorted_indices[idx]
        cap = b[i] * jnp.exp(eta)
        allocation = jnp.minimum(mass_so_far, cap)

        # update
        a_so_far = a_so_far.at[idx].set(allocation)
        mass_so_far = mass_so_far - allocation
        return (mass_so_far, a_so_far), None

    # 2) Iterate in sorted order to fill from smallest q_i to largest
    (final_mass, a_sorted), _ = jax.lax.scan(
        fill_coordinate, (1.0, a_sorted), jnp.arange(d)
    )

    def place_back(x, idx):
        return x.at[sorted_indices[idx]].set(a_sorted[idx])

    a = jnp.zeros(d)
    a = jax.lax.fori_loop(0, d, lambda i, x: place_back(x, i), a)

    return a


def scipy_solve_linear_program(b, lambdas, alpha_prime, alpha):
    """
    Solves the linear program defined in the problem description.

    Args:
        b (list): List of b_i values (prior distribution on the simplex)
        lambdas (list): List of lambda_i values (values for the whatever.).
        alpha_prime (float): Value of alpha': the current coverage
        alpha (float): Value of alpha: the required coverage

    Returns:
        numpy.ndarray or None: The optimal values of p, or None if the problem is infeasible.
    """
    k = len(b)
    if len(lambdas) != k:
        raise ValueError("The length of lambdas must be equal to the length of b.")

    # Objective function coefficients
    c = np.array(list(lambdas) + [0] * k + [0] * 2)

    # Equality constraint: sum(p_i) = 1
    A_eq = np.array([[1] * k + [0] * k + [0] * 2])
    b_eq = np.array([1])

    # Inequality constraints
    A_ub_list = []
    b_ub_list = []

    # Constraint: p_i <= eta * b_i + s_i  =>  p_i - s_i - eta * b_i <= 0
    for i in range(k):
        row = np.zeros(2 * k + 2)
        row[i] = 1  # coefficient of p_i
        row[k + i] = -1  # coefficient of s_i
        row[2 * k + 1] = -b[i]  # coefficient of eta
        A_ub_list.append(row)
        b_ub_list.append(0)

    # Constraint: sum(s_i) <= tau  =>  sum(s_i) - tau <= 0
    row = np.zeros(2 * k + 2)
    row[k : 2 * k] = 1  # coefficients of s_i
    row[2 * k] = -1  # coefficient of tau
    A_ub_list.append(row)
    b_ub_list.append(0)

    # Constraint: eta * alpha_prime + tau <= alpha
    row = np.zeros(2 * k + 2)
    row[2 * k] = 1  # coefficient of tau
    row[2 * k + 1] = alpha_prime  # coefficient of eta
    A_ub_list.append(row)
    b_ub_list.append(alpha)

    A_ub = np.array(A_ub_list)
    b_ub = np.array(b_ub_list)

    # Bounds for variables: p >= 0, s >= 0, tau >= 0, eta >= 0
    bounds = [(0, None)] * (2 * k + 2)

    result = linprog(
        c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bounds, method="highs"
    )

    return result.x[:k]  # The optimal values of p




def project_onto_feasible_region(p, s, tau, eta, k, b, alpha_prime, alpha):
    """
    Projects a solution onto the feasible region defined by the constraints,
    using Jaxopt's projection utilities.

    Applies all projections unconditionally to make it JIT-compatible.

    Args:
        p (jnp.ndarray): Array of p_i values.
        s (jnp.ndarray): Array of s_i values
        tau (jnp.ndarray): Array of tau values
        eta (jnp.ndarray): Array of eta values
        k (int): Number of b_i and lambda_i values.
        b (jnp.ndarray): Array of b_i values.
        alpha_prime (float): Value of alpha'.
        alpha (float): Value of alpha.

    Returns:
        jnp.ndarray: The projected values of p, or the original p if already feasible.
    """

    # Project p onto the constraint p_i >= 0
    p = projection.projection_non_negative(p)

    # Project s onto the constraint s_i >= 0
    s = projection.projection_non_negative(s)

    # Project eta onto the constraint eta >= 0
    eta = projection.projection_non_negative(eta)

    tau = alpha - eta * alpha_prime
    tau = projection.projection_non_negative(tau)


    s_projected = projection.projection_l1_ball(s)
    s = s_projected * tau


    p = jnp.minimum(p, eta * b + s)

    p = projection_simplex(p)

    return p, s, tau, eta


def jax_solve_linear_program(k, b, lambdas, alpha_prime, alpha):
    """
    Solves the linear program using Jaxopt's CvxpyQP, jitted and vmap-able.

    Args:
        k (int): Number of b_i and lambda_i values (static argument).
        b (jnp.ndarray): Array of b_i values.
        lambdas (jnp.ndarray): Array of lambda_i values.
        alpha_prime (float): Value of alpha'.
        alpha (float): Value of alpha.

    Returns:
        jnp.ndarray: The optimal values of p, or jnp.zeros(k) if infeasible.
    """

    # Objective function: Convert linear objective to quadratic form (P=0, q=c)
    P = jnp.zeros((2 * k + 2, 2 * k + 2))
    q = jnp.concatenate([lambdas, jnp.zeros(k + 2)])
    # print(P.shape, q.shape)
    # Equality constraint: sum(p_i) = 1
    A_eq = jnp.concatenate([jnp.ones(k), jnp.zeros(k + 2)])[None, :]
    b_eq = jnp.array([1.0])
    # print(A_eq.shape, b_eq.shape)

    # Inequality constraints
    A_ub = jnp.zeros((3 * k + 2, 2 * k + 2))  # Increased rows for s_i constraints
    b_ub = jnp.zeros(3 * k + 2)  # Increased size for s_i constraints

    # Constraint: p_i <= eta * b_i + s_i  =>  p_i - s_i - eta * b_i <= 0
    A_ub = A_ub.at[jnp.arange(k), jnp.arange(k)].set(1.0)  # p_i
    A_ub = A_ub.at[jnp.arange(k), k + jnp.arange(k)].set(-1.0)  # -s_i
    A_ub = A_ub.at[jnp.arange(k), 2 * k + 1].set(-b)  # -eta * b_i

    # Constraint: sum(s_i) <= tau
    A_ub = A_ub.at[k, k : 2 * k].set(1.0)
    A_ub = A_ub.at[k, 2 * k].set(-1.0)

    # Constraint: eta * alpha_prime + tau <= alpha
    A_ub = A_ub.at[k + 1, 2 * k].set(1.0)  # tau
    A_ub = A_ub.at[k + 1, 2 * k + 1].set(alpha_prime)  # eta * alpha_prime
    b_ub = b_ub.at[k + 1].set(alpha)  #  Set b_ub to alpha

    # Constraint: p_i >= 0
    A_ub = A_ub.at[k + 2 : 2 * k + 2, :k].set(-jnp.eye(k))

    # Constraint: s_i >= 0
    A_ub = A_ub.at[2 * k + 2 : 3 * k + 2, k : 2 * k].set(-jnp.eye(k))

    # Constraint: eta >= 0
    A_ub = A_ub.at[3 * k + 2, 2 * k + 1].set(-1.0)  # -eta <= 0
    # Solve using CvxpyQP

    sol = OSQP(maxiter=10_000)
    # params = dict(init_params=jax.stack() , params_obj=(P, q), params_eq=(A_eq, b_eq), params_ineq=(A_ub, b_ub))
    # sol_state = sol.run(jnp.zeros((1,1)),params_obj=(P, q), params_eq=(A_eq, b_eq), params_ineq=(A_ub, b_ub))
    # sol_state = sol.run(params_obj=(P, q), params_eq=(A_eq, b_eq), params_ineq=(A_ub, b_ub))

    # x = sol_state.params
    # #params, tau, eta
    # return x[0][:k],x[0][k:2*k],x[0][2*k],x[0][2*k+1]
    sol_state = sol.run(
        params_obj=(P, q), params_eq=(A_eq, b_eq), params_ineq=(A_ub, b_ub)
    )

    x = sol_state.params
    p, s, tau, eta = x[0][:k], x[0][k : 2 * k], x[0][2 * k], x[0][2 * k + 1]

    # Project onto the feasible region
    p, s, tau, eta = project_onto_feasible_region(
        p, s, tau, eta, k, b, alpha_prime, alpha
    )

    return p, s, tau, eta
