import torch
from torch.distributions import Categorical

def pick_move(prev, mask, require_prob, pheromone, heuristic, alpha, beta):
    '''
    Enhanced pick_move function with per-ant epsilon-greedy strategy
    '''
    n_ants = prev.shape[0]
    pheromone_vals = pheromone[prev]  # shape: (n_ants, p_size)
    heuristic_vals = heuristic[prev]  # shape: (n_ants, p_size)
    
    probabilities = ((pheromone_vals ** alpha) * (heuristic_vals ** beta) * mask)  # shape: (n_ants, p_size)
    
    probabilities = probabilities + 1e-10
    epsilon = 0.1  
    random_values = torch.rand(n_ants, device=prev.device)
    greedy_mask = random_values < epsilon
    
    actions = torch.zeros(n_ants, dtype=torch.long, device=prev.device)
    
    if greedy_mask.any():
        greedy_values = heuristic_vals * mask 
        greedy_actions = greedy_values.argmax(dim=1) 
        actions[greedy_mask] = greedy_actions[greedy_mask]
    
    if (~greedy_mask).any():
        dist = Categorical(probabilities)
        sampled_actions = dist.sample()  
        actions[~greedy_mask] = sampled_actions[~greedy_mask]
    
    if require_prob:
        dist = Categorical(probabilities)
        log_probs = dist.log_prob(actions)
    else:
        log_probs = None
        
    return actions, log_probs


