import numpy as np

def planning_function(processed_global_state):
    """
    Determines optimal tasks for each agent based on the current battle state.
    
    Args:
        processed_global_state: A tuple containing (available_move_actions, enemy_info, ally_info, own_info)

    Returns:
        dict: Optimal tasks for each agent ('move', 'attack', or 'none' for dead agents)
    """
    move_feats, enemy_info, ally_info, own_info = processed_global_state
    llm_tasks = {}
    
    for agent_id in own_info.keys():
        if not own_info[agent_id][5]:  # If agent is dead
            llm_tasks[agent_id] = 'none'
            continue
        
        # Find closest enemy
        closest_enemy_dist = float('inf')
        for enemy_data in enemy_info[agent_id].values():
            if enemy_data[4][0] == 1 and enemy_data[1][0] < closest_enemy_dist:
                closest_enemy_dist = enemy_data[1][0]
        
        # Decide action based on distance to closest enemy
        if 6 <= closest_enemy_dist <= 9:
            llm_tasks[agent_id] = 'attack'
        elif closest_enemy_dist < 6:
            llm_tasks[agent_id] = 'move'
        else:
            # Move towards center if no enemies in range
            if move_feats[agent_id]:
                llm_tasks[agent_id] = 'move'
            else:
                llm_tasks[agent_id] = 'stop'
    
    return llm_tasks

def compute_reward(processed_state, llm_tasks, tasks):
    """
    Calculate rewards based on the tasks assigned and their outcomes.
    
    Args:
        processed_global_state: returned from function process_global_state(global_state, n, m)
        llm_tasks (dict): Dictionary containing tasks assigned to each agent.
        tasks: A dictionary of task that actually perform by each agent
        
    Returns:
        reward: Dict containing rewards for each agent.
    """
    move_feats, enemy_info, ally_info, own_info = processed_state
    reward = {}
    
    for agent_id in own_info.keys():
        agent_reward = 0
        
        if not own_info[agent_id][5]:  # If agent is dead
            reward[agent_id] = -10
            continue
        
        # Reward for following LLM task
        if llm_tasks[agent_id] == tasks[agent_id]:
            agent_reward += 1
        
        # Reward for damaging enemies
        enemies_in_range = sum(1 for enemy_data in enemy_info[agent_id].values() if enemy_data[0][0] == 1)
        if tasks[agent_id] == 'attack' and enemies_in_range > 0:
            agent_reward += 2 * enemies_in_range
        
        # Penalty for being too close to enemies
        close_enemies = sum(1 for enemy_data in enemy_info[agent_id].values() if enemy_data[1][0] < 3)
        agent_reward -= close_enemies
        
        # Reward for coordinated positioning
        ally_id = [id for id in ally_info[agent_id].keys()][0]
        ally_dist = ally_info[agent_id][ally_id][1][0]
        if 3 <= ally_dist <= 6:
            agent_reward += 1
        
        reward[agent_id] = agent_reward
    
    return reward