import numpy as np

def planning_function(processed_global_state):
    """
    Determines optimal tasks for each agent based on the current battle state.
    
    Args:
        processed_global_state: A tuple containing (move_feats, enemy_info, ally_info, own_info)

    Returns:
        dict: Optimal tasks for each agent ('move', 'attack', or 'none' for dead agents)
    """
    move_feats, enemy_info, ally_info, own_info = processed_global_state
    llm_tasks = {}
    
    for agent_id in own_info.keys():
        if not own_info[agent_id][5]:  # If agent is dead
            llm_tasks[agent_id] = 'none'
            continue
        
        nearest_enemy_dist = min([info[1] for info in enemy_info[agent_id].values()])
        nearest_enemy_id = min(enemy_info[agent_id], key=lambda x: enemy_info[agent_id][x][1])
        
        if nearest_enemy_dist > 6:  # If enemies are out of range, move closer
            llm_tasks[agent_id] = 'move'
        elif nearest_enemy_dist < 5:  # If enemies are too close, move away
            llm_tasks[agent_id] = 'move'
        else:  # If enemies are in optimal range, attack
            llm_tasks[agent_id] = 'attack'
        
        # Coordinate with ally
        ally_id = [id for id in ally_info[agent_id].keys()][0]
        ally_dist = ally_info[agent_id][ally_id][1]
        
        if ally_dist > 10:  # If allies are too far apart, move closer
            llm_tasks[agent_id] = 'move'
    
    return llm_tasks

def compute_reward(processed_state, llm_tasks, tasks):
    """
    Calculate rewards based on the tasks assigned and their outcomes.
    
    Args:
        processed_state: returned from function process_global_state(global_state, n, m)
        llm_tasks (dict): Dictionary containing tasks assigned to each agent.
        tasks: A dictionary of task that actually perform by each agent
        
    Returns:
        reward: Dict containing rewards for each agent.
    """
    move_feats, enemy_info, ally_info, own_info = processed_state
    reward = {}
    
    for agent_id in own_info.keys():
        reward[agent_id] = 0
        
        if not own_info[agent_id][5]:  # If agent is dead
            continue
        
        # Reward for following LLM task
        if llm_tasks[agent_id] == tasks[agent_id]:
            reward[agent_id] += 0.5
        
        # Reward for damaging enemies
        total_enemy_health = sum([info[5] for info in enemy_info[agent_id].values()])
        reward[agent_id] += (64 - total_enemy_health) * 0.1
        
        # Penalty for taking damage
        reward[agent_id] -= (1 - own_info[agent_id][0]) * 0.5
        
        # Reward for maintaining optimal distance from enemies
        nearest_enemy_dist = min([info[1] for info in enemy_info[agent_id].values()])
        if 5 <= nearest_enemy_dist <= 6:
            reward[agent_id] += 0.3
        
        # Reward for staying close to ally
        ally_id = [id for id in ally_info[agent_id].keys()][0]
        ally_dist = ally_info[agent_id][ally_id][1]
        if ally_dist <= 10:
            reward[agent_id] += 0.2
    
    return reward