import numpy as np
from functions.value_iter import solve_fixed_point

TOLERANCE = 0.001

def check_single_step(current_mus, next_mu_candidate, base_params, value_tol=TOLERANCE):
    """
    Checks if adding 'next_mu_candidate' allows the agent to bridge the gap 
    from the last confirmed level.
    """
    test_mus = current_mus + [next_mu_candidate]
    local_params = base_params.copy()
    local_params['mu_list'] = test_mus
    
    # Dynamic grid sizing ensures we cover the new candidate
    X_MAX = max(next_mu_candidate * 1.5, 5.0)
    X_grid = np.linspace(0, X_MAX, 5000)
    
    # Run the solver
    W = solve_fixed_point(local_params, X_grid, TOL=TOLERANCE)
    
    # Check the gap incentive
    last_level_idx = len(test_mus) - 2
    start_attr = local_params['gamma']*next_mu_candidate + local_params[
                        'delta']*(last_level_idx + 1)
    idx_start = (np.abs(X_grid - start_attr)).argmin()
    idx_end = (np.abs(X_grid - next_mu_candidate)).argmin()
    
    return (W[last_level_idx, idx_end] - W[last_level_idx, idx_start]) < value_tol


def find_optimal_thresholds_greedy(params, target_M=None, max_L=50, search_cap=60.0, epsilon=0.05):
    """
    Constructs the ladder iteratively until the highest threshold exceeds target_M.
    
    Args:
        params (dict): System parameters (beta, gamma, etc.)
        target_M (float): The target attribute value to reach. 
                          If None, runs until max_L or feasibility limit.
        max_L (int): Safety limit on number of levels to prevent infinite loops.
        search_cap (float): Maximum value to search for the next threshold.
        epsilon (float): Binary search precision.
        
    Returns:
        list: The sequence of optimal thresholds [0.0, mu_2, mu_3, ...]
    """
    optimal_mus = [0.0] # Start at mu_1 = 0
    
    # Set default target if None (effectively run until stuck)
    if target_M is None:
        target_M = float('inf')

    # Loop until we reach target M or hit the safety level limit
    for l in range(1, max_L + 1):
        prev_mu = optimal_mus[-1]
        
        # Stopping condition 1: Reached the target attribute
        if prev_mu >= target_M:
            break
            
        low = prev_mu
        high = max(search_cap, prev_mu * 2) # Dynamic expansion of search space
        best_next_mu = prev_mu 
        
        # Binary search for the maximum feasible next step
        while (high - low) > epsilon:
            mid = (low + high) / 2
            
            if check_single_step(optimal_mus, mid, params):
                best_next_mu = mid
                low = mid # Try to go higher
            else:
                high = mid # Too high, incentive breaks
        
        # Stopping Condition 2: Cannot advance further (stuck)
        if best_next_mu <= prev_mu + epsilon:
            print(f"  Converged/Stuck at Level {l}. Max attribute: {prev_mu:.2f}")
            break
            
        optimal_mus.append(best_next_mu)
        # Optional: Expand search cap if we are moving fast
        search_cap = max(search_cap, best_next_mu * 1.5)

    return optimal_mus