import numpy as np

def get_agent_step(l_curr_batch, x_curr_batch, V, X_grid, params):
    """
    Vectorized determination of agent steps for a whole population.
    
    Args:
        l_curr_batch (np.ndarray): Shape (N,) int, current levels.
        x_curr_batch (np.ndarray): Shape (N,) float, current attributes.
        V (np.ndarray): Shape (L, M), Value matrix.
        X_grid (np.ndarray): Shape (M,), Discretized attribute grid.
        params (dict): Environment parameters.

    Returns:
        (a_plus, a_minus, l_next, x_next) - all shape (N,)
    """
    beta = params['beta']
    gamma = params['gamma']
    c_plus = params['c_plus']
    c_minus = params['c_minus']
    delta = params['delta']
    r = params['r']
    mu_list = np.array(params['mu_list'])
    L = len(mu_list)
    
    N = len(l_curr_batch)
    M = len(X_grid)
    
    # Outputs
    a_plus_out = np.zeros(N)
    a_minus_out = np.zeros(N)
    l_next_out = np.zeros(N, dtype=int)
    x_next_out = np.zeros(N)
    
    # We process agents group-by-group based on their current level.
    unique_levels = np.unique(l_curr_batch)
    
    for level in unique_levels:
        # 1. Identify agents in this level
        idx_agents = np.where(l_curr_batch == level)[0]
        x_agents = x_curr_batch[idx_agents]
        n_subset = len(idx_agents)
        
        # 2. Pre-calculate base costs for all candidates on the grid.
        # These costs are identical for everyone in this level, 
        # because the 'future' depends only on the candidate x_tilde chosen.
        # Cost Shape: (M,) corresponding to X_grid
        
        # Demote Branch
        if level > 0:
            l_dest = level - 1
            x_dest = gamma * X_grid + delta * l_dest
            v_fut = np.interp(x_dest, X_grid, V[l_dest])
            cost_demote_grid = -r * l_dest + beta * v_fut
        else:
            cost_demote_grid = np.full(M, np.inf)

        # Stay Branch
        l_dest = level
        x_dest = gamma * X_grid + delta * l_dest
        v_fut = np.interp(x_dest, X_grid, V[l_dest])
        mu_stay = mu_list[level]
        gaming_cost = c_minus * np.maximum(0, mu_stay - X_grid)
        cost_stay_grid = gaming_cost - r * l_dest + beta * v_fut

        # Promote Branch 
        if level < L - 1:
            l_dest = level + 1
            x_dest = gamma * X_grid + delta * l_dest
            v_fut = np.interp(x_dest, X_grid, V[l_dest])
            mu_promote = mu_list[l_dest]
            gaming_cost = c_minus * np.maximum(0, mu_promote - X_grid)
            cost_promote_grid = gaming_cost - r * l_dest + beta * v_fut
        else:
            cost_promote_grid = np.full(M, np.inf)
            
        # 3. Combine Branch Costs
        # For each grid point, what is the best branch outcome?
        # Stack shape: (3, M)
        branch_costs_stack = np.vstack([cost_demote_grid, cost_stay_grid, cost_promote_grid])
        
        # Min across branches: (M,)
        best_branch_cost_grid = np.min(branch_costs_stack, axis=0)
        # Which branch indices (0: Demote, 1: Stay, 2: Promote): (M,)
        best_branch_idx_grid = np.argmin(branch_costs_stack, axis=0)
        
        # 4. Calculate Total Costs for Specific Agents
        # Total Cost = Improve Cost + Branch Cost
        # Improve Cost: c_plus * (X_grid - x_current)
        cost_improve = c_plus * (X_grid[None, :] - x_agents[:, None])
        
        total_costs = cost_improve + best_branch_cost_grid[None, :]
        
        # 5. Mask Invalid Moves (cannot improve downwards)
        # Mask where X_grid < x_current
        mask_invalid = X_grid[None, :] < x_agents[:, None]
        total_costs[mask_invalid] = np.inf
        
        # 6. Find Optimal Candidate x_tilde
        # Argmin along grid axis -> indices into X_grid
        best_cand_indices = np.argmin(total_costs, axis=1)
        
        x_tilde_star = X_grid[best_cand_indices]
        branch_choices = best_branch_idx_grid[best_cand_indices]
        
        # 7. Compute Outputs for this subset
        a_plus_subset = x_tilde_star - x_agents
        
        # Determine l_next
        # 0->level-1, 1->level, 2->level+1
        l_next_subset = level + (branch_choices - 1)
        
        # Determine a_minus (re-calculate gaming based on choice)
        a_minus_subset = np.zeros(n_subset)
        
        # Stay gaming
        mask_stay = (branch_choices == 1)
        if np.any(mask_stay):
            mu_s = mu_list[level]
            # gaming = max(0, mu - x_tilde)
            needed = np.maximum(0, mu_s - x_tilde_star[mask_stay])
            a_minus_subset[mask_stay] = needed
            
        # Promote gaming
        mask_prom = (branch_choices == 2)
        if np.any(mask_prom):
            mu_p = mu_list[level + 1]
            needed = np.maximum(0, mu_p - x_tilde_star[mask_prom])
            a_minus_subset[mask_prom] = needed
            
        # Next attribute dynamics
        x_next_subset = gamma * x_tilde_star + delta * l_next_subset

        # 8. Scatter back to main arrays
        a_plus_out[idx_agents] = a_plus_subset
        a_minus_out[idx_agents] = a_minus_subset
        l_next_out[idx_agents] = l_next_subset
        x_next_out[idx_agents] = x_next_subset

    return a_plus_out, a_minus_out, l_next_out, x_next_out