import numpy as np

def planning_function(processed_state):
    """
    Determines optimal tasks for each agent based on the current state.
    
    Args:
        processed_state: A tuple containing food location and level, agent position and level.

    Returns:
        dict: Optimal tasks for each agent ('No op','Target food 0','Target food 1','Pickup')
    """
    food_info, agents_info = processed_state
    llm_tasks = {}
    
    # Calculate distances between agents and food
    distances = {}
    for food_id, food_data in food_info.items():
        if food_data is not None:
            food_pos, food_level = food_data
            for agent_id, (agent_pos, agent_level) in agents_info.items():
                dist = np.linalg.norm(np.array(food_pos) - np.array(agent_pos))
                distances[(agent_id, food_id)] = dist

    # Assign tasks based on distances and levels
    assigned_agents = set()
    for food_id, food_data in food_info.items():
        if food_data is None:
            continue
        food_pos, food_level = food_data
        nearby_agents = sorted(
            [(agent_id, dist) for (agent_id, food_id_), dist in distances.items() if food_id_ == food_id],
            key=lambda x: x[1]
        )
        
        total_level = 0
        for agent_id, dist in nearby_agents:
            if agent_id in assigned_agents:
                continue
            agent_pos, agent_level = agents_info[agent_id]
            total_level += agent_level
            
            if dist <= 1:  # Agent is adjacent to food
                llm_tasks[agent_id] = "Pickup"
            else:
                llm_tasks[agent_id] = f"Target food {food_id[-1]}"
            
            assigned_agents.add(agent_id)
            
            if total_level >= food_level:
                break
    
    # Assign 'No op' to any unassigned agents
    for agent_id in agents_info.keys():
        if agent_id not in assigned_agents:
            llm_tasks[agent_id] = "No op"
    
    return llm_tasks

def compute_reward(processed_state, llm_actions, actions):
    """
    Calculate rewards based on the tasks assigned and their outcomes.
    
    Args:
        processed_state: returned from function process_state(state, p, f)
        llm_actions (dict): dictionary of list of integers which means the suggest actions from llm for each agent.
        actions (dict): dictionary of a integer action that actually perform by each agent.
        
    Returns:
        reward: Dict containing rewards for each agent.
    """
    food_info, agents_info = processed_state
    reward = {agent_id: 0 for agent_id in agents_info.keys()}
    
    # Reward for following LLM suggestions
    for agent_id, llm_action in llm_actions.items():
        if actions[agent_id] in llm_action:
            reward[agent_id] += 0.002
    
    # Reward for successful pickup
    pickup_attempts = [agent_id for agent_id, action in actions.items() if action == 5]
    if pickup_attempts:
        for food_id, food_data in food_info.items():
            if food_data is not None:
                food_pos, food_level = food_data
                nearby_agents = [
                    agent_id for agent_id in pickup_attempts
                    if np.linalg.norm(np.array(food_pos) - np.array(agents_info[agent_id][0])) <= 1
                ]
                total_level = sum(agents_info[agent_id][1] for agent_id in nearby_agents)
                
                if total_level >= food_level:
                    for agent_id in nearby_agents:
                        reward[agent_id] += 0.007
                else:
                    for agent_id in nearby_agents:
                        reward[agent_id] -= 0.003  # Penalty for uncoordinated pickup attempt
    
    # Small reward for moving towards food
    for agent_id, action in actions.items():
        if action in [1, 2, 3, 4]:  # Movement actions
            agent_pos = agents_info[agent_id][0]
            for food_data in food_info.values():
                if food_data is not None:
                    food_pos, _ = food_data
                    old_dist = np.linalg.norm(np.array(agent_pos) - np.array(food_pos))
                    new_pos = list(agent_pos)
                    if action == 1: new_pos[0] -= 1
                    elif action == 2: new_pos[0] += 1
                    elif action == 3: new_pos[1] -= 1
                    elif action == 4: new_pos[1] += 1
                    new_dist = np.linalg.norm(np.array(new_pos) - np.array(food_pos))
                    if new_dist < old_dist:
                        reward[agent_id] += 0.001
    
    return reward