import numpy as np

def T_operator_optimized(W_k_vec, X_grid, params):
    """
    Bellman Operator for the Agent's Value Function.
    Vectorized for speed.
    """
    L = W_k_vec.shape[0]
    D = W_k_vec.shape[1]
    
    beta = params['beta']
    gamma = params['gamma']
    c_plus = params['c_plus']
    c_minus = params['c_minus']
    r = params['r']
    delta = params['delta']
    
    mu_list = np.array(params['mu_list'])
    
    # Pre-calculate common terms
    rho = r + beta * c_plus * delta
    levels_mat = np.arange(L)[:, np.newaxis] # (L, 1)
    X_tilde_grid = X_grid[np.newaxis, :]     # (1, D)

    # Next state (attributes)
    P = gamma * X_tilde_grid + delta * levels_mat
    P_minus = gamma * X_tilde_grid + delta * (levels_mat - 1).clip(min=0)
    P_plus = gamma * X_tilde_grid + delta * (levels_mat + 1).clip(max=L - 1)

    # Interpolate W_k at these next states
    Wk_val_l = np.array([np.interp(P[l, :], X_grid, W_k_vec[l]) for l in range(L)])
    Wk_val_l_minus = np.array([np.interp(P_minus[l, :], X_grid, W_k_vec[max(0, l - 1)]) for l in range(L)])
    Wk_val_l_plus = np.array([np.interp(P_plus[l, :], X_grid, W_k_vec[min(L - 1, l + 1)]) for l in range(L)])

    # Costs & Rewards
    common_term = (1 - beta * gamma) * c_plus * X_tilde_grid
    
    # 1. Demotion
    val_l_minus = common_term - rho * (levels_mat - 1).clip(min=0) + beta * Wk_val_l_minus
    
    # 2. Stay
    cost_term_l = c_minus * (mu_list[:, np.newaxis] - X_tilde_grid).clip(min=0)
    val_l = common_term + cost_term_l - rho * levels_mat + beta * Wk_val_l
    
    # 3. Promotion
    # Careful with mu_list index for next level
    mu_list_next = np.append(mu_list[1:], mu_list[-1]) 
    cost_term_l_plus = c_minus * (mu_list_next[:, np.newaxis] - X_tilde_grid).clip(min=0)
    val_l_plus = common_term + cost_term_l_plus - rho * (levels_mat + 1).clip(max=L - 1) + beta * Wk_val_l_plus

    # Bellman Update
    T_vals = np.minimum.reduce([val_l_minus, val_l, val_l_plus])
    
    # Enforce monotonicity 
    W_kplus1_vec = T_vals.copy()
    
    return W_kplus1_vec

def solve_fixed_point(params, X_grid, TOL=1e-5, MAX_ITER=2000, verbose=False):
    """
    Solves for the Fixed Point W(l, x).
    """
    L = len(params['mu_list'])
    
    # L must be at least 1
    if L < 1:
        raise ValueError("mu_list must have at least 1 level (0.0).")
        
    W_k = np.zeros((L, len(X_grid)))
    
    for i in range(MAX_ITER):
        W_kplus1 = T_operator_optimized(W_k, X_grid, params)
        
        # Check for NaNs
        if np.isnan(W_kplus1).any():
            if verbose: print("Warning: NaN detected in Value Iteration.")
            return W_k # Return last valid
            
        diff = np.max(np.abs(W_kplus1 - W_k))
        
        if diff < TOL:
            if verbose:
                print(f"Converged in {i} iterations.")
            return W_kplus1
        W_k = W_kplus1
        
    if verbose:
        print(f"Warning: VI did not converge within {MAX_ITER} iterations (diff={diff:.4f}).")
    return W_k