import numpy as np

def planning_function(processed_global_state):
    """
    Determines optimal tasks for each agent based on the current battle state.
    
    Args:
        processed_global_state: A tuple containing (available_move_actions, enemy_info, ally_info, own_info)

    Returns:
        dict: Optimal tasks for each agent ('move', 'attack', or 'none' for dead agents)
    """
    move_feats, enemy_info, ally_info, own_info = processed_global_state
    llm_tasks = {}
    
    for agent_id in own_info.keys():
        if not own_info[agent_id][-1]:  # If agent is dead
            llm_tasks[agent_id] = 'none'
            continue
        
        own_health = own_info[agent_id][0]
        own_pos = np.array([own_info[agent_id][2], own_info[agent_id][3]])
        
        # Find closest enemy
        closest_enemy_dist = float('inf')
        closest_enemy_pos = None
        for enemy_data in enemy_info[agent_id].values():
            if enemy_data[4][0] == 1:  # If enemy is visible
                enemy_dist = enemy_data[1][0]
                if enemy_dist < closest_enemy_dist:
                    closest_enemy_dist = enemy_dist
                    closest_enemy_pos = np.array([enemy_data[6][0], enemy_data[7][0]])
        
        # Decide action based on distance and health
        if closest_enemy_pos is not None:
            if closest_enemy_dist < 4 or own_health < 0.3:
                # Move away from closest enemy
                move_direction = own_pos - closest_enemy_pos
                move_direction /= np.linalg.norm(move_direction)
                llm_tasks[agent_id] = 'move'
            elif 4 <= closest_enemy_dist <= 6:
                # Attack if in optimal range
                llm_tasks[agent_id] = 'attack'
            else:
                # Move towards enemy if too far
                move_direction = closest_enemy_pos - own_pos
                move_direction /= np.linalg.norm(move_direction)
                llm_tasks[agent_id] = 'move'
        else:
            # If no visible enemies, move towards center
            move_direction = -own_pos
            llm_tasks[agent_id] = 'move'
        
        # Coordinate with ally
        ally_id = [aid for aid in ally_info[agent_id].keys()][0]
        ally_pos = np.array([ally_info[agent_id][ally_id][7][0], ally_info[agent_id][ally_id][8][0]])
        if np.linalg.norm(own_pos - ally_pos) > 10:
            # Move towards ally if too far apart
            move_direction = ally_pos - own_pos
            llm_tasks[agent_id] = 'move'
    
    return llm_tasks

def compute_reward(processed_state, llm_tasks, tasks):
    """
    Calculate rewards based on the tasks assigned and their outcomes.
    
    Args:
        processed_global_state: returned from function process_global_state(global_state, n, m)
        llm_tasks (dict): Dictionary containing tasks assigned to each agent.
        tasks: A dictionary of task that actually perform by each agent
        
    Returns:
        reward: Dict containing rewards for each agent.
    """
    move_feats, enemy_info, ally_info, own_info = processed_state
    reward = {}
    
    for agent_id in own_info.keys():
        agent_reward = 0
        
        if not own_info[agent_id][-1]:  # If agent is dead
            reward[agent_id] = -10
            continue
        
        # Reward for following LLM task
        if llm_tasks[agent_id] == tasks[agent_id]:
            agent_reward += 1
        
        # Reward for damaging enemies
        enemies_damaged = sum(1 for enemy in enemy_info[agent_id].values() if enemy[5][0] < 1)
        agent_reward += enemies_damaged * 2
        
        # Penalty for taking damage
        if own_info[agent_id][0][0] < 1:
            agent_reward -= (1 - own_info[agent_id][0][0]) * 5
        
        # Reward for good positioning
        optimal_distance_count = sum(1 for enemy in enemy_info[agent_id].values() if 4 <= enemy[1][0] <= 6)
        agent_reward += optimal_distance_count * 0.5
        
        # Reward for coordination
        ally_id = [aid for aid in ally_info[agent_id].keys()][0]
        ally_dist = ally_info[agent_id][ally_id][1][0]
        if 5 <= ally_dist <= 10:
            agent_reward += 1
        
        reward[agent_id] = agent_reward
    
    return reward