import numpy as np

def planning_function(processed_global_state):
    """
    Determines optimal tasks for each agent based on the current battle state.
    
    Args:
        processed_global_state: A tuple containing (available_move_actions, enemy_info, ally_info, own_info)

    Returns:
        dict: Optimal tasks for each agent ('move', 'attack', or 'none' for dead agents)
    """
    move_feats, enemy_info, ally_info, own_info = processed_global_state
    llm_tasks = {}
    
    for agent_id in own_info.keys():
        if not own_info[agent_id][5]:  # If agent is dead
            llm_tasks[agent_id] = 'none'
            continue
        
        enemy_dist = enemy_info[agent_id]['enemy_0'][1][0]
        enemy_health = enemy_info[agent_id]['enemy_0'][5][0]
        own_health = own_info[agent_id][0][0]
        
        if enemy_dist > 9:  # Outside sight range, need to move closer
            llm_tasks[agent_id] = 'move'
        elif 6 < enemy_dist <= 9:  # In sight range but outside attack range
            if own_health > 0.5 and enemy_health > 0.25:  # Healthy enough to engage
                llm_tasks[agent_id] = 'move'
            else:  # Low health or enemy nearly dead, maintain distance
                llm_tasks[agent_id] = 'stop'
        elif enemy_dist <= 6:  # Within attack range
            if own_health > 0.3:  # Healthy enough to attack
                llm_tasks[agent_id] = 'attack'
            else:  # Low health, retreat
                llm_tasks[agent_id] = 'move'
        else:
            llm_tasks[agent_id] = 'stop'
    
    return llm_tasks

def compute_reward(processed_state, llm_tasks, tasks):
    """
    Calculate rewards based on the tasks assigned and their outcomes.
    
    Args:
        processed_global_state: returned from function process_global_state(global_state, n, m)
        llm_tasks (dict): Dictionary containing tasks assigned to each agent.
        tasks: A dictionary of task that actually perform by each agent
        
    Returns:
        reward: Dict containing rewards for each agent
    """
    move_feats, enemy_info, ally_info, own_info = processed_state
    reward = {}
    
    for agent_id in own_info.keys():
        agent_reward = 0
        
        if not own_info[agent_id][5]:  # If agent is dead
            agent_reward -= 100  # Large penalty for death
            reward[agent_id] = agent_reward
            continue
        
        enemy_dist = enemy_info[agent_id]['enemy_0'][1][0]
        enemy_health = enemy_info[agent_id]['enemy_0'][5][0]
        own_health = own_info[agent_id][0][0]
        
        # Reward for following LLM tasks
        if tasks[agent_id] == llm_tasks[agent_id]:
            agent_reward += 1
        
        # Reward for maintaining optimal distance
        if 6 < enemy_dist <= 9:
            agent_reward += 2
        
        # Reward for attacking when appropriate
        if tasks[agent_id] == 'attack' and enemy_dist <= 6 and own_health > 0.3:
            agent_reward += 3
        
        # Penalty for unnecessary movement
        if tasks[agent_id] == 'move' and 6 < enemy_dist <= 9 and own_health > 0.5:
            agent_reward -= 1
        
        # Reward for coordination (both agents attacking simultaneously)
        if all(tasks[a] == 'attack' for a in tasks):
            agent_reward += 5
        
        # Large reward for defeating the enemy
        if enemy_health == 0:
            agent_reward += 100
        
        reward[agent_id] = agent_reward
    
    return reward