import numpy as np

def planning_function(processed_global_state):
    """
    Determines optimal tasks for each agent based on the current battle state.
    
    Args:
        processed_global_state: A tuple containing (available_move_actions, enemy_info, ally_info, own_info)

    Returns:
        dict: Optimal tasks for each agent ('move', 'attack', or 'none' for dead agents)
    """
    move_feats, enemy_info, ally_info, own_info = processed_global_state
    llm_tasks = {}
    
    for agent_id in own_info.keys():
        if not own_info[agent_id][5]:  # If agent is dead
            llm_tasks[agent_id] = 'None'
            continue
        
        own_health = own_info[agent_id][0]
        own_pos = np.array([own_info[agent_id][2], own_info[agent_id][3]])
        
        # Find the closest enemy
        closest_enemy_dist = float('inf')
        closest_enemy_pos = None
        for enemy_id, enemy_data in enemy_info[agent_id].items():
            if enemy_data[5] > 0:  # If enemy is alive
                enemy_dist = enemy_data[1]
                if enemy_dist < closest_enemy_dist:
                    closest_enemy_dist = enemy_dist
                    closest_enemy_pos = np.array([enemy_data[6], enemy_data[7]])
        
        if closest_enemy_pos is None:
            llm_tasks[agent_id] = 'Stop'
        elif closest_enemy_dist <= 6:  # Within attack range
            if own_health > 0.3:  # If health is above 30%, attack
                llm_tasks[agent_id] = 'Attack'
            else:  # If health is low, move away
                move_direction = own_pos - closest_enemy_pos
                move_direction /= np.linalg.norm(move_direction)
                llm_tasks[agent_id] = 'Move'
        else:  # Enemy is out of range, move towards it
            move_direction = closest_enemy_pos - own_pos
            move_direction /= np.linalg.norm(move_direction)
            llm_tasks[agent_id] = 'Move'
    
    return llm_tasks

def compute_reward(processed_state, llm_tasks, tasks):
    """
    Calculate rewards based on the tasks assigned and their outcomes.
    
    Args:
        processed_global_state: returned from function process_global_state(global_state, n, m)
        llm_tasks (dict): Dictionary containing tasks assigned to each agent.
        tasks: A dictionary of task that actually perform by each agent
        
    Returns:
        reward: Dict containing rewards for each agent.
    """
    move_feats, enemy_info, ally_info, own_info = processed_state
    reward = {}
    
    for agent_id in own_info.keys():
        agent_reward = 0
        
        if not own_info[agent_id][5]:  # If agent is dead
            reward[agent_id] = 0
            continue
        
        # Reward for following LLM tasks
        if llm_tasks[agent_id] == tasks[agent_id]:
            agent_reward += 0.5
        
        # Reward for dealing damage
        total_enemy_health = sum(enemy_data[5] for enemy_data in enemy_info[agent_id].values())
        agent_reward += (64 - total_enemy_health) * 0.1
        
        # Penalty for taking damage
        agent_reward -= (1 - own_info[agent_id][0]) * 0.5
        
        # Reward for maintaining distance from enemies
        closest_enemy_dist = min(enemy_data[1] for enemy_data in enemy_info[agent_id].values())
        if 3 <= closest_enemy_dist <= 6:
            agent_reward += 0.2
        
        # Coordinate reward: bonus if agents are not too close to each other
        ally_dist = ally_info[agent_id][list(ally_info[agent_id].keys())[0]][1]
        if 2 <= ally_dist <= 3:
            agent_reward += 0.1
        
        reward[agent_id] = agent_reward
    
    return reward